init commit
This commit is contained in:
120
server/profiler/profiler.go
Normal file
120
server/profiler/profiler.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package profiler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// Profiler provides HTTP endpoints for memory profiling.
|
||||
type Profiler struct {
|
||||
memStatsLogInterval time.Duration
|
||||
}
|
||||
|
||||
// NewProfiler creates a new profiler.
|
||||
func NewProfiler() *Profiler {
|
||||
return &Profiler{
|
||||
memStatsLogInterval: 1 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes adds profiling endpoints to the Echo server.
|
||||
func (*Profiler) RegisterRoutes(e *echo.Echo) {
|
||||
// Register pprof handlers
|
||||
g := e.Group("/debug/pprof")
|
||||
g.GET("", echo.WrapHandler(http.HandlerFunc(pprof.Index)))
|
||||
g.GET("/cmdline", echo.WrapHandler(http.HandlerFunc(pprof.Cmdline)))
|
||||
g.GET("/profile", echo.WrapHandler(http.HandlerFunc(pprof.Profile)))
|
||||
g.POST("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol)))
|
||||
g.GET("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol)))
|
||||
g.GET("/trace", echo.WrapHandler(http.HandlerFunc(pprof.Trace)))
|
||||
g.GET("/allocs", echo.WrapHandler(http.HandlerFunc(pprof.Handler("allocs").ServeHTTP)))
|
||||
g.GET("/block", echo.WrapHandler(http.HandlerFunc(pprof.Handler("block").ServeHTTP)))
|
||||
g.GET("/goroutine", echo.WrapHandler(http.HandlerFunc(pprof.Handler("goroutine").ServeHTTP)))
|
||||
g.GET("/heap", echo.WrapHandler(http.HandlerFunc(pprof.Handler("heap").ServeHTTP)))
|
||||
g.GET("/mutex", echo.WrapHandler(http.HandlerFunc(pprof.Handler("mutex").ServeHTTP)))
|
||||
g.GET("/threadcreate", echo.WrapHandler(http.HandlerFunc(pprof.Handler("threadcreate").ServeHTTP)))
|
||||
|
||||
// Add a custom memory stats endpoint.
|
||||
g.GET("/memstats", func(c echo.Context) error {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"alloc": m.Alloc,
|
||||
"totalAlloc": m.TotalAlloc,
|
||||
"sys": m.Sys,
|
||||
"numGC": m.NumGC,
|
||||
"heapAlloc": m.HeapAlloc,
|
||||
"heapSys": m.HeapSys,
|
||||
"heapInuse": m.HeapInuse,
|
||||
"heapObjects": m.HeapObjects,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// StartMemoryMonitor starts a goroutine that periodically logs memory stats.
|
||||
func (p *Profiler) StartMemoryMonitor(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(p.memStatsLogInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Store previous heap allocation to track growth.
|
||||
var lastHeapAlloc uint64
|
||||
var lastNumGC uint32
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
// Calculate heap growth since last check.
|
||||
heapGrowth := int64(m.HeapAlloc) - int64(lastHeapAlloc)
|
||||
gcCount := m.NumGC - lastNumGC
|
||||
|
||||
slog.Info("memory stats",
|
||||
"heapAlloc", byteCountIEC(m.HeapAlloc),
|
||||
"heapSys", byteCountIEC(m.HeapSys),
|
||||
"heapObjects", m.HeapObjects,
|
||||
"heapGrowth", byteCountIEC(uint64(heapGrowth)),
|
||||
"numGoroutine", runtime.NumGoroutine(),
|
||||
"numGC", m.NumGC,
|
||||
"gcSince", gcCount,
|
||||
"nextGC", byteCountIEC(m.NextGC),
|
||||
"gcPause", time.Duration(m.PauseNs[(m.NumGC+255)%256]).String(),
|
||||
)
|
||||
|
||||
// Track values for next iteration.
|
||||
lastHeapAlloc = m.HeapAlloc
|
||||
lastNumGC = m.NumGC
|
||||
|
||||
// Force GC if memory usage is high to see if objects can be reclaimed.
|
||||
if m.HeapAlloc > 500*1024*1024 { // 500 MB threshold
|
||||
slog.Info("forcing garbage collection due to high memory usage")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// byteCountIEC converts bytes to a human-readable string (MiB, GiB).
|
||||
func byteCountIEC(b uint64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
262
server/router/api/v1/acl.go
Normal file
262
server/router/api/v1/acl.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// ContextKey is the key type of context value.
|
||||
type ContextKey int
|
||||
|
||||
const (
|
||||
// The key name used to store user's ID in the context (for user-based auth).
|
||||
userIDContextKey ContextKey = iota
|
||||
// The key name used to store session ID in the context (for session-based auth).
|
||||
sessionIDContextKey
|
||||
// The key name used to store access token in the context (for token-based auth).
|
||||
accessTokenContextKey
|
||||
)
|
||||
|
||||
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
|
||||
type GRPCAuthInterceptor struct {
|
||||
Store *store.Store
|
||||
secret string
|
||||
}
|
||||
|
||||
// NewGRPCAuthInterceptor returns a new API auth interceptor.
|
||||
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
|
||||
return &GRPCAuthInterceptor{
|
||||
Store: store,
|
||||
secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationInterceptor is the unary interceptor for gRPC API.
|
||||
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
|
||||
}
|
||||
|
||||
// Try to authenticate via session ID (from cookie) first
|
||||
if sessionCookieValue, err := getSessionIDFromMetadata(md); err == nil && sessionCookieValue != "" {
|
||||
user, err := in.authenticateBySession(ctx, sessionCookieValue)
|
||||
if err == nil && user != nil {
|
||||
// Extract just the sessionID part for context storage
|
||||
_, sessionID, parseErr := ParseSessionCookieValue(sessionCookieValue)
|
||||
if parseErr != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to parse session cookie: %v", parseErr)
|
||||
}
|
||||
return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, sessionID, "")
|
||||
}
|
||||
}
|
||||
|
||||
// Try to authenticate via JWT access token (from Authorization header)
|
||||
if accessToken, err := getAccessTokenFromMetadata(md); err == nil && accessToken != "" {
|
||||
user, err := in.authenticateByJWT(ctx, accessToken)
|
||||
if err == nil && user != nil {
|
||||
return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, "", accessToken)
|
||||
}
|
||||
}
|
||||
|
||||
// If no valid authentication found, check if this method is in the allowlist (public endpoints)
|
||||
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
|
||||
return handler(ctx, request)
|
||||
}
|
||||
|
||||
// If authentication is required but not found, reject the request
|
||||
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
|
||||
}
|
||||
|
||||
// handleAuthenticatedRequest processes an authenticated request with the given user and auth info.
|
||||
func (in *GRPCAuthInterceptor) handleAuthenticatedRequest(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler, user *store.User, sessionID, accessToken string) (any, error) {
|
||||
// Check user status
|
||||
if user.RowStatus == store.Archived {
|
||||
return nil, errors.Errorf("user %q is archived", user.Username)
|
||||
}
|
||||
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
|
||||
return nil, errors.Errorf("user %q is not admin", user.Username)
|
||||
}
|
||||
|
||||
// Set context values
|
||||
ctx = context.WithValue(ctx, userIDContextKey, user.ID)
|
||||
|
||||
if sessionID != "" {
|
||||
// Session-based authentication
|
||||
ctx = context.WithValue(ctx, sessionIDContextKey, sessionID)
|
||||
// Update session last accessed time
|
||||
_ = in.updateSessionLastAccessed(ctx, user.ID, sessionID)
|
||||
} else if accessToken != "" {
|
||||
// JWT access token-based authentication
|
||||
ctx = context.WithValue(ctx, accessTokenContextKey, accessToken)
|
||||
}
|
||||
|
||||
return handler(ctx, request)
|
||||
}
|
||||
|
||||
// authenticateByJWT authenticates a user using JWT access token from Authorization header.
|
||||
func (in *GRPCAuthInterceptor) authenticateByJWT(ctx context.Context, accessToken string) (*store.User, error) {
|
||||
if accessToken == "" {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "access token not found")
|
||||
}
|
||||
claims := &ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(in.secret), nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
|
||||
}
|
||||
|
||||
// Get user from JWT claims
|
||||
userID, err := util.ConvertStringToInt32(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "malformed ID in the token")
|
||||
}
|
||||
user, err := in.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.Errorf("user %q not exists", userID)
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return nil, errors.Errorf("user %q is archived", userID)
|
||||
}
|
||||
|
||||
// Validate that this access token exists in the user's access tokens
|
||||
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to get user access tokens")
|
||||
}
|
||||
if !validateAccessToken(accessToken, accessTokens) {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid access token")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// authenticateBySession authenticates a user using session ID from cookie.
|
||||
func (in *GRPCAuthInterceptor) authenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) {
|
||||
if sessionCookieValue == "" {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "session cookie value not found")
|
||||
}
|
||||
|
||||
// Parse the cookie value to extract userID and sessionID
|
||||
userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid session cookie format: %v", err)
|
||||
}
|
||||
|
||||
// Get the user directly using the userID from the cookie
|
||||
user, err := in.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not found")
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user is archived")
|
||||
}
|
||||
|
||||
// Get user sessions and validate the sessionID
|
||||
sessions, err := in.Store.GetUserSessions(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user sessions")
|
||||
}
|
||||
|
||||
if !validateUserSession(sessionID, sessions) {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid or expired session")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// updateSessionLastAccessed updates the last accessed time for a user session.
|
||||
func (in *GRPCAuthInterceptor) updateSessionLastAccessed(ctx context.Context, userID int32, sessionID string) error {
|
||||
return in.Store.UpdateUserSessionLastAccessed(ctx, userID, sessionID, timestamppb.Now())
|
||||
}
|
||||
|
||||
// validateUserSession checks if a session exists and is still valid using sliding expiration.
|
||||
func validateUserSession(sessionID string, userSessions []*storepb.SessionsUserSetting_Session) bool {
|
||||
for _, session := range userSessions {
|
||||
if sessionID == session.SessionId {
|
||||
// Use sliding expiration: check if last_accessed_time + 2 weeks > current_time
|
||||
if session.LastAccessedTime != nil {
|
||||
expirationTime := session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
|
||||
if expirationTime.Before(time.Now()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getSessionIDFromMetadata extracts session cookie value from cookie.
|
||||
func getSessionIDFromMetadata(md metadata.MD) (string, error) {
|
||||
// Check the cookie header for session cookie value
|
||||
var sessionCookieValue string
|
||||
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
|
||||
header := http.Header{}
|
||||
header.Add("Cookie", t)
|
||||
request := http.Request{Header: header}
|
||||
if v, _ := request.Cookie(SessionCookieName); v != nil {
|
||||
sessionCookieValue = v.Value
|
||||
}
|
||||
}
|
||||
if sessionCookieValue == "" {
|
||||
return "", errors.New("session cookie not found")
|
||||
}
|
||||
return sessionCookieValue, nil
|
||||
}
|
||||
|
||||
// getAccessTokenFromMetadata extracts access token from Authorization header.
|
||||
func getAccessTokenFromMetadata(md metadata.MD) (string, error) {
|
||||
// Check the HTTP request Authorization header.
|
||||
authorizationHeaders := md.Get("Authorization")
|
||||
if len(authorizationHeaders) == 0 {
|
||||
return "", errors.New("authorization header not found")
|
||||
}
|
||||
authHeaderParts := strings.Fields(authorizationHeaders[0])
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
return "", errors.New("authorization header format must be Bearer {token}")
|
||||
}
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
|
||||
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
if accessTokenString == userAccessToken.AccessToken {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
34
server/router/api/v1/acl_config.go
Normal file
34
server/router/api/v1/acl_config.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package v1
|
||||
|
||||
var authenticationAllowlistMethods = map[string]bool{
|
||||
"/memos.api.v1.WorkspaceService/GetWorkspaceProfile": true,
|
||||
"/memos.api.v1.WorkspaceService/GetWorkspaceSetting": true,
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
|
||||
"/memos.api.v1.AuthService/CreateSession": true,
|
||||
"/memos.api.v1.AuthService/GetCurrentSession": true,
|
||||
"/memos.api.v1.UserService/CreateUser": true,
|
||||
"/memos.api.v1.UserService/GetUser": true,
|
||||
"/memos.api.v1.UserService/GetUserAvatar": true,
|
||||
"/memos.api.v1.UserService/GetUserStats": true,
|
||||
"/memos.api.v1.UserService/ListAllUserStats": true,
|
||||
"/memos.api.v1.UserService/SearchUsers": true,
|
||||
"/memos.api.v1.MemoService/GetMemo": true,
|
||||
"/memos.api.v1.MemoService/ListMemos": true,
|
||||
"/memos.api.v1.MarkdownService/GetLinkMetadata": true,
|
||||
"/memos.api.v1.AttachmentService/GetAttachmentBinary": true,
|
||||
}
|
||||
|
||||
// isUnauthorizeAllowedMethod returns whether the method is exempted from authentication.
|
||||
func isUnauthorizeAllowedMethod(fullMethodName string) bool {
|
||||
return authenticationAllowlistMethods[fullMethodName]
|
||||
}
|
||||
|
||||
var allowedMethodsOnlyForAdmin = map[string]bool{
|
||||
"/memos.api.v1.UserService/CreateUser": true,
|
||||
"/memos.api.v1.WorkspaceService/UpdateWorkspaceSetting": true,
|
||||
}
|
||||
|
||||
// isOnlyForAdminAllowedMethod returns true if the method is allowed to be called only by admin.
|
||||
func isOnlyForAdminAllowedMethod(methodName string) bool {
|
||||
return allowedMethodsOnlyForAdmin[methodName]
|
||||
}
|
||||
126
server/router/api/v1/activity_service.go
Normal file
126
server/router/api/v1/activity_service.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListActivities(ctx context.Context, request *v1pb.ListActivitiesRequest) (*v1pb.ListActivitiesResponse, error) {
|
||||
// Set default page size if not specified
|
||||
pageSize := request.PageSize
|
||||
if pageSize <= 0 || pageSize > 1000 {
|
||||
pageSize = 100
|
||||
}
|
||||
|
||||
// TODO: Implement pagination with page_token and use pageSize for limiting
|
||||
// For now, we'll fetch all activities and the pageSize will be used in future pagination implementation
|
||||
_ = pageSize // Acknowledge pageSize variable to avoid linter warning
|
||||
|
||||
activities, err := s.Store.ListActivities(ctx, &store.FindActivity{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list activities: %v", err)
|
||||
}
|
||||
|
||||
var activityMessages []*v1pb.Activity
|
||||
for _, activity := range activities {
|
||||
activityMessage, err := s.convertActivityFromStore(ctx, activity)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
|
||||
}
|
||||
activityMessages = append(activityMessages, activityMessage)
|
||||
}
|
||||
|
||||
return &v1pb.ListActivitiesResponse{
|
||||
Activities: activityMessages,
|
||||
// TODO: Implement next_page_token for pagination
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetActivity(ctx context.Context, request *v1pb.GetActivityRequest) (*v1pb.Activity, error) {
|
||||
activityID, err := ExtractActivityIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid activity name: %v", err)
|
||||
}
|
||||
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
|
||||
ID: &activityID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
|
||||
}
|
||||
|
||||
activityMessage, err := s.convertActivityFromStore(ctx, activity)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
|
||||
}
|
||||
return activityMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertActivityFromStore(ctx context.Context, activity *store.Activity) (*v1pb.Activity, error) {
|
||||
payload, err := s.convertActivityPayloadFromStore(ctx, activity.Payload)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert activity payload from store: %v", err)
|
||||
}
|
||||
|
||||
// Convert store activity type to proto enum
|
||||
var activityType v1pb.Activity_Type
|
||||
switch activity.Type {
|
||||
case store.ActivityTypeMemoComment:
|
||||
activityType = v1pb.Activity_MEMO_COMMENT
|
||||
default:
|
||||
activityType = v1pb.Activity_TYPE_UNSPECIFIED
|
||||
}
|
||||
|
||||
// Convert store activity level to proto enum
|
||||
var activityLevel v1pb.Activity_Level
|
||||
switch activity.Level {
|
||||
case store.ActivityLevelInfo:
|
||||
activityLevel = v1pb.Activity_INFO
|
||||
default:
|
||||
activityLevel = v1pb.Activity_LEVEL_UNSPECIFIED
|
||||
}
|
||||
|
||||
return &v1pb.Activity{
|
||||
Name: fmt.Sprintf("%s%d", ActivityNamePrefix, activity.ID),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, activity.CreatorID),
|
||||
Type: activityType,
|
||||
Level: activityLevel,
|
||||
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertActivityPayloadFromStore(ctx context.Context, payload *storepb.ActivityPayload) (*v1pb.ActivityPayload, error) {
|
||||
v2Payload := &v1pb.ActivityPayload{}
|
||||
if payload.MemoComment != nil {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &payload.MemoComment.MemoId,
|
||||
ExcludeContent: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &payload.MemoComment.RelatedMemoId,
|
||||
ExcludeContent: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
|
||||
}
|
||||
v2Payload.Payload = &v1pb.ActivityPayload_MemoComment{
|
||||
MemoComment: &v1pb.ActivityMemoCommentPayload{
|
||||
Memo: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
RelatedMemo: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
|
||||
},
|
||||
}
|
||||
}
|
||||
return v2Payload, nil
|
||||
}
|
||||
673
server/router/api/v1/attachment_service.go
Normal file
673
server/router/api/v1/attachment_service.go
Normal file
@@ -0,0 +1,673 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/genproto/googleapis/api/httpbody"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// The upload memory buffer is 32 MiB.
|
||||
// It should be kept low, so RAM usage doesn't get out of control.
|
||||
// This is unrelated to maximum upload size limit, which is now set through system setting.
|
||||
MaxUploadBufferSizeBytes = 32 << 20
|
||||
MebiByte = 1024 * 1024
|
||||
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
|
||||
ThumbnailCacheFolder = ".thumbnail_cache"
|
||||
)
|
||||
|
||||
var SupportedThumbnailMimeTypes = []string{
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if request.Attachment == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "attachment is required")
|
||||
}
|
||||
if request.Attachment.Filename == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "filename is required")
|
||||
}
|
||||
if request.Attachment.Type == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "type is required")
|
||||
}
|
||||
|
||||
// Use provided attachment_id or generate a new one
|
||||
attachmentUID := request.AttachmentId
|
||||
if attachmentUID == "" {
|
||||
attachmentUID = shortuuid.New()
|
||||
}
|
||||
|
||||
create := &store.Attachment{
|
||||
UID: attachmentUID,
|
||||
CreatorID: user.ID,
|
||||
Filename: request.Attachment.Filename,
|
||||
Type: request.Attachment.Type,
|
||||
}
|
||||
|
||||
workspaceStorageSetting, err := s.Store.GetWorkspaceStorageSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace storage setting: %v", err)
|
||||
}
|
||||
size := binary.Size(request.Attachment.Content)
|
||||
uploadSizeLimit := int(workspaceStorageSetting.UploadSizeLimitMb) * MebiByte
|
||||
if uploadSizeLimit == 0 {
|
||||
uploadSizeLimit = MaxUploadBufferSizeBytes
|
||||
}
|
||||
if size > uploadSizeLimit {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "file size exceeds the limit")
|
||||
}
|
||||
create.Size = int64(size)
|
||||
create.Blob = request.Attachment.Content
|
||||
|
||||
if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
|
||||
}
|
||||
|
||||
if request.Attachment.Memo != nil {
|
||||
memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo)
|
||||
}
|
||||
create.MemoID = &memo.ID
|
||||
}
|
||||
attachment, err := s.Store.CreateAttachment(ctx, create)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err)
|
||||
}
|
||||
|
||||
return s.convertAttachmentFromStore(ctx, attachment), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// Set default page size
|
||||
pageSize := int(request.PageSize)
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50
|
||||
}
|
||||
if pageSize > 1000 {
|
||||
pageSize = 1000
|
||||
}
|
||||
|
||||
// Parse page token for offset
|
||||
offset := 0
|
||||
if request.PageToken != "" {
|
||||
// Simple implementation: page token is the offset as string
|
||||
// In production, you might want to use encrypted tokens
|
||||
if parsed, err := fmt.Sscanf(request.PageToken, "%d", &offset); err != nil || parsed != 1 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid page token")
|
||||
}
|
||||
}
|
||||
|
||||
findAttachment := &store.FindAttachment{
|
||||
CreatorID: &user.ID,
|
||||
Limit: &pageSize,
|
||||
Offset: &offset,
|
||||
}
|
||||
|
||||
// Basic filter support for common cases
|
||||
if request.Filter != "" {
|
||||
// Simple filter parsing - can be enhanced later
|
||||
// For now, support basic type filtering: "type=image/png"
|
||||
if strings.HasPrefix(request.Filter, "type=") {
|
||||
filterType := strings.TrimPrefix(request.Filter, "type=")
|
||||
// Create a temporary struct to hold type filter
|
||||
// Since FindAttachment doesn't have Type field, we'll apply this post-query
|
||||
_ = filterType // We'll filter after getting results
|
||||
}
|
||||
}
|
||||
|
||||
attachments, err := s.Store.ListAttachments(ctx, findAttachment)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
|
||||
}
|
||||
|
||||
// Apply type filter if specified
|
||||
if request.Filter != "" && strings.HasPrefix(request.Filter, "type=") {
|
||||
filterType := strings.TrimPrefix(request.Filter, "type=")
|
||||
filteredAttachments := make([]*store.Attachment, 0)
|
||||
for _, attachment := range attachments {
|
||||
if attachment.Type == filterType {
|
||||
filteredAttachments = append(filteredAttachments, attachment)
|
||||
}
|
||||
}
|
||||
attachments = filteredAttachments
|
||||
}
|
||||
|
||||
response := &v1pb.ListAttachmentsResponse{}
|
||||
|
||||
for _, attachment := range attachments {
|
||||
response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment))
|
||||
}
|
||||
|
||||
// For simplicity, set total size to the number of returned attachments.
|
||||
// In a full implementation, you'd want a separate count query
|
||||
response.TotalSize = int32(len(response.Attachments))
|
||||
|
||||
// Set next page token if we got the full page size (indicating there might be more)
|
||||
if len(attachments) == pageSize {
|
||||
response.NextPageToken = fmt.Sprintf("%d", offset+pageSize)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttachmentRequest) (*v1pb.Attachment, error) {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
return s.convertAttachmentFromStore(ctx, attachment), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetAttachmentBinary(ctx context.Context, request *v1pb.GetAttachmentBinaryRequest) (*httpbody.HttpBody, error) {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
|
||||
GetBlob: true,
|
||||
UID: &attachmentUID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
// Check the related memo visibility.
|
||||
if attachment.MemoID != nil {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: attachment.MemoID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to find memo by ID: %v", attachment.MemoID)
|
||||
}
|
||||
if memo != nil && memo.Visibility != store.Public {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unauthorized access")
|
||||
}
|
||||
if memo.Visibility == store.Private && user.ID != attachment.CreatorID {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unauthorized access")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if request.Thumbnail && util.HasPrefixes(attachment.Type, SupportedThumbnailMimeTypes...) {
|
||||
thumbnailBlob, err := s.getOrGenerateThumbnail(attachment)
|
||||
if err != nil {
|
||||
// thumbnail failures are logged as warnings and not cosidered critical failures as
|
||||
// a attachment image can be used in its place.
|
||||
slog.Warn("failed to get attachment thumbnail image", slog.Any("error", err))
|
||||
} else {
|
||||
return &httpbody.HttpBody{
|
||||
ContentType: attachment.Type,
|
||||
Data: thumbnailBlob,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
blob, err := s.GetAttachmentBlob(attachment)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment blob: %v", err)
|
||||
}
|
||||
|
||||
contentType := attachment.Type
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
contentType += "; charset=utf-8"
|
||||
}
|
||||
// Prevent XSS attacks by serving potentially unsafe files with a content type that prevents script execution.
|
||||
if strings.EqualFold(contentType, "image/svg+xml") ||
|
||||
strings.EqualFold(contentType, "text/html") ||
|
||||
strings.EqualFold(contentType, "application/xhtml+xml") {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
// Extract range header from gRPC metadata for iOS Safari video support
|
||||
var rangeHeader string
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
// Check for range header from gRPC-Gateway
|
||||
if ranges := md.Get("grpcgateway-range"); len(ranges) > 0 {
|
||||
rangeHeader = ranges[0]
|
||||
} else if ranges := md.Get("range"); len(ranges) > 0 {
|
||||
rangeHeader = ranges[0]
|
||||
}
|
||||
|
||||
// Log for debugging iOS Safari issues
|
||||
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
|
||||
userAgent := userAgents[0]
|
||||
if strings.Contains(strings.ToLower(userAgent), "safari") && rangeHeader != "" {
|
||||
slog.Debug("Safari range request detected",
|
||||
slog.String("range", rangeHeader),
|
||||
slog.String("user-agent", userAgent),
|
||||
slog.String("content-type", contentType))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle range requests for video/audio streaming (iOS Safari requirement)
|
||||
if rangeHeader != "" && (strings.HasPrefix(contentType, "video/") || strings.HasPrefix(contentType, "audio/")) {
|
||||
return s.handleRangeRequest(ctx, blob, rangeHeader, contentType)
|
||||
}
|
||||
|
||||
// Set headers for streaming support
|
||||
if strings.HasPrefix(contentType, "video/") || strings.HasPrefix(contentType, "audio/") {
|
||||
if err := setResponseHeaders(ctx, map[string]string{
|
||||
"accept-ranges": "bytes",
|
||||
"content-length": fmt.Sprintf("%d", len(blob)),
|
||||
"cache-control": "public, max-age=3600", // 1 hour cache
|
||||
}); err != nil {
|
||||
slog.Warn("failed to set streaming headers", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
return &httpbody.HttpBody{
|
||||
ContentType: contentType,
|
||||
Data: blob,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.UpdateAttachmentRequest) (*v1pb.Attachment, error) {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(request.Attachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateAttachment{
|
||||
ID: attachment.ID,
|
||||
UpdatedTs: ¤tTs,
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "filename" {
|
||||
update.Filename = &request.Attachment.Filename
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.Store.UpdateAttachment(ctx, update); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
}
|
||||
return s.GetAttachment(ctx, &v1pb.GetAttachmentRequest{
|
||||
Name: request.Attachment.Name,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.DeleteAttachmentRequest) (*emptypb.Empty, error) {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
|
||||
}
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
|
||||
UID: &attachmentUID,
|
||||
CreatorID: &user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to find attachment: %v", err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
// Delete the attachment from the database.
|
||||
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
|
||||
ID: attachment.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertAttachmentFromStore(ctx context.Context, attachment *store.Attachment) *v1pb.Attachment {
|
||||
attachmentMessage := &v1pb.Attachment{
|
||||
Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID),
|
||||
CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)),
|
||||
Filename: attachment.Filename,
|
||||
Type: attachment.Type,
|
||||
Size: attachment.Size,
|
||||
}
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
attachmentMessage.ExternalLink = attachment.Reference
|
||||
}
|
||||
if attachment.MemoID != nil {
|
||||
memo, _ := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: attachment.MemoID,
|
||||
})
|
||||
if memo != nil {
|
||||
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
attachmentMessage.Memo = &memoName
|
||||
}
|
||||
}
|
||||
|
||||
return attachmentMessage
|
||||
}
|
||||
|
||||
// SaveAttachmentBlob save the blob of attachment based on the storage config.
|
||||
func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Attachment) error {
|
||||
workspaceStorageSetting, err := stores.GetWorkspaceStorageSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to find workspace storage setting")
|
||||
}
|
||||
|
||||
if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_LOCAL {
|
||||
filepathTemplate := "assets/{timestamp}_{filename}"
|
||||
if workspaceStorageSetting.FilepathTemplate != "" {
|
||||
filepathTemplate = workspaceStorageSetting.FilepathTemplate
|
||||
}
|
||||
|
||||
internalPath := filepathTemplate
|
||||
if !strings.Contains(internalPath, "{filename}") {
|
||||
internalPath = filepath.Join(internalPath, "{filename}")
|
||||
}
|
||||
internalPath = replaceFilenameWithPathTemplate(internalPath, create.Filename)
|
||||
internalPath = filepath.ToSlash(internalPath)
|
||||
|
||||
// Ensure the directory exists.
|
||||
osPath := filepath.FromSlash(internalPath)
|
||||
if !filepath.IsAbs(osPath) {
|
||||
osPath = filepath.Join(profile.Data, osPath)
|
||||
}
|
||||
dir := filepath.Dir(osPath)
|
||||
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return errors.Wrap(err, "Failed to create directory")
|
||||
}
|
||||
dst, err := os.Create(osPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to create file")
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// Write the blob to the file.
|
||||
if err := os.WriteFile(osPath, create.Blob, 0644); err != nil {
|
||||
return errors.Wrap(err, "Failed to write file")
|
||||
}
|
||||
create.Reference = internalPath
|
||||
create.Blob = nil
|
||||
create.StorageType = storepb.AttachmentStorageType_LOCAL
|
||||
} else if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_S3 {
|
||||
s3Config := workspaceStorageSetting.S3Config
|
||||
if s3Config == nil {
|
||||
return errors.Errorf("No actived external storage found")
|
||||
}
|
||||
s3Client, err := s3.NewClient(ctx, s3Config)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to create s3 client")
|
||||
}
|
||||
|
||||
filepathTemplate := workspaceStorageSetting.FilepathTemplate
|
||||
if !strings.Contains(filepathTemplate, "{filename}") {
|
||||
filepathTemplate = filepath.Join(filepathTemplate, "{filename}")
|
||||
}
|
||||
filepathTemplate = replaceFilenameWithPathTemplate(filepathTemplate, create.Filename)
|
||||
key, err := s3Client.UploadObject(ctx, filepathTemplate, create.Type, bytes.NewReader(create.Blob))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to upload via s3 client")
|
||||
}
|
||||
presignURL, err := s3Client.PresignGetObject(ctx, key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to presign via s3 client")
|
||||
}
|
||||
|
||||
create.Reference = presignURL
|
||||
create.Blob = nil
|
||||
create.StorageType = storepb.AttachmentStorageType_S3
|
||||
create.Payload = &storepb.AttachmentPayload{
|
||||
Payload: &storepb.AttachmentPayload_S3Object_{
|
||||
S3Object: &storepb.AttachmentPayload_S3Object{
|
||||
S3Config: s3Config,
|
||||
Key: key,
|
||||
LastPresignedTime: timestamppb.New(time.Now()),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
|
||||
// For local storage, read the file from the local disk.
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
|
||||
attachmentPath := filepath.FromSlash(attachment.Reference)
|
||||
if !filepath.IsAbs(attachmentPath) {
|
||||
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
|
||||
}
|
||||
|
||||
file, err := os.Open(attachmentPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open the file")
|
||||
}
|
||||
defer file.Close()
|
||||
blob, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read the file")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
// For database storage, return the blob from the database.
|
||||
return attachment.Blob, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// thumbnailRatio is the ratio of the thumbnail image.
|
||||
thumbnailRatio = 0.8
|
||||
)
|
||||
|
||||
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
|
||||
func (s *APIV1Service) getOrGenerateThumbnail(attachment *store.Attachment) ([]byte, error) {
|
||||
thumbnailCacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
|
||||
if err := os.MkdirAll(thumbnailCacheFolder, os.ModePerm); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create thumbnail cache folder")
|
||||
}
|
||||
filePath := filepath.Join(thumbnailCacheFolder, fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename)))
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "failed to check thumbnail image stat")
|
||||
}
|
||||
|
||||
// If thumbnail image does not exist, generate and save the thumbnail image.
|
||||
blob, err := s.GetAttachmentBlob(attachment)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get attachment blob")
|
||||
}
|
||||
img, err := imaging.Decode(bytes.NewReader(blob), imaging.AutoOrientation(true))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode thumbnail image")
|
||||
}
|
||||
|
||||
thumbnailWidth := int(float64(img.Bounds().Dx()) * thumbnailRatio)
|
||||
// Resize the image to the thumbnailWidth.
|
||||
thumbnailImage := imaging.Resize(img, thumbnailWidth, 0, imaging.Lanczos)
|
||||
if err := imaging.Save(thumbnailImage, filePath); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to save thumbnail file")
|
||||
}
|
||||
}
|
||||
|
||||
thumbnailFile, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to open thumbnail file")
|
||||
}
|
||||
defer thumbnailFile.Close()
|
||||
blob, err := io.ReadAll(thumbnailFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read thumbnail file")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
|
||||
|
||||
func replaceFilenameWithPathTemplate(path, filename string) string {
|
||||
t := time.Now()
|
||||
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
|
||||
switch s {
|
||||
case "{filename}":
|
||||
return filename
|
||||
case "{timestamp}":
|
||||
return fmt.Sprintf("%d", t.Unix())
|
||||
case "{year}":
|
||||
return fmt.Sprintf("%d", t.Year())
|
||||
case "{month}":
|
||||
return fmt.Sprintf("%02d", t.Month())
|
||||
case "{day}":
|
||||
return fmt.Sprintf("%02d", t.Day())
|
||||
case "{hour}":
|
||||
return fmt.Sprintf("%02d", t.Hour())
|
||||
case "{minute}":
|
||||
return fmt.Sprintf("%02d", t.Minute())
|
||||
case "{second}":
|
||||
return fmt.Sprintf("%02d", t.Second())
|
||||
case "{uuid}":
|
||||
return util.GenUUID()
|
||||
}
|
||||
return s
|
||||
})
|
||||
return path
|
||||
}
|
||||
|
||||
// handleRangeRequest handles HTTP range requests for video/audio streaming (iOS Safari requirement).
|
||||
func (*APIV1Service) handleRangeRequest(ctx context.Context, data []byte, rangeHeader, contentType string) (*httpbody.HttpBody, error) {
|
||||
// Parse "bytes=start-end"
|
||||
if !strings.HasPrefix(rangeHeader, "bytes=") {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid range header format")
|
||||
}
|
||||
|
||||
rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=")
|
||||
parts := strings.Split(rangeSpec, "-")
|
||||
if len(parts) != 2 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid range specification")
|
||||
}
|
||||
|
||||
fileSize := int64(len(data))
|
||||
start, end := int64(0), fileSize-1
|
||||
|
||||
// Parse start position
|
||||
if parts[0] != "" {
|
||||
if s, err := strconv.ParseInt(parts[0], 10, 64); err == nil {
|
||||
start = s
|
||||
} else {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid range start: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse end position
|
||||
if parts[1] != "" {
|
||||
if e, err := strconv.ParseInt(parts[1], 10, 64); err == nil {
|
||||
end = e
|
||||
} else {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid range end: %s", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Validate range
|
||||
if start < 0 || end >= fileSize || start > end {
|
||||
// Set Content-Range header for 416 response
|
||||
if err := setResponseHeaders(ctx, map[string]string{
|
||||
"content-range": fmt.Sprintf("bytes */%d", fileSize),
|
||||
}); err != nil {
|
||||
slog.Warn("failed to set content-range header", slog.Any("error", err))
|
||||
}
|
||||
return nil, status.Errorf(codes.OutOfRange, "requested range not satisfiable")
|
||||
}
|
||||
|
||||
// Set partial content headers (HTTP 206)
|
||||
if err := setResponseHeaders(ctx, map[string]string{
|
||||
"accept-ranges": "bytes",
|
||||
"content-range": fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize),
|
||||
"content-length": fmt.Sprintf("%d", end-start+1),
|
||||
"cache-control": "public, max-age=3600",
|
||||
}); err != nil {
|
||||
slog.Warn("failed to set partial content headers", slog.Any("error", err))
|
||||
}
|
||||
|
||||
// Extract the requested range
|
||||
rangeData := data[start : end+1]
|
||||
|
||||
slog.Debug("serving partial content",
|
||||
slog.Int64("start", start),
|
||||
slog.Int64("end", end),
|
||||
slog.Int64("total", fileSize),
|
||||
slog.Int("chunk_size", len(rangeData)))
|
||||
|
||||
return &httpbody.HttpBody{
|
||||
ContentType: contentType,
|
||||
Data: rangeData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// setResponseHeaders is a helper function to set gRPC response headers.
|
||||
func setResponseHeaders(ctx context.Context, headers map[string]string) error {
|
||||
pairs := make([]string, 0, len(headers)*2)
|
||||
for key, value := range headers {
|
||||
pairs = append(pairs, key, value)
|
||||
}
|
||||
return grpc.SetHeader(ctx, metadata.Pairs(pairs...))
|
||||
}
|
||||
91
server/router/api/v1/auth.go
Normal file
91
server/router/api/v1/auth.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// issuer is the issuer of the jwt token.
|
||||
Issuer = "memos"
|
||||
// Signing key section. For now, this is only used for signing, not for verifying since we only
|
||||
// have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism.
|
||||
KeyID = "v1"
|
||||
// AccessTokenAudienceName is the audience name of the access token.
|
||||
AccessTokenAudienceName = "user.access-token"
|
||||
// SessionSlidingDuration is the sliding expiration duration for user sessions (2 weeks).
|
||||
// Sessions are considered valid if last_accessed_time + SessionSlidingDuration > current_time.
|
||||
SessionSlidingDuration = 14 * 24 * time.Hour
|
||||
|
||||
// SessionCookieName is the cookie name of user session ID.
|
||||
SessionCookieName = "user_session"
|
||||
)
|
||||
|
||||
type ClaimsMessage struct {
|
||||
Name string `json:"name"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates an access token.
|
||||
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
|
||||
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, secret)
|
||||
}
|
||||
|
||||
// generateToken generates a jwt token.
|
||||
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
|
||||
registeredClaims := jwt.RegisteredClaims{
|
||||
Issuer: Issuer,
|
||||
Audience: jwt.ClaimStrings{audience},
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Subject: fmt.Sprint(userID),
|
||||
}
|
||||
if !expirationTime.IsZero() {
|
||||
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
|
||||
}
|
||||
|
||||
// Declare the token with the HS256 algorithm used for signing, and the claims.
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
|
||||
Name: username,
|
||||
RegisteredClaims: registeredClaims,
|
||||
})
|
||||
token.Header["kid"] = KeyID
|
||||
|
||||
// Create the JWT string.
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a unique session ID using UUIDv4.
|
||||
func GenerateSessionID() (string, error) {
|
||||
return util.GenUUID(), nil
|
||||
}
|
||||
|
||||
// BuildSessionCookieValue builds the session cookie value in format {userID}-{sessionID}.
|
||||
func BuildSessionCookieValue(userID int32, sessionID string) string {
|
||||
return fmt.Sprintf("%d-%s", userID, sessionID)
|
||||
}
|
||||
|
||||
// ParseSessionCookieValue parses the session cookie value to extract userID and sessionID.
|
||||
func ParseSessionCookieValue(cookieValue string) (int32, string, error) {
|
||||
parts := strings.SplitN(cookieValue, "-", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, "", errors.New("invalid session cookie format")
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(parts[0])
|
||||
if err != nil {
|
||||
return 0, "", errors.Errorf("invalid user ID in session cookie: %v", err)
|
||||
}
|
||||
|
||||
return userID, parts[1], nil
|
||||
}
|
||||
502
server/router/api/v1/auth_service.go
Normal file
502
server/router/api/v1/auth_service.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
"github.com/usememos/memos/plugin/idp/oauth2"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
unmatchedUsernameAndPasswordError = "unmatched username and password"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrentSessionRequest) (*v1pb.GetCurrentSessionResponse, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
// Clear auth cookies
|
||||
if err := s.clearAuthCookies(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies: %v", err)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not found")
|
||||
}
|
||||
|
||||
var lastAccessedAt *timestamppb.Timestamp
|
||||
// Update session last accessed time if we have a session ID and get the current session info
|
||||
if sessionID, ok := ctx.Value(sessionIDContextKey).(string); ok && sessionID != "" {
|
||||
now := timestamppb.Now()
|
||||
if err := s.Store.UpdateUserSessionLastAccessed(ctx, user.ID, sessionID, now); err != nil {
|
||||
// Log error but don't fail the request
|
||||
slog.Error("failed to update session last accessed time", "error", err)
|
||||
}
|
||||
lastAccessedAt = now
|
||||
}
|
||||
|
||||
return &v1pb.GetCurrentSessionResponse{
|
||||
User: convertUserFromStore(user),
|
||||
LastAccessedAt: lastAccessedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSessionRequest) (*v1pb.CreateSessionResponse, error) {
|
||||
var existingUser *store.User
|
||||
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &passwordCredentials.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
|
||||
}
|
||||
// Compare the stored hashed password, with the hashed version of the password that was received.
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
|
||||
}
|
||||
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
|
||||
}
|
||||
// Check if the password auth in is allowed.
|
||||
if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
|
||||
}
|
||||
existingUser = user
|
||||
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &ssoCredentials.IdpId,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
|
||||
}
|
||||
|
||||
var userInfo *idp.IdentityProviderUserInfo
|
||||
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
|
||||
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
|
||||
}
|
||||
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
|
||||
}
|
||||
userInfo, err = oauth2IdentityProvider.UserInfo(token)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
identifierFilter := identityProvider.IdentifierFilter
|
||||
if identifierFilter != "" {
|
||||
identifierFilterRegex, err := regexp.Compile(identifierFilter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
|
||||
}
|
||||
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &userInfo.Identifier,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
// Check if the user is allowed to sign up.
|
||||
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
|
||||
}
|
||||
if workspaceGeneralSetting.DisallowUserRegistration {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
|
||||
}
|
||||
|
||||
// Create a new user with the user info from the identity provider.
|
||||
userCreate := &store.User{
|
||||
Username: userInfo.Identifier,
|
||||
// The new signup user should be normal user by default.
|
||||
Role: store.RoleUser,
|
||||
Nickname: userInfo.DisplayName,
|
||||
Email: userInfo.Email,
|
||||
AvatarURL: userInfo.AvatarURL,
|
||||
}
|
||||
password, err := util.RandomString(20)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
|
||||
}
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
|
||||
}
|
||||
userCreate.PasswordHash = string(passwordHash)
|
||||
user, err = s.Store.CreateUser(ctx, userCreate)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
|
||||
}
|
||||
}
|
||||
existingUser = user
|
||||
}
|
||||
|
||||
if existingUser == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid credentials")
|
||||
}
|
||||
if existingUser.RowStatus == store.Archived {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username)
|
||||
}
|
||||
|
||||
// Default session expiration time is 100 year
|
||||
expireTime := time.Now().Add(100 * 365 * 24 * time.Hour)
|
||||
if err := s.doSignIn(ctx, existingUser, expireTime); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
|
||||
}
|
||||
|
||||
return &v1pb.CreateSessionResponse{
|
||||
User: convertUserFromStore(existingUser),
|
||||
LastAccessedAt: timestamppb.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
|
||||
// Generate unique session ID for web use
|
||||
sessionID, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to generate session ID, error: %v", err)
|
||||
}
|
||||
|
||||
// Track session in user settings
|
||||
if err := s.trackUserSession(ctx, user.ID, sessionID); err != nil {
|
||||
// Log the error but don't fail the login if session tracking fails
|
||||
// This ensures backward compatibility
|
||||
slog.Error("failed to track user session", "error", err)
|
||||
}
|
||||
|
||||
// Set session cookie for web use (format: userID-sessionID)
|
||||
sessionCookieValue := BuildSessionCookieValue(user.ID, sessionID)
|
||||
sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
|
||||
}
|
||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||
"Set-Cookie": sessionCookie,
|
||||
})); err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteSession(ctx context.Context, _ *v1pb.DeleteSessionRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not found")
|
||||
}
|
||||
|
||||
// Check if we have a session ID (from cookie-based auth)
|
||||
if sessionID, ok := ctx.Value(sessionIDContextKey).(string); ok && sessionID != "" {
|
||||
// Remove session from user settings
|
||||
if err := s.Store.RemoveUserSession(ctx, user.ID, sessionID); err != nil {
|
||||
slog.Error("failed to remove user session", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.clearAuthCookies(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies, error: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
|
||||
// Clear session cookie
|
||||
sessionCookie, err := s.buildSessionCookie(ctx, "", time.Time{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to build session cookie")
|
||||
}
|
||||
|
||||
// Set both cookies in the response
|
||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||
"Set-Cookie": sessionCookie,
|
||||
})); err != nil {
|
||||
return errors.Wrap(err, "failed to set grpc header")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) buildSessionCookie(ctx context.Context, sessionCookieValue string, expireTime time.Time) (string, error) {
|
||||
attrs := []string{
|
||||
fmt.Sprintf("%s=%s", SessionCookieName, sessionCookieValue),
|
||||
"Path=/",
|
||||
"HttpOnly",
|
||||
}
|
||||
if expireTime.IsZero() {
|
||||
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
|
||||
} else {
|
||||
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
|
||||
}
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", errors.New("failed to get metadata from context")
|
||||
}
|
||||
var origin string
|
||||
for _, v := range md.Get("origin") {
|
||||
origin = v
|
||||
}
|
||||
isHTTPS := strings.HasPrefix(origin, "https://")
|
||||
if isHTTPS {
|
||||
attrs = append(attrs, "SameSite=None")
|
||||
attrs = append(attrs, "Secure")
|
||||
} else {
|
||||
attrs = append(attrs, "SameSite=Strict")
|
||||
}
|
||||
return strings.Join(attrs, "; "), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) {
|
||||
userID, ok := ctx.Value(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.Errorf("user %d not found", userID)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Helper function to track user session for session management.
|
||||
func (s *APIV1Service) trackUserSession(ctx context.Context, userID int32, sessionID string) error {
|
||||
// Extract client information from the context
|
||||
clientInfo := s.extractClientInfo(ctx)
|
||||
|
||||
session := &storepb.SessionsUserSetting_Session{
|
||||
SessionId: sessionID,
|
||||
CreateTime: timestamppb.Now(),
|
||||
LastAccessedTime: timestamppb.Now(),
|
||||
ClientInfo: clientInfo,
|
||||
}
|
||||
|
||||
return s.Store.AddUserSession(ctx, userID, session)
|
||||
}
|
||||
|
||||
// Helper function to extract client information from the gRPC context.
|
||||
// extractClientInfo extracts comprehensive client information from the request context.
|
||||
// This includes user agent parsing to determine device type, operating system, browser,
|
||||
// and IP address extraction. This information is used to provide detailed session
|
||||
// tracking and management capabilities in the web UI.
|
||||
//
|
||||
// Fields populated:
|
||||
// - UserAgent: Raw user agent string
|
||||
// - IpAddress: Client IP (from X-Forwarded-For or X-Real-IP headers)
|
||||
// - DeviceType: "mobile", "tablet", or "desktop"
|
||||
// - Os: Operating system name and version (e.g., "iOS 17.1", "Windows 10/11")
|
||||
// - Browser: Browser name and version (e.g., "Chrome 120.0.0.0")
|
||||
// - Country: Geographic location (TODO: implement with GeoIP service).
|
||||
func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo {
|
||||
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
|
||||
|
||||
// Extract user agent from metadata if available
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
|
||||
userAgent := userAgents[0]
|
||||
clientInfo.UserAgent = userAgent
|
||||
|
||||
// Parse user agent to extract device type, OS, browser info
|
||||
s.parseUserAgent(userAgent, clientInfo)
|
||||
}
|
||||
if forwardedFor := md.Get("x-forwarded-for"); len(forwardedFor) > 0 {
|
||||
ipAddress := strings.Split(forwardedFor[0], ",")[0] // Get the first IP in case of multiple
|
||||
ipAddress = strings.TrimSpace(ipAddress)
|
||||
clientInfo.IpAddress = ipAddress
|
||||
} else if realIP := md.Get("x-real-ip"); len(realIP) > 0 {
|
||||
clientInfo.IpAddress = realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
return clientInfo
|
||||
}
|
||||
|
||||
// parseUserAgent extracts device type, OS, and browser information from user agent string.
|
||||
func (*APIV1Service) parseUserAgent(userAgent string, clientInfo *storepb.SessionsUserSetting_ClientInfo) {
|
||||
if userAgent == "" {
|
||||
return
|
||||
}
|
||||
|
||||
userAgent = strings.ToLower(userAgent)
|
||||
|
||||
// Detect device type
|
||||
if strings.Contains(userAgent, "ipad") {
|
||||
clientInfo.DeviceType = "tablet"
|
||||
} else if strings.Contains(userAgent, "mobile") || strings.Contains(userAgent, "android") ||
|
||||
strings.Contains(userAgent, "iphone") || strings.Contains(userAgent, "ipod") ||
|
||||
strings.Contains(userAgent, "windows phone") || strings.Contains(userAgent, "blackberry") {
|
||||
clientInfo.DeviceType = "mobile"
|
||||
} else if strings.Contains(userAgent, "tablet") {
|
||||
clientInfo.DeviceType = "tablet"
|
||||
} else {
|
||||
clientInfo.DeviceType = "desktop"
|
||||
}
|
||||
|
||||
// Detect operating system
|
||||
if strings.Contains(userAgent, "iphone os") || strings.Contains(userAgent, "cpu os") {
|
||||
// Extract iOS version
|
||||
if idx := strings.Index(userAgent, "cpu os "); idx != -1 {
|
||||
versionStart := idx + 7
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd != -1 {
|
||||
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
|
||||
clientInfo.Os = "iOS " + version
|
||||
} else {
|
||||
clientInfo.Os = "iOS"
|
||||
}
|
||||
} else if idx := strings.Index(userAgent, "iphone os "); idx != -1 {
|
||||
versionStart := idx + 10
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd != -1 {
|
||||
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
|
||||
clientInfo.Os = "iOS " + version
|
||||
} else {
|
||||
clientInfo.Os = "iOS"
|
||||
}
|
||||
} else {
|
||||
clientInfo.Os = "iOS"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "android") {
|
||||
// Extract Android version
|
||||
if idx := strings.Index(userAgent, "android "); idx != -1 {
|
||||
versionStart := idx + 8
|
||||
versionEnd := strings.Index(userAgent[versionStart:], ";")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = strings.Index(userAgent[versionStart:], ")")
|
||||
}
|
||||
if versionEnd != -1 {
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Os = "Android " + version
|
||||
} else {
|
||||
clientInfo.Os = "Android"
|
||||
}
|
||||
} else {
|
||||
clientInfo.Os = "Android"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "windows nt 10.0") {
|
||||
clientInfo.Os = "Windows 10/11"
|
||||
} else if strings.Contains(userAgent, "windows nt 6.3") {
|
||||
clientInfo.Os = "Windows 8.1"
|
||||
} else if strings.Contains(userAgent, "windows nt 6.1") {
|
||||
clientInfo.Os = "Windows 7"
|
||||
} else if strings.Contains(userAgent, "windows") {
|
||||
clientInfo.Os = "Windows"
|
||||
} else if strings.Contains(userAgent, "mac os x") {
|
||||
// Extract macOS version
|
||||
if idx := strings.Index(userAgent, "mac os x "); idx != -1 {
|
||||
versionStart := idx + 9
|
||||
versionEnd := strings.Index(userAgent[versionStart:], ";")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = strings.Index(userAgent[versionStart:], ")")
|
||||
}
|
||||
if versionEnd != -1 {
|
||||
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
|
||||
clientInfo.Os = "macOS " + version
|
||||
} else {
|
||||
clientInfo.Os = "macOS"
|
||||
}
|
||||
} else {
|
||||
clientInfo.Os = "macOS"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "linux") {
|
||||
clientInfo.Os = "Linux"
|
||||
} else if strings.Contains(userAgent, "cros") {
|
||||
clientInfo.Os = "Chrome OS"
|
||||
}
|
||||
|
||||
// Detect browser
|
||||
if strings.Contains(userAgent, "edg/") {
|
||||
// Extract Edge version
|
||||
if idx := strings.Index(userAgent, "edg/"); idx != -1 {
|
||||
versionStart := idx + 4
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Edge " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Edge"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "chrome/") && !strings.Contains(userAgent, "edg") {
|
||||
// Extract Chrome version
|
||||
if idx := strings.Index(userAgent, "chrome/"); idx != -1 {
|
||||
versionStart := idx + 7
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Chrome " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Chrome"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "firefox/") {
|
||||
// Extract Firefox version
|
||||
if idx := strings.Index(userAgent, "firefox/"); idx != -1 {
|
||||
versionStart := idx + 8
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Firefox " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Firefox"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "safari/") && !strings.Contains(userAgent, "chrome") && !strings.Contains(userAgent, "edg") {
|
||||
// Extract Safari version
|
||||
if idx := strings.Index(userAgent, "version/"); idx != -1 {
|
||||
versionStart := idx + 8
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd == -1 {
|
||||
versionEnd = len(userAgent) - versionStart
|
||||
}
|
||||
version := userAgent[versionStart : versionStart+versionEnd]
|
||||
clientInfo.Browser = "Safari " + version
|
||||
} else {
|
||||
clientInfo.Browser = "Safari"
|
||||
}
|
||||
} else if strings.Contains(userAgent, "opera/") || strings.Contains(userAgent, "opr/") {
|
||||
clientInfo.Browser = "Opera"
|
||||
}
|
||||
}
|
||||
179
server/router/api/v1/auth_service_client_info_test.go
Normal file
179
server/router/api/v1/auth_service_client_info_test.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func TestParseUserAgent(t *testing.T) {
|
||||
service := &APIV1Service{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
expectedDevice string
|
||||
expectedOS string
|
||||
expectedBrowser string
|
||||
}{
|
||||
{
|
||||
name: "Chrome on Windows",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "Windows 10/11",
|
||||
expectedBrowser: "Chrome 119.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "Safari on macOS",
|
||||
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "macOS 10.15.7",
|
||||
expectedBrowser: "Safari 17.0",
|
||||
},
|
||||
{
|
||||
name: "Chrome on Android Mobile",
|
||||
userAgent: "Mozilla/5.0 (Linux; Android 13; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Mobile Safari/537.36",
|
||||
expectedDevice: "mobile",
|
||||
expectedOS: "Android 13",
|
||||
expectedBrowser: "Chrome 119.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "Safari on iPhone",
|
||||
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
|
||||
expectedDevice: "mobile",
|
||||
expectedOS: "iOS 17.0",
|
||||
expectedBrowser: "Safari 17.0",
|
||||
},
|
||||
{
|
||||
name: "Firefox on Windows",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/119.0",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "Windows 10/11",
|
||||
expectedBrowser: "Firefox 119.0",
|
||||
},
|
||||
{
|
||||
name: "Edge on Windows",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0",
|
||||
expectedDevice: "desktop",
|
||||
expectedOS: "Windows 10/11",
|
||||
expectedBrowser: "Edge 119.0.0.0",
|
||||
},
|
||||
{
|
||||
name: "iPad Safari",
|
||||
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
|
||||
expectedDevice: "tablet",
|
||||
expectedOS: "iOS 17.0",
|
||||
expectedBrowser: "Safari 17.0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
|
||||
service.parseUserAgent(tt.userAgent, clientInfo)
|
||||
|
||||
if clientInfo.DeviceType != tt.expectedDevice {
|
||||
t.Errorf("Expected device type %s, got %s", tt.expectedDevice, clientInfo.DeviceType)
|
||||
}
|
||||
if clientInfo.Os != tt.expectedOS {
|
||||
t.Errorf("Expected OS %s, got %s", tt.expectedOS, clientInfo.Os)
|
||||
}
|
||||
if clientInfo.Browser != tt.expectedBrowser {
|
||||
t.Errorf("Expected browser %s, got %s", tt.expectedBrowser, clientInfo.Browser)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClientInfo(t *testing.T) {
|
||||
service := &APIV1Service{}
|
||||
|
||||
// Test with metadata containing user agent and IP
|
||||
md := metadata.New(map[string]string{
|
||||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
|
||||
"x-forwarded-for": "203.0.113.1, 198.51.100.1",
|
||||
"x-real-ip": "203.0.113.1",
|
||||
})
|
||||
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
clientInfo := service.extractClientInfo(ctx)
|
||||
|
||||
if clientInfo.UserAgent == "" {
|
||||
t.Error("Expected user agent to be set")
|
||||
}
|
||||
if clientInfo.IpAddress != "203.0.113.1" {
|
||||
t.Errorf("Expected IP address to be 203.0.113.1, got %s", clientInfo.IpAddress)
|
||||
}
|
||||
if clientInfo.DeviceType != "desktop" {
|
||||
t.Errorf("Expected device type to be desktop, got %s", clientInfo.DeviceType)
|
||||
}
|
||||
if clientInfo.Os != "Windows 10/11" {
|
||||
t.Errorf("Expected OS to be Windows 10/11, got %s", clientInfo.Os)
|
||||
}
|
||||
if clientInfo.Browser != "Chrome 119.0.0.0" {
|
||||
t.Errorf("Expected browser to be Chrome 119.0.0.0, got %s", clientInfo.Browser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientInfoExamples demonstrates the enhanced client info extraction with various user agents.
|
||||
func TestClientInfoExamples(t *testing.T) {
|
||||
service := &APIV1Service{}
|
||||
|
||||
examples := []struct {
|
||||
description string
|
||||
userAgent string
|
||||
}{
|
||||
{
|
||||
description: "Modern Chrome on Windows 11",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
},
|
||||
{
|
||||
description: "Safari on iPhone 15 Pro",
|
||||
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
|
||||
},
|
||||
{
|
||||
description: "Chrome on Samsung Galaxy",
|
||||
userAgent: "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36",
|
||||
},
|
||||
{
|
||||
description: "Firefox on Ubuntu",
|
||||
userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/120.0",
|
||||
},
|
||||
{
|
||||
description: "Edge on Windows 10",
|
||||
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
||||
},
|
||||
{
|
||||
description: "Safari on iPad Air",
|
||||
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, example := range examples {
|
||||
t.Run(example.description, func(t *testing.T) {
|
||||
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
|
||||
service.parseUserAgent(example.userAgent, clientInfo)
|
||||
|
||||
t.Logf("User Agent: %s", example.userAgent)
|
||||
t.Logf("Device Type: %s", clientInfo.DeviceType)
|
||||
t.Logf("Operating System: %s", clientInfo.Os)
|
||||
t.Logf("Browser: %s", clientInfo.Browser)
|
||||
t.Logf("---")
|
||||
|
||||
// Ensure all fields are populated
|
||||
if clientInfo.DeviceType == "" {
|
||||
t.Error("Device type should not be empty")
|
||||
}
|
||||
if clientInfo.Os == "" {
|
||||
t.Error("OS should not be empty")
|
||||
}
|
||||
if clientInfo.Browser == "" {
|
||||
t.Error("Browser should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
70
server/router/api/v1/common.go
Normal file
70
server/router/api/v1/common.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPageSize is the default page size for requests.
|
||||
DefaultPageSize = 10
|
||||
// MaxPageSize is the maximum page size for requests.
|
||||
MaxPageSize = 1000
|
||||
)
|
||||
|
||||
func convertStateFromStore(rowStatus store.RowStatus) v1pb.State {
|
||||
switch rowStatus {
|
||||
case store.Normal:
|
||||
return v1pb.State_NORMAL
|
||||
case store.Archived:
|
||||
return v1pb.State_ARCHIVED
|
||||
default:
|
||||
return v1pb.State_STATE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertStateToStore(state v1pb.State) store.RowStatus {
|
||||
switch state {
|
||||
case v1pb.State_NORMAL:
|
||||
return store.Normal
|
||||
case v1pb.State_ARCHIVED:
|
||||
return store.Archived
|
||||
default:
|
||||
return store.Normal
|
||||
}
|
||||
}
|
||||
|
||||
func getPageToken(limit int, offset int) (string, error) {
|
||||
return marshalPageToken(&v1pb.PageToken{
|
||||
Limit: int32(limit),
|
||||
Offset: int32(offset),
|
||||
})
|
||||
}
|
||||
|
||||
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
|
||||
b, err := proto.Marshal(pageToken)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to marshal page token")
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
|
||||
b, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to decode page token")
|
||||
}
|
||||
if err := proto.Unmarshal(b, pageToken); err != nil {
|
||||
return errors.Wrapf(err, "failed to unmarshal page token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSuperUser(user *store.User) bool {
|
||||
return user.Role == store.RoleAdmin || user.Role == store.RoleHost
|
||||
}
|
||||
21
server/router/api/v1/health_service.go
Normal file
21
server/router/api/v1/health_service.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) Check(ctx context.Context,
|
||||
_ *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
|
||||
history, err := s.Store.GetDriver().FindMigrationHistoryList(ctx, &store.FindMigrationHistory{})
|
||||
if err != nil || len(history) == 0 {
|
||||
return nil, status.Errorf(codes.Unavailable, "not available")
|
||||
}
|
||||
|
||||
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil
|
||||
}
|
||||
183
server/router/api/v1/idp_service.go
Normal file
183
server/router/api/v1/idp_service.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
|
||||
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListIdentityProvidersResponse{
|
||||
IdentityProviders: []*v1pb.IdentityProvider{},
|
||||
}
|
||||
for _, identityProvider := range identityProviders {
|
||||
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||
}
|
||||
|
||||
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
update := &store.UpdateIdentityProviderV1{
|
||||
ID: id,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
switch field {
|
||||
case "title":
|
||||
update.Name = &request.IdentityProvider.Title
|
||||
case "identifier_filter":
|
||||
update.IdentifierFilter = &request.IdentityProvider.IdentifierFilter
|
||||
case "config":
|
||||
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
|
||||
}
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
|
||||
// Check if the identity provider exists before trying to delete it
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
|
||||
temp := &v1pb.IdentityProvider{
|
||||
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
|
||||
Title: identityProvider.Name,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
}
|
||||
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
|
||||
oauth2Config := identityProvider.Config.GetOauth2Config()
|
||||
temp.Config = &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: oauth2Config.ClientId,
|
||||
ClientSecret: oauth2Config.ClientSecret,
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: oauth2Config.FieldMapping.Identifier,
|
||||
DisplayName: oauth2Config.FieldMapping.DisplayName,
|
||||
Email: oauth2Config.FieldMapping.Email,
|
||||
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return temp
|
||||
}
|
||||
|
||||
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
|
||||
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
|
||||
|
||||
temp := &storepb.IdentityProvider{
|
||||
Id: id,
|
||||
Name: identityProvider.Title,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
|
||||
}
|
||||
return temp
|
||||
}
|
||||
|
||||
func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
|
||||
if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
|
||||
oauth2Config := config.GetOauth2Config()
|
||||
return &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: oauth2Config.ClientId,
|
||||
ClientSecret: oauth2Config.ClientSecret,
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: oauth2Config.FieldMapping.Identifier,
|
||||
DisplayName: oauth2Config.FieldMapping.DisplayName,
|
||||
Email: oauth2Config.FieldMapping.Email,
|
||||
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
226
server/router/api/v1/inbox_service.go
Normal file
226
server/router/api/v1/inbox_service.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListInboxes(ctx context.Context, request *v1pb.ListInboxesRequest) (*v1pb.ListInboxesResponse, error) {
|
||||
// Extract user ID from parent resource name
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent name %q: %v", request.Parent, err)
|
||||
}
|
||||
|
||||
// Get current user for authorization
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Check if current user can access the requested user's inboxes
|
||||
if currentUser.ID != userID {
|
||||
// Only allow hosts and admins to access other users' inboxes
|
||||
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "cannot access inboxes for user %q", request.Parent)
|
||||
}
|
||||
}
|
||||
|
||||
var limit, offset int
|
||||
if request.PageToken != "" {
|
||||
var pageToken v1pb.PageToken
|
||||
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
|
||||
}
|
||||
limit = int(pageToken.Limit)
|
||||
offset = int(pageToken.Offset)
|
||||
} else {
|
||||
limit = int(request.PageSize)
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = DefaultPageSize
|
||||
}
|
||||
if limit > MaxPageSize {
|
||||
limit = MaxPageSize
|
||||
}
|
||||
limitPlusOne := limit + 1
|
||||
|
||||
findInbox := &store.FindInbox{
|
||||
ReceiverID: &userID,
|
||||
Limit: &limitPlusOne,
|
||||
Offset: &offset,
|
||||
}
|
||||
|
||||
inboxes, err := s.Store.ListInboxes(ctx, findInbox)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list inboxes: %v", err)
|
||||
}
|
||||
|
||||
inboxMessages := []*v1pb.Inbox{}
|
||||
nextPageToken := ""
|
||||
if len(inboxes) == limitPlusOne {
|
||||
inboxes = inboxes[:limit]
|
||||
nextPageToken, err = getPageToken(limit, offset+limit)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get next page token: %v", err)
|
||||
}
|
||||
}
|
||||
for _, inbox := range inboxes {
|
||||
inboxMessage := convertInboxFromStore(inbox)
|
||||
if inboxMessage.Type == v1pb.Inbox_TYPE_UNSPECIFIED {
|
||||
continue
|
||||
}
|
||||
inboxMessages = append(inboxMessages, inboxMessage)
|
||||
}
|
||||
|
||||
response := &v1pb.ListInboxesResponse{
|
||||
Inboxes: inboxMessages,
|
||||
NextPageToken: nextPageToken,
|
||||
TotalSize: int32(len(inboxMessages)), // For now, use actual returned count
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateInbox(ctx context.Context, request *v1pb.UpdateInboxRequest) (*v1pb.Inbox, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
inboxID, err := ExtractInboxIDFromName(request.Inbox.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name %q: %v", request.Inbox.Name, err)
|
||||
}
|
||||
|
||||
// Get current user for authorization
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Get the existing inbox to verify ownership
|
||||
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ID: &inboxID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get inbox: %v", err)
|
||||
}
|
||||
if len(inboxes) == 0 {
|
||||
return nil, status.Errorf(codes.NotFound, "inbox %q not found", request.Inbox.Name)
|
||||
}
|
||||
existingInbox := inboxes[0]
|
||||
|
||||
// Check if current user can update this inbox (must be the receiver)
|
||||
if currentUser.ID != existingInbox.ReceiverID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "cannot update inbox for another user")
|
||||
}
|
||||
|
||||
update := &store.UpdateInbox{
|
||||
ID: inboxID,
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "status" {
|
||||
if request.Inbox.Status == v1pb.Inbox_STATUS_UNSPECIFIED {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "status cannot be unspecified")
|
||||
}
|
||||
update.Status = convertInboxStatusToStore(request.Inbox.Status)
|
||||
} else {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "unsupported field in update mask: %q", field)
|
||||
}
|
||||
}
|
||||
|
||||
inbox, err := s.Store.UpdateInbox(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
|
||||
}
|
||||
|
||||
return convertInboxFromStore(inbox), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteInbox(ctx context.Context, request *v1pb.DeleteInboxRequest) (*emptypb.Empty, error) {
|
||||
inboxID, err := ExtractInboxIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name %q: %v", request.Name, err)
|
||||
}
|
||||
|
||||
// Get current user for authorization
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Get the existing inbox to verify ownership
|
||||
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ID: &inboxID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get inbox: %v", err)
|
||||
}
|
||||
if len(inboxes) == 0 {
|
||||
return nil, status.Errorf(codes.NotFound, "inbox %q not found", request.Name)
|
||||
}
|
||||
existingInbox := inboxes[0]
|
||||
|
||||
// Check if current user can delete this inbox (must be the receiver)
|
||||
if currentUser.ID != existingInbox.ReceiverID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "cannot delete inbox for another user")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteInbox(ctx, &store.DeleteInbox{
|
||||
ID: inboxID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete inbox: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertInboxFromStore(inbox *store.Inbox) *v1pb.Inbox {
|
||||
return &v1pb.Inbox{
|
||||
Name: fmt.Sprintf("%s%d", InboxNamePrefix, inbox.ID),
|
||||
Sender: fmt.Sprintf("%s%d", UserNamePrefix, inbox.SenderID),
|
||||
Receiver: fmt.Sprintf("%s%d", UserNamePrefix, inbox.ReceiverID),
|
||||
Status: convertInboxStatusFromStore(inbox.Status),
|
||||
CreateTime: timestamppb.New(time.Unix(inbox.CreatedTs, 0)),
|
||||
Type: v1pb.Inbox_Type(inbox.Message.Type),
|
||||
ActivityId: inbox.Message.ActivityId,
|
||||
}
|
||||
}
|
||||
|
||||
func convertInboxStatusFromStore(status store.InboxStatus) v1pb.Inbox_Status {
|
||||
switch status {
|
||||
case store.UNREAD:
|
||||
return v1pb.Inbox_UNREAD
|
||||
case store.ARCHIVED:
|
||||
return v1pb.Inbox_ARCHIVED
|
||||
default:
|
||||
return v1pb.Inbox_STATUS_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertInboxStatusToStore(status v1pb.Inbox_Status) store.InboxStatus {
|
||||
switch status {
|
||||
case v1pb.Inbox_UNREAD:
|
||||
return store.UNREAD
|
||||
case v1pb.Inbox_ARCHIVED:
|
||||
return store.ARCHIVED
|
||||
default:
|
||||
return store.UNREAD
|
||||
}
|
||||
}
|
||||
48
server/router/api/v1/logger_interceptor.go
Normal file
48
server/router/api/v1/logger_interceptor.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type LoggerInterceptor struct {
|
||||
}
|
||||
|
||||
func NewLoggerInterceptor() *LoggerInterceptor {
|
||||
return &LoggerInterceptor{}
|
||||
}
|
||||
|
||||
func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
resp, err := handler(ctx, request)
|
||||
in.loggerInterceptorDo(ctx, serverInfo.FullMethod, err)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (*LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) {
|
||||
st := status.Convert(err)
|
||||
var logLevel slog.Level
|
||||
var logMsg string
|
||||
switch st.Code() {
|
||||
case codes.OK:
|
||||
logLevel = slog.LevelInfo
|
||||
logMsg = "OK"
|
||||
case codes.Unauthenticated, codes.OutOfRange, codes.PermissionDenied, codes.NotFound:
|
||||
logLevel = slog.LevelInfo
|
||||
logMsg = "client error"
|
||||
case codes.Internal, codes.Unknown, codes.DataLoss, codes.Unavailable, codes.DeadlineExceeded:
|
||||
logLevel = slog.LevelError
|
||||
logMsg = "server error"
|
||||
default:
|
||||
logLevel = slog.LevelError
|
||||
logMsg = "unknown error"
|
||||
}
|
||||
logAttrs := []slog.Attr{slog.String("method", fullMethod)}
|
||||
if err != nil {
|
||||
logAttrs = append(logAttrs, slog.String("error", err.Error()))
|
||||
}
|
||||
slog.LogAttrs(ctx, logLevel, logMsg, logAttrs...)
|
||||
}
|
||||
279
server/router/api/v1/markdown_service.go
Normal file
279
server/router/api/v1/markdown_service.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/gomark/ast"
|
||||
"github.com/usememos/gomark/parser"
|
||||
"github.com/usememos/gomark/parser/tokenizer"
|
||||
"github.com/usememos/gomark/renderer"
|
||||
"github.com/usememos/gomark/restore"
|
||||
|
||||
"github.com/usememos/memos/plugin/httpgetter"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func (*APIV1Service) ParseMarkdown(_ context.Context, request *v1pb.ParseMarkdownRequest) (*v1pb.ParseMarkdownResponse, error) {
|
||||
rawNodes, err := parser.Parse(tokenizer.Tokenize(request.Markdown))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse memo content")
|
||||
}
|
||||
|
||||
nodes := convertFromASTNodes(rawNodes)
|
||||
return &v1pb.ParseMarkdownResponse{
|
||||
Nodes: nodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) RestoreMarkdownNodes(_ context.Context, request *v1pb.RestoreMarkdownNodesRequest) (*v1pb.RestoreMarkdownNodesResponse, error) {
|
||||
markdown := restore.Restore(convertToASTNodes(request.Nodes))
|
||||
return &v1pb.RestoreMarkdownNodesResponse{
|
||||
Markdown: markdown,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) StringifyMarkdownNodes(_ context.Context, request *v1pb.StringifyMarkdownNodesRequest) (*v1pb.StringifyMarkdownNodesResponse, error) {
|
||||
stringRenderer := renderer.NewStringRenderer()
|
||||
plainText := stringRenderer.Render(convertToASTNodes(request.Nodes))
|
||||
return &v1pb.StringifyMarkdownNodesResponse{
|
||||
PlainText: plainText,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) GetLinkMetadata(_ context.Context, request *v1pb.GetLinkMetadataRequest) (*v1pb.LinkMetadata, error) {
|
||||
htmlMeta, err := httpgetter.GetHTMLMeta(request.Link)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1pb.LinkMetadata{
|
||||
Title: htmlMeta.Title,
|
||||
Description: htmlMeta.Description,
|
||||
Image: htmlMeta.Image,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertFromASTNode(rawNode ast.Node) *v1pb.Node {
|
||||
node := &v1pb.Node{
|
||||
Type: v1pb.NodeType(v1pb.NodeType_value[string(rawNode.Type())]),
|
||||
}
|
||||
|
||||
switch n := rawNode.(type) {
|
||||
case *ast.LineBreak:
|
||||
node.Node = &v1pb.Node_LineBreakNode{}
|
||||
case *ast.Paragraph:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_ParagraphNode{ParagraphNode: &v1pb.ParagraphNode{Children: children}}
|
||||
case *ast.CodeBlock:
|
||||
node.Node = &v1pb.Node_CodeBlockNode{CodeBlockNode: &v1pb.CodeBlockNode{Language: n.Language, Content: n.Content}}
|
||||
case *ast.Heading:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_HeadingNode{HeadingNode: &v1pb.HeadingNode{Level: int32(n.Level), Children: children}}
|
||||
case *ast.HorizontalRule:
|
||||
node.Node = &v1pb.Node_HorizontalRuleNode{HorizontalRuleNode: &v1pb.HorizontalRuleNode{Symbol: n.Symbol}}
|
||||
case *ast.Blockquote:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_BlockquoteNode{BlockquoteNode: &v1pb.BlockquoteNode{Children: children}}
|
||||
case *ast.List:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_ListNode{ListNode: &v1pb.ListNode{Kind: convertListKindFromASTNode(n.Kind), Indent: int32(n.Indent), Children: children}}
|
||||
case *ast.OrderedListItem:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_OrderedListItemNode{OrderedListItemNode: &v1pb.OrderedListItemNode{Number: n.Number, Indent: int32(n.Indent), Children: children}}
|
||||
case *ast.UnorderedListItem:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_UnorderedListItemNode{UnorderedListItemNode: &v1pb.UnorderedListItemNode{Symbol: n.Symbol, Indent: int32(n.Indent), Children: children}}
|
||||
case *ast.TaskListItem:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_TaskListItemNode{TaskListItemNode: &v1pb.TaskListItemNode{Symbol: n.Symbol, Indent: int32(n.Indent), Complete: n.Complete, Children: children}}
|
||||
case *ast.MathBlock:
|
||||
node.Node = &v1pb.Node_MathBlockNode{MathBlockNode: &v1pb.MathBlockNode{Content: n.Content}}
|
||||
case *ast.Table:
|
||||
node.Node = &v1pb.Node_TableNode{TableNode: convertTableFromASTNode(n)}
|
||||
case *ast.EmbeddedContent:
|
||||
node.Node = &v1pb.Node_EmbeddedContentNode{EmbeddedContentNode: &v1pb.EmbeddedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
|
||||
case *ast.Text:
|
||||
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{Content: n.Content}}
|
||||
case *ast.Bold:
|
||||
node.Node = &v1pb.Node_BoldNode{BoldNode: &v1pb.BoldNode{Symbol: n.Symbol, Children: convertFromASTNodes(n.Children)}}
|
||||
case *ast.Italic:
|
||||
node.Node = &v1pb.Node_ItalicNode{ItalicNode: &v1pb.ItalicNode{Symbol: n.Symbol, Children: convertFromASTNodes(n.Children)}}
|
||||
case *ast.BoldItalic:
|
||||
node.Node = &v1pb.Node_BoldItalicNode{BoldItalicNode: &v1pb.BoldItalicNode{Symbol: n.Symbol, Content: n.Content}}
|
||||
case *ast.Code:
|
||||
node.Node = &v1pb.Node_CodeNode{CodeNode: &v1pb.CodeNode{Content: n.Content}}
|
||||
case *ast.Image:
|
||||
node.Node = &v1pb.Node_ImageNode{ImageNode: &v1pb.ImageNode{AltText: n.AltText, Url: n.URL}}
|
||||
case *ast.Link:
|
||||
node.Node = &v1pb.Node_LinkNode{LinkNode: &v1pb.LinkNode{Content: convertFromASTNodes(n.Content), Url: n.URL}}
|
||||
case *ast.AutoLink:
|
||||
node.Node = &v1pb.Node_AutoLinkNode{AutoLinkNode: &v1pb.AutoLinkNode{Url: n.URL, IsRawText: n.IsRawText}}
|
||||
case *ast.Tag:
|
||||
node.Node = &v1pb.Node_TagNode{TagNode: &v1pb.TagNode{Content: n.Content}}
|
||||
case *ast.Strikethrough:
|
||||
node.Node = &v1pb.Node_StrikethroughNode{StrikethroughNode: &v1pb.StrikethroughNode{Content: n.Content}}
|
||||
case *ast.EscapingCharacter:
|
||||
node.Node = &v1pb.Node_EscapingCharacterNode{EscapingCharacterNode: &v1pb.EscapingCharacterNode{Symbol: n.Symbol}}
|
||||
case *ast.Math:
|
||||
node.Node = &v1pb.Node_MathNode{MathNode: &v1pb.MathNode{Content: n.Content}}
|
||||
case *ast.Highlight:
|
||||
node.Node = &v1pb.Node_HighlightNode{HighlightNode: &v1pb.HighlightNode{Content: n.Content}}
|
||||
case *ast.Subscript:
|
||||
node.Node = &v1pb.Node_SubscriptNode{SubscriptNode: &v1pb.SubscriptNode{Content: n.Content}}
|
||||
case *ast.Superscript:
|
||||
node.Node = &v1pb.Node_SuperscriptNode{SuperscriptNode: &v1pb.SuperscriptNode{Content: n.Content}}
|
||||
case *ast.ReferencedContent:
|
||||
node.Node = &v1pb.Node_ReferencedContentNode{ReferencedContentNode: &v1pb.ReferencedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
|
||||
case *ast.Spoiler:
|
||||
node.Node = &v1pb.Node_SpoilerNode{SpoilerNode: &v1pb.SpoilerNode{Content: n.Content}}
|
||||
case *ast.HTMLElement:
|
||||
node.Node = &v1pb.Node_HtmlElementNode{HtmlElementNode: &v1pb.HTMLElementNode{TagName: n.TagName, Attributes: n.Attributes}}
|
||||
default:
|
||||
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{}}
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
func convertFromASTNodes(rawNodes []ast.Node) []*v1pb.Node {
|
||||
nodes := []*v1pb.Node{}
|
||||
for _, rawNode := range rawNodes {
|
||||
node := convertFromASTNode(rawNode)
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
func convertTableFromASTNode(node *ast.Table) *v1pb.TableNode {
|
||||
table := &v1pb.TableNode{
|
||||
Header: convertFromASTNodes(node.Header),
|
||||
Delimiter: node.Delimiter,
|
||||
}
|
||||
for _, row := range node.Rows {
|
||||
table.Rows = append(table.Rows, &v1pb.TableNode_Row{Cells: convertFromASTNodes(row)})
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func convertListKindFromASTNode(node ast.ListKind) v1pb.ListNode_Kind {
|
||||
switch node {
|
||||
case ast.OrderedList:
|
||||
return v1pb.ListNode_ORDERED
|
||||
case ast.UnorderedList:
|
||||
return v1pb.ListNode_UNORDERED
|
||||
case ast.DescrpitionList:
|
||||
return v1pb.ListNode_DESCRIPTION
|
||||
default:
|
||||
return v1pb.ListNode_KIND_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertToASTNode(node *v1pb.Node) ast.Node {
|
||||
switch n := node.Node.(type) {
|
||||
case *v1pb.Node_LineBreakNode:
|
||||
return &ast.LineBreak{}
|
||||
case *v1pb.Node_ParagraphNode:
|
||||
children := convertToASTNodes(n.ParagraphNode.Children)
|
||||
return &ast.Paragraph{Children: children}
|
||||
case *v1pb.Node_CodeBlockNode:
|
||||
return &ast.CodeBlock{Language: n.CodeBlockNode.Language, Content: n.CodeBlockNode.Content}
|
||||
case *v1pb.Node_HeadingNode:
|
||||
children := convertToASTNodes(n.HeadingNode.Children)
|
||||
return &ast.Heading{Level: int(n.HeadingNode.Level), Children: children}
|
||||
case *v1pb.Node_HorizontalRuleNode:
|
||||
return &ast.HorizontalRule{Symbol: n.HorizontalRuleNode.Symbol}
|
||||
case *v1pb.Node_BlockquoteNode:
|
||||
children := convertToASTNodes(n.BlockquoteNode.Children)
|
||||
return &ast.Blockquote{Children: children}
|
||||
case *v1pb.Node_ListNode:
|
||||
children := convertToASTNodes(n.ListNode.Children)
|
||||
return &ast.List{Kind: convertListKindToASTNode(n.ListNode.Kind), Indent: int(n.ListNode.Indent), Children: children}
|
||||
case *v1pb.Node_OrderedListItemNode:
|
||||
children := convertToASTNodes(n.OrderedListItemNode.Children)
|
||||
return &ast.OrderedListItem{Number: n.OrderedListItemNode.Number, Indent: int(n.OrderedListItemNode.Indent), Children: children}
|
||||
case *v1pb.Node_UnorderedListItemNode:
|
||||
children := convertToASTNodes(n.UnorderedListItemNode.Children)
|
||||
return &ast.UnorderedListItem{Symbol: n.UnorderedListItemNode.Symbol, Indent: int(n.UnorderedListItemNode.Indent), Children: children}
|
||||
case *v1pb.Node_TaskListItemNode:
|
||||
children := convertToASTNodes(n.TaskListItemNode.Children)
|
||||
return &ast.TaskListItem{Symbol: n.TaskListItemNode.Symbol, Indent: int(n.TaskListItemNode.Indent), Complete: n.TaskListItemNode.Complete, Children: children}
|
||||
case *v1pb.Node_MathBlockNode:
|
||||
return &ast.MathBlock{Content: n.MathBlockNode.Content}
|
||||
case *v1pb.Node_TableNode:
|
||||
return convertTableToASTNode(n.TableNode)
|
||||
case *v1pb.Node_EmbeddedContentNode:
|
||||
return &ast.EmbeddedContent{ResourceName: n.EmbeddedContentNode.ResourceName, Params: n.EmbeddedContentNode.Params}
|
||||
case *v1pb.Node_TextNode:
|
||||
return &ast.Text{Content: n.TextNode.Content}
|
||||
case *v1pb.Node_BoldNode:
|
||||
return &ast.Bold{Symbol: n.BoldNode.Symbol, Children: convertToASTNodes(n.BoldNode.Children)}
|
||||
case *v1pb.Node_ItalicNode:
|
||||
return &ast.Italic{Symbol: n.ItalicNode.Symbol, Children: convertToASTNodes(n.ItalicNode.Children)}
|
||||
case *v1pb.Node_BoldItalicNode:
|
||||
return &ast.BoldItalic{Symbol: n.BoldItalicNode.Symbol, Content: n.BoldItalicNode.Content}
|
||||
case *v1pb.Node_CodeNode:
|
||||
return &ast.Code{Content: n.CodeNode.Content}
|
||||
case *v1pb.Node_ImageNode:
|
||||
return &ast.Image{AltText: n.ImageNode.AltText, URL: n.ImageNode.Url}
|
||||
case *v1pb.Node_LinkNode:
|
||||
return &ast.Link{Content: convertToASTNodes(n.LinkNode.Content), URL: n.LinkNode.Url}
|
||||
case *v1pb.Node_AutoLinkNode:
|
||||
return &ast.AutoLink{URL: n.AutoLinkNode.Url, IsRawText: n.AutoLinkNode.IsRawText}
|
||||
case *v1pb.Node_TagNode:
|
||||
return &ast.Tag{Content: n.TagNode.Content}
|
||||
case *v1pb.Node_StrikethroughNode:
|
||||
return &ast.Strikethrough{Content: n.StrikethroughNode.Content}
|
||||
case *v1pb.Node_EscapingCharacterNode:
|
||||
return &ast.EscapingCharacter{Symbol: n.EscapingCharacterNode.Symbol}
|
||||
case *v1pb.Node_MathNode:
|
||||
return &ast.Math{Content: n.MathNode.Content}
|
||||
case *v1pb.Node_HighlightNode:
|
||||
return &ast.Highlight{Content: n.HighlightNode.Content}
|
||||
case *v1pb.Node_SubscriptNode:
|
||||
return &ast.Subscript{Content: n.SubscriptNode.Content}
|
||||
case *v1pb.Node_SuperscriptNode:
|
||||
return &ast.Superscript{Content: n.SuperscriptNode.Content}
|
||||
case *v1pb.Node_ReferencedContentNode:
|
||||
return &ast.ReferencedContent{ResourceName: n.ReferencedContentNode.ResourceName, Params: n.ReferencedContentNode.Params}
|
||||
case *v1pb.Node_SpoilerNode:
|
||||
return &ast.Spoiler{Content: n.SpoilerNode.Content}
|
||||
case *v1pb.Node_HtmlElementNode:
|
||||
return &ast.HTMLElement{TagName: n.HtmlElementNode.TagName, Attributes: n.HtmlElementNode.Attributes}
|
||||
default:
|
||||
return &ast.Text{}
|
||||
}
|
||||
}
|
||||
|
||||
func convertToASTNodes(nodes []*v1pb.Node) []ast.Node {
|
||||
rawNodes := []ast.Node{}
|
||||
for _, node := range nodes {
|
||||
rawNode := convertToASTNode(node)
|
||||
rawNodes = append(rawNodes, rawNode)
|
||||
}
|
||||
return rawNodes
|
||||
}
|
||||
|
||||
func convertTableToASTNode(node *v1pb.TableNode) *ast.Table {
|
||||
table := &ast.Table{
|
||||
Header: convertToASTNodes(node.Header),
|
||||
Delimiter: node.Delimiter,
|
||||
}
|
||||
for _, row := range node.Rows {
|
||||
table.Rows = append(table.Rows, convertToASTNodes(row.Cells))
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func convertListKindToASTNode(kind v1pb.ListNode_Kind) ast.ListKind {
|
||||
switch kind {
|
||||
case v1pb.ListNode_ORDERED:
|
||||
return ast.OrderedList
|
||||
case v1pb.ListNode_UNORDERED:
|
||||
return ast.UnorderedList
|
||||
case v1pb.ListNode_DESCRIPTION:
|
||||
return ast.DescrpitionList
|
||||
default:
|
||||
// Default to description list.
|
||||
return ast.DescrpitionList
|
||||
}
|
||||
}
|
||||
102
server/router/api/v1/memo_attachment_service.go
Normal file
102
server/router/api/v1/memo_attachment_service.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
// Delete attachments that are not in the request.
|
||||
for _, attachment := range attachments {
|
||||
found := false
|
||||
for _, requestAttachment := range request.Attachments {
|
||||
requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
if attachment.UID == requestAttachmentUID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if err = s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
|
||||
ID: int32(attachment.ID),
|
||||
MemoID: &memo.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(request.Attachments)
|
||||
// Update attachments' memo_id in the request.
|
||||
for index, attachment := range request.Attachments {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
updatedTs := time.Now().Unix() + int64(index)
|
||||
if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
ID: tempAttachment.ID,
|
||||
MemoID: &memo.ID,
|
||||
UpdatedTs: &updatedTs,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoAttachmentsResponse{
|
||||
Attachments: []*v1pb.Attachment{},
|
||||
}
|
||||
for _, attachment := range attachments {
|
||||
response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
427
server/router/api/v1/memo_export_import.go
Normal file
427
server/router/api/v1/memo_export_import.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/runner/memopayload"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// ExportFormat represents the format for export/import operations
|
||||
type ExportFormat string
|
||||
|
||||
const (
|
||||
FormatJSON ExportFormat = "json"
|
||||
)
|
||||
|
||||
// ExportData represents the structure of exported data
|
||||
type ExportData struct {
|
||||
Version string `json:"version"`
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Memos []ExportMemo `json:"memos"`
|
||||
}
|
||||
|
||||
// ExportMemo represents a memo in the export format
|
||||
type ExportMemo struct {
|
||||
UID string `json:"uid"`
|
||||
Content string `json:"content"`
|
||||
Visibility string `json:"visibility"`
|
||||
Pinned bool `json:"pinned"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DisplayTime *time.Time `json:"display_time,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Location *ExportLocation `json:"location,omitempty"`
|
||||
Attachments []ExportAttachment `json:"attachments,omitempty"`
|
||||
Relations []ExportMemoRelation `json:"relations,omitempty"`
|
||||
}
|
||||
|
||||
// ExportLocation represents location data in export format
|
||||
type ExportLocation struct {
|
||||
Placeholder string `json:"placeholder,omitempty"`
|
||||
Latitude float64 `json:"latitude,omitempty"`
|
||||
Longitude float64 `json:"longitude,omitempty"`
|
||||
}
|
||||
|
||||
// ExportAttachment represents attachment data in export format
|
||||
type ExportAttachment struct {
|
||||
UID string `json:"uid"`
|
||||
Filename string `json:"filename"`
|
||||
Type string `json:"type"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// ExportMemoRelation represents memo relations in export format
|
||||
type ExportMemoRelation struct {
|
||||
RelatedMemoUID string `json:"related_memo_uid"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// ExportMemos exports memos for the current user in JSON format
|
||||
func (s *APIV1Service) ExportMemos(ctx context.Context, request *v1pb.ExportMemosRequest) (*v1pb.ExportMemosResponse, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
|
||||
// Validate format (default to JSON)
|
||||
format := request.Format
|
||||
if format == "" {
|
||||
format = string(FormatJSON)
|
||||
}
|
||||
if format != string(FormatJSON) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "unsupported export format: %s", format)
|
||||
}
|
||||
|
||||
// Get all memos for the user
|
||||
memoFind := &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
ExcludeComments: true,
|
||||
}
|
||||
|
||||
// Apply filters if specified
|
||||
if request.Filter != "" {
|
||||
// Use existing filter validation from shortcut service
|
||||
memoFind.Filter = &request.Filter
|
||||
}
|
||||
|
||||
// Include archived memos if requested
|
||||
if request.ExcludeArchived {
|
||||
normalStatus := store.Normal
|
||||
memoFind.RowStatus = &normalStatus
|
||||
}
|
||||
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
// Convert memos to export format
|
||||
exportMemos := make([]ExportMemo, 0, len(memos))
|
||||
for _, memo := range memos {
|
||||
exportMemo, err := s.convertMemoToExport(ctx, memo, request.IncludeAttachments, request.IncludeRelations)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to convert memo to export format", slog.Any("memo_id", memo.ID), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
exportMemos = append(exportMemos, *exportMemo)
|
||||
}
|
||||
|
||||
// Create export data structure
|
||||
exportData := &ExportData{
|
||||
Version: "1.0",
|
||||
ExportedAt: time.Now(),
|
||||
Memos: exportMemos,
|
||||
}
|
||||
|
||||
// Serialize to JSON
|
||||
jsonData, err := json.MarshalIndent(exportData, "", " ")
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to marshal export data: %v", err)
|
||||
}
|
||||
|
||||
return &v1pb.ExportMemosResponse{
|
||||
Data: jsonData,
|
||||
Format: format,
|
||||
Filename: fmt.Sprintf("memos_export_%s.json", time.Now().Format("20060102_150405")),
|
||||
MemoCount: int32(len(exportMemos)),
|
||||
SizeBytes: int64(len(jsonData)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImportMemos imports memos from JSON data
|
||||
func (s *APIV1Service) ImportMemos(ctx context.Context, request *v1pb.ImportMemosRequest) (*v1pb.ImportMemosResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
|
||||
// Validate format (default to JSON)
|
||||
format := request.Format
|
||||
if format == "" {
|
||||
format = string(FormatJSON)
|
||||
}
|
||||
if format != string(FormatJSON) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "unsupported import format: %s", format)
|
||||
}
|
||||
|
||||
// Parse the JSON data
|
||||
var importData ExportData
|
||||
if err := json.Unmarshal(request.Data, &importData); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse import data: %v", err)
|
||||
}
|
||||
|
||||
// Validate import data version
|
||||
if importData.Version != "1.0" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "unsupported import data version: %s", importData.Version)
|
||||
}
|
||||
|
||||
var importedCount int32
|
||||
var skippedCount int32
|
||||
var createdCount int32
|
||||
var updatedCount int32
|
||||
var validationErrors int32
|
||||
var attachmentsImported int32
|
||||
var relationsImported int32
|
||||
var errors []string
|
||||
var warnings []string
|
||||
|
||||
// Import each memo
|
||||
for _, exportMemo := range importData.Memos {
|
||||
result, err := s.importSingleMemo(ctx, user.ID, &exportMemo, request)
|
||||
if err != nil {
|
||||
errorMsg := fmt.Sprintf("Failed to import memo %s: %v", exportMemo.UID, err)
|
||||
errors = append(errors, errorMsg)
|
||||
skippedCount++
|
||||
if request.ValidateOnly {
|
||||
validationErrors++
|
||||
}
|
||||
slog.Warn("Failed to import memo", slog.String("uid", exportMemo.UID), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
|
||||
importedCount++
|
||||
if result.Created {
|
||||
createdCount++
|
||||
} else {
|
||||
updatedCount++
|
||||
}
|
||||
attachmentsImported += result.AttachmentsImported
|
||||
relationsImported += result.RelationsImported
|
||||
|
||||
if len(result.Warnings) > 0 {
|
||||
warnings = append(warnings, result.Warnings...)
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
|
||||
summary := &v1pb.ImportSummary{
|
||||
TotalMemos: int32(len(importData.Memos)),
|
||||
CreatedCount: createdCount,
|
||||
UpdatedCount: updatedCount,
|
||||
AttachmentsImported: attachmentsImported,
|
||||
RelationsImported: relationsImported,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
}
|
||||
|
||||
return &v1pb.ImportMemosResponse{
|
||||
ImportedCount: importedCount,
|
||||
SkippedCount: skippedCount,
|
||||
ValidationErrors: validationErrors,
|
||||
Errors: errors,
|
||||
Warnings: warnings,
|
||||
Summary: summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMemoToExport converts a store memo to export format
|
||||
func (s *APIV1Service) convertMemoToExport(ctx context.Context, memo *store.Memo, includeAttachments, includeRelations bool) (*ExportMemo, error) {
|
||||
exportMemo := &ExportMemo{
|
||||
UID: memo.UID,
|
||||
Content: memo.Content,
|
||||
Visibility: memo.Visibility.String(),
|
||||
Pinned: memo.Pinned,
|
||||
CreatedAt: time.Unix(memo.CreatedTs, 0),
|
||||
UpdatedAt: time.Unix(memo.UpdatedTs, 0),
|
||||
}
|
||||
|
||||
// Extract tags from payload
|
||||
if memo.Payload != nil && len(memo.Payload.Tags) > 0 {
|
||||
exportMemo.Tags = memo.Payload.Tags
|
||||
}
|
||||
|
||||
// Add location if present
|
||||
if memo.Payload != nil && memo.Payload.Location != nil {
|
||||
exportMemo.Location = &ExportLocation{
|
||||
Placeholder: memo.Payload.Location.Placeholder,
|
||||
Latitude: memo.Payload.Location.Latitude,
|
||||
Longitude: memo.Payload.Location.Longitude,
|
||||
}
|
||||
}
|
||||
|
||||
// Add attachments if requested
|
||||
if includeAttachments {
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list attachments")
|
||||
}
|
||||
|
||||
for _, attachment := range attachments {
|
||||
exportMemo.Attachments = append(exportMemo.Attachments, ExportAttachment{
|
||||
UID: attachment.UID,
|
||||
Filename: attachment.Filename,
|
||||
Type: attachment.Type,
|
||||
Size: attachment.Size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add relations if requested
|
||||
if includeRelations {
|
||||
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{MemoID: &memo.ID})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo relations")
|
||||
}
|
||||
|
||||
for _, relation := range relations {
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &relation.RelatedMemoID})
|
||||
if err != nil || relatedMemo == nil {
|
||||
continue // Skip if related memo not found
|
||||
}
|
||||
|
||||
exportMemo.Relations = append(exportMemo.Relations, ExportMemoRelation{
|
||||
RelatedMemoUID: relatedMemo.UID,
|
||||
Type: string(relation.Type),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return exportMemo, nil
|
||||
}
|
||||
|
||||
// ImportResult represents the result of importing a single memo
|
||||
type ImportResult struct {
|
||||
Created bool
|
||||
AttachmentsImported int32
|
||||
RelationsImported int32
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
// importSingleMemo imports a single memo
|
||||
func (s *APIV1Service) importSingleMemo(ctx context.Context, userID int32, exportMemo *ExportMemo, request *v1pb.ImportMemosRequest) (*ImportResult, error) {
|
||||
result := &ImportResult{
|
||||
Warnings: []string{},
|
||||
}
|
||||
|
||||
// Check if memo with this UID already exists
|
||||
existingMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &exportMemo.UID})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to check for existing memo")
|
||||
}
|
||||
|
||||
if existingMemo != nil && !request.OverwriteExisting {
|
||||
return nil, fmt.Errorf("memo with UID %s already exists", exportMemo.UID)
|
||||
}
|
||||
|
||||
// Validate memo content length
|
||||
contentLengthLimit, err := s.getContentLengthLimit(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get content length limit")
|
||||
}
|
||||
if len(exportMemo.Content) > contentLengthLimit {
|
||||
return nil, fmt.Errorf("content too long (max %d characters)", contentLengthLimit)
|
||||
}
|
||||
|
||||
// Parse visibility
|
||||
visibility := store.Private
|
||||
switch exportMemo.Visibility {
|
||||
case "PUBLIC":
|
||||
visibility = store.Public
|
||||
case "PROTECTED":
|
||||
visibility = store.Protected
|
||||
case "PRIVATE":
|
||||
visibility = store.Private
|
||||
default:
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("Unknown visibility %s for memo %s, defaulting to PRIVATE", exportMemo.Visibility, exportMemo.UID))
|
||||
}
|
||||
|
||||
// Create memo payload
|
||||
payload := &storepb.MemoPayload{
|
||||
Tags: exportMemo.Tags,
|
||||
}
|
||||
|
||||
if exportMemo.Location != nil {
|
||||
payload.Location = &storepb.MemoPayload_Location{
|
||||
Placeholder: exportMemo.Location.Placeholder,
|
||||
Latitude: exportMemo.Location.Latitude,
|
||||
Longitude: exportMemo.Location.Longitude,
|
||||
}
|
||||
}
|
||||
|
||||
// Set timestamps
|
||||
createdTs := exportMemo.CreatedAt.Unix()
|
||||
updatedTs := exportMemo.UpdatedAt.Unix()
|
||||
if !request.PreserveTimestamps {
|
||||
now := time.Now().Unix()
|
||||
createdTs = now
|
||||
updatedTs = now
|
||||
}
|
||||
|
||||
if request.ValidateOnly {
|
||||
// Just validate, don't actually create/update
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if existingMemo != nil {
|
||||
// Update existing memo
|
||||
update := &store.UpdateMemo{
|
||||
ID: existingMemo.ID,
|
||||
Content: &exportMemo.Content,
|
||||
Visibility: &visibility,
|
||||
Pinned: &exportMemo.Pinned,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
if request.PreserveTimestamps {
|
||||
update.CreatedTs = &createdTs
|
||||
update.UpdatedTs = &updatedTs
|
||||
}
|
||||
|
||||
if err := s.Store.UpdateMemo(ctx, update); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update existing memo")
|
||||
}
|
||||
result.Created = false
|
||||
} else {
|
||||
// Create new memo
|
||||
create := &store.Memo{
|
||||
UID: exportMemo.UID,
|
||||
CreatorID: userID,
|
||||
CreatedTs: createdTs,
|
||||
UpdatedTs: updatedTs,
|
||||
Content: exportMemo.Content,
|
||||
Visibility: visibility,
|
||||
Pinned: exportMemo.Pinned,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
// Rebuild memo payload to extract tags and other metadata
|
||||
if err := memopayload.RebuildMemoPayload(create); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to rebuild memo payload")
|
||||
}
|
||||
|
||||
_, err := s.Store.CreateMemo(ctx, create)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create memo")
|
||||
}
|
||||
result.Created = true
|
||||
}
|
||||
|
||||
// Import attachments if not skipped
|
||||
if !request.SkipAttachments && len(exportMemo.Attachments) > 0 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("Attachments for memo %s were skipped (attachment import not yet implemented)", exportMemo.UID))
|
||||
// TODO: Implement attachment import
|
||||
// This would require handling file uploads and storage
|
||||
}
|
||||
|
||||
// Import relations if not skipped
|
||||
if !request.SkipRelations && len(exportMemo.Relations) > 0 {
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("Relations for memo %s were skipped (relation import not yet implemented)", exportMemo.UID))
|
||||
// TODO: Implement relation import
|
||||
// This would require resolving related memo UIDs and creating relations
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
170
server/router/api/v1/memo_relation_service.go
Normal file
170
server/router/api/v1/memo_relation_service.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
referenceType := store.MemoRelationReference
|
||||
// Delete all reference relations first.
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
Type: &referenceType,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
|
||||
}
|
||||
|
||||
for _, relation := range request.Relations {
|
||||
// Ignore reflexive relations.
|
||||
if request.Name == relation.RelatedMemo.Name {
|
||||
continue
|
||||
}
|
||||
// Ignore comment relations as there's no need to update a comment's relation.
|
||||
// Inserting/Deleting a comment is handled elsewhere.
|
||||
if relation.Type == v1pb.MemoRelation_COMMENT {
|
||||
continue
|
||||
}
|
||||
relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get related memo")
|
||||
}
|
||||
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: convertMemoRelationTypeToStore(relation.Type),
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
var memoFilter string
|
||||
if currentUser == nil {
|
||||
memoFilter = `visibility == "PUBLIC"`
|
||||
} else {
|
||||
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
}
|
||||
relationList := []*v1pb.MemoRelation{}
|
||||
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
MemoFilter: &memoFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, raw := range tempList {
|
||||
relation, err := s.convertMemoRelationFromStore(ctx, raw)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
|
||||
}
|
||||
relationList = append(relationList, relation)
|
||||
}
|
||||
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &memo.ID,
|
||||
MemoFilter: &memoFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, raw := range tempList {
|
||||
relation, err := s.convertMemoRelationFromStore(ctx, raw)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
|
||||
}
|
||||
relationList = append(relationList, relation)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoRelationsResponse{
|
||||
Relations: relationList,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertMemoRelationFromStore(ctx context.Context, memoRelation *store.MemoRelation) (*v1pb.MemoRelation, error) {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.MemoID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memoSnippet, err := getMemoContentSnippet(memo.Content)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo content snippet")
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.RelatedMemoID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relatedMemoSnippet, err := getMemoContentSnippet(relatedMemo.Content)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get related memo content snippet")
|
||||
}
|
||||
return &v1pb.MemoRelation{
|
||||
Memo: &v1pb.MemoRelation_Memo{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
Snippet: memoSnippet,
|
||||
},
|
||||
RelatedMemo: &v1pb.MemoRelation_Memo{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
|
||||
Snippet: relatedMemoSnippet,
|
||||
},
|
||||
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) v1pb.MemoRelation_Type {
|
||||
switch relationType {
|
||||
case store.MemoRelationReference:
|
||||
return v1pb.MemoRelation_REFERENCE
|
||||
case store.MemoRelationComment:
|
||||
return v1pb.MemoRelation_COMMENT
|
||||
default:
|
||||
return v1pb.MemoRelation_TYPE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertMemoRelationTypeToStore(relationType v1pb.MemoRelation_Type) store.MemoRelationType {
|
||||
switch relationType {
|
||||
case v1pb.MemoRelation_REFERENCE:
|
||||
return store.MemoRelationReference
|
||||
case v1pb.MemoRelation_COMMENT:
|
||||
return store.MemoRelationComment
|
||||
default:
|
||||
return store.MemoRelationReference
|
||||
}
|
||||
}
|
||||
785
server/router/api/v1/memo_service.go
Normal file
785
server/router/api/v1/memo_service.go
Normal file
@@ -0,0 +1,785 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/gomark/ast"
|
||||
"github.com/usememos/gomark/parser"
|
||||
"github.com/usememos/gomark/parser/tokenizer"
|
||||
"github.com/usememos/gomark/renderer"
|
||||
"github.com/usememos/gomark/restore"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/usememos/memos/plugin/webhook"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/runner/memopayload"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
|
||||
create := &store.Memo{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: user.ID,
|
||||
Content: request.Memo.Content,
|
||||
Visibility: convertVisibilityToStore(request.Memo.Visibility),
|
||||
}
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
|
||||
}
|
||||
contentLengthLimit, err := s.getContentLengthLimit(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
|
||||
}
|
||||
if len(create.Content) > contentLengthLimit {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
|
||||
}
|
||||
if err := memopayload.RebuildMemoPayload(create); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
|
||||
}
|
||||
if request.Memo.Location != nil {
|
||||
create.Payload.Location = convertLocationToStore(request.Memo.Location)
|
||||
}
|
||||
|
||||
memo, err := s.Store.CreateMemo(ctx, create)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(request.Memo.Attachments) > 0 {
|
||||
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
Attachments: request.Memo.Attachments,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo attachments")
|
||||
}
|
||||
}
|
||||
if len(request.Memo.Relations) > 0 {
|
||||
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
|
||||
Relations: request.Memo.Relations,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo relations")
|
||||
}
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
// Try to dispatch webhook when memo is created.
|
||||
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosRequest) (*v1pb.ListMemosResponse, error) {
|
||||
memoFind := &store.FindMemo{
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
}
|
||||
// Handle deprecated old_filter for backward compatibility
|
||||
if request.OldFilter != "" && request.Filter == "" {
|
||||
//nolint:staticcheck // SA1019: Using deprecated field for backward compatibility
|
||||
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.OldFilter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter: %v", err)
|
||||
}
|
||||
}
|
||||
if request.Parent != "" && request.Parent != "users/-" {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
|
||||
}
|
||||
memoFind.CreatorID = &userID
|
||||
memoFind.OrderByPinned = true
|
||||
}
|
||||
if request.State == v1pb.State_ARCHIVED {
|
||||
state := store.Archived
|
||||
memoFind.RowStatus = &state
|
||||
} else {
|
||||
state := store.Normal
|
||||
memoFind.RowStatus = &state
|
||||
}
|
||||
|
||||
// Parse order_by field (replaces the old sort and direction fields)
|
||||
if request.OrderBy != "" {
|
||||
if err := s.parseMemoOrderBy(request.OrderBy, memoFind); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid order_by: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Default ordering by display_time desc
|
||||
memoFind.OrderByTimeAsc = false
|
||||
}
|
||||
|
||||
if request.Filter != "" {
|
||||
if err := s.validateFilter(ctx, request.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
memoFind.Filter = &request.Filter
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if currentUser == nil {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
if memoFind.CreatorID == nil {
|
||||
internalFilter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
if memoFind.Filter != nil {
|
||||
filter := fmt.Sprintf("(%s) && (%s)", *memoFind.Filter, internalFilter)
|
||||
memoFind.Filter = &filter
|
||||
} else {
|
||||
memoFind.Filter = &internalFilter
|
||||
}
|
||||
} else if *memoFind.CreatorID != currentUser.ID {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
|
||||
}
|
||||
}
|
||||
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
memoFind.OrderByUpdatedTs = true
|
||||
}
|
||||
|
||||
var limit, offset int
|
||||
if request.PageToken != "" {
|
||||
var pageToken v1pb.PageToken
|
||||
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
|
||||
}
|
||||
limit = int(pageToken.Limit)
|
||||
offset = int(pageToken.Offset)
|
||||
} else {
|
||||
limit = int(request.PageSize)
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = DefaultPageSize
|
||||
}
|
||||
limitPlusOne := limit + 1
|
||||
memoFind.Limit = &limitPlusOne
|
||||
memoFind.Offset = &offset
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
memoMessages := []*v1pb.Memo{}
|
||||
nextPageToken := ""
|
||||
if len(memos) == limitPlusOne {
|
||||
memos = memos[:limit]
|
||||
nextPageToken, err = getPageToken(limit, offset+limit)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
|
||||
}
|
||||
}
|
||||
for _, memo := range memos {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
memoMessages = append(memoMessages, memoMessage)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemosResponse{
|
||||
Memos: memoMessages,
|
||||
NextPageToken: nextPageToken,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest) (*v1pb.Memo, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
UID: &memoUID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
if memo.Visibility != store.Public {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoRequest) (*v1pb.Memo, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Memo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
// Only the creator or admin can update the memo.
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
update := &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
}
|
||||
for _, path := range request.UpdateMask.Paths {
|
||||
if path == "content" {
|
||||
contentLengthLimit, err := s.getContentLengthLimit(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
|
||||
}
|
||||
if len(request.Memo.Content) > contentLengthLimit {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
|
||||
}
|
||||
memo.Content = request.Memo.Content
|
||||
if err := memopayload.RebuildMemoPayload(memo); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
|
||||
}
|
||||
update.Content = &memo.Content
|
||||
update.Payload = memo.Payload
|
||||
} else if path == "visibility" {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
visibility := convertVisibilityToStore(request.Memo.Visibility)
|
||||
if workspaceMemoRelatedSetting.DisallowPublicVisibility && visibility == store.Public {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
|
||||
}
|
||||
update.Visibility = &visibility
|
||||
} else if path == "pinned" {
|
||||
update.Pinned = &request.Memo.Pinned
|
||||
} else if path == "state" {
|
||||
rowStatus := convertStateToStore(request.Memo.State)
|
||||
update.RowStatus = &rowStatus
|
||||
} else if path == "create_time" {
|
||||
createdTs := request.Memo.CreateTime.AsTime().Unix()
|
||||
update.CreatedTs = &createdTs
|
||||
} else if path == "update_time" {
|
||||
updatedTs := time.Now().Unix()
|
||||
if request.Memo.UpdateTime != nil {
|
||||
updatedTs = request.Memo.UpdateTime.AsTime().Unix()
|
||||
}
|
||||
update.UpdatedTs = &updatedTs
|
||||
} else if path == "display_time" {
|
||||
displayTs := request.Memo.DisplayTime.AsTime().Unix()
|
||||
memoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if memoRelatedSetting.DisplayWithUpdateTime {
|
||||
update.UpdatedTs = &displayTs
|
||||
} else {
|
||||
update.CreatedTs = &displayTs
|
||||
}
|
||||
} else if path == "location" {
|
||||
payload := memo.Payload
|
||||
payload.Location = convertLocationToStore(request.Memo.Location)
|
||||
update.Payload = payload
|
||||
} else if path == "attachments" {
|
||||
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Attachments: request.Memo.Attachments,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo attachments")
|
||||
}
|
||||
} else if path == "relations" {
|
||||
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Relations: request.Memo.Relations,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo relations")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = s.Store.UpdateMemo(ctx, update); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update memo")
|
||||
}
|
||||
|
||||
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
// Try to dispatch webhook when memo is updated.
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoRequest) (*emptypb.Empty, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
UID: &memoUID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
// Only the creator or admin can update the memo.
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil {
|
||||
// Try to dispatch webhook when memo is deleted.
|
||||
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo")
|
||||
}
|
||||
|
||||
// Delete memo relation
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{MemoID: &memo.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo relations")
|
||||
}
|
||||
|
||||
// Delete related attachments.
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
for _, attachment := range attachments {
|
||||
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
}
|
||||
}
|
||||
|
||||
// Delete memo comments
|
||||
commentType := store.MemoRelationComment
|
||||
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo comments")
|
||||
}
|
||||
for _, relation := range relations {
|
||||
if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: relation.MemoID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo comment")
|
||||
}
|
||||
}
|
||||
|
||||
// Delete memo references
|
||||
referenceType := store.MemoRelationReference
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{RelatedMemoID: &memo.ID, Type: &referenceType}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo references")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.CreateMemoCommentRequest) (*v1pb.Memo, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
// Create the memo comment first.
|
||||
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{Memo: request.Comment})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create memo")
|
||||
}
|
||||
memoUID, err = ExtractMemoUIDFromName(memoComment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
// Build the relation between the comment memo and the original memo.
|
||||
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
|
||||
}
|
||||
creatorID, err := ExtractUserIDFromName(memoComment.Creator)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
if memoComment.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
|
||||
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: creatorID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{
|
||||
MemoComment: &storepb.ActivityMemoCommentPayload{
|
||||
MemoId: memo.ID,
|
||||
RelatedMemoId: relatedMemo.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create activity")
|
||||
}
|
||||
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: creatorID,
|
||||
ReceiverID: relatedMemo.CreatorID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
ActivityId: &activity.ID,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create inbox")
|
||||
}
|
||||
}
|
||||
|
||||
return memoComment, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListMemoCommentsRequest) (*v1pb.ListMemoCommentsResponse, error) {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
var memoFilter string
|
||||
if currentUser == nil {
|
||||
memoFilter = `visibility == "PUBLIC"`
|
||||
} else {
|
||||
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
}
|
||||
memoRelationComment := store.MemoRelationComment
|
||||
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &memo.ID,
|
||||
Type: &memoRelationComment,
|
||||
MemoFilter: &memoFilter,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
|
||||
}
|
||||
|
||||
var memos []*v1pb.Memo
|
||||
for _, memoRelation := range memoRelations {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &memoRelation.MemoID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if memo != nil {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
memos = append(memos, memoMessage)
|
||||
}
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoCommentsResponse{
|
||||
Memos: memos,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) RenameMemoTag(ctx context.Context, request *v1pb.RenameMemoTagRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
|
||||
memoFind := &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
PayloadFind: &store.FindMemoPayload{TagSearch: []string{request.OldTag}},
|
||||
ExcludeComments: true,
|
||||
}
|
||||
if (request.Parent) != "memos/-" {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memoFind.UID = &memoUID
|
||||
}
|
||||
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos")
|
||||
}
|
||||
|
||||
for _, memo := range memos {
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to parse memo: %v", err)
|
||||
}
|
||||
memopayload.TraverseASTNodes(nodes, func(node ast.Node) {
|
||||
if tag, ok := node.(*ast.Tag); ok && tag.Content == request.OldTag {
|
||||
tag.Content = request.NewTag
|
||||
}
|
||||
})
|
||||
memo.Content = restore.Restore(nodes)
|
||||
if err := memopayload.RebuildMemoPayload(memo); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
|
||||
}
|
||||
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Content: &memo.Content,
|
||||
Payload: memo.Payload,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update memo: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemoTag(ctx context.Context, request *v1pb.DeleteMemoTagRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
|
||||
memoFind := &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
PayloadFind: &store.FindMemoPayload{TagSearch: []string{request.Tag}},
|
||||
ExcludeContent: true,
|
||||
ExcludeComments: true,
|
||||
}
|
||||
if request.Parent != "memos/-" {
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memoFind.UID = &memoUID
|
||||
}
|
||||
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos")
|
||||
}
|
||||
|
||||
for _, memo := range memos {
|
||||
if request.DeleteRelatedMemos {
|
||||
err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo")
|
||||
}
|
||||
} else {
|
||||
archived := store.Archived
|
||||
err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
RowStatus: &archived,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update memo")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) getContentLengthLimit(ctx context.Context) (int, error) {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return 0, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
return int(workspaceMemoRelatedSetting.ContentLengthLimit), nil
|
||||
}
|
||||
|
||||
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
|
||||
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
|
||||
}
|
||||
|
||||
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
|
||||
func (s *APIV1Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
|
||||
}
|
||||
|
||||
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
|
||||
func (s *APIV1Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
|
||||
}
|
||||
|
||||
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, creatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, hook := range webhooks {
|
||||
payload, err := convertMemoToWebhookPayload(memo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to convert memo to webhook payload")
|
||||
}
|
||||
payload.ActivityType = activityType
|
||||
payload.URL = hook.Url
|
||||
|
||||
// Use asynchronous webhook dispatch
|
||||
webhook.PostAsync(payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid memo creator")
|
||||
}
|
||||
return &webhook.WebhookRequestPayload{
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creatorID),
|
||||
Memo: memo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getMemoContentSnippet(content string) (string, error) {
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(content))
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse content")
|
||||
}
|
||||
|
||||
plainText := renderer.NewStringRenderer().Render(nodes)
|
||||
if len(plainText) > 64 {
|
||||
return substring(plainText, 64) + "...", nil
|
||||
}
|
||||
return plainText, nil
|
||||
}
|
||||
|
||||
func substring(s string, length int) string {
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
runeCount := 0
|
||||
byteIndex := 0
|
||||
for byteIndex < len(s) {
|
||||
_, size := utf8.DecodeRuneInString(s[byteIndex:])
|
||||
byteIndex += size
|
||||
runeCount++
|
||||
if runeCount == length {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return s[:byteIndex]
|
||||
}
|
||||
|
||||
// parseMemoOrderBy parses the order_by field and sets the appropriate ordering in memoFind.
|
||||
func (*APIV1Service) parseMemoOrderBy(orderBy string, memoFind *store.FindMemo) error {
|
||||
// Parse order_by field like "display_time desc" or "create_time asc"
|
||||
parts := strings.Fields(strings.TrimSpace(orderBy))
|
||||
if len(parts) == 0 {
|
||||
return errors.New("empty order_by")
|
||||
}
|
||||
|
||||
field := parts[0]
|
||||
direction := "desc" // default
|
||||
if len(parts) > 1 {
|
||||
direction = strings.ToLower(parts[1])
|
||||
if direction != "asc" && direction != "desc" {
|
||||
return errors.Errorf("invalid order direction: %s, must be 'asc' or 'desc'", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
switch field {
|
||||
case "display_time":
|
||||
memoFind.OrderByTimeAsc = direction == "asc"
|
||||
case "create_time":
|
||||
memoFind.OrderByTimeAsc = direction == "asc"
|
||||
case "update_time":
|
||||
memoFind.OrderByUpdatedTs = true
|
||||
memoFind.OrderByTimeAsc = direction == "asc"
|
||||
case "name":
|
||||
// For ordering by memo name/id - not commonly used but supported
|
||||
memoFind.OrderByTimeAsc = direction == "asc"
|
||||
default:
|
||||
return errors.Errorf("unsupported order field: %s, supported fields are: display_time, create_time, update_time, name", field)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
149
server/router/api/v1/memo_service_converter.go
Normal file
149
server/router/api/v1/memo_service_converter.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/gomark/parser"
|
||||
"github.com/usememos/gomark/parser/tokenizer"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) {
|
||||
displayTs := memo.CreatedTs
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
memoMessage := &v1pb.Memo{
|
||||
Name: name,
|
||||
State: convertStateFromStore(memo.RowStatus),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, memo.CreatorID),
|
||||
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
|
||||
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
|
||||
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
|
||||
Content: memo.Content,
|
||||
Visibility: convertVisibilityFromStore(memo.Visibility),
|
||||
Pinned: memo.Pinned,
|
||||
}
|
||||
if memo.Payload != nil {
|
||||
memoMessage.Tags = memo.Payload.Tags
|
||||
memoMessage.Property = convertMemoPropertyFromStore(memo.Payload.Property)
|
||||
memoMessage.Location = convertLocationFromStore(memo.Payload.Location)
|
||||
}
|
||||
if memo.ParentID != nil {
|
||||
parent, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: memo.ParentID,
|
||||
ExcludeContent: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get parent memo")
|
||||
}
|
||||
parentName := fmt.Sprintf("%s%s", MemoNamePrefix, parent.UID)
|
||||
memoMessage.Parent = &parentName
|
||||
}
|
||||
|
||||
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo relations")
|
||||
}
|
||||
memoMessage.Relations = listMemoRelationsResponse.Relations
|
||||
|
||||
listMemoAttachmentsResponse, err := s.ListMemoAttachments(ctx, &v1pb.ListMemoAttachmentsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo attachments")
|
||||
}
|
||||
memoMessage.Attachments = listMemoAttachmentsResponse.Attachments
|
||||
|
||||
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo reactions")
|
||||
}
|
||||
memoMessage.Reactions = listMemoReactionsResponse.Reactions
|
||||
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse content")
|
||||
}
|
||||
memoMessage.Nodes = convertFromASTNodes(nodes)
|
||||
|
||||
snippet, err := getMemoContentSnippet(memo.Content)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo content snippet")
|
||||
}
|
||||
memoMessage.Snippet = snippet
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
|
||||
if property == nil {
|
||||
return nil
|
||||
}
|
||||
return &v1pb.Memo_Property{
|
||||
HasLink: property.HasLink,
|
||||
HasTaskList: property.HasTaskList,
|
||||
HasCode: property.HasCode,
|
||||
HasIncompleteTasks: property.HasIncompleteTasks,
|
||||
}
|
||||
}
|
||||
|
||||
func convertLocationFromStore(location *storepb.MemoPayload_Location) *v1pb.Location {
|
||||
if location == nil {
|
||||
return nil
|
||||
}
|
||||
return &v1pb.Location{
|
||||
Placeholder: location.Placeholder,
|
||||
Latitude: location.Latitude,
|
||||
Longitude: location.Longitude,
|
||||
}
|
||||
}
|
||||
|
||||
func convertLocationToStore(location *v1pb.Location) *storepb.MemoPayload_Location {
|
||||
if location == nil {
|
||||
return nil
|
||||
}
|
||||
return &storepb.MemoPayload_Location{
|
||||
Placeholder: location.Placeholder,
|
||||
Latitude: location.Latitude,
|
||||
Longitude: location.Longitude,
|
||||
}
|
||||
}
|
||||
|
||||
func convertVisibilityFromStore(visibility store.Visibility) v1pb.Visibility {
|
||||
switch visibility {
|
||||
case store.Private:
|
||||
return v1pb.Visibility_PRIVATE
|
||||
case store.Protected:
|
||||
return v1pb.Visibility_PROTECTED
|
||||
case store.Public:
|
||||
return v1pb.Visibility_PUBLIC
|
||||
default:
|
||||
return v1pb.Visibility_VISIBILITY_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertVisibilityToStore(visibility v1pb.Visibility) store.Visibility {
|
||||
switch visibility {
|
||||
case v1pb.Visibility_PRIVATE:
|
||||
return store.Private
|
||||
case v1pb.Visibility_PROTECTED:
|
||||
return store.Protected
|
||||
case v1pb.Visibility_PUBLIC:
|
||||
return store.Public
|
||||
default:
|
||||
return store.Private
|
||||
}
|
||||
}
|
||||
168
server/router/api/v1/memo_service_filter.go
Normal file
168
server/router/api/v1/memo_service_filter.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store.FindMemo, filter string) error {
|
||||
if find.PayloadFind == nil {
|
||||
find.PayloadFind = &store.FindMemoPayload{}
|
||||
}
|
||||
if filter != "" {
|
||||
filterExpr, err := parseMemoFilter(filter)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
if len(filterExpr.ContentSearch) > 0 {
|
||||
find.ContentSearch = filterExpr.ContentSearch
|
||||
}
|
||||
if filterExpr.TagSearch != nil {
|
||||
if find.PayloadFind == nil {
|
||||
find.PayloadFind = &store.FindMemoPayload{}
|
||||
}
|
||||
find.PayloadFind.TagSearch = filterExpr.TagSearch
|
||||
}
|
||||
if filterExpr.DisplayTimeAfter != nil {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
find.UpdatedTsAfter = filterExpr.DisplayTimeAfter
|
||||
} else {
|
||||
find.CreatedTsAfter = filterExpr.DisplayTimeAfter
|
||||
}
|
||||
}
|
||||
if filterExpr.DisplayTimeBefore != nil {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
find.UpdatedTsBefore = filterExpr.DisplayTimeBefore
|
||||
} else {
|
||||
find.CreatedTsBefore = filterExpr.DisplayTimeBefore
|
||||
}
|
||||
}
|
||||
if filterExpr.Pinned {
|
||||
pinned := true
|
||||
find.Pinned = &pinned
|
||||
}
|
||||
if filterExpr.HasLink {
|
||||
find.PayloadFind.HasLink = true
|
||||
}
|
||||
if filterExpr.HasTaskList {
|
||||
find.PayloadFind.HasTaskList = true
|
||||
}
|
||||
if filterExpr.HasCode {
|
||||
find.PayloadFind.HasCode = true
|
||||
}
|
||||
if filterExpr.HasIncompleteTasks {
|
||||
find.PayloadFind.HasIncompleteTasks = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MemoFilterCELAttributes are the CEL attributes.
|
||||
var MemoFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content_search", cel.ListType(cel.StringType)),
|
||||
cel.Variable("tag_search", cel.ListType(cel.StringType)),
|
||||
cel.Variable("display_time_before", cel.IntType),
|
||||
cel.Variable("display_time_after", cel.IntType),
|
||||
cel.Variable("pinned", cel.BoolType),
|
||||
cel.Variable("has_link", cel.BoolType),
|
||||
cel.Variable("has_task_list", cel.BoolType),
|
||||
cel.Variable("has_code", cel.BoolType),
|
||||
cel.Variable("has_incomplete_tasks", cel.BoolType),
|
||||
}
|
||||
|
||||
type MemoFilter struct {
|
||||
ContentSearch []string
|
||||
TagSearch []string
|
||||
DisplayTimeBefore *int64
|
||||
DisplayTimeAfter *int64
|
||||
Pinned bool
|
||||
HasLink bool
|
||||
HasTaskList bool
|
||||
HasCode bool
|
||||
HasIncompleteTasks bool
|
||||
}
|
||||
|
||||
func parseMemoFilter(expression string) (*MemoFilter, error) {
|
||||
e, err := cel.NewEnv(MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ast, issues := e.Compile(expression)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("found issue %v", issues)
|
||||
}
|
||||
filter := &MemoFilter{}
|
||||
parsedExpr, err := cel.AstToParsedExpr(ast)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
callExpr := parsedExpr.GetExpr().GetCallExpr()
|
||||
findMemoField(callExpr, filter)
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func findMemoField(callExpr *exprv1.Expr_Call, filter *MemoFilter) {
|
||||
if len(callExpr.Args) == 2 {
|
||||
idExpr := callExpr.Args[0].GetIdentExpr()
|
||||
if idExpr != nil {
|
||||
if idExpr.Name == "content_search" {
|
||||
contentSearch := []string{}
|
||||
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
|
||||
value := expr.GetConstExpr().GetStringValue()
|
||||
contentSearch = append(contentSearch, value)
|
||||
}
|
||||
filter.ContentSearch = contentSearch
|
||||
} else if idExpr.Name == "tag_search" {
|
||||
tagSearch := []string{}
|
||||
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
|
||||
value := expr.GetConstExpr().GetStringValue()
|
||||
tagSearch = append(tagSearch, value)
|
||||
}
|
||||
filter.TagSearch = tagSearch
|
||||
} else if idExpr.Name == "display_time_before" {
|
||||
displayTimeBefore := callExpr.Args[1].GetConstExpr().GetInt64Value()
|
||||
filter.DisplayTimeBefore = &displayTimeBefore
|
||||
} else if idExpr.Name == "display_time_after" {
|
||||
displayTimeAfter := callExpr.Args[1].GetConstExpr().GetInt64Value()
|
||||
filter.DisplayTimeAfter = &displayTimeAfter
|
||||
} else if idExpr.Name == "pinned" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.Pinned = value
|
||||
} else if idExpr.Name == "has_link" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.HasLink = value
|
||||
} else if idExpr.Name == "has_task_list" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.HasTaskList = value
|
||||
} else if idExpr.Name == "has_code" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.HasCode = value
|
||||
} else if idExpr.Name == "has_incomplete_tasks" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.HasIncompleteTasks = value
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, arg := range callExpr.Args {
|
||||
callExpr := arg.GetCallExpr()
|
||||
if callExpr != nil {
|
||||
findMemoField(callExpr, filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
90
server/router/api/v1/reaction_service.go
Normal file
90
server/router/api/v1/reaction_service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoReactionsResponse{
|
||||
Reactions: []*v1pb.Reaction{},
|
||||
}
|
||||
for _, reaction := range reactions {
|
||||
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||
}
|
||||
response.Reactions = append(response.Reactions, reactionMessage)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.UpsertMemoReactionRequest) (*v1pb.Reaction, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: request.Reaction.ContentId,
|
||||
ReactionType: request.Reaction.ReactionType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
|
||||
}
|
||||
|
||||
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||
}
|
||||
return reactionMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
|
||||
reactionID, err := ExtractReactionIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
|
||||
ID: reactionID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertReactionFromStore(ctx context.Context, reaction *store.Reaction) (*v1pb.Reaction, error) {
|
||||
creator, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &reaction.CreatorID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reactionUID := fmt.Sprintf("%d", reaction.ID)
|
||||
return &v1pb.Reaction{
|
||||
Name: fmt.Sprintf("%s%s", ReactionNamePrefix, reactionUID),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID),
|
||||
ContentId: reaction.ContentID,
|
||||
ReactionType: reaction.ReactionType,
|
||||
CreateTime: timestamppb.New(time.Unix(reaction.CreatedTs, 0)),
|
||||
}, nil
|
||||
}
|
||||
162
server/router/api/v1/resource_name.go
Normal file
162
server/router/api/v1/resource_name.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
WorkspaceSettingNamePrefix = "workspace/settings/"
|
||||
UserNamePrefix = "users/"
|
||||
MemoNamePrefix = "memos/"
|
||||
AttachmentNamePrefix = "attachments/"
|
||||
ReactionNamePrefix = "reactions/"
|
||||
InboxNamePrefix = "inboxes/"
|
||||
IdentityProviderNamePrefix = "identityProviders/"
|
||||
ActivityNamePrefix = "activities/"
|
||||
WebhookNamePrefix = "webhooks/"
|
||||
)
|
||||
|
||||
// GetNameParentTokens returns the tokens from a resource name.
|
||||
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 2*len(tokenPrefixes) {
|
||||
return nil, errors.Errorf("invalid request %q", name)
|
||||
}
|
||||
|
||||
var tokens []string
|
||||
for i, tokenPrefix := range tokenPrefixes {
|
||||
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
|
||||
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
|
||||
}
|
||||
if parts[2*i+1] == "" {
|
||||
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
|
||||
}
|
||||
tokens = append(tokens, parts[2*i+1])
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func ExtractWorkspaceSettingKeyFromName(name string) (string, error) {
|
||||
const prefix = "workspace/settings/"
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return "", errors.Errorf("invalid workspace setting name: expected prefix %q, got %q", prefix, name)
|
||||
}
|
||||
|
||||
settingKey := strings.TrimPrefix(name, prefix)
|
||||
if settingKey == "" {
|
||||
return "", errors.Errorf("invalid workspace setting name: empty setting key in %q", name)
|
||||
}
|
||||
|
||||
// Ensure there are no additional path segments
|
||||
if strings.Contains(settingKey, "/") {
|
||||
return "", errors.Errorf("invalid workspace setting name: setting key cannot contain '/' in %q", name)
|
||||
}
|
||||
|
||||
return settingKey, nil
|
||||
}
|
||||
|
||||
// ExtractUserIDFromName returns the uid from a resource name.
|
||||
func ExtractUserIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, UserNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid user ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractMemoUIDFromName returns the memo UID from a resource name.
|
||||
// e.g., "memos/uuid" -> "uuid".
|
||||
func ExtractMemoUIDFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, MemoNamePrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
id := tokens[0]
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractAttachmentUIDFromName returns the attachment UID from a resource name.
|
||||
func ExtractAttachmentUIDFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, AttachmentNamePrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
id := tokens[0]
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractReactionIDFromName returns the reaction ID from a resource name.
|
||||
// e.g., "reactions/123" -> 123.
|
||||
func ExtractReactionIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, ReactionNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid reaction ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractInboxIDFromName returns the inbox ID from a resource name.
|
||||
func ExtractInboxIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func ExtractActivityIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, ActivityNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid activity ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractWebhookIDFromName returns the webhook ID from a resource name.
|
||||
func ExtractWebhookIDFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, UserNamePrefix, WebhookNamePrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(tokens) != 2 {
|
||||
return "", errors.Errorf("invalid webhook name format: %q", name)
|
||||
}
|
||||
webhookID := tokens[1]
|
||||
if webhookID == "" {
|
||||
return "", errors.Errorf("invalid webhook ID %q", webhookID)
|
||||
}
|
||||
return webhookID, nil
|
||||
}
|
||||
337
server/router/api/v1/shortcut_service.go
Normal file
337
server/router/api/v1/shortcut_service.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Helper function to extract user ID and shortcut ID from shortcut resource name.
|
||||
// Format: users/{user}/shortcuts/{shortcut}.
|
||||
func extractUserAndShortcutIDFromName(name string) (int32, string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "shortcuts" {
|
||||
return 0, "", errors.Errorf("invalid shortcut name format: %s", name)
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(parts[1])
|
||||
if err != nil {
|
||||
return 0, "", errors.Errorf("invalid user ID %q", parts[1])
|
||||
}
|
||||
|
||||
shortcutID := parts[3]
|
||||
if shortcutID == "" {
|
||||
return 0, "", errors.Errorf("empty shortcut ID in name: %s", name)
|
||||
}
|
||||
|
||||
return userID, shortcutID, nil
|
||||
}
|
||||
|
||||
// Helper function to construct shortcut resource name.
|
||||
func constructShortcutName(userID int32, shortcutID string) string {
|
||||
return fmt.Sprintf("users/%d/shortcuts/%s", userID, shortcutID)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShortcutsRequest) (*v1pb.ListShortcutsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return &v1pb.ListShortcutsResponse{
|
||||
Shortcuts: []*v1pb.Shortcut{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := []*v1pb.Shortcut{}
|
||||
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
|
||||
shortcuts = append(shortcuts, &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, shortcut.GetId()),
|
||||
Title: shortcut.GetTitle(),
|
||||
Filter: shortcut.GetFilter(),
|
||||
})
|
||||
}
|
||||
|
||||
return &v1pb.ListShortcutsResponse{
|
||||
Shortcuts: shortcuts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
|
||||
if shortcut.GetId() == shortcutID {
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, shortcut.GetId()),
|
||||
Title: shortcut.GetTitle(),
|
||||
Filter: shortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
newShortcut := &storepb.ShortcutsUserSetting_Shortcut{
|
||||
Id: util.GenUUID(),
|
||||
Title: request.Shortcut.GetTitle(),
|
||||
Filter: request.Shortcut.GetFilter(),
|
||||
}
|
||||
if newShortcut.Title == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "title is required")
|
||||
}
|
||||
if err := s.validateFilter(ctx, newShortcut.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
if request.ValidateOnly {
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, newShortcut.GetId()),
|
||||
Title: newShortcut.GetTitle(),
|
||||
Filter: newShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
userSetting = &storepb.UserSetting{
|
||||
UserId: userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := shortcutsUserSetting.GetShortcuts()
|
||||
shortcuts = append(shortcuts, newShortcut)
|
||||
shortcutsUserSetting.Shortcuts = shortcuts
|
||||
|
||||
userSetting.Value = &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: shortcutsUserSetting,
|
||||
}
|
||||
|
||||
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, newShortcut.GetId()),
|
||||
Title: newShortcut.GetTitle(),
|
||||
Filter: newShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Shortcut.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := shortcutsUserSetting.GetShortcuts()
|
||||
var foundShortcut *storepb.ShortcutsUserSetting_Shortcut
|
||||
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
|
||||
for _, shortcut := range shortcuts {
|
||||
if shortcut.GetId() == shortcutID {
|
||||
foundShortcut = shortcut
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "title" {
|
||||
if request.Shortcut.GetTitle() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "title is required")
|
||||
}
|
||||
shortcut.Title = request.Shortcut.GetTitle()
|
||||
} else if field == "filter" {
|
||||
if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
shortcut.Filter = request.Shortcut.GetFilter()
|
||||
}
|
||||
}
|
||||
}
|
||||
newShortcuts = append(newShortcuts, shortcut)
|
||||
}
|
||||
|
||||
if foundShortcut == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting.Shortcuts = newShortcuts
|
||||
userSetting.Value = &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: shortcutsUserSetting,
|
||||
}
|
||||
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, foundShortcut.GetId()),
|
||||
Title: foundShortcut.GetTitle(),
|
||||
Filter: foundShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteShortcutRequest) (*emptypb.Empty, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
|
||||
shortcutsUserSetting := userSetting.GetShortcuts()
|
||||
shortcuts := shortcutsUserSetting.GetShortcuts()
|
||||
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
|
||||
found := false
|
||||
for _, shortcut := range shortcuts {
|
||||
if shortcut.GetId() != shortcutID {
|
||||
newShortcuts = append(newShortcuts, shortcut)
|
||||
} else {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, status.Errorf(codes.NotFound, "shortcut not found")
|
||||
}
|
||||
shortcutsUserSetting.Shortcuts = newShortcuts
|
||||
userSetting.Value = &storepb.UserSetting_Shortcuts{
|
||||
Shortcuts: shortcutsUserSetting,
|
||||
}
|
||||
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error {
|
||||
if filterStr == "" {
|
||||
return errors.New("filter cannot be empty")
|
||||
}
|
||||
// Validate the filter.
|
||||
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to parse filter")
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = s.Store.GetDriver().ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to convert filter to SQL")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
519
server/router/api/v1/test/idp_service_test.go
Normal file
519
server/router/api/v1/test/idp_service_test.go
Normal file
@@ -0,0 +1,519 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestCreateIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create OAuth2 identity provider
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test OAuth2 Provider",
|
||||
IdentifierFilter: "",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "https://example.com/oauth/authorize",
|
||||
TokenUrl: "https://example.com/oauth/token",
|
||||
UserInfoUrl: "https://example.com/oauth/userinfo",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
AvatarUrl: "avatar_url",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "Test OAuth2 Provider", resp.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.Contains(t, resp.Name, "identityProviders/")
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client-id", resp.Config.GetOauth2Config().ClientId)
|
||||
})
|
||||
|
||||
t.Run("CreateIdentityProvider permission denied for non-host user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("CreateIdentityProvider unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListIdentityProviders(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListIdentityProviders empty", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.ListIdentityProvidersRequest{}
|
||||
resp, err := ts.Service.ListIdentityProviders(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.IdentityProviders)
|
||||
})
|
||||
|
||||
t.Run("ListIdentityProviders with providers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a couple of identity providers
|
||||
createReq1 := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider 1",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client1",
|
||||
AuthUrl: "https://example1.com/auth",
|
||||
TokenUrl: "https://example1.com/token",
|
||||
UserInfoUrl: "https://example1.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
createReq2 := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider 2",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client2",
|
||||
AuthUrl: "https://example2.com/auth",
|
||||
TokenUrl: "https://example2.com/token",
|
||||
UserInfoUrl: "https://example2.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq1)
|
||||
require.NoError(t, err)
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List providers
|
||||
listReq := &v1pb.ListIdentityProvidersRequest{}
|
||||
resp, err := ts.Service.ListIdentityProviders(ctx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.IdentityProviders, 2)
|
||||
|
||||
// Verify response contains expected providers
|
||||
titles := []string{resp.IdentityProviders[0].Title, resp.IdentityProviders[1].Title}
|
||||
require.Contains(t, titles, "Provider 1")
|
||||
require.Contains(t, titles, "Provider 2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
AuthUrl: "https://example.com/auth",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/user",
|
||||
Scopes: []string{"openid", "profile"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get identity provider
|
||||
getReq := &v1pb.GetIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, created.Name, resp.Name)
|
||||
require.Equal(t, "Test Provider", resp.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret)
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.GetIdentityProviderRequest{
|
||||
Name: "identityProviders/999",
|
||||
}
|
||||
|
||||
_, err := ts.Service.GetIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.GetIdentityProviderRequest{
|
||||
Name: "invalid-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.GetIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Original Provider",
|
||||
IdentifierFilter: "",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "original-client",
|
||||
AuthUrl: "https://original.com/auth",
|
||||
TokenUrl: "https://original.com/token",
|
||||
UserInfoUrl: "https://original.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update identity provider
|
||||
updateReq := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: created.Name,
|
||||
Title: "Updated Provider",
|
||||
IdentifierFilter: "test@example.com",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "updated-client",
|
||||
ClientSecret: "updated-secret",
|
||||
AuthUrl: "https://updated.com/auth",
|
||||
TokenUrl: "https://updated.com/token",
|
||||
UserInfoUrl: "https://updated.com/user",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "given_name",
|
||||
Email: "email",
|
||||
AvatarUrl: "picture",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "identifier_filter", "config"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateIdentityProvider(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, "Updated Provider", updated.Title)
|
||||
require.Equal(t, "test@example.com", updated.IdentifierFilter)
|
||||
require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId)
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "identityProviders/1",
|
||||
Title: "Updated Provider",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update_mask is required")
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "invalid-name",
|
||||
Title: "Updated Provider",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteIdentityProvider success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Provider to Delete",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "client-to-delete",
|
||||
AuthUrl: "https://example.com/auth",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: "id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete identity provider
|
||||
deleteReq := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
getReq := &v1pb.GetIdentityProviderRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("DeleteIdentityProvider invalid name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "invalid-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.DeleteIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
|
||||
t.Run("DeleteIdentityProvider not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "identityProviders/999",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
// Note: Delete might succeed even if item doesn't exist, depending on store implementation
|
||||
})
|
||||
}
|
||||
|
||||
func TestIdentityProviderPermissions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Only host users can create identity providers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("Authentication required", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Test Provider",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
}
|
||||
559
server/router/api/v1/test/inbox_service_test.go
Normal file
559
server/router/api/v1/test/inbox_service_test.go
Normal file
@@ -0,0 +1,559 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestListInboxes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListInboxes success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List inboxes (should be empty initially)
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
resp, err := ts.Service.ListInboxes(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.Inboxes)
|
||||
require.Equal(t, int32(0), resp.TotalSize)
|
||||
})
|
||||
|
||||
t.Run("ListInboxes with pagination", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create some inbox entries
|
||||
const systemBotID int32 = 0
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List inboxes with page size limit
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
PageSize: 2,
|
||||
}
|
||||
|
||||
resp, err := ts.Service.ListInboxes(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, 2, len(resp.Inboxes))
|
||||
require.NotEmpty(t, resp.NextPageToken)
|
||||
})
|
||||
|
||||
t.Run("ListInboxes permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to list user2's inboxes
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListInboxes(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot access inboxes")
|
||||
})
|
||||
|
||||
t.Run("ListInboxes host can access other users' inboxes", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user and a regular user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "hostuser")
|
||||
require.NoError(t, err)
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an inbox for the regular user
|
||||
const systemBotID int32 = 0
|
||||
_, err = ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: regularUser.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set host user context and try to list regular user's inboxes
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: fmt.Sprintf("users/%d", regularUser.ID),
|
||||
}
|
||||
|
||||
resp, err := ts.Service.ListInboxes(hostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, 1, len(resp.Inboxes))
|
||||
})
|
||||
|
||||
t.Run("ListInboxes invalid parent format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: "invalid-parent-format",
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListInboxes(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid parent name")
|
||||
})
|
||||
|
||||
t.Run("ListInboxes unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: "users/1",
|
||||
}
|
||||
|
||||
_, err := ts.Service.ListInboxes(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "user not authenticated")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateInbox(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateInbox success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an inbox entry
|
||||
const systemBotID int32 = 0
|
||||
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Update inbox status
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
Status: v1pb.Inbox_ARCHIVED,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"status"},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.UpdateInbox(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, v1pb.Inbox_ARCHIVED, resp.Status)
|
||||
})
|
||||
|
||||
t.Run("UpdateInbox permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an inbox entry for user2
|
||||
const systemBotID int32 = 0
|
||||
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user2.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to update user2's inbox
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
Status: v1pb.Inbox_ARCHIVED,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"status"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot update inbox")
|
||||
})
|
||||
|
||||
t.Run("UpdateInbox missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
Name: "inboxes/1",
|
||||
Status: v1pb.Inbox_ARCHIVED,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update mask is required")
|
||||
})
|
||||
|
||||
t.Run("UpdateInbox invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
Name: "invalid-inbox-name",
|
||||
Status: v1pb.Inbox_ARCHIVED,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"status"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid inbox name")
|
||||
})
|
||||
|
||||
t.Run("UpdateInbox not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
Name: "inboxes/99999", // Non-existent inbox
|
||||
Status: v1pb.Inbox_ARCHIVED,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"status"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
st, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, codes.NotFound, st.Code())
|
||||
})
|
||||
|
||||
t.Run("UpdateInbox unsupported field", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an inbox entry
|
||||
const systemBotID int32 = 0
|
||||
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
Status: v1pb.Inbox_ARCHIVED,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"unsupported_field"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "unsupported field")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteInbox(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteInbox success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an inbox entry
|
||||
const systemBotID int32 = 0
|
||||
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Delete inbox
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteInbox(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify inbox is deleted
|
||||
inboxes, err := ts.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(inboxes))
|
||||
})
|
||||
|
||||
t.Run("DeleteInbox permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an inbox entry for user2
|
||||
const systemBotID int32 = 0
|
||||
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user2.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to delete user2's inbox
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot delete inbox")
|
||||
})
|
||||
|
||||
t.Run("DeleteInbox invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
Name: "invalid-inbox-name",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid inbox name")
|
||||
})
|
||||
|
||||
t.Run("DeleteInbox not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
Name: "inboxes/99999", // Non-existent inbox
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteInbox(userCtx, req)
|
||||
require.Error(t, err)
|
||||
st, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, codes.NotFound, st.Code())
|
||||
})
|
||||
}
|
||||
|
||||
func TestInboxCRUDComplete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an inbox entry directly in store
|
||||
const systemBotID int32 = 0
|
||||
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: systemBotID,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// 1. List inboxes - should have 1
|
||||
listReq := &v1pb.ListInboxesRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
listResp, err := ts.Service.ListInboxes(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(listResp.Inboxes))
|
||||
require.Equal(t, v1pb.Inbox_UNREAD, listResp.Inboxes[0].Status)
|
||||
|
||||
// 2. Update inbox status to ARCHIVED
|
||||
updateReq := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
Status: v1pb.Inbox_ARCHIVED,
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"status"},
|
||||
},
|
||||
}
|
||||
updateResp, err := ts.Service.UpdateInbox(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, v1pb.Inbox_ARCHIVED, updateResp.Status)
|
||||
|
||||
// 3. List inboxes again - should still have 1 but ARCHIVED
|
||||
listResp, err = ts.Service.ListInboxes(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(listResp.Inboxes))
|
||||
require.Equal(t, v1pb.Inbox_ARCHIVED, listResp.Inboxes[0].Status)
|
||||
|
||||
// 4. Delete inbox
|
||||
deleteReq := &v1pb.DeleteInboxRequest{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
}
|
||||
_, err = ts.Service.DeleteInbox(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 5. List inboxes - should be empty
|
||||
listResp, err = ts.Service.ListInboxes(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(listResp.Inboxes))
|
||||
require.Equal(t, int32(0), listResp.TotalSize)
|
||||
})
|
||||
}
|
||||
819
server/router/api/v1/test/shortcut_service_test.go
Normal file
819
server/router/api/v1/test/shortcut_service_test.go
Normal file
@@ -0,0 +1,819 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestListShortcuts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListShortcuts success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List shortcuts (should be empty initially)
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
resp, err := ts.Service.ListShortcuts(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.Shortcuts)
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to list user2's shortcuts
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts invalid parent format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "invalid-parent-format",
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts unauthenticated", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "users/1",
|
||||
}
|
||||
|
||||
_, err := ts.Service.ListShortcuts(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// First create a shortcut
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now get the shortcut
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
resp, err := ts.Service.GetShortcut(userCtx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, created.Name, resp.Name)
|
||||
require.Equal(t, "Test Shortcut", resp.Title)
|
||||
require.Equal(t, "tag in [\"test\"]", resp.Filter)
|
||||
})
|
||||
|
||||
t.Run("GetShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(user2Ctx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("GetShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: "invalid-shortcut-name",
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("GetShortcut not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "My Shortcut",
|
||||
Filter: "tag in [\"important\"]",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.CreateShortcut(userCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "My Shortcut", resp.Title)
|
||||
require.Equal(t, "tag in [\"important\"]", resp.Filter)
|
||||
require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID))
|
||||
|
||||
// Verify the shortcut was created by listing
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listResp.Shortcuts, 1)
|
||||
require.Equal(t, "My Shortcut", listResp.Shortcuts[0].Title)
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to create shortcut for user2
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Forbidden Shortcut",
|
||||
Filter: "tag in [\"forbidden\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut invalid parent format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: "invalid-parent",
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut invalid filter", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Shortcut",
|
||||
Filter: "invalid||filter))syntax",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid filter")
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut missing title", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "title is required")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Original Title",
|
||||
Filter: "tag in [\"original\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the shortcut
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Title: "Updated Title",
|
||||
Filter: "tag in [\"updated\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, "Updated Title", updated.Title)
|
||||
require.Equal(t, "tag in [\"updated\"]", updated.Filter)
|
||||
require.Equal(t, created.Name, updated.Name)
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to update shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Title: "Hacked Title",
|
||||
Filter: "tag in [\"hacked\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(user2Ctx, updateReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user and context for authentication
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID),
|
||||
Title: "Updated Title",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update mask is required")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: "invalid-shortcut-name",
|
||||
Title: "Updated Title",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateShortcut(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("UpdateShortcut invalid filter", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to update with invalid filter
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
Filter: "invalid||filter))syntax",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"filter"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid filter")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteShortcut(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("DeleteShortcut success", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Shortcut to Delete",
|
||||
Filter: "tag in [\"delete\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the shortcut
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion by listing shortcuts
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, listResp.Shortcuts)
|
||||
|
||||
// Also verify by trying to get the deleted shortcut
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, getReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut permission denied for different user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create two users
|
||||
user1, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
user2, err := ts.CreateRegularUser(ctx, "user2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to delete shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(user2Ctx, deleteReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut invalid name format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: "invalid-shortcut-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.DeleteShortcut(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid shortcut name")
|
||||
})
|
||||
|
||||
t.Run("DeleteShortcut not found", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestShortcutFiltering(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CreateShortcut with valid filters", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various valid filter formats
|
||||
validFilters := []string{
|
||||
"tag in [\"work\"]",
|
||||
"content.contains(\"meeting\")",
|
||||
"tag in [\"work\"] && content.contains(\"meeting\")",
|
||||
"tag in [\"work\"] || tag in [\"personal\"]",
|
||||
"creator_id == 1",
|
||||
"visibility == \"PUBLIC\"",
|
||||
"has_task_list == true",
|
||||
"has_task_list == false",
|
||||
}
|
||||
|
||||
for i, filter := range validFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Valid Filter " + string(rune(i)),
|
||||
Filter: filter,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.NoError(t, err, "Filter should be valid: %s", filter)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreateShortcut with invalid filters", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various invalid filter formats
|
||||
invalidFilters := []string{
|
||||
"tag in ", // incomplete expression
|
||||
"invalid_field @in [\"value\"]", // unknown field
|
||||
"tag in [\"work\"] &&", // incomplete expression
|
||||
"tag in [\"work\"] || || tag in [\"test\"]", // double operator
|
||||
"((tag in [\"work\"]", // unmatched parentheses
|
||||
"tag in [\"work\"] && )", // mismatched parentheses
|
||||
"tag == \"work\"", // wrong operator (== not supported for tags)
|
||||
"tag in work", // missing brackets
|
||||
}
|
||||
|
||||
for _, filter := range invalidFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Test",
|
||||
Filter: filter,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateShortcut(userCtx, req)
|
||||
require.Error(t, err, "Filter should be invalid: %s", filter)
|
||||
require.Contains(t, err.Error(), "invalid filter", "Error should mention invalid filter for: %s", filter)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShortcutCRUDComplete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// 1. Create multiple shortcuts
|
||||
shortcut1Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Work Notes",
|
||||
Filter: "tag in [\"work\"]",
|
||||
},
|
||||
}
|
||||
|
||||
shortcut2Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Personal Notes",
|
||||
Filter: "tag in [\"personal\"]",
|
||||
},
|
||||
}
|
||||
|
||||
created1, err := ts.Service.CreateShortcut(userCtx, shortcut1Req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work Notes", created1.Title)
|
||||
|
||||
created2, err := ts.Service.CreateShortcut(userCtx, shortcut2Req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Personal Notes", created2.Title)
|
||||
|
||||
// 2. List shortcuts and verify both exist
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listResp.Shortcuts, 2)
|
||||
|
||||
// 3. Get individual shortcuts
|
||||
getReq1 := &v1pb.GetShortcutRequest{Name: created1.Name}
|
||||
getResp1, err := ts.Service.GetShortcut(userCtx, getReq1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created1.Name, getResp1.Name)
|
||||
require.Equal(t, "Work Notes", getResp1.Title)
|
||||
|
||||
getReq2 := &v1pb.GetShortcutRequest{Name: created2.Name}
|
||||
getResp2, err := ts.Service.GetShortcut(userCtx, getReq2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created2.Name, getResp2.Name)
|
||||
require.Equal(t, "Personal Notes", getResp2.Title)
|
||||
|
||||
// 4. Update one shortcut
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created1.Name,
|
||||
Title: "Work & Meeting Notes",
|
||||
Filter: "tag in [\"work\"] || tag in [\"meeting\"]",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"title", "filter"},
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work & Meeting Notes", updated.Title)
|
||||
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", updated.Filter)
|
||||
|
||||
// 5. Verify update by getting it again
|
||||
getUpdatedReq := &v1pb.GetShortcutRequest{Name: created1.Name}
|
||||
getUpdatedResp, err := ts.Service.GetShortcut(userCtx, getUpdatedReq)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Work & Meeting Notes", getUpdatedResp.Title)
|
||||
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", getUpdatedResp.Filter)
|
||||
|
||||
// 6. Delete one shortcut
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created2.Name,
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 7. Verify deletion by listing (should only have 1 left)
|
||||
finalListResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, finalListResp.Shortcuts, 1)
|
||||
require.Equal(t, "Work & Meeting Notes", finalListResp.Shortcuts[0].Title)
|
||||
|
||||
// 8. Verify deleted shortcut can't be accessed
|
||||
getDeletedReq := &v1pb.GetShortcutRequest{Name: created2.Name}
|
||||
_, err = ts.Service.GetShortcut(userCtx, getDeletedReq)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
81
server/router/api/v1/test/test_helper.go
Normal file
81
server/router/api/v1/test/test_helper.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
teststore "github.com/usememos/memos/store/test"
|
||||
)
|
||||
|
||||
// TestService holds the test service setup for API v1 services.
|
||||
type TestService struct {
|
||||
Service *apiv1.APIV1Service
|
||||
Store *store.Store
|
||||
Profile *profile.Profile
|
||||
Secret string
|
||||
}
|
||||
|
||||
// NewTestService creates a new test service with SQLite database.
|
||||
func NewTestService(t *testing.T) *TestService {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a test store with SQLite
|
||||
testStore := teststore.NewTestingStore(ctx, t)
|
||||
|
||||
// Create a test profile
|
||||
testProfile := &profile.Profile{
|
||||
Mode: "dev",
|
||||
Version: "test-1.0.0",
|
||||
InstanceURL: "http://localhost:8080",
|
||||
Driver: "sqlite",
|
||||
DSN: ":memory:",
|
||||
}
|
||||
|
||||
// Create APIV1Service with nil grpcServer since we're testing direct calls
|
||||
secret := "test-secret"
|
||||
service := &apiv1.APIV1Service{
|
||||
Secret: secret,
|
||||
Profile: testProfile,
|
||||
Store: testStore,
|
||||
}
|
||||
|
||||
return &TestService{
|
||||
Service: service,
|
||||
Store: testStore,
|
||||
Profile: testProfile,
|
||||
Secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup clears caches and closes resources after test.
|
||||
func (ts *TestService) Cleanup() {
|
||||
ts.Store.Close()
|
||||
// Note: Owner cache is package-level in parent package, cannot clear from test package
|
||||
}
|
||||
|
||||
// CreateHostUser creates a host user for testing.
|
||||
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
|
||||
return ts.Store.CreateUser(ctx, &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleHost,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRegularUser creates a regular user for testing.
|
||||
func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (*store.User, error) {
|
||||
return ts.Store.CreateUser(ctx, &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleUser,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateUserContext creates a context with the given user's ID for authentication.
|
||||
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
|
||||
// Use the real context key from the parent package
|
||||
return apiv1.CreateTestUserContext(ctx, userID)
|
||||
}
|
||||
105
server/router/api/v1/test/user_service_stats_test.go
Normal file
105
server/router/api/v1/test/user_service_stats_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestGetUserStats_TagCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test host user
|
||||
user, err := ts.CreateHostUser(ctx, "test_user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user context for authentication
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a memo with a single tag
|
||||
memo, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "This is a test memo with #test tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"test"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
// Test GetUserStats
|
||||
userName := fmt.Sprintf("users/%d", user.ID)
|
||||
response, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Check that the tag count is exactly 1, not 2
|
||||
require.Contains(t, response.TagCount, "test")
|
||||
require.Equal(t, int32(1), response.TagCount["test"], "Tag count should be 1 for a single occurrence")
|
||||
|
||||
// Create another memo with the same tag
|
||||
memo2, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "Another memo with #test tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"test"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo2)
|
||||
|
||||
// Test GetUserStats again
|
||||
response2, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response2)
|
||||
|
||||
// Check that the tag count is exactly 2, not 3
|
||||
require.Contains(t, response2.TagCount, "test")
|
||||
require.Equal(t, int32(2), response2.TagCount["test"], "Tag count should be 2 for two occurrences")
|
||||
|
||||
// Test with a new unique tag
|
||||
memo3, err := ts.Store.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-3",
|
||||
CreatorID: user.ID,
|
||||
Content: "Memo with #unique tag",
|
||||
Visibility: store.Public,
|
||||
Payload: &storepb.MemoPayload{
|
||||
Tags: []string{"unique"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo3)
|
||||
|
||||
// Test GetUserStats for the new tag
|
||||
response3, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response3)
|
||||
|
||||
// Check that the unique tag count is exactly 1
|
||||
require.Contains(t, response3.TagCount, "unique")
|
||||
require.Equal(t, int32(1), response3.TagCount["unique"], "New tag count should be 1 for first occurrence")
|
||||
|
||||
// The original test tag should still be 2
|
||||
require.Contains(t, response3.TagCount, "test")
|
||||
require.Equal(t, int32(2), response3.TagCount["test"], "Original tag count should remain 2")
|
||||
}
|
||||
406
server/router/api/v1/test/webhook_service_test.go
Normal file
406
server/router/api/v1/test/webhook_service_test.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestCreateWebhook(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("CreateWebhook with host user", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create and authenticate as host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a webhook
|
||||
req := &v1pb.CreateWebhookRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Test Webhook",
|
||||
Url: "https://example.com/webhook",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ts.Service.CreateWebhook(userCtx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "Test Webhook", resp.DisplayName)
|
||||
require.Equal(t, "https://example.com/webhook", resp.Url)
|
||||
require.Contains(t, resp.Name, "webhooks/")
|
||||
require.Contains(t, resp.Name, fmt.Sprintf("users/%d", hostUser.ID))
|
||||
})
|
||||
|
||||
t.Run("CreateWebhook fails without authentication", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
// Try to create webhook without authentication
|
||||
req := &v1pb.CreateWebhookRequest{
|
||||
Parent: "users/1", // Dummy parent since we don't have a real user
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Test Webhook",
|
||||
Url: "https://example.com/webhook",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.CreateWebhook(ctx, req)
|
||||
|
||||
// Should fail with permission denied or unauthenticated
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateWebhook fails with regular user", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create and authenticate as regular user
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
// Try to create webhook as regular user
|
||||
req := &v1pb.CreateWebhookRequest{
|
||||
Parent: fmt.Sprintf("users/%d", regularUser.ID),
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Test Webhook",
|
||||
Url: "https://example.com/webhook",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateWebhook(userCtx, req)
|
||||
|
||||
// Should fail with permission denied
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("CreateWebhook validates required fields", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create and authenticate as host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
// Try to create webhook with missing URL
|
||||
req := &v1pb.CreateWebhookRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Test Webhook",
|
||||
// URL missing
|
||||
},
|
||||
}
|
||||
|
||||
_, err = ts.Service.CreateWebhook(userCtx, req)
|
||||
|
||||
// Should fail with validation error
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListWebhooks(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ListWebhooks returns empty list initially", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user for authentication
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
// List webhooks
|
||||
req := &v1pb.ListWebhooksRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
}
|
||||
resp, err := ts.Service.ListWebhooks(userCtx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Empty(t, resp.Webhooks)
|
||||
})
|
||||
|
||||
t.Run("ListWebhooks returns created webhooks", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Test Webhook",
|
||||
Url: "https://example.com/webhook",
|
||||
},
|
||||
}
|
||||
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List webhooks
|
||||
listReq := &v1pb.ListWebhooksRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
}
|
||||
resp, err := ts.Service.ListWebhooks(userCtx, listReq)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Webhooks, 1)
|
||||
require.Equal(t, createdWebhook.Name, resp.Webhooks[0].Name)
|
||||
require.Equal(t, createdWebhook.Url, resp.Webhooks[0].Url)
|
||||
})
|
||||
|
||||
t.Run("ListWebhooks fails without authentication", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
// Try to list webhooks without authentication
|
||||
req := &v1pb.ListWebhooksRequest{
|
||||
Parent: "users/1", // Dummy parent since we don't have a real user
|
||||
}
|
||||
_, err := ts.Service.ListWebhooks(ctx, req)
|
||||
|
||||
// Should fail with permission denied or unauthenticated
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetWebhook(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetWebhook returns webhook by name", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Test Webhook",
|
||||
Url: "https://example.com/webhook",
|
||||
},
|
||||
}
|
||||
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the webhook
|
||||
getReq := &v1pb.GetWebhookRequest{
|
||||
Name: createdWebhook.Name,
|
||||
}
|
||||
resp, err := ts.Service.GetWebhook(userCtx, getReq)
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, createdWebhook.Name, resp.Name)
|
||||
require.Equal(t, createdWebhook.Url, resp.Url)
|
||||
})
|
||||
|
||||
t.Run("GetWebhook fails with invalid name", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Try to get webhook with invalid name
|
||||
req := &v1pb.GetWebhookRequest{
|
||||
Name: "invalid/webhook/name",
|
||||
}
|
||||
_, err = ts.Service.GetWebhook(userCtx, req)
|
||||
|
||||
// Should return an error
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("GetWebhook fails with non-existent webhook", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
// Try to get non-existent webhook
|
||||
req := &v1pb.GetWebhookRequest{
|
||||
Name: fmt.Sprintf("users/%d/webhooks/999", hostUser.ID),
|
||||
}
|
||||
_, err = ts.Service.GetWebhook(userCtx, req)
|
||||
|
||||
// Should return not found error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateWebhook(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateWebhook updates webhook properties", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Original Webhook",
|
||||
Url: "https://example.com/webhook",
|
||||
},
|
||||
}
|
||||
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the webhook
|
||||
updateReq := &v1pb.UpdateWebhookRequest{
|
||||
Webhook: &v1pb.Webhook{
|
||||
Name: createdWebhook.Name,
|
||||
Url: "https://updated.example.com/webhook",
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"url"},
|
||||
},
|
||||
}
|
||||
resp, err := ts.Service.UpdateWebhook(userCtx, updateReq)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, createdWebhook.Name, resp.Name)
|
||||
require.Equal(t, "https://updated.example.com/webhook", resp.Url)
|
||||
})
|
||||
|
||||
t.Run("UpdateWebhook fails without authentication", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
// Try to update webhook without authentication
|
||||
req := &v1pb.UpdateWebhookRequest{
|
||||
Webhook: &v1pb.Webhook{
|
||||
Name: "users/1/webhooks/1",
|
||||
Url: "https://updated.example.com/webhook",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateWebhook(ctx, req)
|
||||
|
||||
// Should fail with permission denied or unauthenticated
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteWebhook(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("DeleteWebhook removes webhook", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
Parent: fmt.Sprintf("users/%d", hostUser.ID),
|
||||
Webhook: &v1pb.Webhook{
|
||||
DisplayName: "Test Webhook",
|
||||
Url: "https://example.com/webhook",
|
||||
},
|
||||
}
|
||||
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the webhook
|
||||
deleteReq := &v1pb.DeleteWebhookRequest{
|
||||
Name: createdWebhook.Name,
|
||||
}
|
||||
_, err = ts.Service.DeleteWebhook(userCtx, deleteReq)
|
||||
|
||||
// Verify deletion
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get the deleted webhook
|
||||
getReq := &v1pb.GetWebhookRequest{
|
||||
Name: createdWebhook.Name,
|
||||
}
|
||||
_, err = ts.Service.GetWebhook(userCtx, getReq)
|
||||
|
||||
// Should return not found error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("DeleteWebhook fails without authentication", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
// Try to delete webhook without authentication
|
||||
req := &v1pb.DeleteWebhookRequest{
|
||||
Name: "users/1/webhooks/1",
|
||||
}
|
||||
|
||||
_, err := ts.Service.DeleteWebhook(ctx, req)
|
||||
|
||||
// Should fail with permission denied or unauthenticated
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("DeleteWebhook fails with non-existent webhook", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
// Try to delete non-existent webhook
|
||||
req := &v1pb.DeleteWebhookRequest{
|
||||
Name: fmt.Sprintf("users/%d/webhooks/999", hostUser.ID),
|
||||
}
|
||||
_, err = ts.Service.DeleteWebhook(userCtx, req)
|
||||
|
||||
// Should return not found error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
206
server/router/api/v1/test/workspace_service_test.go
Normal file
206
server/router/api/v1/test/workspace_service_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestGetWorkspaceProfile(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetWorkspaceProfile returns workspace profile", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetWorkspaceProfile directly
|
||||
req := &v1pb.GetWorkspaceProfileRequest{}
|
||||
resp, err := ts.Service.GetWorkspaceProfile(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify the response contains expected data
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.Equal(t, "dev", resp.Mode)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// Owner should be empty since no users are created
|
||||
require.Empty(t, resp.Owner)
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceProfile with owner", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user in the store
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, hostUser)
|
||||
|
||||
// Call GetWorkspaceProfile directly
|
||||
req := &v1pb.GetWorkspaceProfileRequest{}
|
||||
resp, err := ts.Service.GetWorkspaceProfile(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify the response contains expected data including owner
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.Equal(t, "dev", resp.Mode)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// User name should be "users/{id}" format where id is the user's ID
|
||||
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
|
||||
require.Equal(t, expectedOwnerName, resp.Owner)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetWorkspaceProfile_Concurrency(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Concurrent access to service", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
|
||||
|
||||
// Make concurrent requests
|
||||
numGoroutines := 10
|
||||
results := make(chan *v1pb.WorkspaceProfile, numGoroutines)
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
req := &v1pb.GetWorkspaceProfileRequest{}
|
||||
resp, err := ts.Service.GetWorkspaceProfile(ctx, req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
results <- resp
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect all results
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
select {
|
||||
case err := <-errors:
|
||||
t.Fatalf("Goroutine returned error: %v", err)
|
||||
case resp := <-results:
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.Equal(t, "dev", resp.Mode)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
require.Equal(t, expectedOwnerName, resp.Owner)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetWorkspaceSetting(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetWorkspaceSetting - general setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetWorkspaceSetting for general setting
|
||||
req := &v1pb.GetWorkspaceSettingRequest{
|
||||
Name: "workspace/settings/GENERAL",
|
||||
}
|
||||
resp, err := ts.Service.GetWorkspaceSetting(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "workspace/settings/GENERAL", resp.Name)
|
||||
|
||||
// The general setting should have a general_setting field
|
||||
generalSetting := resp.GetGeneralSetting()
|
||||
require.NotNil(t, generalSetting)
|
||||
|
||||
// General setting should have default values
|
||||
require.False(t, generalSetting.DisallowUserRegistration)
|
||||
require.False(t, generalSetting.DisallowPasswordAuth)
|
||||
require.Empty(t, generalSetting.AdditionalScript)
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceSetting - storage setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user for storage setting access
|
||||
hostUser, err := ts.CreateHostUser(ctx, "testhost")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add user to context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Call GetWorkspaceSetting for storage setting
|
||||
req := &v1pb.GetWorkspaceSettingRequest{
|
||||
Name: "workspace/settings/STORAGE",
|
||||
}
|
||||
resp, err := ts.Service.GetWorkspaceSetting(userCtx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "workspace/settings/STORAGE", resp.Name)
|
||||
|
||||
// The storage setting should have a storage_setting field
|
||||
storageSetting := resp.GetStorageSetting()
|
||||
require.NotNil(t, storageSetting)
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceSetting - memo related setting", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetWorkspaceSetting for memo related setting
|
||||
req := &v1pb.GetWorkspaceSettingRequest{
|
||||
Name: "workspace/settings/MEMO_RELATED",
|
||||
}
|
||||
resp, err := ts.Service.GetWorkspaceSetting(ctx, req)
|
||||
|
||||
// Verify response
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "workspace/settings/MEMO_RELATED", resp.Name)
|
||||
|
||||
// The memo related setting should have a memo_related_setting field
|
||||
memoRelatedSetting := resp.GetMemoRelatedSetting()
|
||||
require.NotNil(t, memoRelatedSetting)
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceSetting - invalid setting name", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Call GetWorkspaceSetting with invalid name
|
||||
req := &v1pb.GetWorkspaceSettingRequest{
|
||||
Name: "invalid/setting/name",
|
||||
}
|
||||
_, err := ts.Service.GetWorkspaceSetting(ctx, req)
|
||||
|
||||
// Should return an error
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid workspace setting name")
|
||||
})
|
||||
}
|
||||
19
server/router/api/v1/test_auth.go
Normal file
19
server/router/api/v1/test_auth.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// CreateTestUserContext creates a context with user's ID for testing purposes.
|
||||
// This function is only intended for use in tests.
|
||||
func CreateTestUserContext(ctx context.Context, userID int32) context.Context {
|
||||
return context.WithValue(ctx, userIDContextKey, userID)
|
||||
}
|
||||
|
||||
// CreateTestUserContextWithUser creates a context and ensures the user exists for testing.
|
||||
// This function is only intended for use in tests.
|
||||
func CreateTestUserContextWithUser(ctx context.Context, _ *APIV1Service, user *store.User) context.Context {
|
||||
return context.WithValue(ctx, userIDContextKey, user.ID)
|
||||
}
|
||||
831
server/router/api/v1/user_service.go
Normal file
831
server/router/api/v1/user_service.go
Normal file
@@ -0,0 +1,831 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/genproto/googleapis/api/httpbody"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/base"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListUsers(ctx context.Context, _ *v1pb.ListUsersRequest) (*v1pb.ListUsersResponse, error) {
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
users, err := s.Store.ListUsers(ctx, &store.FindUser{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
|
||||
}
|
||||
|
||||
// TODO: Implement proper filtering, ordering, and pagination
|
||||
// For now, return all users with basic structure
|
||||
response := &v1pb.ListUsersResponse{
|
||||
Users: []*v1pb.User{},
|
||||
TotalSize: int32(len(users)),
|
||||
}
|
||||
for _, user := range users {
|
||||
response.Users = append(response.Users, convertUserFromStore(user))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUser(ctx context.Context, request *v1pb.GetUserRequest) (*v1pb.User, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
userPb := convertUserFromStore(user)
|
||||
|
||||
// TODO: Implement read_mask field filtering
|
||||
// For now, return all fields
|
||||
|
||||
return userPb, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SearchUsers(ctx context.Context, request *v1pb.SearchUsersRequest) (*v1pb.SearchUsersResponse, error) {
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Search users by username, email, or display name
|
||||
users, err := s.Store.ListUsers(ctx, &store.FindUser{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
|
||||
}
|
||||
|
||||
var filteredUsers []*store.User
|
||||
query := strings.ToLower(request.Query)
|
||||
for _, user := range users {
|
||||
if strings.Contains(strings.ToLower(user.Username), query) ||
|
||||
strings.Contains(strings.ToLower(user.Email), query) ||
|
||||
strings.Contains(strings.ToLower(user.Nickname), query) {
|
||||
filteredUsers = append(filteredUsers, user)
|
||||
}
|
||||
}
|
||||
|
||||
response := &v1pb.SearchUsersResponse{
|
||||
Users: []*v1pb.User{},
|
||||
TotalSize: int32(len(filteredUsers)),
|
||||
}
|
||||
for _, user := range filteredUsers {
|
||||
response.Users = append(response.Users, convertUserFromStore(user))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserAvatar(ctx context.Context, request *v1pb.GetUserAvatarRequest) (*httpbody.HttpBody, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
if user.AvatarURL == "" {
|
||||
return nil, status.Errorf(codes.NotFound, "avatar not found")
|
||||
}
|
||||
|
||||
imageType, base64Data, err := extractImageInfo(user.AvatarURL)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to extract image info: %v", err)
|
||||
}
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to decode string: %v", err)
|
||||
}
|
||||
httpBody := &httpbody.HttpBody{
|
||||
ContentType: imageType,
|
||||
Data: imageData,
|
||||
}
|
||||
return httpBody, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserRequest) (*v1pb.User, error) {
|
||||
// Check if there are any existing host users (for first-time setup detection)
|
||||
hostUserType := store.RoleHost
|
||||
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
|
||||
Role: &hostUserType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list host users: %v", err)
|
||||
}
|
||||
|
||||
// Determine the role to assign and check permissions
|
||||
var roleToAssign store.Role
|
||||
if len(existedHostUsers) == 0 {
|
||||
// First-time setup: create the first user as HOST (no authentication required)
|
||||
roleToAssign = store.RoleHost
|
||||
} else {
|
||||
// Regular user creation: allow unauthenticated creation of normal users
|
||||
// But if authenticated, check if user has HOST permission for any role
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err == nil && currentUser != nil && currentUser.Role == store.RoleHost {
|
||||
// Authenticated HOST user can create users with any role specified in request
|
||||
if request.User.Role != v1pb.User_ROLE_UNSPECIFIED {
|
||||
roleToAssign = convertUserRoleToStore(request.User.Role)
|
||||
} else {
|
||||
roleToAssign = store.RoleUser
|
||||
}
|
||||
} else {
|
||||
// Unauthenticated or non-HOST users can only create normal users
|
||||
roleToAssign = store.RoleUser
|
||||
}
|
||||
}
|
||||
|
||||
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
|
||||
}
|
||||
|
||||
// If validate_only is true, just validate without creating
|
||||
if request.ValidateOnly {
|
||||
// Perform validation checks without actually creating the user
|
||||
return &v1pb.User{
|
||||
Username: request.User.Username,
|
||||
Email: request.User.Email,
|
||||
DisplayName: request.User.DisplayName,
|
||||
Role: convertUserRoleFromStore(roleToAssign),
|
||||
}, nil
|
||||
}
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
|
||||
}
|
||||
|
||||
user, err := s.Store.CreateUser(ctx, &store.User{
|
||||
Username: request.User.Username,
|
||||
Role: roleToAssign,
|
||||
Email: request.User.Email,
|
||||
Nickname: request.User.DisplayName,
|
||||
PasswordHash: string(passwordHash),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
|
||||
}
|
||||
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserRequest) (*v1pb.User, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
|
||||
}
|
||||
userID, err := ExtractUserIDFromName(request.User.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
// Check permission.
|
||||
// Only allow admin or self to update user.
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
// Handle allow_missing field
|
||||
if request.AllowMissing {
|
||||
// Could create user if missing, but for now return not found
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
UpdatedTs: ¤tTs,
|
||||
}
|
||||
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting: %v", err)
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
switch field {
|
||||
case "username":
|
||||
if workspaceGeneralSetting.DisallowChangeUsername {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied: disallow change username")
|
||||
}
|
||||
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
|
||||
}
|
||||
update.Username = &request.User.Username
|
||||
case "display_name":
|
||||
if workspaceGeneralSetting.DisallowChangeNickname {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied: disallow change nickname")
|
||||
}
|
||||
update.Nickname = &request.User.DisplayName
|
||||
case "email":
|
||||
update.Email = &request.User.Email
|
||||
case "avatar_url":
|
||||
update.AvatarURL = &request.User.AvatarUrl
|
||||
case "description":
|
||||
update.Description = &request.User.Description
|
||||
case "role":
|
||||
// Only allow admin to update role.
|
||||
if currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
role := convertUserRoleToStore(request.User.Role)
|
||||
update.Role = &role
|
||||
case "password":
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
|
||||
}
|
||||
passwordHashStr := string(passwordHash)
|
||||
update.PasswordHash = &passwordHashStr
|
||||
case "state":
|
||||
rowStatus := convertStateToStore(request.User.State)
|
||||
update.RowStatus = &rowStatus
|
||||
default:
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
updatedUser, err := s.Store.UpdateUser(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
|
||||
}
|
||||
|
||||
return convertUserFromStore(updatedUser), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserRequest) (*emptypb.Empty, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
|
||||
ID: user.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func getDefaultUserSetting() *v1pb.UserSetting {
|
||||
return &v1pb.UserSetting{
|
||||
Name: "", // Will be set by caller
|
||||
Locale: "en",
|
||||
Appearance: "system",
|
||||
MemoVisibility: "PRIVATE",
|
||||
Theme: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserSetting(ctx context.Context, request *v1pb.GetUserSettingRequest) (*v1pb.UserSetting, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// Only allow user to get their own settings
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSettings, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err)
|
||||
}
|
||||
|
||||
userSettingMessage := getDefaultUserSetting()
|
||||
userSettingMessage.Name = fmt.Sprintf("users/%d", userID)
|
||||
|
||||
for _, setting := range userSettings {
|
||||
if setting.Key == storepb.UserSetting_GENERAL {
|
||||
general := setting.GetGeneral()
|
||||
if general != nil {
|
||||
userSettingMessage.Locale = general.Locale
|
||||
userSettingMessage.Appearance = general.Appearance
|
||||
userSettingMessage.MemoVisibility = general.MemoVisibility
|
||||
userSettingMessage.Theme = general.Theme
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Backfill theme if empty: use workspace theme or default to "default"
|
||||
if userSettingMessage.Theme == "" {
|
||||
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting: %v", err)
|
||||
}
|
||||
workspaceTheme := workspaceGeneralSetting.Theme
|
||||
if workspaceTheme == "" {
|
||||
workspaceTheme = "default"
|
||||
}
|
||||
userSettingMessage.Theme = workspaceTheme
|
||||
}
|
||||
|
||||
return userSettingMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateUserSetting(ctx context.Context, request *v1pb.UpdateUserSettingRequest) (*v1pb.UserSetting, error) {
|
||||
// Extract user ID from the setting resource name
|
||||
userID, err := ExtractUserIDFromName(request.Setting.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// Only allow user to update their own settings
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
|
||||
}
|
||||
|
||||
// Get the current general setting
|
||||
existingGeneralSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get existing general setting: %v", err)
|
||||
}
|
||||
|
||||
// Create or update the general setting
|
||||
generalSetting := &storepb.GeneralUserSetting{
|
||||
Locale: "en",
|
||||
Appearance: "system",
|
||||
MemoVisibility: "PRIVATE",
|
||||
Theme: "",
|
||||
}
|
||||
|
||||
// If there's an existing setting, use its values as defaults
|
||||
if existingGeneralSetting != nil && existingGeneralSetting.GetGeneral() != nil {
|
||||
existing := existingGeneralSetting.GetGeneral()
|
||||
generalSetting.Locale = existing.Locale
|
||||
generalSetting.Appearance = existing.Appearance
|
||||
generalSetting.MemoVisibility = existing.MemoVisibility
|
||||
generalSetting.Theme = existing.Theme
|
||||
}
|
||||
|
||||
// Apply updates based on the update mask
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
switch field {
|
||||
case "locale":
|
||||
generalSetting.Locale = request.Setting.Locale
|
||||
case "appearance":
|
||||
generalSetting.Appearance = request.Setting.Appearance
|
||||
case "memo_visibility":
|
||||
generalSetting.MemoVisibility = request.Setting.MemoVisibility
|
||||
case "theme":
|
||||
generalSetting.Theme = request.Setting.Theme
|
||||
default:
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert the general setting
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: userID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: &storepb.UserSetting_General{
|
||||
General: generalSetting,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
|
||||
return s.GetUserSetting(ctx, &v1pb.GetUserSettingRequest{Name: request.Setting.Name})
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListUserAccessTokens(ctx context.Context, request *v1pb.ListUserAccessTokensRequest) (*v1pb.ListUserAccessTokensResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
|
||||
}
|
||||
|
||||
accessTokens := []*v1pb.UserAccessToken{}
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
claims := &ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(s.Secret), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
// If the access token is invalid or expired, just ignore it.
|
||||
continue
|
||||
}
|
||||
|
||||
accessTokenResponse := &v1pb.UserAccessToken{
|
||||
Name: fmt.Sprintf("users/%d/accessTokens/%s", userID, userAccessToken.AccessToken),
|
||||
AccessToken: userAccessToken.AccessToken,
|
||||
Description: userAccessToken.Description,
|
||||
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
accessTokenResponse.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
|
||||
}
|
||||
accessTokens = append(accessTokens, accessTokenResponse)
|
||||
}
|
||||
|
||||
// Sort by issued time in descending order.
|
||||
slices.SortFunc(accessTokens, func(i, j *v1pb.UserAccessToken) int {
|
||||
return int(i.IssuedAt.Seconds - j.IssuedAt.Seconds)
|
||||
})
|
||||
response := &v1pb.ListUserAccessTokensResponse{
|
||||
AccessTokens: accessTokens,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateUserAccessToken(ctx context.Context, request *v1pb.CreateUserAccessTokenRequest) (*v1pb.UserAccessToken, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if request.AccessToken.ExpiresAt != nil {
|
||||
expiresAt = request.AccessToken.ExpiresAt.AsTime()
|
||||
}
|
||||
|
||||
accessToken, err := GenerateAccessToken(currentUser.Username, currentUser.ID, expiresAt, []byte(s.Secret))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
|
||||
}
|
||||
|
||||
claims := &ClaimsMessage{}
|
||||
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(s.Secret), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
|
||||
}
|
||||
|
||||
// Upsert the access token to user setting store.
|
||||
if err := s.UpsertAccessTokenToStore(ctx, currentUser, accessToken, request.AccessToken.Description); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
|
||||
}
|
||||
|
||||
userAccessToken := &v1pb.UserAccessToken{
|
||||
Name: fmt.Sprintf("users/%d/accessTokens/%s", userID, accessToken),
|
||||
AccessToken: accessToken,
|
||||
Description: request.AccessToken.Description,
|
||||
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
|
||||
}
|
||||
return userAccessToken, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteUserAccessToken(ctx context.Context, request *v1pb.DeleteUserAccessTokenRequest) (*emptypb.Empty, error) {
|
||||
// Extract user ID from the access token resource name
|
||||
// Format: users/{user}/accessTokens/{access_token}
|
||||
parts := strings.Split(request.Name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "accessTokens" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid access token name format: %s", request.Name)
|
||||
}
|
||||
|
||||
userID, err := ExtractUserIDFromName(fmt.Sprintf("users/%s", parts[1]))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
accessTokenToDelete := parts[3]
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
|
||||
}
|
||||
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
if userAccessToken.AccessToken == accessTokenToDelete {
|
||||
continue
|
||||
}
|
||||
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
|
||||
}
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: currentUser.ID,
|
||||
Key: storepb.UserSetting_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: &storepb.AccessTokensUserSetting{
|
||||
AccessTokens: updatedUserAccessTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListUserSessions(ctx context.Context, request *v1pb.ListUserSessionsRequest) (*v1pb.ListUserSessionsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userSessions, err := s.Store.GetUserSessions(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list sessions: %v", err)
|
||||
}
|
||||
|
||||
sessions := []*v1pb.UserSession{}
|
||||
for _, userSession := range userSessions {
|
||||
sessionResponse := &v1pb.UserSession{
|
||||
Name: fmt.Sprintf("users/%d/sessions/%s", userID, userSession.SessionId),
|
||||
SessionId: userSession.SessionId,
|
||||
CreateTime: userSession.CreateTime,
|
||||
LastAccessedTime: userSession.LastAccessedTime,
|
||||
}
|
||||
|
||||
if userSession.ClientInfo != nil {
|
||||
sessionResponse.ClientInfo = &v1pb.UserSession_ClientInfo{
|
||||
UserAgent: userSession.ClientInfo.UserAgent,
|
||||
IpAddress: userSession.ClientInfo.IpAddress,
|
||||
DeviceType: userSession.ClientInfo.DeviceType,
|
||||
Os: userSession.ClientInfo.Os,
|
||||
Browser: userSession.ClientInfo.Browser,
|
||||
}
|
||||
}
|
||||
|
||||
sessions = append(sessions, sessionResponse)
|
||||
}
|
||||
|
||||
// Sort by last accessed time in descending order.
|
||||
slices.SortFunc(sessions, func(i, j *v1pb.UserSession) int {
|
||||
return int(j.LastAccessedTime.Seconds - i.LastAccessedTime.Seconds)
|
||||
})
|
||||
|
||||
response := &v1pb.ListUserSessionsResponse{
|
||||
Sessions: sessions,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) RevokeUserSession(ctx context.Context, request *v1pb.RevokeUserSessionRequest) (*emptypb.Empty, error) {
|
||||
// Extract user ID and session ID from the session resource name
|
||||
// Format: users/{user}/sessions/{session}
|
||||
parts := strings.Split(request.Name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "sessions" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid session name format: %s", request.Name)
|
||||
}
|
||||
|
||||
userID, err := ExtractUserIDFromName(fmt.Sprintf("users/%s", parts[1]))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
sessionIDToRevoke := parts[3]
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if err := s.Store.RemoveUserSession(ctx, userID, sessionIDToRevoke); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to revoke session: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// UpsertUserSession adds or updates a user session.
|
||||
func (s *APIV1Service) UpsertUserSession(ctx context.Context, userID int32, sessionID string, clientInfo *storepb.SessionsUserSetting_ClientInfo) error {
|
||||
session := &storepb.SessionsUserSetting_Session{
|
||||
SessionId: sessionID,
|
||||
CreateTime: timestamppb.Now(),
|
||||
LastAccessedTime: timestamppb.Now(),
|
||||
ClientInfo: clientInfo,
|
||||
}
|
||||
|
||||
return s.Store.AddUserSession(ctx, userID, session)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get user access tokens")
|
||||
}
|
||||
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
|
||||
AccessToken: accessToken,
|
||||
Description: description,
|
||||
}
|
||||
userAccessTokens = append(userAccessTokens, &userAccessToken)
|
||||
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: &storepb.AccessTokensUserSetting{
|
||||
AccessTokens: userAccessTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert user setting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertUserFromStore(user *store.User) *v1pb.User {
|
||||
userpb := &v1pb.User{
|
||||
Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
|
||||
State: convertStateFromStore(user.RowStatus),
|
||||
CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
|
||||
UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
|
||||
Role: convertUserRoleFromStore(user.Role),
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
DisplayName: user.Nickname,
|
||||
AvatarUrl: user.AvatarURL,
|
||||
Description: user.Description,
|
||||
}
|
||||
// Use the avatar URL instead of raw base64 image data to reduce the response size.
|
||||
if user.AvatarURL != "" {
|
||||
// Check if avatar url is base64 format.
|
||||
_, _, err := extractImageInfo(user.AvatarURL)
|
||||
if err == nil {
|
||||
userpb.AvatarUrl = fmt.Sprintf("/api/v1/%s/avatar", userpb.Name)
|
||||
} else {
|
||||
userpb.AvatarUrl = user.AvatarURL
|
||||
}
|
||||
}
|
||||
return userpb
|
||||
}
|
||||
|
||||
func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
|
||||
switch role {
|
||||
case store.RoleHost:
|
||||
return v1pb.User_HOST
|
||||
case store.RoleAdmin:
|
||||
return v1pb.User_ADMIN
|
||||
case store.RoleUser:
|
||||
return v1pb.User_USER
|
||||
default:
|
||||
return v1pb.User_ROLE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertUserRoleToStore(role v1pb.User_Role) store.Role {
|
||||
switch role {
|
||||
case v1pb.User_HOST:
|
||||
return store.RoleHost
|
||||
case v1pb.User_ADMIN:
|
||||
return store.RoleAdmin
|
||||
case v1pb.User_USER:
|
||||
return store.RoleUser
|
||||
default:
|
||||
return store.RoleUser
|
||||
}
|
||||
}
|
||||
|
||||
func extractImageInfo(dataURI string) (string, string, error) {
|
||||
dataURIRegex := regexp.MustCompile(`^data:(?P<type>.+);base64,(?P<base64>.+)`)
|
||||
matches := dataURIRegex.FindStringSubmatch(dataURI)
|
||||
if len(matches) != 3 {
|
||||
return "", "", errors.New("Invalid data URI format")
|
||||
}
|
||||
imageType := matches[1]
|
||||
base64Data := matches[2]
|
||||
return imageType, base64Data, nil
|
||||
}
|
||||
168
server/router/api/v1/user_service_stats.go
Normal file
168
server/router/api/v1/user_service_stats.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUserStatsRequest) (*v1pb.ListAllUserStatsResponse, error) {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
|
||||
}
|
||||
|
||||
normalStatus := store.Normal
|
||||
memoFind := &store.FindMemo{
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
ExcludeContent: true,
|
||||
RowStatus: &normalStatus,
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
if memoFind.CreatorID == nil {
|
||||
internalFilter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
|
||||
if memoFind.Filter != nil {
|
||||
filter := fmt.Sprintf("(%s) && (%s)", *memoFind.Filter, internalFilter)
|
||||
memoFind.Filter = &filter
|
||||
} else {
|
||||
memoFind.Filter = &internalFilter
|
||||
}
|
||||
} else if *memoFind.CreatorID != currentUser.ID {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
|
||||
}
|
||||
}
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
userMemoStatMap := make(map[int32]*v1pb.UserStats)
|
||||
for _, memo := range memos {
|
||||
displayTs := memo.CreatedTs
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{
|
||||
Name: fmt.Sprintf("users/%d/stats", memo.CreatorID),
|
||||
}
|
||||
userMemoStatMap[memo.CreatorID].MemoDisplayTimestamps = append(userMemoStatMap[memo.CreatorID].MemoDisplayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
|
||||
}
|
||||
|
||||
userMemoStats := []*v1pb.UserStats{}
|
||||
for _, userMemoStat := range userMemoStatMap {
|
||||
userMemoStats = append(userMemoStats, userMemoStat)
|
||||
}
|
||||
|
||||
response := &v1pb.ListAllUserStatsResponse{
|
||||
UserStats: userMemoStats,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserStatsRequest) (*v1pb.UserStats, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
|
||||
normalStatus := store.Normal
|
||||
memoFind := &store.FindMemo{
|
||||
CreatorID: &userID,
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
ExcludeContent: true,
|
||||
RowStatus: &normalStatus,
|
||||
}
|
||||
|
||||
if currentUser == nil {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public}
|
||||
} else if currentUser.ID != userID {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
|
||||
}
|
||||
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
|
||||
}
|
||||
|
||||
displayTimestamps := []*timestamppb.Timestamp{}
|
||||
tagCount := make(map[string]int32)
|
||||
linkCount := int32(0)
|
||||
codeCount := int32(0)
|
||||
todoCount := int32(0)
|
||||
undoCount := int32(0)
|
||||
pinnedMemos := []string{}
|
||||
|
||||
for _, memo := range memos {
|
||||
displayTs := memo.CreatedTs
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
displayTimestamps = append(displayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
|
||||
// Count different memo types based on content.
|
||||
if memo.Payload != nil {
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
tagCount[tag]++
|
||||
}
|
||||
if memo.Payload.Property != nil {
|
||||
if memo.Payload.Property.HasLink {
|
||||
linkCount++
|
||||
}
|
||||
if memo.Payload.Property.HasCode {
|
||||
codeCount++
|
||||
}
|
||||
if memo.Payload.Property.HasTaskList {
|
||||
todoCount++
|
||||
}
|
||||
if memo.Payload.Property.HasIncompleteTasks {
|
||||
undoCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
if memo.Pinned {
|
||||
pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID))
|
||||
}
|
||||
}
|
||||
|
||||
userStats := &v1pb.UserStats{
|
||||
Name: fmt.Sprintf("users/%d/stats", userID),
|
||||
MemoDisplayTimestamps: displayTimestamps,
|
||||
TagCount: tagCount,
|
||||
PinnedMemos: pinnedMemos,
|
||||
TotalMemoCount: int32(len(memos)),
|
||||
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
|
||||
LinkCount: linkCount,
|
||||
CodeCount: codeCount,
|
||||
TodoCount: todoCount,
|
||||
UndoCount: undoCount,
|
||||
},
|
||||
}
|
||||
|
||||
return userStats, nil
|
||||
}
|
||||
137
server/router/api/v1/v1.go
Normal file
137
server/router/api/v1/v1.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/improbable-eng/grpc-web/go/grpcweb"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/reflection"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type APIV1Service struct {
|
||||
grpc_health_v1.UnimplementedHealthServer
|
||||
|
||||
v1pb.UnimplementedWorkspaceServiceServer
|
||||
v1pb.UnimplementedAuthServiceServer
|
||||
v1pb.UnimplementedUserServiceServer
|
||||
v1pb.UnimplementedMemoServiceServer
|
||||
v1pb.UnimplementedAttachmentServiceServer
|
||||
v1pb.UnimplementedShortcutServiceServer
|
||||
v1pb.UnimplementedInboxServiceServer
|
||||
v1pb.UnimplementedActivityServiceServer
|
||||
v1pb.UnimplementedWebhookServiceServer
|
||||
v1pb.UnimplementedMarkdownServiceServer
|
||||
v1pb.UnimplementedIdentityProviderServiceServer
|
||||
|
||||
Secret string
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
|
||||
grpcServer *grpc.Server
|
||||
}
|
||||
|
||||
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store, grpcServer *grpc.Server) *APIV1Service {
|
||||
grpc.EnableTracing = true
|
||||
apiv1Service := &APIV1Service{
|
||||
Secret: secret,
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
grpcServer: grpcServer,
|
||||
}
|
||||
grpc_health_v1.RegisterHealthServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterWorkspaceServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterAuthServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterUserServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterMemoServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterAttachmentServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterShortcutServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterInboxServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterActivityServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterWebhookServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterMarkdownServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterIdentityProviderServiceServer(grpcServer, apiv1Service)
|
||||
reflection.Register(grpcServer)
|
||||
return apiv1Service
|
||||
}
|
||||
|
||||
// RegisterGateway registers the gRPC-Gateway with the given Echo instance.
|
||||
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
|
||||
var target string
|
||||
if len(s.Profile.UNIXSock) == 0 {
|
||||
target = fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port)
|
||||
} else {
|
||||
target = fmt.Sprintf("unix:%s", s.Profile.UNIXSock)
|
||||
}
|
||||
conn, err := grpc.NewClient(
|
||||
target,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
gwMux := runtime.NewServeMux()
|
||||
if err := v1pb.RegisterWorkspaceServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterAuthServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterUserServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterMemoServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterAttachmentServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterShortcutServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterInboxServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterActivityServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterWebhookServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterMarkdownServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterIdentityProviderServiceHandler(ctx, gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
gwGroup := echoServer.Group("")
|
||||
gwGroup.Use(middleware.CORS())
|
||||
handler := echo.WrapHandler(gwMux)
|
||||
|
||||
gwGroup.Any("/api/v1/*", handler)
|
||||
gwGroup.Any("/file/*", handler)
|
||||
|
||||
// GRPC web proxy.
|
||||
options := []grpcweb.Option{
|
||||
grpcweb.WithCorsForRegisteredEndpointsOnly(false),
|
||||
grpcweb.WithOriginFunc(func(_ string) bool {
|
||||
return true
|
||||
}),
|
||||
}
|
||||
wrappedGrpc := grpcweb.WrapServer(s.grpcServer, options...)
|
||||
echoServer.Any("/memos.api.v1.*", echo.WrapHandler(wrappedGrpc))
|
||||
|
||||
return nil
|
||||
}
|
||||
317
server/router/api/v1/webhook_service.go
Normal file
317
server/router/api/v1/webhook_service.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateWebhook(ctx context.Context, request *v1pb.CreateWebhookRequest) (*v1pb.Webhook, error) {
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Extract user ID from parent (format: users/{user})
|
||||
parentUserID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
|
||||
}
|
||||
|
||||
// Users can only create webhooks for themselves
|
||||
if parentUserID != currentUser.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Only host users can create webhooks
|
||||
if !isSuperUser(currentUser) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if request.Webhook == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "webhook is required")
|
||||
}
|
||||
if strings.TrimSpace(request.Webhook.Url) == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "webhook URL is required")
|
||||
}
|
||||
|
||||
// Handle validate_only field
|
||||
if request.ValidateOnly {
|
||||
// Perform validation checks without actually creating the webhook
|
||||
return &v1pb.Webhook{
|
||||
Name: fmt.Sprintf("users/%d/webhooks/validate", currentUser.ID),
|
||||
DisplayName: request.Webhook.DisplayName,
|
||||
Url: request.Webhook.Url,
|
||||
}, nil
|
||||
}
|
||||
|
||||
err = s.Store.AddUserWebhook(ctx, currentUser.ID, &storepb.WebhooksUserSetting_Webhook{
|
||||
Id: generateWebhookID(),
|
||||
Title: request.Webhook.DisplayName,
|
||||
Url: strings.TrimSpace(request.Webhook.Url),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create webhook, error: %+v", err)
|
||||
}
|
||||
|
||||
// Return the newly created webhook
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user webhooks, error: %+v", err)
|
||||
}
|
||||
|
||||
// Find the webhook we just created
|
||||
for _, webhook := range webhooks {
|
||||
if webhook.Title == request.Webhook.DisplayName && webhook.Url == strings.TrimSpace(request.Webhook.Url) {
|
||||
return convertWebhookFromUserSetting(webhook, currentUser.ID), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Internal, "failed to find created webhook")
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListWebhooks(ctx context.Context, request *v1pb.ListWebhooksRequest) (*v1pb.ListWebhooksResponse, error) {
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Extract user ID from parent (format: users/{user})
|
||||
parentUserID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
|
||||
}
|
||||
|
||||
// Users can only list their own webhooks
|
||||
if parentUserID != currentUser.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list webhooks, error: %+v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListWebhooksResponse{
|
||||
Webhooks: []*v1pb.Webhook{},
|
||||
}
|
||||
for _, webhook := range webhooks {
|
||||
response.Webhooks = append(response.Webhooks, convertWebhookFromUserSetting(webhook, currentUser.ID))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetWebhook(ctx context.Context, request *v1pb.GetWebhookRequest) (*v1pb.Webhook, error) {
|
||||
// Extract user ID and webhook ID from name (format: users/{user}/webhooks/{webhook})
|
||||
tokens, err := GetNameParentTokens(request.Name, UserNamePrefix, WebhookNamePrefix)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
|
||||
}
|
||||
if len(tokens) != 2 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name format")
|
||||
}
|
||||
|
||||
userIDStr := tokens[0]
|
||||
webhookID := tokens[1]
|
||||
|
||||
requestedUserID, err := util.ConvertStringToInt32(userIDStr)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID in webhook name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Users can only access their own webhooks
|
||||
if requestedUserID != currentUser.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get webhooks, error: %+v", err)
|
||||
}
|
||||
|
||||
// Find webhook by ID
|
||||
for _, webhook := range webhooks {
|
||||
if webhook.Id == webhookID {
|
||||
return convertWebhookFromUserSetting(webhook, currentUser.ID), nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "webhook not found")
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateWebhook(ctx context.Context, request *v1pb.UpdateWebhookRequest) (*v1pb.Webhook, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||
}
|
||||
|
||||
// Extract user ID and webhook ID from name (format: users/{user}/webhooks/{webhook})
|
||||
tokens, err := GetNameParentTokens(request.Webhook.Name, UserNamePrefix, WebhookNamePrefix)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
|
||||
}
|
||||
if len(tokens) != 2 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name format")
|
||||
}
|
||||
|
||||
userIDStr := tokens[0]
|
||||
webhookID := tokens[1]
|
||||
|
||||
requestedUserID, err := util.ConvertStringToInt32(userIDStr)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID in webhook name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Users can only update their own webhooks
|
||||
if requestedUserID != currentUser.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Get existing webhooks from user settings
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get webhooks: %v", err)
|
||||
}
|
||||
|
||||
// Find the webhook to update
|
||||
var existingWebhook *storepb.WebhooksUserSetting_Webhook
|
||||
for _, webhook := range webhooks {
|
||||
if webhook.Id == webhookID {
|
||||
existingWebhook = webhook
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if existingWebhook == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "webhook not found")
|
||||
}
|
||||
|
||||
// Create updated webhook
|
||||
updatedWebhook := &storepb.WebhooksUserSetting_Webhook{
|
||||
Id: existingWebhook.Id,
|
||||
Title: existingWebhook.Title,
|
||||
Url: existingWebhook.Url,
|
||||
}
|
||||
|
||||
// Apply updates based on update mask
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
switch field {
|
||||
case "display_name":
|
||||
updatedWebhook.Title = request.Webhook.DisplayName
|
||||
case "url":
|
||||
updatedWebhook.Url = request.Webhook.Url
|
||||
default:
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
// Update the webhook in user settings
|
||||
err = s.Store.UpdateUserWebhook(ctx, currentUser.ID, updatedWebhook)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update webhook: %v", err)
|
||||
}
|
||||
|
||||
return convertWebhookFromUserSetting(updatedWebhook, currentUser.ID), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteWebhook(ctx context.Context, request *v1pb.DeleteWebhookRequest) (*emptypb.Empty, error) {
|
||||
// Extract user ID and webhook ID from name (format: users/{user}/webhooks/{webhook})
|
||||
tokens, err := GetNameParentTokens(request.Name, UserNamePrefix, WebhookNamePrefix)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
|
||||
}
|
||||
if len(tokens) != 2 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name format")
|
||||
}
|
||||
|
||||
userIDStr := tokens[0]
|
||||
webhookID := tokens[1]
|
||||
|
||||
requestedUserID, err := util.ConvertStringToInt32(userIDStr)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID in webhook name: %v", err)
|
||||
}
|
||||
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Users can only delete their own webhooks
|
||||
if requestedUserID != currentUser.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Get existing webhooks from user settings to verify it exists
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get webhooks: %v", err)
|
||||
}
|
||||
|
||||
// Check if webhook exists
|
||||
webhookExists := false
|
||||
for _, webhook := range webhooks {
|
||||
if webhook.Id == webhookID {
|
||||
webhookExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !webhookExists {
|
||||
return nil, status.Errorf(codes.NotFound, "webhook not found")
|
||||
}
|
||||
|
||||
err = s.Store.RemoveUserWebhook(ctx, currentUser.ID, webhookID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete webhook: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertWebhookFromUserSetting(webhook *storepb.WebhooksUserSetting_Webhook, userID int32) *v1pb.Webhook {
|
||||
return &v1pb.Webhook{
|
||||
Name: fmt.Sprintf("users/%d/webhooks/%s", userID, webhook.Id),
|
||||
DisplayName: webhook.Title,
|
||||
Url: webhook.Url,
|
||||
}
|
||||
}
|
||||
|
||||
func generateWebhookID() string {
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
306
server/router/api/v1/workspace_service.go
Normal file
306
server/router/api/v1/workspace_service.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// GetWorkspaceProfile returns the workspace profile.
|
||||
func (s *APIV1Service) GetWorkspaceProfile(ctx context.Context, _ *v1pb.GetWorkspaceProfileRequest) (*v1pb.WorkspaceProfile, error) {
|
||||
workspaceProfile := &v1pb.WorkspaceProfile{
|
||||
Version: s.Profile.Version,
|
||||
Mode: s.Profile.Mode,
|
||||
InstanceUrl: s.Profile.InstanceURL,
|
||||
}
|
||||
owner, err := s.GetInstanceOwner(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance owner: %v", err)
|
||||
}
|
||||
if owner != nil {
|
||||
workspaceProfile.Owner = owner.Name
|
||||
}
|
||||
return workspaceProfile, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetWorkspaceSetting(ctx context.Context, request *v1pb.GetWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
|
||||
workspaceSettingKeyString, err := ExtractWorkspaceSettingKeyFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid workspace setting name: %v", err)
|
||||
}
|
||||
|
||||
workspaceSettingKey := storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[workspaceSettingKeyString])
|
||||
// Get workspace setting from store with default value.
|
||||
switch workspaceSettingKey {
|
||||
case storepb.WorkspaceSettingKey_BASIC:
|
||||
_, err = s.Store.GetWorkspaceBasicSetting(ctx)
|
||||
case storepb.WorkspaceSettingKey_GENERAL:
|
||||
_, err = s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||
case storepb.WorkspaceSettingKey_MEMO_RELATED:
|
||||
_, err = s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
case storepb.WorkspaceSettingKey_STORAGE:
|
||||
_, err = s.Store.GetWorkspaceStorageSetting(ctx)
|
||||
default:
|
||||
return nil, status.Errorf(codes.InvalidArgument, "unsupported workspace setting key: %v", workspaceSettingKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
|
||||
}
|
||||
|
||||
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
|
||||
Name: workspaceSettingKey.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
|
||||
}
|
||||
if workspaceSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "workspace setting not found")
|
||||
}
|
||||
|
||||
// For storage setting, only host can get it.
|
||||
if workspaceSetting.Key == storepb.WorkspaceSettingKey_STORAGE {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil || user.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
return convertWorkspaceSettingFromStore(workspaceSetting), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateWorkspaceSetting(ctx context.Context, request *v1pb.UpdateWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// TODO: Apply update_mask if specified
|
||||
_ = request.UpdateMask
|
||||
|
||||
updateSetting := convertWorkspaceSettingToStore(request.Setting)
|
||||
workspaceSetting, err := s.Store.UpsertWorkspaceSetting(ctx, updateSetting)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert workspace setting: %v", err)
|
||||
}
|
||||
|
||||
return convertWorkspaceSettingFromStore(workspaceSetting), nil
|
||||
}
|
||||
|
||||
func convertWorkspaceSettingFromStore(setting *storepb.WorkspaceSetting) *v1pb.WorkspaceSetting {
|
||||
workspaceSetting := &v1pb.WorkspaceSetting{
|
||||
Name: fmt.Sprintf("workspace/settings/%s", setting.Key.String()),
|
||||
}
|
||||
switch setting.Value.(type) {
|
||||
case *storepb.WorkspaceSetting_GeneralSetting:
|
||||
workspaceSetting.Value = &v1pb.WorkspaceSetting_GeneralSetting{
|
||||
GeneralSetting: convertWorkspaceGeneralSettingFromStore(setting.GetGeneralSetting()),
|
||||
}
|
||||
case *storepb.WorkspaceSetting_StorageSetting:
|
||||
workspaceSetting.Value = &v1pb.WorkspaceSetting_StorageSetting{
|
||||
StorageSetting: convertWorkspaceStorageSettingFromStore(setting.GetStorageSetting()),
|
||||
}
|
||||
case *storepb.WorkspaceSetting_MemoRelatedSetting:
|
||||
workspaceSetting.Value = &v1pb.WorkspaceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingFromStore(setting.GetMemoRelatedSetting()),
|
||||
}
|
||||
}
|
||||
return workspaceSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceSettingToStore(setting *v1pb.WorkspaceSetting) *storepb.WorkspaceSetting {
|
||||
settingKeyString, _ := ExtractWorkspaceSettingKeyFromName(setting.Name)
|
||||
workspaceSetting := &storepb.WorkspaceSetting{
|
||||
Key: storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString]),
|
||||
Value: &storepb.WorkspaceSetting_GeneralSetting{
|
||||
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
|
||||
},
|
||||
}
|
||||
switch workspaceSetting.Key {
|
||||
case storepb.WorkspaceSettingKey_GENERAL:
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_GeneralSetting{
|
||||
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
|
||||
}
|
||||
case storepb.WorkspaceSettingKey_STORAGE:
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_StorageSetting{
|
||||
StorageSetting: convertWorkspaceStorageSettingToStore(setting.GetStorageSetting()),
|
||||
}
|
||||
case storepb.WorkspaceSettingKey_MEMO_RELATED:
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingToStore(setting.GetMemoRelatedSetting()),
|
||||
}
|
||||
}
|
||||
return workspaceSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceGeneralSettingFromStore(setting *storepb.WorkspaceGeneralSetting) *v1pb.WorkspaceGeneralSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
// Backfill theme if empty
|
||||
theme := setting.Theme
|
||||
if theme == "" {
|
||||
theme = "default"
|
||||
}
|
||||
|
||||
generalSetting := &v1pb.WorkspaceGeneralSetting{
|
||||
Theme: theme,
|
||||
DisallowUserRegistration: setting.DisallowUserRegistration,
|
||||
DisallowPasswordAuth: setting.DisallowPasswordAuth,
|
||||
AdditionalScript: setting.AdditionalScript,
|
||||
AdditionalStyle: setting.AdditionalStyle,
|
||||
WeekStartDayOffset: setting.WeekStartDayOffset,
|
||||
DisallowChangeUsername: setting.DisallowChangeUsername,
|
||||
DisallowChangeNickname: setting.DisallowChangeNickname,
|
||||
}
|
||||
if setting.CustomProfile != nil {
|
||||
generalSetting.CustomProfile = &v1pb.WorkspaceCustomProfile{
|
||||
Title: setting.CustomProfile.Title,
|
||||
Description: setting.CustomProfile.Description,
|
||||
LogoUrl: setting.CustomProfile.LogoUrl,
|
||||
Locale: setting.CustomProfile.Locale,
|
||||
Appearance: setting.CustomProfile.Appearance,
|
||||
}
|
||||
}
|
||||
return generalSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceGeneralSettingToStore(setting *v1pb.WorkspaceGeneralSetting) *storepb.WorkspaceGeneralSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
generalSetting := &storepb.WorkspaceGeneralSetting{
|
||||
Theme: setting.Theme,
|
||||
DisallowUserRegistration: setting.DisallowUserRegistration,
|
||||
DisallowPasswordAuth: setting.DisallowPasswordAuth,
|
||||
AdditionalScript: setting.AdditionalScript,
|
||||
AdditionalStyle: setting.AdditionalStyle,
|
||||
WeekStartDayOffset: setting.WeekStartDayOffset,
|
||||
DisallowChangeUsername: setting.DisallowChangeUsername,
|
||||
DisallowChangeNickname: setting.DisallowChangeNickname,
|
||||
}
|
||||
if setting.CustomProfile != nil {
|
||||
generalSetting.CustomProfile = &storepb.WorkspaceCustomProfile{
|
||||
Title: setting.CustomProfile.Title,
|
||||
Description: setting.CustomProfile.Description,
|
||||
LogoUrl: setting.CustomProfile.LogoUrl,
|
||||
Locale: setting.CustomProfile.Locale,
|
||||
Appearance: setting.CustomProfile.Appearance,
|
||||
}
|
||||
}
|
||||
return generalSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceStorageSettingFromStore(settingpb *storepb.WorkspaceStorageSetting) *v1pb.WorkspaceStorageSetting {
|
||||
if settingpb == nil {
|
||||
return nil
|
||||
}
|
||||
setting := &v1pb.WorkspaceStorageSetting{
|
||||
StorageType: v1pb.WorkspaceStorageSetting_StorageType(settingpb.StorageType),
|
||||
FilepathTemplate: settingpb.FilepathTemplate,
|
||||
UploadSizeLimitMb: settingpb.UploadSizeLimitMb,
|
||||
}
|
||||
if settingpb.S3Config != nil {
|
||||
setting.S3Config = &v1pb.WorkspaceStorageSetting_S3Config{
|
||||
AccessKeyId: settingpb.S3Config.AccessKeyId,
|
||||
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
|
||||
Endpoint: settingpb.S3Config.Endpoint,
|
||||
Region: settingpb.S3Config.Region,
|
||||
Bucket: settingpb.S3Config.Bucket,
|
||||
UsePathStyle: settingpb.S3Config.UsePathStyle,
|
||||
}
|
||||
}
|
||||
return setting
|
||||
}
|
||||
|
||||
func convertWorkspaceStorageSettingToStore(setting *v1pb.WorkspaceStorageSetting) *storepb.WorkspaceStorageSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
settingpb := &storepb.WorkspaceStorageSetting{
|
||||
StorageType: storepb.WorkspaceStorageSetting_StorageType(setting.StorageType),
|
||||
FilepathTemplate: setting.FilepathTemplate,
|
||||
UploadSizeLimitMb: setting.UploadSizeLimitMb,
|
||||
}
|
||||
if setting.S3Config != nil {
|
||||
settingpb.S3Config = &storepb.StorageS3Config{
|
||||
AccessKeyId: setting.S3Config.AccessKeyId,
|
||||
AccessKeySecret: setting.S3Config.AccessKeySecret,
|
||||
Endpoint: setting.S3Config.Endpoint,
|
||||
Region: setting.S3Config.Region,
|
||||
Bucket: setting.S3Config.Bucket,
|
||||
UsePathStyle: setting.S3Config.UsePathStyle,
|
||||
}
|
||||
}
|
||||
return settingpb
|
||||
}
|
||||
|
||||
func convertWorkspaceMemoRelatedSettingFromStore(setting *storepb.WorkspaceMemoRelatedSetting) *v1pb.WorkspaceMemoRelatedSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
return &v1pb.WorkspaceMemoRelatedSetting{
|
||||
DisallowPublicVisibility: setting.DisallowPublicVisibility,
|
||||
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
|
||||
ContentLengthLimit: setting.ContentLengthLimit,
|
||||
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
|
||||
EnableLinkPreview: setting.EnableLinkPreview,
|
||||
EnableComment: setting.EnableComment,
|
||||
Reactions: setting.Reactions,
|
||||
DisableMarkdownShortcuts: setting.DisableMarkdownShortcuts,
|
||||
EnableBlurNsfwContent: setting.EnableBlurNsfwContent,
|
||||
NsfwTags: setting.NsfwTags,
|
||||
}
|
||||
}
|
||||
|
||||
func convertWorkspaceMemoRelatedSettingToStore(setting *v1pb.WorkspaceMemoRelatedSetting) *storepb.WorkspaceMemoRelatedSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
return &storepb.WorkspaceMemoRelatedSetting{
|
||||
DisallowPublicVisibility: setting.DisallowPublicVisibility,
|
||||
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
|
||||
ContentLengthLimit: setting.ContentLengthLimit,
|
||||
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
|
||||
EnableLinkPreview: setting.EnableLinkPreview,
|
||||
EnableComment: setting.EnableComment,
|
||||
Reactions: setting.Reactions,
|
||||
DisableMarkdownShortcuts: setting.DisableMarkdownShortcuts,
|
||||
EnableBlurNsfwContent: setting.EnableBlurNsfwContent,
|
||||
NsfwTags: setting.NsfwTags,
|
||||
}
|
||||
}
|
||||
|
||||
var ownerCache *v1pb.User
|
||||
|
||||
func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) {
|
||||
if ownerCache != nil {
|
||||
return ownerCache, nil
|
||||
}
|
||||
|
||||
hostUserType := store.RoleHost
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Role: &hostUserType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to find owner")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ownerCache = convertUserFromStore(user)
|
||||
return ownerCache, nil
|
||||
}
|
||||
61
server/router/frontend/frontend.go
Normal file
61
server/router/frontend/frontend.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package frontend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
//go:embed dist/*
|
||||
var embeddedFiles embed.FS
|
||||
|
||||
type FrontendService struct {
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
func NewFrontendService(profile *profile.Profile, store *store.Store) *FrontendService {
|
||||
return &FrontendService{
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (*FrontendService) Serve(_ context.Context, e *echo.Echo) {
|
||||
skipper := func(c echo.Context) bool {
|
||||
// Skip API routes.
|
||||
if util.HasPrefixes(c.Path(), "/api", "/memos.api.v1") {
|
||||
return true
|
||||
}
|
||||
// Skip setting cache headers for index.html
|
||||
if c.Path() == "/" || c.Path() == "/index.html" {
|
||||
return false
|
||||
}
|
||||
// Set Cache-Control header to allow public caching with a max-age of 7 days.
|
||||
c.Response().Header().Set(echo.HeaderCacheControl, "public, max-age=604800") // 7 days
|
||||
return false
|
||||
}
|
||||
|
||||
// Route to serve the main app with HTML5 fallback for SPA behavior.
|
||||
e.Use(middleware.StaticWithConfig(middleware.StaticConfig{
|
||||
Filesystem: getFileSystem("dist"),
|
||||
HTML5: true, // Enable fallback to index.html
|
||||
Skipper: skipper,
|
||||
}))
|
||||
}
|
||||
|
||||
func getFileSystem(path string) http.FileSystem {
|
||||
fs, err := fs.Sub(embeddedFiles, path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return http.FS(fs)
|
||||
}
|
||||
179
server/router/rss/rss.go
Normal file
179
server/router/rss/rss.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package rss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/feeds"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/gomark"
|
||||
"github.com/usememos/gomark/renderer"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
maxRSSItemCount = 100
|
||||
)
|
||||
|
||||
type RSSService struct {
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
type RSSHeading struct {
|
||||
Title string
|
||||
Description string
|
||||
}
|
||||
|
||||
func NewRSSService(profile *profile.Profile, store *store.Store) *RSSService {
|
||||
return &RSSService{
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RSSService) RegisterRoutes(g *echo.Group) {
|
||||
g.GET("/explore/rss.xml", s.GetExploreRSS)
|
||||
g.GET("/u/:username/rss.xml", s.GetUserRSS)
|
||||
}
|
||||
|
||||
func (s *RSSService) GetExploreRSS(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
normalStatus := store.Normal
|
||||
memoFind := store.FindMemo{
|
||||
RowStatus: &normalStatus,
|
||||
VisibilityList: []store.Visibility{store.Public},
|
||||
}
|
||||
memoList, err := s.Store.ListMemos(ctx, &memoFind)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
|
||||
}
|
||||
|
||||
baseURL := c.Scheme() + "://" + c.Request().Host
|
||||
rss, err := s.generateRSSFromMemoList(ctx, memoList, baseURL)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").SetInternal(err)
|
||||
}
|
||||
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationXMLCharsetUTF8)
|
||||
return c.String(http.StatusOK, rss)
|
||||
}
|
||||
|
||||
func (s *RSSService) GetUserRSS(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
username := c.Param("username")
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||
}
|
||||
if user == nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "User not found")
|
||||
}
|
||||
|
||||
normalStatus := store.Normal
|
||||
memoFind := store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
RowStatus: &normalStatus,
|
||||
VisibilityList: []store.Visibility{store.Public},
|
||||
}
|
||||
memoList, err := s.Store.ListMemos(ctx, &memoFind)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
|
||||
}
|
||||
|
||||
baseURL := c.Scheme() + "://" + c.Request().Host
|
||||
rss, err := s.generateRSSFromMemoList(ctx, memoList, baseURL)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").SetInternal(err)
|
||||
}
|
||||
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationXMLCharsetUTF8)
|
||||
return c.String(http.StatusOK, rss)
|
||||
}
|
||||
|
||||
func (s *RSSService) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string) (string, error) {
|
||||
rssHeading, err := getRSSHeading(ctx, s.Store)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
feed := &feeds.Feed{
|
||||
Title: rssHeading.Title,
|
||||
Link: &feeds.Link{Href: baseURL},
|
||||
Description: rssHeading.Description,
|
||||
Created: time.Now(),
|
||||
}
|
||||
|
||||
var itemCountLimit = min(len(memoList), maxRSSItemCount)
|
||||
feed.Items = make([]*feeds.Item, itemCountLimit)
|
||||
for i := 0; i < itemCountLimit; i++ {
|
||||
memo := memoList[i]
|
||||
description, err := getRSSItemDescription(memo.Content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
link := &feeds.Link{Href: baseURL + "/memos/" + memo.UID}
|
||||
feed.Items[i] = &feeds.Item{
|
||||
Link: link,
|
||||
Description: description,
|
||||
Created: time.Unix(memo.CreatedTs, 0),
|
||||
Id: link.Href,
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(attachments) > 0 {
|
||||
attachment := attachments[0]
|
||||
enclosure := feeds.Enclosure{}
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
enclosure.Url = attachment.Reference
|
||||
} else {
|
||||
enclosure.Url = fmt.Sprintf("%s/file/attachments/%s/%s", baseURL, attachment.UID, attachment.Filename)
|
||||
}
|
||||
enclosure.Length = strconv.Itoa(int(attachment.Size))
|
||||
enclosure.Type = attachment.Type
|
||||
feed.Items[i].Enclosure = &enclosure
|
||||
}
|
||||
}
|
||||
|
||||
rss, err := feed.ToRss()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return rss, nil
|
||||
}
|
||||
|
||||
func getRSSItemDescription(content string) (string, error) {
|
||||
nodes, err := gomark.Parse(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result := renderer.NewHTMLRenderer().Render(nodes)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func getRSSHeading(ctx context.Context, stores *store.Store) (RSSHeading, error) {
|
||||
settings, err := stores.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return RSSHeading{}, err
|
||||
}
|
||||
if settings == nil || settings.CustomProfile == nil {
|
||||
return RSSHeading{
|
||||
Title: "Memos",
|
||||
Description: "An open source, lightweight note-taking service. Easily capture and share your great thoughts.",
|
||||
}, nil
|
||||
}
|
||||
customProfile := settings.CustomProfile
|
||||
return RSSHeading{
|
||||
Title: customProfile.Title,
|
||||
Description: customProfile.Description,
|
||||
}, nil
|
||||
}
|
||||
134
server/runner/memopayload/runner.go
Normal file
134
server/runner/memopayload/runner.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package memopayload
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"slices"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/gomark/ast"
|
||||
"github.com/usememos/gomark/parser"
|
||||
"github.com/usememos/gomark/parser/tokenizer"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type Runner struct {
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
func NewRunner(store *store.Store) *Runner {
|
||||
return &Runner{
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// RunOnce rebuilds the payload of all memos.
|
||||
func (r *Runner) RunOnce(ctx context.Context) {
|
||||
// Process memos in batches to avoid loading all memos into memory at once
|
||||
const batchSize = 100
|
||||
offset := 0
|
||||
processed := 0
|
||||
|
||||
for {
|
||||
limit := batchSize
|
||||
memos, err := r.Store.ListMemos(ctx, &store.FindMemo{
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to list memos", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Break if no more memos
|
||||
if len(memos) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Process batch
|
||||
batchSuccessCount := 0
|
||||
for _, memo := range memos {
|
||||
if err := RebuildMemoPayload(memo); err != nil {
|
||||
slog.Error("failed to rebuild memo payload", "err", err, "memoID", memo.ID)
|
||||
continue
|
||||
}
|
||||
if err := r.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Payload: memo.Payload,
|
||||
}); err != nil {
|
||||
slog.Error("failed to update memo", "err", err, "memoID", memo.ID)
|
||||
continue
|
||||
}
|
||||
batchSuccessCount++
|
||||
}
|
||||
|
||||
processed += len(memos)
|
||||
slog.Info("Processed memo batch", "batchSize", len(memos), "successCount", batchSuccessCount, "totalProcessed", processed)
|
||||
|
||||
// Move to next batch
|
||||
offset += len(memos)
|
||||
}
|
||||
}
|
||||
|
||||
func RebuildMemoPayload(memo *store.Memo) error {
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to parse content")
|
||||
}
|
||||
|
||||
if memo.Payload == nil {
|
||||
memo.Payload = &storepb.MemoPayload{}
|
||||
}
|
||||
tags := []string{}
|
||||
property := &storepb.MemoPayload_Property{}
|
||||
TraverseASTNodes(nodes, func(node ast.Node) {
|
||||
switch n := node.(type) {
|
||||
case *ast.Tag:
|
||||
tag := n.Content
|
||||
if !slices.Contains(tags, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
case *ast.Link, *ast.AutoLink:
|
||||
property.HasLink = true
|
||||
case *ast.TaskListItem:
|
||||
property.HasTaskList = true
|
||||
if !n.Complete {
|
||||
property.HasIncompleteTasks = true
|
||||
}
|
||||
case *ast.CodeBlock:
|
||||
property.HasCode = true
|
||||
case *ast.EmbeddedContent:
|
||||
// TODO: validate references.
|
||||
property.References = append(property.References, n.ResourceName)
|
||||
}
|
||||
})
|
||||
memo.Payload.Tags = tags
|
||||
memo.Payload.Property = property
|
||||
return nil
|
||||
}
|
||||
|
||||
func TraverseASTNodes(nodes []ast.Node, fn func(ast.Node)) {
|
||||
for _, node := range nodes {
|
||||
fn(node)
|
||||
switch n := node.(type) {
|
||||
case *ast.Paragraph:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.Heading:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.Blockquote:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.List:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.OrderedListItem:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.UnorderedListItem:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.TaskListItem:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.Bold:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
134
server/runner/s3presign/runner.go
Normal file
134
server/runner/s3presign/runner.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package s3presign
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type Runner struct {
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
func NewRunner(store *store.Store) *Runner {
|
||||
return &Runner{
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule runner every 12 hours.
|
||||
const runnerInterval = time.Hour * 12
|
||||
|
||||
func (r *Runner) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(runnerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
r.RunOnce(ctx)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) RunOnce(ctx context.Context) {
|
||||
r.CheckAndPresign(ctx)
|
||||
}
|
||||
|
||||
func (r *Runner) CheckAndPresign(ctx context.Context) {
|
||||
workspaceStorageSetting, err := r.Store.GetWorkspaceStorageSetting(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s3StorageType := storepb.AttachmentStorageType_S3
|
||||
// Limit attachments to a reasonable batch size
|
||||
const batchSize = 100
|
||||
offset := 0
|
||||
|
||||
for {
|
||||
limit := batchSize
|
||||
attachments, err := r.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
GetBlob: false,
|
||||
StorageType: &s3StorageType,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to list attachments for presigning", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Break if no more attachments
|
||||
if len(attachments) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Process batch of attachments
|
||||
presignCount := 0
|
||||
for _, attachment := range attachments {
|
||||
s3ObjectPayload := attachment.Payload.GetS3Object()
|
||||
if s3ObjectPayload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if s3ObjectPayload.LastPresignedTime != nil {
|
||||
// Skip if the presigned URL is still valid for the next 4 days.
|
||||
// The expiration time is set to 5 days.
|
||||
if time.Now().Before(s3ObjectPayload.LastPresignedTime.AsTime().Add(4 * 24 * time.Hour)) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s3Config := workspaceStorageSetting.GetS3Config()
|
||||
if s3ObjectPayload.S3Config != nil {
|
||||
s3Config = s3ObjectPayload.S3Config
|
||||
}
|
||||
if s3Config == nil {
|
||||
slog.Error("S3 config is not found")
|
||||
continue
|
||||
}
|
||||
|
||||
s3Client, err := s3.NewClient(ctx, s3Config)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create S3 client", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
presignURL, err := s3Client.PresignGetObject(ctx, s3ObjectPayload.Key)
|
||||
if err != nil {
|
||||
slog.Error("Failed to presign URL", "error", err, "attachmentID", attachment.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
s3ObjectPayload.S3Config = s3Config
|
||||
s3ObjectPayload.LastPresignedTime = timestamppb.New(time.Now())
|
||||
if err := r.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
ID: attachment.ID,
|
||||
Reference: &presignURL,
|
||||
Payload: &storepb.AttachmentPayload{
|
||||
Payload: &storepb.AttachmentPayload_S3Object_{
|
||||
S3Object: s3ObjectPayload,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Error("Failed to update attachment", "error", err, "attachmentID", attachment.ID)
|
||||
continue
|
||||
}
|
||||
presignCount++
|
||||
}
|
||||
|
||||
slog.Info("Presigned batch of S3 attachments", "batchSize", len(attachments), "presigned", presignCount)
|
||||
|
||||
// Move to next batch
|
||||
offset += len(attachments)
|
||||
}
|
||||
}
|
||||
227
server/server.go
Normal file
227
server/server.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
grpcrecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/soheilhy/cmux"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/profiler"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/server/router/frontend"
|
||||
"github.com/usememos/memos/server/router/rss"
|
||||
"github.com/usememos/memos/server/runner/s3presign"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Secret string
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
|
||||
echoServer *echo.Echo
|
||||
grpcServer *grpc.Server
|
||||
profiler *profiler.Profiler
|
||||
runnerCancelFuncs []context.CancelFunc
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) {
|
||||
s := &Server{
|
||||
Store: store,
|
||||
Profile: profile,
|
||||
}
|
||||
|
||||
echoServer := echo.New()
|
||||
echoServer.Debug = true
|
||||
echoServer.HideBanner = true
|
||||
echoServer.HidePort = true
|
||||
echoServer.Use(middleware.Recover())
|
||||
s.echoServer = echoServer
|
||||
|
||||
// Initialize profiler
|
||||
s.profiler = profiler.NewProfiler()
|
||||
s.profiler.RegisterRoutes(echoServer)
|
||||
s.profiler.StartMemoryMonitor(ctx)
|
||||
|
||||
workspaceBasicSetting, err := s.getOrUpsertWorkspaceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace basic setting")
|
||||
}
|
||||
|
||||
secret := "usememos"
|
||||
if profile.Mode == "prod" {
|
||||
secret = workspaceBasicSetting.SecretKey
|
||||
}
|
||||
s.Secret = secret
|
||||
|
||||
// Register healthz endpoint.
|
||||
echoServer.GET("/healthz", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "Service ready.")
|
||||
})
|
||||
|
||||
// Serve frontend static files.
|
||||
frontend.NewFrontendService(profile, store).Serve(ctx, echoServer)
|
||||
|
||||
rootGroup := echoServer.Group("")
|
||||
|
||||
// Create and register RSS routes.
|
||||
rss.NewRSSService(s.Profile, s.Store).RegisterRoutes(rootGroup)
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
// Override the maximum receiving message size to math.MaxInt32 for uploading large attachments.
|
||||
grpc.MaxRecvMsgSize(math.MaxInt32),
|
||||
grpc.ChainUnaryInterceptor(
|
||||
apiv1.NewLoggerInterceptor().LoggerInterceptor,
|
||||
grpcrecovery.UnaryServerInterceptor(),
|
||||
apiv1.NewGRPCAuthInterceptor(store, secret).AuthenticationInterceptor,
|
||||
))
|
||||
s.grpcServer = grpcServer
|
||||
|
||||
apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store, grpcServer)
|
||||
// Register gRPC gateway as api v1.
|
||||
if err := apiV1Service.RegisterGateway(ctx, echoServer); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to register gRPC gateway")
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
var address, network string
|
||||
if len(s.Profile.UNIXSock) == 0 {
|
||||
address = fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port)
|
||||
network = "tcp"
|
||||
} else {
|
||||
address = s.Profile.UNIXSock
|
||||
network = "unix"
|
||||
}
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to listen")
|
||||
}
|
||||
|
||||
muxServer := cmux.New(listener)
|
||||
go func() {
|
||||
grpcListener := muxServer.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"))
|
||||
if err := s.grpcServer.Serve(grpcListener); err != nil {
|
||||
slog.Error("failed to serve gRPC", "error", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
httpListener := muxServer.Match(cmux.HTTP1Fast(http.MethodPatch))
|
||||
s.echoServer.Listener = httpListener
|
||||
if err := s.echoServer.Start(address); err != nil {
|
||||
slog.Error("failed to start echo server", "error", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := muxServer.Serve(); err != nil {
|
||||
slog.Error("mux server listen error", "error", err)
|
||||
}
|
||||
}()
|
||||
s.StartBackgroundRunners(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
slog.Info("server shutting down")
|
||||
|
||||
// Cancel all background runners
|
||||
for _, cancelFunc := range s.runnerCancelFuncs {
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown echo server.
|
||||
if err := s.echoServer.Shutdown(ctx); err != nil {
|
||||
slog.Error("failed to shutdown server", slog.String("error", err.Error()))
|
||||
}
|
||||
|
||||
// Shutdown gRPC server.
|
||||
s.grpcServer.GracefulStop()
|
||||
|
||||
// Stop the profiler
|
||||
if s.profiler != nil {
|
||||
slog.Info("stopping profiler")
|
||||
// Log final memory stats
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
slog.Info("final memory stats before exit",
|
||||
"heapAlloc", m.Alloc,
|
||||
"heapSys", m.Sys,
|
||||
"heapObjects", m.HeapObjects,
|
||||
"numGoroutine", runtime.NumGoroutine(),
|
||||
)
|
||||
}
|
||||
|
||||
// Close database connection.
|
||||
if err := s.Store.Close(); err != nil {
|
||||
slog.Error("failed to close database", slog.String("error", err.Error()))
|
||||
}
|
||||
|
||||
slog.Info("memos stopped properly")
|
||||
}
|
||||
|
||||
func (s *Server) StartBackgroundRunners(ctx context.Context) {
|
||||
// Create a separate context for each background runner
|
||||
// This allows us to control cancellation for each runner independently
|
||||
s3Context, s3Cancel := context.WithCancel(ctx)
|
||||
|
||||
// Store the cancel function so we can properly shut down runners
|
||||
s.runnerCancelFuncs = append(s.runnerCancelFuncs, s3Cancel)
|
||||
|
||||
// Create and start S3 presign runner
|
||||
s3presignRunner := s3presign.NewRunner(s.Store)
|
||||
s3presignRunner.RunOnce(ctx)
|
||||
|
||||
// Start continuous S3 presign runner
|
||||
go func() {
|
||||
s3presignRunner.Run(s3Context)
|
||||
slog.Info("s3presign runner stopped")
|
||||
}()
|
||||
|
||||
// Log the number of goroutines running
|
||||
slog.Info("background runners started", "goroutines", runtime.NumGoroutine())
|
||||
}
|
||||
|
||||
func (s *Server) getOrUpsertWorkspaceBasicSetting(ctx context.Context) (*storepb.WorkspaceBasicSetting, error) {
|
||||
workspaceBasicSetting, err := s.Store.GetWorkspaceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace basic setting")
|
||||
}
|
||||
modified := false
|
||||
if workspaceBasicSetting.SecretKey == "" {
|
||||
workspaceBasicSetting.SecretKey = uuid.NewString()
|
||||
modified = true
|
||||
}
|
||||
if modified {
|
||||
workspaceSetting, err := s.Store.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
|
||||
Key: storepb.WorkspaceSettingKey_BASIC,
|
||||
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to upsert workspace setting")
|
||||
}
|
||||
workspaceBasicSetting = workspaceSetting.GetBasicSetting()
|
||||
}
|
||||
return workspaceBasicSetting, nil
|
||||
}
|
||||
Reference in New Issue
Block a user