init commit

This commit is contained in:
2025-11-30 13:01:24 -05:00
parent f4596a372d
commit 29355260ed
607 changed files with 136371 additions and 234 deletions

120
server/profiler/profiler.go Normal file
View File

@@ -0,0 +1,120 @@
package profiler
import (
"context"
"fmt"
"log/slog"
"net/http"
"net/http/pprof"
"runtime"
"time"
"github.com/labstack/echo/v4"
)
// Profiler provides HTTP endpoints for memory profiling.
type Profiler struct {
memStatsLogInterval time.Duration
}
// NewProfiler creates a new profiler.
func NewProfiler() *Profiler {
return &Profiler{
memStatsLogInterval: 1 * time.Minute,
}
}
// RegisterRoutes adds profiling endpoints to the Echo server.
func (*Profiler) RegisterRoutes(e *echo.Echo) {
// Register pprof handlers
g := e.Group("/debug/pprof")
g.GET("", echo.WrapHandler(http.HandlerFunc(pprof.Index)))
g.GET("/cmdline", echo.WrapHandler(http.HandlerFunc(pprof.Cmdline)))
g.GET("/profile", echo.WrapHandler(http.HandlerFunc(pprof.Profile)))
g.POST("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol)))
g.GET("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol)))
g.GET("/trace", echo.WrapHandler(http.HandlerFunc(pprof.Trace)))
g.GET("/allocs", echo.WrapHandler(http.HandlerFunc(pprof.Handler("allocs").ServeHTTP)))
g.GET("/block", echo.WrapHandler(http.HandlerFunc(pprof.Handler("block").ServeHTTP)))
g.GET("/goroutine", echo.WrapHandler(http.HandlerFunc(pprof.Handler("goroutine").ServeHTTP)))
g.GET("/heap", echo.WrapHandler(http.HandlerFunc(pprof.Handler("heap").ServeHTTP)))
g.GET("/mutex", echo.WrapHandler(http.HandlerFunc(pprof.Handler("mutex").ServeHTTP)))
g.GET("/threadcreate", echo.WrapHandler(http.HandlerFunc(pprof.Handler("threadcreate").ServeHTTP)))
// Add a custom memory stats endpoint.
g.GET("/memstats", func(c echo.Context) error {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return c.JSON(http.StatusOK, map[string]interface{}{
"alloc": m.Alloc,
"totalAlloc": m.TotalAlloc,
"sys": m.Sys,
"numGC": m.NumGC,
"heapAlloc": m.HeapAlloc,
"heapSys": m.HeapSys,
"heapInuse": m.HeapInuse,
"heapObjects": m.HeapObjects,
})
})
}
// StartMemoryMonitor starts a goroutine that periodically logs memory stats.
func (p *Profiler) StartMemoryMonitor(ctx context.Context) {
go func() {
ticker := time.NewTicker(p.memStatsLogInterval)
defer ticker.Stop()
// Store previous heap allocation to track growth.
var lastHeapAlloc uint64
var lastNumGC uint32
for {
select {
case <-ticker.C:
var m runtime.MemStats
runtime.ReadMemStats(&m)
// Calculate heap growth since last check.
heapGrowth := int64(m.HeapAlloc) - int64(lastHeapAlloc)
gcCount := m.NumGC - lastNumGC
slog.Info("memory stats",
"heapAlloc", byteCountIEC(m.HeapAlloc),
"heapSys", byteCountIEC(m.HeapSys),
"heapObjects", m.HeapObjects,
"heapGrowth", byteCountIEC(uint64(heapGrowth)),
"numGoroutine", runtime.NumGoroutine(),
"numGC", m.NumGC,
"gcSince", gcCount,
"nextGC", byteCountIEC(m.NextGC),
"gcPause", time.Duration(m.PauseNs[(m.NumGC+255)%256]).String(),
)
// Track values for next iteration.
lastHeapAlloc = m.HeapAlloc
lastNumGC = m.NumGC
// Force GC if memory usage is high to see if objects can be reclaimed.
if m.HeapAlloc > 500*1024*1024 { // 500 MB threshold
slog.Info("forcing garbage collection due to high memory usage")
}
case <-ctx.Done():
return
}
}
}()
}
// byteCountIEC converts bytes to a human-readable string (MiB, GiB).
func byteCountIEC(b uint64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := uint64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp])
}

262
server/router/api/v1/acl.go Normal file
View File

@@ -0,0 +1,262 @@
package v1
import (
"context"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// ContextKey is the key type of context value.
type ContextKey int
const (
// The key name used to store user's ID in the context (for user-based auth).
userIDContextKey ContextKey = iota
// The key name used to store session ID in the context (for session-based auth).
sessionIDContextKey
// The key name used to store access token in the context (for token-based auth).
accessTokenContextKey
)
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
type GRPCAuthInterceptor struct {
Store *store.Store
secret string
}
// NewGRPCAuthInterceptor returns a new API auth interceptor.
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
return &GRPCAuthInterceptor{
Store: store,
secret: secret,
}
}
// AuthenticationInterceptor is the unary interceptor for gRPC API.
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
}
// Try to authenticate via session ID (from cookie) first
if sessionCookieValue, err := getSessionIDFromMetadata(md); err == nil && sessionCookieValue != "" {
user, err := in.authenticateBySession(ctx, sessionCookieValue)
if err == nil && user != nil {
// Extract just the sessionID part for context storage
_, sessionID, parseErr := ParseSessionCookieValue(sessionCookieValue)
if parseErr != nil {
return nil, status.Errorf(codes.Internal, "failed to parse session cookie: %v", parseErr)
}
return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, sessionID, "")
}
}
// Try to authenticate via JWT access token (from Authorization header)
if accessToken, err := getAccessTokenFromMetadata(md); err == nil && accessToken != "" {
user, err := in.authenticateByJWT(ctx, accessToken)
if err == nil && user != nil {
return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, "", accessToken)
}
}
// If no valid authentication found, check if this method is in the allowlist (public endpoints)
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
// If authentication is required but not found, reject the request
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
}
// handleAuthenticatedRequest processes an authenticated request with the given user and auth info.
func (in *GRPCAuthInterceptor) handleAuthenticatedRequest(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler, user *store.User, sessionID, accessToken string) (any, error) {
// Check user status
if user.RowStatus == store.Archived {
return nil, errors.Errorf("user %q is archived", user.Username)
}
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
return nil, errors.Errorf("user %q is not admin", user.Username)
}
// Set context values
ctx = context.WithValue(ctx, userIDContextKey, user.ID)
if sessionID != "" {
// Session-based authentication
ctx = context.WithValue(ctx, sessionIDContextKey, sessionID)
// Update session last accessed time
_ = in.updateSessionLastAccessed(ctx, user.ID, sessionID)
} else if accessToken != "" {
// JWT access token-based authentication
ctx = context.WithValue(ctx, accessTokenContextKey, accessToken)
}
return handler(ctx, request)
}
// authenticateByJWT authenticates a user using JWT access token from Authorization header.
func (in *GRPCAuthInterceptor) authenticateByJWT(ctx context.Context, accessToken string) (*store.User, error) {
if accessToken == "" {
return nil, status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(in.secret), nil
}
}
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
}
// Get user from JWT claims
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return nil, errors.Wrap(err, "malformed ID in the token")
}
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
if user == nil {
return nil, errors.Errorf("user %q not exists", userID)
}
if user.RowStatus == store.Archived {
return nil, errors.Errorf("user %q is archived", userID)
}
// Validate that this access token exists in the user's access tokens
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return nil, errors.Wrapf(err, "failed to get user access tokens")
}
if !validateAccessToken(accessToken, accessTokens) {
return nil, status.Errorf(codes.Unauthenticated, "invalid access token")
}
return user, nil
}
// authenticateBySession authenticates a user using session ID from cookie.
func (in *GRPCAuthInterceptor) authenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) {
if sessionCookieValue == "" {
return nil, status.Errorf(codes.Unauthenticated, "session cookie value not found")
}
// Parse the cookie value to extract userID and sessionID
userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid session cookie format: %v", err)
}
// Get the user directly using the userID from the cookie
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.Unauthenticated, "user is archived")
}
// Get user sessions and validate the sessionID
sessions, err := in.Store.GetUserSessions(ctx, userID)
if err != nil {
return nil, errors.Wrap(err, "failed to get user sessions")
}
if !validateUserSession(sessionID, sessions) {
return nil, status.Errorf(codes.Unauthenticated, "invalid or expired session")
}
return user, nil
}
// updateSessionLastAccessed updates the last accessed time for a user session.
func (in *GRPCAuthInterceptor) updateSessionLastAccessed(ctx context.Context, userID int32, sessionID string) error {
return in.Store.UpdateUserSessionLastAccessed(ctx, userID, sessionID, timestamppb.Now())
}
// validateUserSession checks if a session exists and is still valid using sliding expiration.
func validateUserSession(sessionID string, userSessions []*storepb.SessionsUserSetting_Session) bool {
for _, session := range userSessions {
if sessionID == session.SessionId {
// Use sliding expiration: check if last_accessed_time + 2 weeks > current_time
if session.LastAccessedTime != nil {
expirationTime := session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
if expirationTime.Before(time.Now()) {
return false
}
}
return true
}
}
return false
}
// getSessionIDFromMetadata extracts session cookie value from cookie.
func getSessionIDFromMetadata(md metadata.MD) (string, error) {
// Check the cookie header for session cookie value
var sessionCookieValue string
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
header := http.Header{}
header.Add("Cookie", t)
request := http.Request{Header: header}
if v, _ := request.Cookie(SessionCookieName); v != nil {
sessionCookieValue = v.Value
}
}
if sessionCookieValue == "" {
return "", errors.New("session cookie not found")
}
return sessionCookieValue, nil
}
// getAccessTokenFromMetadata extracts access token from Authorization header.
func getAccessTokenFromMetadata(md metadata.MD) (string, error) {
// Check the HTTP request Authorization header.
authorizationHeaders := md.Get("Authorization")
if len(authorizationHeaders) == 0 {
return "", errors.New("authorization header not found")
}
authHeaderParts := strings.Fields(authorizationHeaders[0])
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken {
return true
}
}
return false
}

View File

@@ -0,0 +1,34 @@
package v1
var authenticationAllowlistMethods = map[string]bool{
"/memos.api.v1.WorkspaceService/GetWorkspaceProfile": true,
"/memos.api.v1.WorkspaceService/GetWorkspaceSetting": true,
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
"/memos.api.v1.AuthService/CreateSession": true,
"/memos.api.v1.AuthService/GetCurrentSession": true,
"/memos.api.v1.UserService/CreateUser": true,
"/memos.api.v1.UserService/GetUser": true,
"/memos.api.v1.UserService/GetUserAvatar": true,
"/memos.api.v1.UserService/GetUserStats": true,
"/memos.api.v1.UserService/ListAllUserStats": true,
"/memos.api.v1.UserService/SearchUsers": true,
"/memos.api.v1.MemoService/GetMemo": true,
"/memos.api.v1.MemoService/ListMemos": true,
"/memos.api.v1.MarkdownService/GetLinkMetadata": true,
"/memos.api.v1.AttachmentService/GetAttachmentBinary": true,
}
// isUnauthorizeAllowedMethod returns whether the method is exempted from authentication.
func isUnauthorizeAllowedMethod(fullMethodName string) bool {
return authenticationAllowlistMethods[fullMethodName]
}
var allowedMethodsOnlyForAdmin = map[string]bool{
"/memos.api.v1.UserService/CreateUser": true,
"/memos.api.v1.WorkspaceService/UpdateWorkspaceSetting": true,
}
// isOnlyForAdminAllowedMethod returns true if the method is allowed to be called only by admin.
func isOnlyForAdminAllowedMethod(methodName string) bool {
return allowedMethodsOnlyForAdmin[methodName]
}

View File

@@ -0,0 +1,126 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListActivities(ctx context.Context, request *v1pb.ListActivitiesRequest) (*v1pb.ListActivitiesResponse, error) {
// Set default page size if not specified
pageSize := request.PageSize
if pageSize <= 0 || pageSize > 1000 {
pageSize = 100
}
// TODO: Implement pagination with page_token and use pageSize for limiting
// For now, we'll fetch all activities and the pageSize will be used in future pagination implementation
_ = pageSize // Acknowledge pageSize variable to avoid linter warning
activities, err := s.Store.ListActivities(ctx, &store.FindActivity{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list activities: %v", err)
}
var activityMessages []*v1pb.Activity
for _, activity := range activities {
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
}
activityMessages = append(activityMessages, activityMessage)
}
return &v1pb.ListActivitiesResponse{
Activities: activityMessages,
// TODO: Implement next_page_token for pagination
}, nil
}
func (s *APIV1Service) GetActivity(ctx context.Context, request *v1pb.GetActivityRequest) (*v1pb.Activity, error) {
activityID, err := ExtractActivityIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid activity name: %v", err)
}
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
ID: &activityID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
}
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
}
return activityMessage, nil
}
func (s *APIV1Service) convertActivityFromStore(ctx context.Context, activity *store.Activity) (*v1pb.Activity, error) {
payload, err := s.convertActivityPayloadFromStore(ctx, activity.Payload)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert activity payload from store: %v", err)
}
// Convert store activity type to proto enum
var activityType v1pb.Activity_Type
switch activity.Type {
case store.ActivityTypeMemoComment:
activityType = v1pb.Activity_MEMO_COMMENT
default:
activityType = v1pb.Activity_TYPE_UNSPECIFIED
}
// Convert store activity level to proto enum
var activityLevel v1pb.Activity_Level
switch activity.Level {
case store.ActivityLevelInfo:
activityLevel = v1pb.Activity_INFO
default:
activityLevel = v1pb.Activity_LEVEL_UNSPECIFIED
}
return &v1pb.Activity{
Name: fmt.Sprintf("%s%d", ActivityNamePrefix, activity.ID),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, activity.CreatorID),
Type: activityType,
Level: activityLevel,
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
Payload: payload,
}, nil
}
func (s *APIV1Service) convertActivityPayloadFromStore(ctx context.Context, payload *storepb.ActivityPayload) (*v1pb.ActivityPayload, error) {
v2Payload := &v1pb.ActivityPayload{}
if payload.MemoComment != nil {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &payload.MemoComment.MemoId,
ExcludeContent: true,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &payload.MemoComment.RelatedMemoId,
ExcludeContent: true,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo: %v", err)
}
v2Payload.Payload = &v1pb.ActivityPayload_MemoComment{
MemoComment: &v1pb.ActivityMemoCommentPayload{
Memo: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
RelatedMemo: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
},
}
}
return v2Payload, nil
}

View File

@@ -0,0 +1,673 @@
package v1
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/disintegration/imaging"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/storage/s3"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
// The upload memory buffer is 32 MiB.
// It should be kept low, so RAM usage doesn't get out of control.
// This is unrelated to maximum upload size limit, which is now set through system setting.
MaxUploadBufferSizeBytes = 32 << 20
MebiByte = 1024 * 1024
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
ThumbnailCacheFolder = ".thumbnail_cache"
)
var SupportedThumbnailMimeTypes = []string{
"image/png",
"image/jpeg",
}
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
// Validate required fields
if request.Attachment == nil {
return nil, status.Errorf(codes.InvalidArgument, "attachment is required")
}
if request.Attachment.Filename == "" {
return nil, status.Errorf(codes.InvalidArgument, "filename is required")
}
if request.Attachment.Type == "" {
return nil, status.Errorf(codes.InvalidArgument, "type is required")
}
// Use provided attachment_id or generate a new one
attachmentUID := request.AttachmentId
if attachmentUID == "" {
attachmentUID = shortuuid.New()
}
create := &store.Attachment{
UID: attachmentUID,
CreatorID: user.ID,
Filename: request.Attachment.Filename,
Type: request.Attachment.Type,
}
workspaceStorageSetting, err := s.Store.GetWorkspaceStorageSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace storage setting: %v", err)
}
size := binary.Size(request.Attachment.Content)
uploadSizeLimit := int(workspaceStorageSetting.UploadSizeLimitMb) * MebiByte
if uploadSizeLimit == 0 {
uploadSizeLimit = MaxUploadBufferSizeBytes
}
if size > uploadSizeLimit {
return nil, status.Errorf(codes.InvalidArgument, "file size exceeds the limit")
}
create.Size = int64(size)
create.Blob = request.Attachment.Content
if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil {
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
}
if request.Attachment.Memo != nil {
memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo)
}
create.MemoID = &memo.ID
}
attachment, err := s.Store.CreateAttachment(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err)
}
return s.convertAttachmentFromStore(ctx, attachment), nil
}
func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
// Set default page size
pageSize := int(request.PageSize)
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 1000 {
pageSize = 1000
}
// Parse page token for offset
offset := 0
if request.PageToken != "" {
// Simple implementation: page token is the offset as string
// In production, you might want to use encrypted tokens
if parsed, err := fmt.Sscanf(request.PageToken, "%d", &offset); err != nil || parsed != 1 {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token")
}
}
findAttachment := &store.FindAttachment{
CreatorID: &user.ID,
Limit: &pageSize,
Offset: &offset,
}
// Basic filter support for common cases
if request.Filter != "" {
// Simple filter parsing - can be enhanced later
// For now, support basic type filtering: "type=image/png"
if strings.HasPrefix(request.Filter, "type=") {
filterType := strings.TrimPrefix(request.Filter, "type=")
// Create a temporary struct to hold type filter
// Since FindAttachment doesn't have Type field, we'll apply this post-query
_ = filterType // We'll filter after getting results
}
}
attachments, err := s.Store.ListAttachments(ctx, findAttachment)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
}
// Apply type filter if specified
if request.Filter != "" && strings.HasPrefix(request.Filter, "type=") {
filterType := strings.TrimPrefix(request.Filter, "type=")
filteredAttachments := make([]*store.Attachment, 0)
for _, attachment := range attachments {
if attachment.Type == filterType {
filteredAttachments = append(filteredAttachments, attachment)
}
}
attachments = filteredAttachments
}
response := &v1pb.ListAttachmentsResponse{}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment))
}
// For simplicity, set total size to the number of returned attachments.
// In a full implementation, you'd want a separate count query
response.TotalSize = int32(len(response.Attachments))
// Set next page token if we got the full page size (indicating there might be more)
if len(attachments) == pageSize {
response.NextPageToken = fmt.Sprintf("%d", offset+pageSize)
}
return response, nil
}
func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttachmentRequest) (*v1pb.Attachment, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
return s.convertAttachmentFromStore(ctx, attachment), nil
}
func (s *APIV1Service) GetAttachmentBinary(ctx context.Context, request *v1pb.GetAttachmentBinaryRequest) (*httpbody.HttpBody, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
GetBlob: true,
UID: &attachmentUID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Check the related memo visibility.
if attachment.MemoID != nil {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: attachment.MemoID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find memo by ID: %v", attachment.MemoID)
}
if memo != nil && memo.Visibility != store.Public {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "unauthorized access")
}
if memo.Visibility == store.Private && user.ID != attachment.CreatorID {
return nil, status.Errorf(codes.Unauthenticated, "unauthorized access")
}
}
}
if request.Thumbnail && util.HasPrefixes(attachment.Type, SupportedThumbnailMimeTypes...) {
thumbnailBlob, err := s.getOrGenerateThumbnail(attachment)
if err != nil {
// thumbnail failures are logged as warnings and not cosidered critical failures as
// a attachment image can be used in its place.
slog.Warn("failed to get attachment thumbnail image", slog.Any("error", err))
} else {
return &httpbody.HttpBody{
ContentType: attachment.Type,
Data: thumbnailBlob,
}, nil
}
}
blob, err := s.GetAttachmentBlob(attachment)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment blob: %v", err)
}
contentType := attachment.Type
if strings.HasPrefix(contentType, "text/") {
contentType += "; charset=utf-8"
}
// Prevent XSS attacks by serving potentially unsafe files with a content type that prevents script execution.
if strings.EqualFold(contentType, "image/svg+xml") ||
strings.EqualFold(contentType, "text/html") ||
strings.EqualFold(contentType, "application/xhtml+xml") {
contentType = "application/octet-stream"
}
// Extract range header from gRPC metadata for iOS Safari video support
var rangeHeader string
if md, ok := metadata.FromIncomingContext(ctx); ok {
// Check for range header from gRPC-Gateway
if ranges := md.Get("grpcgateway-range"); len(ranges) > 0 {
rangeHeader = ranges[0]
} else if ranges := md.Get("range"); len(ranges) > 0 {
rangeHeader = ranges[0]
}
// Log for debugging iOS Safari issues
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
userAgent := userAgents[0]
if strings.Contains(strings.ToLower(userAgent), "safari") && rangeHeader != "" {
slog.Debug("Safari range request detected",
slog.String("range", rangeHeader),
slog.String("user-agent", userAgent),
slog.String("content-type", contentType))
}
}
}
// Handle range requests for video/audio streaming (iOS Safari requirement)
if rangeHeader != "" && (strings.HasPrefix(contentType, "video/") || strings.HasPrefix(contentType, "audio/")) {
return s.handleRangeRequest(ctx, blob, rangeHeader, contentType)
}
// Set headers for streaming support
if strings.HasPrefix(contentType, "video/") || strings.HasPrefix(contentType, "audio/") {
if err := setResponseHeaders(ctx, map[string]string{
"accept-ranges": "bytes",
"content-length": fmt.Sprintf("%d", len(blob)),
"cache-control": "public, max-age=3600", // 1 hour cache
}); err != nil {
slog.Warn("failed to set streaming headers", slog.Any("error", err))
}
}
return &httpbody.HttpBody{
ContentType: contentType,
Data: blob,
}, nil
}
func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.UpdateAttachmentRequest) (*v1pb.Attachment, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Attachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
currentTs := time.Now().Unix()
update := &store.UpdateAttachment{
ID: attachment.ID,
UpdatedTs: &currentTs,
}
for _, field := range request.UpdateMask.Paths {
if field == "filename" {
update.Filename = &request.Attachment.Filename
}
}
if err := s.Store.UpdateAttachment(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
}
return s.GetAttachment(ctx, &v1pb.GetAttachmentRequest{
Name: request.Attachment.Name,
})
}
func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.DeleteAttachmentRequest) (*emptypb.Empty, error) {
attachmentUID, err := ExtractAttachmentUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
UID: &attachmentUID,
CreatorID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Delete the attachment from the database.
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: attachment.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) convertAttachmentFromStore(ctx context.Context, attachment *store.Attachment) *v1pb.Attachment {
attachmentMessage := &v1pb.Attachment{
Name: fmt.Sprintf("%s%s", AttachmentNamePrefix, attachment.UID),
CreateTime: timestamppb.New(time.Unix(attachment.CreatedTs, 0)),
Filename: attachment.Filename,
Type: attachment.Type,
Size: attachment.Size,
}
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
attachmentMessage.ExternalLink = attachment.Reference
}
if attachment.MemoID != nil {
memo, _ := s.Store.GetMemo(ctx, &store.FindMemo{
ID: attachment.MemoID,
})
if memo != nil {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
attachmentMessage.Memo = &memoName
}
}
return attachmentMessage
}
// SaveAttachmentBlob save the blob of attachment based on the storage config.
func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Attachment) error {
workspaceStorageSetting, err := stores.GetWorkspaceStorageSetting(ctx)
if err != nil {
return errors.Wrap(err, "Failed to find workspace storage setting")
}
if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_LOCAL {
filepathTemplate := "assets/{timestamp}_{filename}"
if workspaceStorageSetting.FilepathTemplate != "" {
filepathTemplate = workspaceStorageSetting.FilepathTemplate
}
internalPath := filepathTemplate
if !strings.Contains(internalPath, "{filename}") {
internalPath = filepath.Join(internalPath, "{filename}")
}
internalPath = replaceFilenameWithPathTemplate(internalPath, create.Filename)
internalPath = filepath.ToSlash(internalPath)
// Ensure the directory exists.
osPath := filepath.FromSlash(internalPath)
if !filepath.IsAbs(osPath) {
osPath = filepath.Join(profile.Data, osPath)
}
dir := filepath.Dir(osPath)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return errors.Wrap(err, "Failed to create directory")
}
dst, err := os.Create(osPath)
if err != nil {
return errors.Wrap(err, "Failed to create file")
}
defer dst.Close()
// Write the blob to the file.
if err := os.WriteFile(osPath, create.Blob, 0644); err != nil {
return errors.Wrap(err, "Failed to write file")
}
create.Reference = internalPath
create.Blob = nil
create.StorageType = storepb.AttachmentStorageType_LOCAL
} else if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_S3 {
s3Config := workspaceStorageSetting.S3Config
if s3Config == nil {
return errors.Errorf("No actived external storage found")
}
s3Client, err := s3.NewClient(ctx, s3Config)
if err != nil {
return errors.Wrap(err, "Failed to create s3 client")
}
filepathTemplate := workspaceStorageSetting.FilepathTemplate
if !strings.Contains(filepathTemplate, "{filename}") {
filepathTemplate = filepath.Join(filepathTemplate, "{filename}")
}
filepathTemplate = replaceFilenameWithPathTemplate(filepathTemplate, create.Filename)
key, err := s3Client.UploadObject(ctx, filepathTemplate, create.Type, bytes.NewReader(create.Blob))
if err != nil {
return errors.Wrap(err, "Failed to upload via s3 client")
}
presignURL, err := s3Client.PresignGetObject(ctx, key)
if err != nil {
return errors.Wrap(err, "Failed to presign via s3 client")
}
create.Reference = presignURL
create.Blob = nil
create.StorageType = storepb.AttachmentStorageType_S3
create.Payload = &storepb.AttachmentPayload{
Payload: &storepb.AttachmentPayload_S3Object_{
S3Object: &storepb.AttachmentPayload_S3Object{
S3Config: s3Config,
Key: key,
LastPresignedTime: timestamppb.New(time.Now()),
},
},
}
}
return nil
}
func (s *APIV1Service) GetAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
// For local storage, read the file from the local disk.
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
attachmentPath := filepath.FromSlash(attachment.Reference)
if !filepath.IsAbs(attachmentPath) {
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
}
file, err := os.Open(attachmentPath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open the file")
}
defer file.Close()
blob, err := io.ReadAll(file)
if err != nil {
return nil, errors.Wrap(err, "failed to read the file")
}
return blob, nil
}
// For database storage, return the blob from the database.
return attachment.Blob, nil
}
const (
// thumbnailRatio is the ratio of the thumbnail image.
thumbnailRatio = 0.8
)
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
func (s *APIV1Service) getOrGenerateThumbnail(attachment *store.Attachment) ([]byte, error) {
thumbnailCacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
if err := os.MkdirAll(thumbnailCacheFolder, os.ModePerm); err != nil {
return nil, errors.Wrap(err, "failed to create thumbnail cache folder")
}
filePath := filepath.Join(thumbnailCacheFolder, fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename)))
if _, err := os.Stat(filePath); err != nil {
if !os.IsNotExist(err) {
return nil, errors.Wrap(err, "failed to check thumbnail image stat")
}
// If thumbnail image does not exist, generate and save the thumbnail image.
blob, err := s.GetAttachmentBlob(attachment)
if err != nil {
return nil, errors.Wrap(err, "failed to get attachment blob")
}
img, err := imaging.Decode(bytes.NewReader(blob), imaging.AutoOrientation(true))
if err != nil {
return nil, errors.Wrap(err, "failed to decode thumbnail image")
}
thumbnailWidth := int(float64(img.Bounds().Dx()) * thumbnailRatio)
// Resize the image to the thumbnailWidth.
thumbnailImage := imaging.Resize(img, thumbnailWidth, 0, imaging.Lanczos)
if err := imaging.Save(thumbnailImage, filePath); err != nil {
return nil, errors.Wrap(err, "failed to save thumbnail file")
}
}
thumbnailFile, err := os.Open(filePath)
if err != nil {
return nil, errors.Wrap(err, "failed to open thumbnail file")
}
defer thumbnailFile.Close()
blob, err := io.ReadAll(thumbnailFile)
if err != nil {
return nil, errors.Wrap(err, "failed to read thumbnail file")
}
return blob, nil
}
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
func replaceFilenameWithPathTemplate(path, filename string) string {
t := time.Now()
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
switch s {
case "{filename}":
return filename
case "{timestamp}":
return fmt.Sprintf("%d", t.Unix())
case "{year}":
return fmt.Sprintf("%d", t.Year())
case "{month}":
return fmt.Sprintf("%02d", t.Month())
case "{day}":
return fmt.Sprintf("%02d", t.Day())
case "{hour}":
return fmt.Sprintf("%02d", t.Hour())
case "{minute}":
return fmt.Sprintf("%02d", t.Minute())
case "{second}":
return fmt.Sprintf("%02d", t.Second())
case "{uuid}":
return util.GenUUID()
}
return s
})
return path
}
// handleRangeRequest handles HTTP range requests for video/audio streaming (iOS Safari requirement).
func (*APIV1Service) handleRangeRequest(ctx context.Context, data []byte, rangeHeader, contentType string) (*httpbody.HttpBody, error) {
// Parse "bytes=start-end"
if !strings.HasPrefix(rangeHeader, "bytes=") {
return nil, status.Errorf(codes.InvalidArgument, "invalid range header format")
}
rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=")
parts := strings.Split(rangeSpec, "-")
if len(parts) != 2 {
return nil, status.Errorf(codes.InvalidArgument, "invalid range specification")
}
fileSize := int64(len(data))
start, end := int64(0), fileSize-1
// Parse start position
if parts[0] != "" {
if s, err := strconv.ParseInt(parts[0], 10, 64); err == nil {
start = s
} else {
return nil, status.Errorf(codes.InvalidArgument, "invalid range start: %s", parts[0])
}
}
// Parse end position
if parts[1] != "" {
if e, err := strconv.ParseInt(parts[1], 10, 64); err == nil {
end = e
} else {
return nil, status.Errorf(codes.InvalidArgument, "invalid range end: %s", parts[1])
}
}
// Validate range
if start < 0 || end >= fileSize || start > end {
// Set Content-Range header for 416 response
if err := setResponseHeaders(ctx, map[string]string{
"content-range": fmt.Sprintf("bytes */%d", fileSize),
}); err != nil {
slog.Warn("failed to set content-range header", slog.Any("error", err))
}
return nil, status.Errorf(codes.OutOfRange, "requested range not satisfiable")
}
// Set partial content headers (HTTP 206)
if err := setResponseHeaders(ctx, map[string]string{
"accept-ranges": "bytes",
"content-range": fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize),
"content-length": fmt.Sprintf("%d", end-start+1),
"cache-control": "public, max-age=3600",
}); err != nil {
slog.Warn("failed to set partial content headers", slog.Any("error", err))
}
// Extract the requested range
rangeData := data[start : end+1]
slog.Debug("serving partial content",
slog.Int64("start", start),
slog.Int64("end", end),
slog.Int64("total", fileSize),
slog.Int("chunk_size", len(rangeData)))
return &httpbody.HttpBody{
ContentType: contentType,
Data: rangeData,
}, nil
}
// setResponseHeaders is a helper function to set gRPC response headers.
func setResponseHeaders(ctx context.Context, headers map[string]string) error {
pairs := make([]string, 0, len(headers)*2)
for key, value := range headers {
pairs = append(pairs, key, value)
}
return grpc.SetHeader(ctx, metadata.Pairs(pairs...))
}

View File

@@ -0,0 +1,91 @@
package v1
import (
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
const (
// issuer is the issuer of the jwt token.
Issuer = "memos"
// Signing key section. For now, this is only used for signing, not for verifying since we only
// have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism.
KeyID = "v1"
// AccessTokenAudienceName is the audience name of the access token.
AccessTokenAudienceName = "user.access-token"
// SessionSlidingDuration is the sliding expiration duration for user sessions (2 weeks).
// Sessions are considered valid if last_accessed_time + SessionSlidingDuration > current_time.
SessionSlidingDuration = 14 * 24 * time.Hour
// SessionCookieName is the cookie name of user session ID.
SessionCookieName = "user_session"
)
type ClaimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
// GenerateAccessToken generates an access token.
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, secret)
}
// generateToken generates a jwt token.
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
registeredClaims := jwt.RegisteredClaims{
Issuer: Issuer,
Audience: jwt.ClaimStrings{audience},
IssuedAt: jwt.NewNumericDate(time.Now()),
Subject: fmt.Sprint(userID),
}
if !expirationTime.IsZero() {
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
Name: username,
RegisteredClaims: registeredClaims,
})
token.Header["kid"] = KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}
// GenerateSessionID generates a unique session ID using UUIDv4.
func GenerateSessionID() (string, error) {
return util.GenUUID(), nil
}
// BuildSessionCookieValue builds the session cookie value in format {userID}-{sessionID}.
func BuildSessionCookieValue(userID int32, sessionID string) string {
return fmt.Sprintf("%d-%s", userID, sessionID)
}
// ParseSessionCookieValue parses the session cookie value to extract userID and sessionID.
func ParseSessionCookieValue(cookieValue string) (int32, string, error) {
parts := strings.SplitN(cookieValue, "-", 2)
if len(parts) != 2 {
return 0, "", errors.New("invalid session cookie format")
}
userID, err := util.ConvertStringToInt32(parts[0])
if err != nil {
return 0, "", errors.Errorf("invalid user ID in session cookie: %v", err)
}
return userID, parts[1], nil
}

View File

@@ -0,0 +1,502 @@
package v1
import (
"context"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
unmatchedUsernameAndPasswordError = "unmatched username and password"
)
func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrentSessionRequest) (*v1pb.GetCurrentSessionResponse, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
if user == nil {
// Clear auth cookies
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies: %v", err)
}
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
var lastAccessedAt *timestamppb.Timestamp
// Update session last accessed time if we have a session ID and get the current session info
if sessionID, ok := ctx.Value(sessionIDContextKey).(string); ok && sessionID != "" {
now := timestamppb.Now()
if err := s.Store.UpdateUserSessionLastAccessed(ctx, user.ID, sessionID, now); err != nil {
// Log error but don't fail the request
slog.Error("failed to update session last accessed time", "error", err)
}
lastAccessedAt = now
}
return &v1pb.GetCurrentSessionResponse{
User: convertUserFromStore(user),
LastAccessedAt: lastAccessedAt,
}, nil
}
func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSessionRequest) (*v1pb.CreateSessionResponse, error) {
var existingUser *store.User
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &passwordCredentials.Username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
}
// Check if the password auth in is allowed.
if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
}
existingUser = user
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &ssoCredentials.IdpId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
// Check if the user is allowed to sign up.
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
}
if workspaceGeneralSetting.DisallowUserRegistration {
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
}
// Create a new user with the user info from the identity provider.
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
AvatarURL: userInfo.AvatarURL,
}
password, err := util.RandomString(20)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
}
}
existingUser = user
}
if existingUser == nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid credentials")
}
if existingUser.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username)
}
// Default session expiration time is 100 year
expireTime := time.Now().Add(100 * 365 * 24 * time.Hour)
if err := s.doSignIn(ctx, existingUser, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
}
return &v1pb.CreateSessionResponse{
User: convertUserFromStore(existingUser),
LastAccessedAt: timestamppb.Now(),
}, nil
}
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
// Generate unique session ID for web use
sessionID, err := GenerateSessionID()
if err != nil {
return status.Errorf(codes.Internal, "failed to generate session ID, error: %v", err)
}
// Track session in user settings
if err := s.trackUserSession(ctx, user.ID, sessionID); err != nil {
// Log the error but don't fail the login if session tracking fails
// This ensures backward compatibility
slog.Error("failed to track user session", "error", err)
}
// Set session cookie for web use (format: userID-sessionID)
sessionCookieValue := BuildSessionCookieValue(user.ID, sessionID)
sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime)
if err != nil {
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": sessionCookie,
})); err != nil {
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
return nil
}
func (s *APIV1Service) DeleteSession(ctx context.Context, _ *v1pb.DeleteSessionRequest) (*emptypb.Empty, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
// Check if we have a session ID (from cookie-based auth)
if sessionID, ok := ctx.Value(sessionIDContextKey).(string); ok && sessionID != "" {
// Remove session from user settings
if err := s.Store.RemoveUserSession(ctx, user.ID, sessionID); err != nil {
slog.Error("failed to remove user session", "error", err)
}
}
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies, error: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
// Clear session cookie
sessionCookie, err := s.buildSessionCookie(ctx, "", time.Time{})
if err != nil {
return errors.Wrap(err, "failed to build session cookie")
}
// Set both cookies in the response
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": sessionCookie,
})); err != nil {
return errors.Wrap(err, "failed to set grpc header")
}
return nil
}
func (*APIV1Service) buildSessionCookie(ctx context.Context, sessionCookieValue string, expireTime time.Time) (string, error) {
attrs := []string{
fmt.Sprintf("%s=%s", SessionCookieName, sessionCookieValue),
"Path=/",
"HttpOnly",
}
if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else {
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", errors.New("failed to get metadata from context")
}
var origin string
for _, v := range md.Get("origin") {
origin = v
}
isHTTPS := strings.HasPrefix(origin, "https://")
if isHTTPS {
attrs = append(attrs, "SameSite=None")
attrs = append(attrs, "Secure")
} else {
attrs = append(attrs, "SameSite=Strict")
}
return strings.Join(attrs, "; "), nil
}
func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) {
userID, ok := ctx.Value(userIDContextKey).(int32)
if !ok {
return nil, nil
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.Errorf("user %d not found", userID)
}
return user, nil
}
// Helper function to track user session for session management.
func (s *APIV1Service) trackUserSession(ctx context.Context, userID int32, sessionID string) error {
// Extract client information from the context
clientInfo := s.extractClientInfo(ctx)
session := &storepb.SessionsUserSetting_Session{
SessionId: sessionID,
CreateTime: timestamppb.Now(),
LastAccessedTime: timestamppb.Now(),
ClientInfo: clientInfo,
}
return s.Store.AddUserSession(ctx, userID, session)
}
// Helper function to extract client information from the gRPC context.
// extractClientInfo extracts comprehensive client information from the request context.
// This includes user agent parsing to determine device type, operating system, browser,
// and IP address extraction. This information is used to provide detailed session
// tracking and management capabilities in the web UI.
//
// Fields populated:
// - UserAgent: Raw user agent string
// - IpAddress: Client IP (from X-Forwarded-For or X-Real-IP headers)
// - DeviceType: "mobile", "tablet", or "desktop"
// - Os: Operating system name and version (e.g., "iOS 17.1", "Windows 10/11")
// - Browser: Browser name and version (e.g., "Chrome 120.0.0.0")
// - Country: Geographic location (TODO: implement with GeoIP service).
func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo {
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
// Extract user agent from metadata if available
if md, ok := metadata.FromIncomingContext(ctx); ok {
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
userAgent := userAgents[0]
clientInfo.UserAgent = userAgent
// Parse user agent to extract device type, OS, browser info
s.parseUserAgent(userAgent, clientInfo)
}
if forwardedFor := md.Get("x-forwarded-for"); len(forwardedFor) > 0 {
ipAddress := strings.Split(forwardedFor[0], ",")[0] // Get the first IP in case of multiple
ipAddress = strings.TrimSpace(ipAddress)
clientInfo.IpAddress = ipAddress
} else if realIP := md.Get("x-real-ip"); len(realIP) > 0 {
clientInfo.IpAddress = realIP[0]
}
}
return clientInfo
}
// parseUserAgent extracts device type, OS, and browser information from user agent string.
func (*APIV1Service) parseUserAgent(userAgent string, clientInfo *storepb.SessionsUserSetting_ClientInfo) {
if userAgent == "" {
return
}
userAgent = strings.ToLower(userAgent)
// Detect device type
if strings.Contains(userAgent, "ipad") {
clientInfo.DeviceType = "tablet"
} else if strings.Contains(userAgent, "mobile") || strings.Contains(userAgent, "android") ||
strings.Contains(userAgent, "iphone") || strings.Contains(userAgent, "ipod") ||
strings.Contains(userAgent, "windows phone") || strings.Contains(userAgent, "blackberry") {
clientInfo.DeviceType = "mobile"
} else if strings.Contains(userAgent, "tablet") {
clientInfo.DeviceType = "tablet"
} else {
clientInfo.DeviceType = "desktop"
}
// Detect operating system
if strings.Contains(userAgent, "iphone os") || strings.Contains(userAgent, "cpu os") {
// Extract iOS version
if idx := strings.Index(userAgent, "cpu os "); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else if idx := strings.Index(userAgent, "iphone os "); idx != -1 {
versionStart := idx + 10
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else {
clientInfo.Os = "iOS"
}
} else if strings.Contains(userAgent, "android") {
// Extract Android version
if idx := strings.Index(userAgent, "android "); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Os = "Android " + version
} else {
clientInfo.Os = "Android"
}
} else {
clientInfo.Os = "Android"
}
} else if strings.Contains(userAgent, "windows nt 10.0") {
clientInfo.Os = "Windows 10/11"
} else if strings.Contains(userAgent, "windows nt 6.3") {
clientInfo.Os = "Windows 8.1"
} else if strings.Contains(userAgent, "windows nt 6.1") {
clientInfo.Os = "Windows 7"
} else if strings.Contains(userAgent, "windows") {
clientInfo.Os = "Windows"
} else if strings.Contains(userAgent, "mac os x") {
// Extract macOS version
if idx := strings.Index(userAgent, "mac os x "); idx != -1 {
versionStart := idx + 9
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := strings.ReplaceAll(userAgent[versionStart:versionStart+versionEnd], "_", ".")
clientInfo.Os = "macOS " + version
} else {
clientInfo.Os = "macOS"
}
} else {
clientInfo.Os = "macOS"
}
} else if strings.Contains(userAgent, "linux") {
clientInfo.Os = "Linux"
} else if strings.Contains(userAgent, "cros") {
clientInfo.Os = "Chrome OS"
}
// Detect browser
if strings.Contains(userAgent, "edg/") {
// Extract Edge version
if idx := strings.Index(userAgent, "edg/"); idx != -1 {
versionStart := idx + 4
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Edge " + version
} else {
clientInfo.Browser = "Edge"
}
} else if strings.Contains(userAgent, "chrome/") && !strings.Contains(userAgent, "edg") {
// Extract Chrome version
if idx := strings.Index(userAgent, "chrome/"); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Chrome " + version
} else {
clientInfo.Browser = "Chrome"
}
} else if strings.Contains(userAgent, "firefox/") {
// Extract Firefox version
if idx := strings.Index(userAgent, "firefox/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Firefox " + version
} else {
clientInfo.Browser = "Firefox"
}
} else if strings.Contains(userAgent, "safari/") && !strings.Contains(userAgent, "chrome") && !strings.Contains(userAgent, "edg") {
// Extract Safari version
if idx := strings.Index(userAgent, "version/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Safari " + version
} else {
clientInfo.Browser = "Safari"
}
} else if strings.Contains(userAgent, "opera/") || strings.Contains(userAgent, "opr/") {
clientInfo.Browser = "Opera"
}
}

View File

@@ -0,0 +1,179 @@
package v1
import (
"context"
"testing"
"google.golang.org/grpc/metadata"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestParseUserAgent(t *testing.T) {
service := &APIV1Service{}
tests := []struct {
name string
userAgent string
expectedDevice string
expectedOS string
expectedBrowser string
}{
{
name: "Chrome on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on macOS",
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15",
expectedDevice: "desktop",
expectedOS: "macOS 10.15.7",
expectedBrowser: "Safari 17.0",
},
{
name: "Chrome on Android Mobile",
userAgent: "Mozilla/5.0 (Linux; Android 13; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Mobile Safari/537.36",
expectedDevice: "mobile",
expectedOS: "Android 13",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on iPhone",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "mobile",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
{
name: "Firefox on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/119.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Firefox 119.0",
},
{
name: "Edge on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Edge 119.0.0.0",
},
{
name: "iPad Safari",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "tablet",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
service.parseUserAgent(tt.userAgent, clientInfo)
if clientInfo.DeviceType != tt.expectedDevice {
t.Errorf("Expected device type %s, got %s", tt.expectedDevice, clientInfo.DeviceType)
}
if clientInfo.Os != tt.expectedOS {
t.Errorf("Expected OS %s, got %s", tt.expectedOS, clientInfo.Os)
}
if clientInfo.Browser != tt.expectedBrowser {
t.Errorf("Expected browser %s, got %s", tt.expectedBrowser, clientInfo.Browser)
}
})
}
}
func TestExtractClientInfo(t *testing.T) {
service := &APIV1Service{}
// Test with metadata containing user agent and IP
md := metadata.New(map[string]string{
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
"x-forwarded-for": "203.0.113.1, 198.51.100.1",
"x-real-ip": "203.0.113.1",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
clientInfo := service.extractClientInfo(ctx)
if clientInfo.UserAgent == "" {
t.Error("Expected user agent to be set")
}
if clientInfo.IpAddress != "203.0.113.1" {
t.Errorf("Expected IP address to be 203.0.113.1, got %s", clientInfo.IpAddress)
}
if clientInfo.DeviceType != "desktop" {
t.Errorf("Expected device type to be desktop, got %s", clientInfo.DeviceType)
}
if clientInfo.Os != "Windows 10/11" {
t.Errorf("Expected OS to be Windows 10/11, got %s", clientInfo.Os)
}
if clientInfo.Browser != "Chrome 119.0.0.0" {
t.Errorf("Expected browser to be Chrome 119.0.0.0, got %s", clientInfo.Browser)
}
}
// TestClientInfoExamples demonstrates the enhanced client info extraction with various user agents.
func TestClientInfoExamples(t *testing.T) {
service := &APIV1Service{}
examples := []struct {
description string
userAgent string
}{
{
description: "Modern Chrome on Windows 11",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
},
{
description: "Safari on iPhone 15 Pro",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
{
description: "Chrome on Samsung Galaxy",
userAgent: "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36",
},
{
description: "Firefox on Ubuntu",
userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/120.0",
},
{
description: "Edge on Windows 10",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
},
{
description: "Safari on iPad Air",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
}
for _, example := range examples {
t.Run(example.description, func(t *testing.T) {
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
service.parseUserAgent(example.userAgent, clientInfo)
t.Logf("User Agent: %s", example.userAgent)
t.Logf("Device Type: %s", clientInfo.DeviceType)
t.Logf("Operating System: %s", clientInfo.Os)
t.Logf("Browser: %s", clientInfo.Browser)
t.Logf("---")
// Ensure all fields are populated
if clientInfo.DeviceType == "" {
t.Error("Device type should not be empty")
}
if clientInfo.Os == "" {
t.Error("OS should not be empty")
}
if clientInfo.Browser == "" {
t.Error("Browser should not be empty")
}
})
}
}

View File

@@ -0,0 +1,70 @@
package v1
import (
"encoding/base64"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
const (
// DefaultPageSize is the default page size for requests.
DefaultPageSize = 10
// MaxPageSize is the maximum page size for requests.
MaxPageSize = 1000
)
func convertStateFromStore(rowStatus store.RowStatus) v1pb.State {
switch rowStatus {
case store.Normal:
return v1pb.State_NORMAL
case store.Archived:
return v1pb.State_ARCHIVED
default:
return v1pb.State_STATE_UNSPECIFIED
}
}
func convertStateToStore(state v1pb.State) store.RowStatus {
switch state {
case v1pb.State_NORMAL:
return store.Normal
case v1pb.State_ARCHIVED:
return store.Archived
default:
return store.Normal
}
}
func getPageToken(limit int, offset int) (string, error) {
return marshalPageToken(&v1pb.PageToken{
Limit: int32(limit),
Offset: int32(offset),
})
}
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
b, err := proto.Marshal(pageToken)
if err != nil {
return "", errors.Wrapf(err, "failed to marshal page token")
}
return base64.StdEncoding.EncodeToString(b), nil
}
func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return errors.Wrapf(err, "failed to decode page token")
}
if err := proto.Unmarshal(b, pageToken); err != nil {
return errors.Wrapf(err, "failed to unmarshal page token")
}
return nil
}
func isSuperUser(user *store.User) bool {
return user.Role == store.RoleAdmin || user.Role == store.RoleHost
}

View File

@@ -0,0 +1,21 @@
package v1
import (
"context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) Check(ctx context.Context,
_ *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
history, err := s.Store.GetDriver().FindMigrationHistoryList(ctx, &store.FindMigrationHistory{})
if err != nil || len(history) == 0 {
return nil, status.Errorf(codes.Unavailable, "not available")
}
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil
}

View File

@@ -0,0 +1,183 @@
package v1
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil || currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
}
response := &v1pb.ListIdentityProvidersResponse{
IdentityProviders: []*v1pb.IdentityProvider{},
}
for _, identityProvider := range identityProviders {
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
}
return response, nil
}
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
update := &store.UpdateIdentityProviderV1{
ID: id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
}
for _, field := range request.UpdateMask.Paths {
switch field {
case "title":
update.Name = &request.IdentityProvider.Title
case "identifier_filter":
update.IdentifierFilter = &request.IdentityProvider.IdentifierFilter
case "config":
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
}
}
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
// Check if the identity provider exists before trying to delete it
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
}
return &emptypb.Empty{}, nil
}
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
temp := &v1pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
}
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2Config := identityProvider.Config.GetOauth2Config()
temp.Config = &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &v1pb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
},
},
},
}
}
return temp
}
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
}
return temp
}
func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
oauth2Config := config.GetOauth2Config()
return &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &storepb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
AvatarUrl: oauth2Config.FieldMapping.AvatarUrl,
},
},
},
}
}
return nil
}

View File

@@ -0,0 +1,226 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListInboxes(ctx context.Context, request *v1pb.ListInboxesRequest) (*v1pb.ListInboxesResponse, error) {
// Extract user ID from parent resource name
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid parent name %q: %v", request.Parent, err)
}
// Get current user for authorization
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Check if current user can access the requested user's inboxes
if currentUser.ID != userID {
// Only allow hosts and admins to access other users' inboxes
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "cannot access inboxes for user %q", request.Parent)
}
}
var limit, offset int
if request.PageToken != "" {
var pageToken v1pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = int(pageToken.Limit)
offset = int(pageToken.Offset)
} else {
limit = int(request.PageSize)
}
if limit <= 0 {
limit = DefaultPageSize
}
if limit > MaxPageSize {
limit = MaxPageSize
}
limitPlusOne := limit + 1
findInbox := &store.FindInbox{
ReceiverID: &userID,
Limit: &limitPlusOne,
Offset: &offset,
}
inboxes, err := s.Store.ListInboxes(ctx, findInbox)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list inboxes: %v", err)
}
inboxMessages := []*v1pb.Inbox{}
nextPageToken := ""
if len(inboxes) == limitPlusOne {
inboxes = inboxes[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token: %v", err)
}
}
for _, inbox := range inboxes {
inboxMessage := convertInboxFromStore(inbox)
if inboxMessage.Type == v1pb.Inbox_TYPE_UNSPECIFIED {
continue
}
inboxMessages = append(inboxMessages, inboxMessage)
}
response := &v1pb.ListInboxesResponse{
Inboxes: inboxMessages,
NextPageToken: nextPageToken,
TotalSize: int32(len(inboxMessages)), // For now, use actual returned count
}
return response, nil
}
func (s *APIV1Service) UpdateInbox(ctx context.Context, request *v1pb.UpdateInboxRequest) (*v1pb.Inbox, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
inboxID, err := ExtractInboxIDFromName(request.Inbox.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name %q: %v", request.Inbox.Name, err)
}
// Get current user for authorization
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Get the existing inbox to verify ownership
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
ID: &inboxID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get inbox: %v", err)
}
if len(inboxes) == 0 {
return nil, status.Errorf(codes.NotFound, "inbox %q not found", request.Inbox.Name)
}
existingInbox := inboxes[0]
// Check if current user can update this inbox (must be the receiver)
if currentUser.ID != existingInbox.ReceiverID {
return nil, status.Errorf(codes.PermissionDenied, "cannot update inbox for another user")
}
update := &store.UpdateInbox{
ID: inboxID,
}
for _, field := range request.UpdateMask.Paths {
if field == "status" {
if request.Inbox.Status == v1pb.Inbox_STATUS_UNSPECIFIED {
return nil, status.Errorf(codes.InvalidArgument, "status cannot be unspecified")
}
update.Status = convertInboxStatusToStore(request.Inbox.Status)
} else {
return nil, status.Errorf(codes.InvalidArgument, "unsupported field in update mask: %q", field)
}
}
inbox, err := s.Store.UpdateInbox(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
}
return convertInboxFromStore(inbox), nil
}
func (s *APIV1Service) DeleteInbox(ctx context.Context, request *v1pb.DeleteInboxRequest) (*emptypb.Empty, error) {
inboxID, err := ExtractInboxIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name %q: %v", request.Name, err)
}
// Get current user for authorization
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Get the existing inbox to verify ownership
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
ID: &inboxID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get inbox: %v", err)
}
if len(inboxes) == 0 {
return nil, status.Errorf(codes.NotFound, "inbox %q not found", request.Name)
}
existingInbox := inboxes[0]
// Check if current user can delete this inbox (must be the receiver)
if currentUser.ID != existingInbox.ReceiverID {
return nil, status.Errorf(codes.PermissionDenied, "cannot delete inbox for another user")
}
if err := s.Store.DeleteInbox(ctx, &store.DeleteInbox{
ID: inboxID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete inbox: %v", err)
}
return &emptypb.Empty{}, nil
}
func convertInboxFromStore(inbox *store.Inbox) *v1pb.Inbox {
return &v1pb.Inbox{
Name: fmt.Sprintf("%s%d", InboxNamePrefix, inbox.ID),
Sender: fmt.Sprintf("%s%d", UserNamePrefix, inbox.SenderID),
Receiver: fmt.Sprintf("%s%d", UserNamePrefix, inbox.ReceiverID),
Status: convertInboxStatusFromStore(inbox.Status),
CreateTime: timestamppb.New(time.Unix(inbox.CreatedTs, 0)),
Type: v1pb.Inbox_Type(inbox.Message.Type),
ActivityId: inbox.Message.ActivityId,
}
}
func convertInboxStatusFromStore(status store.InboxStatus) v1pb.Inbox_Status {
switch status {
case store.UNREAD:
return v1pb.Inbox_UNREAD
case store.ARCHIVED:
return v1pb.Inbox_ARCHIVED
default:
return v1pb.Inbox_STATUS_UNSPECIFIED
}
}
func convertInboxStatusToStore(status v1pb.Inbox_Status) store.InboxStatus {
switch status {
case v1pb.Inbox_UNREAD:
return store.UNREAD
case v1pb.Inbox_ARCHIVED:
return store.ARCHIVED
default:
return store.UNREAD
}
}

View File

@@ -0,0 +1,48 @@
package v1
import (
"context"
"log/slog"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type LoggerInterceptor struct {
}
func NewLoggerInterceptor() *LoggerInterceptor {
return &LoggerInterceptor{}
}
func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
resp, err := handler(ctx, request)
in.loggerInterceptorDo(ctx, serverInfo.FullMethod, err)
return resp, err
}
func (*LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) {
st := status.Convert(err)
var logLevel slog.Level
var logMsg string
switch st.Code() {
case codes.OK:
logLevel = slog.LevelInfo
logMsg = "OK"
case codes.Unauthenticated, codes.OutOfRange, codes.PermissionDenied, codes.NotFound:
logLevel = slog.LevelInfo
logMsg = "client error"
case codes.Internal, codes.Unknown, codes.DataLoss, codes.Unavailable, codes.DeadlineExceeded:
logLevel = slog.LevelError
logMsg = "server error"
default:
logLevel = slog.LevelError
logMsg = "unknown error"
}
logAttrs := []slog.Attr{slog.String("method", fullMethod)}
if err != nil {
logAttrs = append(logAttrs, slog.String("error", err.Error()))
}
slog.LogAttrs(ctx, logLevel, logMsg, logAttrs...)
}

View File

@@ -0,0 +1,279 @@
package v1
import (
"context"
"github.com/pkg/errors"
"github.com/usememos/gomark/ast"
"github.com/usememos/gomark/parser"
"github.com/usememos/gomark/parser/tokenizer"
"github.com/usememos/gomark/renderer"
"github.com/usememos/gomark/restore"
"github.com/usememos/memos/plugin/httpgetter"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func (*APIV1Service) ParseMarkdown(_ context.Context, request *v1pb.ParseMarkdownRequest) (*v1pb.ParseMarkdownResponse, error) {
rawNodes, err := parser.Parse(tokenizer.Tokenize(request.Markdown))
if err != nil {
return nil, errors.Wrap(err, "failed to parse memo content")
}
nodes := convertFromASTNodes(rawNodes)
return &v1pb.ParseMarkdownResponse{
Nodes: nodes,
}, nil
}
func (*APIV1Service) RestoreMarkdownNodes(_ context.Context, request *v1pb.RestoreMarkdownNodesRequest) (*v1pb.RestoreMarkdownNodesResponse, error) {
markdown := restore.Restore(convertToASTNodes(request.Nodes))
return &v1pb.RestoreMarkdownNodesResponse{
Markdown: markdown,
}, nil
}
func (*APIV1Service) StringifyMarkdownNodes(_ context.Context, request *v1pb.StringifyMarkdownNodesRequest) (*v1pb.StringifyMarkdownNodesResponse, error) {
stringRenderer := renderer.NewStringRenderer()
plainText := stringRenderer.Render(convertToASTNodes(request.Nodes))
return &v1pb.StringifyMarkdownNodesResponse{
PlainText: plainText,
}, nil
}
func (*APIV1Service) GetLinkMetadata(_ context.Context, request *v1pb.GetLinkMetadataRequest) (*v1pb.LinkMetadata, error) {
htmlMeta, err := httpgetter.GetHTMLMeta(request.Link)
if err != nil {
return nil, err
}
return &v1pb.LinkMetadata{
Title: htmlMeta.Title,
Description: htmlMeta.Description,
Image: htmlMeta.Image,
}, nil
}
func convertFromASTNode(rawNode ast.Node) *v1pb.Node {
node := &v1pb.Node{
Type: v1pb.NodeType(v1pb.NodeType_value[string(rawNode.Type())]),
}
switch n := rawNode.(type) {
case *ast.LineBreak:
node.Node = &v1pb.Node_LineBreakNode{}
case *ast.Paragraph:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_ParagraphNode{ParagraphNode: &v1pb.ParagraphNode{Children: children}}
case *ast.CodeBlock:
node.Node = &v1pb.Node_CodeBlockNode{CodeBlockNode: &v1pb.CodeBlockNode{Language: n.Language, Content: n.Content}}
case *ast.Heading:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_HeadingNode{HeadingNode: &v1pb.HeadingNode{Level: int32(n.Level), Children: children}}
case *ast.HorizontalRule:
node.Node = &v1pb.Node_HorizontalRuleNode{HorizontalRuleNode: &v1pb.HorizontalRuleNode{Symbol: n.Symbol}}
case *ast.Blockquote:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_BlockquoteNode{BlockquoteNode: &v1pb.BlockquoteNode{Children: children}}
case *ast.List:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_ListNode{ListNode: &v1pb.ListNode{Kind: convertListKindFromASTNode(n.Kind), Indent: int32(n.Indent), Children: children}}
case *ast.OrderedListItem:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_OrderedListItemNode{OrderedListItemNode: &v1pb.OrderedListItemNode{Number: n.Number, Indent: int32(n.Indent), Children: children}}
case *ast.UnorderedListItem:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_UnorderedListItemNode{UnorderedListItemNode: &v1pb.UnorderedListItemNode{Symbol: n.Symbol, Indent: int32(n.Indent), Children: children}}
case *ast.TaskListItem:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_TaskListItemNode{TaskListItemNode: &v1pb.TaskListItemNode{Symbol: n.Symbol, Indent: int32(n.Indent), Complete: n.Complete, Children: children}}
case *ast.MathBlock:
node.Node = &v1pb.Node_MathBlockNode{MathBlockNode: &v1pb.MathBlockNode{Content: n.Content}}
case *ast.Table:
node.Node = &v1pb.Node_TableNode{TableNode: convertTableFromASTNode(n)}
case *ast.EmbeddedContent:
node.Node = &v1pb.Node_EmbeddedContentNode{EmbeddedContentNode: &v1pb.EmbeddedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
case *ast.Text:
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{Content: n.Content}}
case *ast.Bold:
node.Node = &v1pb.Node_BoldNode{BoldNode: &v1pb.BoldNode{Symbol: n.Symbol, Children: convertFromASTNodes(n.Children)}}
case *ast.Italic:
node.Node = &v1pb.Node_ItalicNode{ItalicNode: &v1pb.ItalicNode{Symbol: n.Symbol, Children: convertFromASTNodes(n.Children)}}
case *ast.BoldItalic:
node.Node = &v1pb.Node_BoldItalicNode{BoldItalicNode: &v1pb.BoldItalicNode{Symbol: n.Symbol, Content: n.Content}}
case *ast.Code:
node.Node = &v1pb.Node_CodeNode{CodeNode: &v1pb.CodeNode{Content: n.Content}}
case *ast.Image:
node.Node = &v1pb.Node_ImageNode{ImageNode: &v1pb.ImageNode{AltText: n.AltText, Url: n.URL}}
case *ast.Link:
node.Node = &v1pb.Node_LinkNode{LinkNode: &v1pb.LinkNode{Content: convertFromASTNodes(n.Content), Url: n.URL}}
case *ast.AutoLink:
node.Node = &v1pb.Node_AutoLinkNode{AutoLinkNode: &v1pb.AutoLinkNode{Url: n.URL, IsRawText: n.IsRawText}}
case *ast.Tag:
node.Node = &v1pb.Node_TagNode{TagNode: &v1pb.TagNode{Content: n.Content}}
case *ast.Strikethrough:
node.Node = &v1pb.Node_StrikethroughNode{StrikethroughNode: &v1pb.StrikethroughNode{Content: n.Content}}
case *ast.EscapingCharacter:
node.Node = &v1pb.Node_EscapingCharacterNode{EscapingCharacterNode: &v1pb.EscapingCharacterNode{Symbol: n.Symbol}}
case *ast.Math:
node.Node = &v1pb.Node_MathNode{MathNode: &v1pb.MathNode{Content: n.Content}}
case *ast.Highlight:
node.Node = &v1pb.Node_HighlightNode{HighlightNode: &v1pb.HighlightNode{Content: n.Content}}
case *ast.Subscript:
node.Node = &v1pb.Node_SubscriptNode{SubscriptNode: &v1pb.SubscriptNode{Content: n.Content}}
case *ast.Superscript:
node.Node = &v1pb.Node_SuperscriptNode{SuperscriptNode: &v1pb.SuperscriptNode{Content: n.Content}}
case *ast.ReferencedContent:
node.Node = &v1pb.Node_ReferencedContentNode{ReferencedContentNode: &v1pb.ReferencedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
case *ast.Spoiler:
node.Node = &v1pb.Node_SpoilerNode{SpoilerNode: &v1pb.SpoilerNode{Content: n.Content}}
case *ast.HTMLElement:
node.Node = &v1pb.Node_HtmlElementNode{HtmlElementNode: &v1pb.HTMLElementNode{TagName: n.TagName, Attributes: n.Attributes}}
default:
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{}}
}
return node
}
func convertFromASTNodes(rawNodes []ast.Node) []*v1pb.Node {
nodes := []*v1pb.Node{}
for _, rawNode := range rawNodes {
node := convertFromASTNode(rawNode)
nodes = append(nodes, node)
}
return nodes
}
func convertTableFromASTNode(node *ast.Table) *v1pb.TableNode {
table := &v1pb.TableNode{
Header: convertFromASTNodes(node.Header),
Delimiter: node.Delimiter,
}
for _, row := range node.Rows {
table.Rows = append(table.Rows, &v1pb.TableNode_Row{Cells: convertFromASTNodes(row)})
}
return table
}
func convertListKindFromASTNode(node ast.ListKind) v1pb.ListNode_Kind {
switch node {
case ast.OrderedList:
return v1pb.ListNode_ORDERED
case ast.UnorderedList:
return v1pb.ListNode_UNORDERED
case ast.DescrpitionList:
return v1pb.ListNode_DESCRIPTION
default:
return v1pb.ListNode_KIND_UNSPECIFIED
}
}
func convertToASTNode(node *v1pb.Node) ast.Node {
switch n := node.Node.(type) {
case *v1pb.Node_LineBreakNode:
return &ast.LineBreak{}
case *v1pb.Node_ParagraphNode:
children := convertToASTNodes(n.ParagraphNode.Children)
return &ast.Paragraph{Children: children}
case *v1pb.Node_CodeBlockNode:
return &ast.CodeBlock{Language: n.CodeBlockNode.Language, Content: n.CodeBlockNode.Content}
case *v1pb.Node_HeadingNode:
children := convertToASTNodes(n.HeadingNode.Children)
return &ast.Heading{Level: int(n.HeadingNode.Level), Children: children}
case *v1pb.Node_HorizontalRuleNode:
return &ast.HorizontalRule{Symbol: n.HorizontalRuleNode.Symbol}
case *v1pb.Node_BlockquoteNode:
children := convertToASTNodes(n.BlockquoteNode.Children)
return &ast.Blockquote{Children: children}
case *v1pb.Node_ListNode:
children := convertToASTNodes(n.ListNode.Children)
return &ast.List{Kind: convertListKindToASTNode(n.ListNode.Kind), Indent: int(n.ListNode.Indent), Children: children}
case *v1pb.Node_OrderedListItemNode:
children := convertToASTNodes(n.OrderedListItemNode.Children)
return &ast.OrderedListItem{Number: n.OrderedListItemNode.Number, Indent: int(n.OrderedListItemNode.Indent), Children: children}
case *v1pb.Node_UnorderedListItemNode:
children := convertToASTNodes(n.UnorderedListItemNode.Children)
return &ast.UnorderedListItem{Symbol: n.UnorderedListItemNode.Symbol, Indent: int(n.UnorderedListItemNode.Indent), Children: children}
case *v1pb.Node_TaskListItemNode:
children := convertToASTNodes(n.TaskListItemNode.Children)
return &ast.TaskListItem{Symbol: n.TaskListItemNode.Symbol, Indent: int(n.TaskListItemNode.Indent), Complete: n.TaskListItemNode.Complete, Children: children}
case *v1pb.Node_MathBlockNode:
return &ast.MathBlock{Content: n.MathBlockNode.Content}
case *v1pb.Node_TableNode:
return convertTableToASTNode(n.TableNode)
case *v1pb.Node_EmbeddedContentNode:
return &ast.EmbeddedContent{ResourceName: n.EmbeddedContentNode.ResourceName, Params: n.EmbeddedContentNode.Params}
case *v1pb.Node_TextNode:
return &ast.Text{Content: n.TextNode.Content}
case *v1pb.Node_BoldNode:
return &ast.Bold{Symbol: n.BoldNode.Symbol, Children: convertToASTNodes(n.BoldNode.Children)}
case *v1pb.Node_ItalicNode:
return &ast.Italic{Symbol: n.ItalicNode.Symbol, Children: convertToASTNodes(n.ItalicNode.Children)}
case *v1pb.Node_BoldItalicNode:
return &ast.BoldItalic{Symbol: n.BoldItalicNode.Symbol, Content: n.BoldItalicNode.Content}
case *v1pb.Node_CodeNode:
return &ast.Code{Content: n.CodeNode.Content}
case *v1pb.Node_ImageNode:
return &ast.Image{AltText: n.ImageNode.AltText, URL: n.ImageNode.Url}
case *v1pb.Node_LinkNode:
return &ast.Link{Content: convertToASTNodes(n.LinkNode.Content), URL: n.LinkNode.Url}
case *v1pb.Node_AutoLinkNode:
return &ast.AutoLink{URL: n.AutoLinkNode.Url, IsRawText: n.AutoLinkNode.IsRawText}
case *v1pb.Node_TagNode:
return &ast.Tag{Content: n.TagNode.Content}
case *v1pb.Node_StrikethroughNode:
return &ast.Strikethrough{Content: n.StrikethroughNode.Content}
case *v1pb.Node_EscapingCharacterNode:
return &ast.EscapingCharacter{Symbol: n.EscapingCharacterNode.Symbol}
case *v1pb.Node_MathNode:
return &ast.Math{Content: n.MathNode.Content}
case *v1pb.Node_HighlightNode:
return &ast.Highlight{Content: n.HighlightNode.Content}
case *v1pb.Node_SubscriptNode:
return &ast.Subscript{Content: n.SubscriptNode.Content}
case *v1pb.Node_SuperscriptNode:
return &ast.Superscript{Content: n.SuperscriptNode.Content}
case *v1pb.Node_ReferencedContentNode:
return &ast.ReferencedContent{ResourceName: n.ReferencedContentNode.ResourceName, Params: n.ReferencedContentNode.Params}
case *v1pb.Node_SpoilerNode:
return &ast.Spoiler{Content: n.SpoilerNode.Content}
case *v1pb.Node_HtmlElementNode:
return &ast.HTMLElement{TagName: n.HtmlElementNode.TagName, Attributes: n.HtmlElementNode.Attributes}
default:
return &ast.Text{}
}
}
func convertToASTNodes(nodes []*v1pb.Node) []ast.Node {
rawNodes := []ast.Node{}
for _, node := range nodes {
rawNode := convertToASTNode(node)
rawNodes = append(rawNodes, rawNode)
}
return rawNodes
}
func convertTableToASTNode(node *v1pb.TableNode) *ast.Table {
table := &ast.Table{
Header: convertToASTNodes(node.Header),
Delimiter: node.Delimiter,
}
for _, row := range node.Rows {
table.Rows = append(table.Rows, convertToASTNodes(row.Cells))
}
return table
}
func convertListKindToASTNode(kind v1pb.ListNode_Kind) ast.ListKind {
switch kind {
case v1pb.ListNode_ORDERED:
return ast.OrderedList
case v1pb.ListNode_UNORDERED:
return ast.UnorderedList
case v1pb.ListNode_DESCRIPTION:
return ast.DescrpitionList
default:
// Default to description list.
return ast.DescrpitionList
}
}

View File

@@ -0,0 +1,102 @@
package v1
import (
"context"
"slices"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
// Delete attachments that are not in the request.
for _, attachment := range attachments {
found := false
for _, requestAttachment := range request.Attachments {
requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
}
if attachment.UID == requestAttachmentUID {
found = true
break
}
}
if !found {
if err = s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{
ID: int32(attachment.ID),
MemoID: &memo.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
}
}
}
slices.Reverse(request.Attachments)
// Update attachments' memo_id in the request.
for index, attachment := range request.Attachments {
attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
}
tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
updatedTs := time.Now().Unix() + int64(index)
if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
ID: tempAttachment.ID,
MemoID: &memo.ID,
UpdatedTs: &updatedTs,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments: %v", err)
}
response := &v1pb.ListMemoAttachmentsResponse{
Attachments: []*v1pb.Attachment{},
}
for _, attachment := range attachments {
response.Attachments = append(response.Attachments, s.convertAttachmentFromStore(ctx, attachment))
}
return response, nil
}

View File

@@ -0,0 +1,427 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"time"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/runner/memopayload"
"github.com/usememos/memos/store"
)
// ExportFormat represents the format for export/import operations
type ExportFormat string
const (
FormatJSON ExportFormat = "json"
)
// ExportData represents the structure of exported data
type ExportData struct {
Version string `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Memos []ExportMemo `json:"memos"`
}
// ExportMemo represents a memo in the export format
type ExportMemo struct {
UID string `json:"uid"`
Content string `json:"content"`
Visibility string `json:"visibility"`
Pinned bool `json:"pinned"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DisplayTime *time.Time `json:"display_time,omitempty"`
Tags []string `json:"tags,omitempty"`
Location *ExportLocation `json:"location,omitempty"`
Attachments []ExportAttachment `json:"attachments,omitempty"`
Relations []ExportMemoRelation `json:"relations,omitempty"`
}
// ExportLocation represents location data in export format
type ExportLocation struct {
Placeholder string `json:"placeholder,omitempty"`
Latitude float64 `json:"latitude,omitempty"`
Longitude float64 `json:"longitude,omitempty"`
}
// ExportAttachment represents attachment data in export format
type ExportAttachment struct {
UID string `json:"uid"`
Filename string `json:"filename"`
Type string `json:"type"`
Size int64 `json:"size"`
}
// ExportMemoRelation represents memo relations in export format
type ExportMemoRelation struct {
RelatedMemoUID string `json:"related_memo_uid"`
Type string `json:"type"`
}
// ExportMemos exports memos for the current user in JSON format
func (s *APIV1Service) ExportMemos(ctx context.Context, request *v1pb.ExportMemosRequest) (*v1pb.ExportMemosResponse, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
// Validate format (default to JSON)
format := request.Format
if format == "" {
format = string(FormatJSON)
}
if format != string(FormatJSON) {
return nil, status.Errorf(codes.InvalidArgument, "unsupported export format: %s", format)
}
// Get all memos for the user
memoFind := &store.FindMemo{
CreatorID: &user.ID,
ExcludeComments: true,
}
// Apply filters if specified
if request.Filter != "" {
// Use existing filter validation from shortcut service
memoFind.Filter = &request.Filter
}
// Include archived memos if requested
if request.ExcludeArchived {
normalStatus := store.Normal
memoFind.RowStatus = &normalStatus
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
// Convert memos to export format
exportMemos := make([]ExportMemo, 0, len(memos))
for _, memo := range memos {
exportMemo, err := s.convertMemoToExport(ctx, memo, request.IncludeAttachments, request.IncludeRelations)
if err != nil {
slog.Warn("Failed to convert memo to export format", slog.Any("memo_id", memo.ID), slog.Any("error", err))
continue
}
exportMemos = append(exportMemos, *exportMemo)
}
// Create export data structure
exportData := &ExportData{
Version: "1.0",
ExportedAt: time.Now(),
Memos: exportMemos,
}
// Serialize to JSON
jsonData, err := json.MarshalIndent(exportData, "", " ")
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to marshal export data: %v", err)
}
return &v1pb.ExportMemosResponse{
Data: jsonData,
Format: format,
Filename: fmt.Sprintf("memos_export_%s.json", time.Now().Format("20060102_150405")),
MemoCount: int32(len(exportMemos)),
SizeBytes: int64(len(jsonData)),
}, nil
}
// ImportMemos imports memos from JSON data
func (s *APIV1Service) ImportMemos(ctx context.Context, request *v1pb.ImportMemosRequest) (*v1pb.ImportMemosResponse, error) {
startTime := time.Now()
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
// Validate format (default to JSON)
format := request.Format
if format == "" {
format = string(FormatJSON)
}
if format != string(FormatJSON) {
return nil, status.Errorf(codes.InvalidArgument, "unsupported import format: %s", format)
}
// Parse the JSON data
var importData ExportData
if err := json.Unmarshal(request.Data, &importData); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse import data: %v", err)
}
// Validate import data version
if importData.Version != "1.0" {
return nil, status.Errorf(codes.InvalidArgument, "unsupported import data version: %s", importData.Version)
}
var importedCount int32
var skippedCount int32
var createdCount int32
var updatedCount int32
var validationErrors int32
var attachmentsImported int32
var relationsImported int32
var errors []string
var warnings []string
// Import each memo
for _, exportMemo := range importData.Memos {
result, err := s.importSingleMemo(ctx, user.ID, &exportMemo, request)
if err != nil {
errorMsg := fmt.Sprintf("Failed to import memo %s: %v", exportMemo.UID, err)
errors = append(errors, errorMsg)
skippedCount++
if request.ValidateOnly {
validationErrors++
}
slog.Warn("Failed to import memo", slog.String("uid", exportMemo.UID), slog.Any("error", err))
continue
}
importedCount++
if result.Created {
createdCount++
} else {
updatedCount++
}
attachmentsImported += result.AttachmentsImported
relationsImported += result.RelationsImported
if len(result.Warnings) > 0 {
warnings = append(warnings, result.Warnings...)
}
}
duration := time.Since(startTime)
summary := &v1pb.ImportSummary{
TotalMemos: int32(len(importData.Memos)),
CreatedCount: createdCount,
UpdatedCount: updatedCount,
AttachmentsImported: attachmentsImported,
RelationsImported: relationsImported,
DurationMs: duration.Milliseconds(),
}
return &v1pb.ImportMemosResponse{
ImportedCount: importedCount,
SkippedCount: skippedCount,
ValidationErrors: validationErrors,
Errors: errors,
Warnings: warnings,
Summary: summary,
}, nil
}
// convertMemoToExport converts a store memo to export format
func (s *APIV1Service) convertMemoToExport(ctx context.Context, memo *store.Memo, includeAttachments, includeRelations bool) (*ExportMemo, error) {
exportMemo := &ExportMemo{
UID: memo.UID,
Content: memo.Content,
Visibility: memo.Visibility.String(),
Pinned: memo.Pinned,
CreatedAt: time.Unix(memo.CreatedTs, 0),
UpdatedAt: time.Unix(memo.UpdatedTs, 0),
}
// Extract tags from payload
if memo.Payload != nil && len(memo.Payload.Tags) > 0 {
exportMemo.Tags = memo.Payload.Tags
}
// Add location if present
if memo.Payload != nil && memo.Payload.Location != nil {
exportMemo.Location = &ExportLocation{
Placeholder: memo.Payload.Location.Placeholder,
Latitude: memo.Payload.Location.Latitude,
Longitude: memo.Payload.Location.Longitude,
}
}
// Add attachments if requested
if includeAttachments {
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
if err != nil {
return nil, errors.Wrap(err, "failed to list attachments")
}
for _, attachment := range attachments {
exportMemo.Attachments = append(exportMemo.Attachments, ExportAttachment{
UID: attachment.UID,
Filename: attachment.Filename,
Type: attachment.Type,
Size: attachment.Size,
})
}
}
// Add relations if requested
if includeRelations {
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{MemoID: &memo.ID})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo relations")
}
for _, relation := range relations {
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &relation.RelatedMemoID})
if err != nil || relatedMemo == nil {
continue // Skip if related memo not found
}
exportMemo.Relations = append(exportMemo.Relations, ExportMemoRelation{
RelatedMemoUID: relatedMemo.UID,
Type: string(relation.Type),
})
}
}
return exportMemo, nil
}
// ImportResult represents the result of importing a single memo
type ImportResult struct {
Created bool
AttachmentsImported int32
RelationsImported int32
Warnings []string
}
// importSingleMemo imports a single memo
func (s *APIV1Service) importSingleMemo(ctx context.Context, userID int32, exportMemo *ExportMemo, request *v1pb.ImportMemosRequest) (*ImportResult, error) {
result := &ImportResult{
Warnings: []string{},
}
// Check if memo with this UID already exists
existingMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &exportMemo.UID})
if err != nil {
return nil, errors.Wrap(err, "failed to check for existing memo")
}
if existingMemo != nil && !request.OverwriteExisting {
return nil, fmt.Errorf("memo with UID %s already exists", exportMemo.UID)
}
// Validate memo content length
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get content length limit")
}
if len(exportMemo.Content) > contentLengthLimit {
return nil, fmt.Errorf("content too long (max %d characters)", contentLengthLimit)
}
// Parse visibility
visibility := store.Private
switch exportMemo.Visibility {
case "PUBLIC":
visibility = store.Public
case "PROTECTED":
visibility = store.Protected
case "PRIVATE":
visibility = store.Private
default:
result.Warnings = append(result.Warnings, fmt.Sprintf("Unknown visibility %s for memo %s, defaulting to PRIVATE", exportMemo.Visibility, exportMemo.UID))
}
// Create memo payload
payload := &storepb.MemoPayload{
Tags: exportMemo.Tags,
}
if exportMemo.Location != nil {
payload.Location = &storepb.MemoPayload_Location{
Placeholder: exportMemo.Location.Placeholder,
Latitude: exportMemo.Location.Latitude,
Longitude: exportMemo.Location.Longitude,
}
}
// Set timestamps
createdTs := exportMemo.CreatedAt.Unix()
updatedTs := exportMemo.UpdatedAt.Unix()
if !request.PreserveTimestamps {
now := time.Now().Unix()
createdTs = now
updatedTs = now
}
if request.ValidateOnly {
// Just validate, don't actually create/update
return result, nil
}
if existingMemo != nil {
// Update existing memo
update := &store.UpdateMemo{
ID: existingMemo.ID,
Content: &exportMemo.Content,
Visibility: &visibility,
Pinned: &exportMemo.Pinned,
Payload: payload,
}
if request.PreserveTimestamps {
update.CreatedTs = &createdTs
update.UpdatedTs = &updatedTs
}
if err := s.Store.UpdateMemo(ctx, update); err != nil {
return nil, errors.Wrap(err, "failed to update existing memo")
}
result.Created = false
} else {
// Create new memo
create := &store.Memo{
UID: exportMemo.UID,
CreatorID: userID,
CreatedTs: createdTs,
UpdatedTs: updatedTs,
Content: exportMemo.Content,
Visibility: visibility,
Pinned: exportMemo.Pinned,
Payload: payload,
}
// Rebuild memo payload to extract tags and other metadata
if err := memopayload.RebuildMemoPayload(create); err != nil {
return nil, errors.Wrap(err, "failed to rebuild memo payload")
}
_, err := s.Store.CreateMemo(ctx, create)
if err != nil {
return nil, errors.Wrap(err, "failed to create memo")
}
result.Created = true
}
// Import attachments if not skipped
if !request.SkipAttachments && len(exportMemo.Attachments) > 0 {
result.Warnings = append(result.Warnings, fmt.Sprintf("Attachments for memo %s were skipped (attachment import not yet implemented)", exportMemo.UID))
// TODO: Implement attachment import
// This would require handling file uploads and storage
}
// Import relations if not skipped
if !request.SkipRelations && len(exportMemo.Relations) > 0 {
result.Warnings = append(result.Warnings, fmt.Sprintf("Relations for memo %s were skipped (relation import not yet implemented)", exportMemo.UID))
// TODO: Implement relation import
// This would require resolving related memo UIDs and creating relations
}
return result, nil
}

View File

@@ -0,0 +1,170 @@
package v1
import (
"context"
"fmt"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
referenceType := store.MemoRelationReference
// Delete all reference relations first.
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &memo.ID,
Type: &referenceType,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
}
for _, relation := range request.Relations {
// Ignore reflexive relations.
if request.Name == relation.RelatedMemo.Name {
continue
}
// Ignore comment relations as there's no need to update a comment's relation.
// Inserting/Deleting a comment is handled elsewhere.
if relation.Type == v1pb.MemoRelation_COMMENT {
continue
}
relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get related memo")
}
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: relatedMemo.ID,
Type: convertMemoRelationTypeToStore(relation.Type),
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
var memoFilter string
if currentUser == nil {
memoFilter = `visibility == "PUBLIC"`
} else {
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
}
relationList := []*v1pb.MemoRelation{}
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo.ID,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, err
}
for _, raw := range tempList {
relation, err := s.convertMemoRelationFromStore(ctx, raw)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
}
relationList = append(relationList, relation)
}
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, err
}
for _, raw := range tempList {
relation, err := s.convertMemoRelationFromStore(ctx, raw)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert memo relation")
}
relationList = append(relationList, relation)
}
response := &v1pb.ListMemoRelationsResponse{
Relations: relationList,
}
return response, nil
}
func (s *APIV1Service) convertMemoRelationFromStore(ctx context.Context, memoRelation *store.MemoRelation) (*v1pb.MemoRelation, error) {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.MemoID})
if err != nil {
return nil, err
}
memoSnippet, err := getMemoContentSnippet(memo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoRelation.RelatedMemoID})
if err != nil {
return nil, err
}
relatedMemoSnippet, err := getMemoContentSnippet(relatedMemo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get related memo content snippet")
}
return &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Snippet: memoSnippet,
},
RelatedMemo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
Snippet: relatedMemoSnippet,
},
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
}, nil
}
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) v1pb.MemoRelation_Type {
switch relationType {
case store.MemoRelationReference:
return v1pb.MemoRelation_REFERENCE
case store.MemoRelationComment:
return v1pb.MemoRelation_COMMENT
default:
return v1pb.MemoRelation_TYPE_UNSPECIFIED
}
}
func convertMemoRelationTypeToStore(relationType v1pb.MemoRelation_Type) store.MemoRelationType {
switch relationType {
case v1pb.MemoRelation_REFERENCE:
return store.MemoRelationReference
case v1pb.MemoRelation_COMMENT:
return store.MemoRelationComment
default:
return store.MemoRelationReference
}
}

View File

@@ -0,0 +1,785 @@
package v1
import (
"context"
"fmt"
"log/slog"
"strings"
"time"
"unicode/utf8"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"github.com/usememos/gomark/ast"
"github.com/usememos/gomark/parser"
"github.com/usememos/gomark/parser/tokenizer"
"github.com/usememos/gomark/renderer"
"github.com/usememos/gomark/restore"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/runner/memopayload"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
create := &store.Memo{
UID: shortuuid.New(),
CreatorID: user.ID,
Content: request.Memo.Content,
Visibility: convertVisibilityToStore(request.Memo.Visibility),
}
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if len(create.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
if err := memopayload.RebuildMemoPayload(create); err != nil {
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
}
if request.Memo.Location != nil {
create.Payload.Location = convertLocationToStore(request.Memo.Location)
}
memo, err := s.Store.CreateMemo(ctx, create)
if err != nil {
return nil, err
}
if len(request.Memo.Attachments) > 0 {
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Attachments: request.Memo.Attachments,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo attachments")
}
}
if len(request.Memo.Relations) > 0 {
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID),
Relations: request.Memo.Relations,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo relations")
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is created.
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
}
return memoMessage, nil
}
func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosRequest) (*v1pb.ListMemosResponse, error) {
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
}
// Handle deprecated old_filter for backward compatibility
if request.OldFilter != "" && request.Filter == "" {
//nolint:staticcheck // SA1019: Using deprecated field for backward compatibility
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.OldFilter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter: %v", err)
}
}
if request.Parent != "" && request.Parent != "users/-" {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
}
memoFind.CreatorID = &userID
memoFind.OrderByPinned = true
}
if request.State == v1pb.State_ARCHIVED {
state := store.Archived
memoFind.RowStatus = &state
} else {
state := store.Normal
memoFind.RowStatus = &state
}
// Parse order_by field (replaces the old sort and direction fields)
if request.OrderBy != "" {
if err := s.parseMemoOrderBy(request.OrderBy, memoFind); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid order_by: %v", err)
}
} else {
// Default ordering by display_time desc
memoFind.OrderByTimeAsc = false
}
if request.Filter != "" {
if err := s.validateFilter(ctx, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
memoFind.Filter = &request.Filter
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else {
if memoFind.CreatorID == nil {
internalFilter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
if memoFind.Filter != nil {
filter := fmt.Sprintf("(%s) && (%s)", *memoFind.Filter, internalFilter)
memoFind.Filter = &filter
} else {
memoFind.Filter = &internalFilter
}
} else if *memoFind.CreatorID != currentUser.ID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
}
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
memoFind.OrderByUpdatedTs = true
}
var limit, offset int
if request.PageToken != "" {
var pageToken v1pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = int(pageToken.Limit)
offset = int(pageToken.Offset)
} else {
limit = int(request.PageSize)
}
if limit <= 0 {
limit = DefaultPageSize
}
limitPlusOne := limit + 1
memoFind.Limit = &limitPlusOne
memoFind.Offset = &offset
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
memoMessages := []*v1pb.Memo{}
nextPageToken := ""
if len(memos) == limitPlusOne {
memos = memos[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
}
}
for _, memo := range memos {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memoMessages = append(memoMessages, memoMessage)
}
response := &v1pb.ListMemosResponse{
Memos: memoMessages,
NextPageToken: nextPageToken,
}
return response, nil
}
func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
UID: &memoUID,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility != store.Public {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
return memoMessage, nil
}
func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Memo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
// Only the creator or admin can update the memo.
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
update := &store.UpdateMemo{
ID: memo.ID,
}
for _, path := range request.UpdateMask.Paths {
if path == "content" {
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if len(request.Memo.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
memo.Content = request.Memo.Content
if err := memopayload.RebuildMemoPayload(memo); err != nil {
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
}
update.Content = &memo.Content
update.Payload = memo.Payload
} else if path == "visibility" {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
visibility := convertVisibilityToStore(request.Memo.Visibility)
if workspaceMemoRelatedSetting.DisallowPublicVisibility && visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
update.Visibility = &visibility
} else if path == "pinned" {
update.Pinned = &request.Memo.Pinned
} else if path == "state" {
rowStatus := convertStateToStore(request.Memo.State)
update.RowStatus = &rowStatus
} else if path == "create_time" {
createdTs := request.Memo.CreateTime.AsTime().Unix()
update.CreatedTs = &createdTs
} else if path == "update_time" {
updatedTs := time.Now().Unix()
if request.Memo.UpdateTime != nil {
updatedTs = request.Memo.UpdateTime.AsTime().Unix()
}
update.UpdatedTs = &updatedTs
} else if path == "display_time" {
displayTs := request.Memo.DisplayTime.AsTime().Unix()
memoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if memoRelatedSetting.DisplayWithUpdateTime {
update.UpdatedTs = &displayTs
} else {
update.CreatedTs = &displayTs
}
} else if path == "location" {
payload := memo.Payload
payload.Location = convertLocationToStore(request.Memo.Location)
update.Payload = payload
} else if path == "attachments" {
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: request.Memo.Name,
Attachments: request.Memo.Attachments,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo attachments")
}
} else if path == "relations" {
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: request.Memo.Name,
Relations: request.Memo.Relations,
})
if err != nil {
return nil, errors.Wrap(err, "failed to set memo relations")
}
}
}
if err = s.Store.UpdateMemo(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memo.ID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is updated.
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
}
return memoMessage, nil
}
func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoRequest) (*emptypb.Empty, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
UID: &memoUID,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
// Only the creator or admin can update the memo.
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil {
// Try to dispatch webhook when memo is deleted.
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo deleted webhook", slog.Any("err", err))
}
}
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
// Delete memo relation
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{MemoID: &memo.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo relations")
}
// Delete related attachments.
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
for _, attachment := range attachments {
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
}
}
// Delete memo comments
commentType := store.MemoRelationComment
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo comments")
}
for _, relation := range relations {
if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: relation.MemoID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo comment")
}
}
// Delete memo references
referenceType := store.MemoRelationReference
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{RelatedMemoID: &memo.ID, Type: &referenceType}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo references")
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.CreateMemoCommentRequest) (*v1pb.Memo, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
// Create the memo comment first.
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{Memo: request.Comment})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo")
}
memoUID, err = ExtractMemoUIDFromName(memoComment.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
// Build the relation between the comment memo and the original memo.
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationComment,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
}
creatorID, err := ExtractUserIDFromName(memoComment.Creator)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
if memoComment.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
CreatorID: creatorID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memo.ID,
RelatedMemoId: relatedMemo.ID,
},
},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create activity")
}
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
SenderID: creatorID,
ReceiverID: relatedMemo.CreatorID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
ActivityId: &activity.ID,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to create inbox")
}
}
return memoComment, nil
}
func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListMemoCommentsRequest) (*v1pb.ListMemoCommentsResponse, error) {
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
var memoFilter string
if currentUser == nil {
memoFilter = `visibility == "PUBLIC"`
} else {
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
}
memoRelationComment := store.MemoRelationComment
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID,
Type: &memoRelationComment,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
}
var memos []*v1pb.Memo
for _, memoRelation := range memoRelations {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memoRelation.MemoID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo != nil {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memos = append(memos, memoMessage)
}
}
response := &v1pb.ListMemoCommentsResponse{
Memos: memos,
}
return response, nil
}
func (s *APIV1Service) RenameMemoTag(ctx context.Context, request *v1pb.RenameMemoTagRequest) (*emptypb.Empty, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
memoFind := &store.FindMemo{
CreatorID: &user.ID,
PayloadFind: &store.FindMemoPayload{TagSearch: []string{request.OldTag}},
ExcludeComments: true,
}
if (request.Parent) != "memos/-" {
memoUID, err := ExtractMemoUIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memoFind.UID = &memoUID
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
for _, memo := range memos {
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse memo: %v", err)
}
memopayload.TraverseASTNodes(nodes, func(node ast.Node) {
if tag, ok := node.(*ast.Tag); ok && tag.Content == request.OldTag {
tag.Content = request.NewTag
}
})
memo.Content = restore.Restore(nodes)
if err := memopayload.RebuildMemoPayload(memo); err != nil {
return nil, status.Errorf(codes.Internal, "failed to rebuild memo payload: %v", err)
}
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Content: &memo.Content,
Payload: memo.Payload,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo: %v", err)
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) DeleteMemoTag(ctx context.Context, request *v1pb.DeleteMemoTagRequest) (*emptypb.Empty, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
memoFind := &store.FindMemo{
CreatorID: &user.ID,
PayloadFind: &store.FindMemoPayload{TagSearch: []string{request.Tag}},
ExcludeContent: true,
ExcludeComments: true,
}
if request.Parent != "memos/-" {
memoUID, err := ExtractMemoUIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memoFind.UID = &memoUID
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
for _, memo := range memos {
if request.DeleteRelatedMemos {
err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
} else {
archived := store.Archived
err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
RowStatus: &archived,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) getContentLengthLimit(ctx context.Context) (int, error) {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return 0, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
return int(workspaceMemoRelatedSetting.ContentLengthLimit), nil
}
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
}
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
func (s *APIV1Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
}
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
func (s *APIV1Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
}
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
webhooks, err := s.Store.GetUserWebhooks(ctx, creatorID)
if err != nil {
return err
}
for _, hook := range webhooks {
payload, err := convertMemoToWebhookPayload(memo)
if err != nil {
return errors.Wrap(err, "failed to convert memo to webhook payload")
}
payload.ActivityType = activityType
payload.URL = hook.Url
// Use asynchronous webhook dispatch
webhook.PostAsync(payload)
}
return nil
}
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return nil, errors.Wrap(err, "invalid memo creator")
}
return &webhook.WebhookRequestPayload{
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creatorID),
Memo: memo,
}, nil
}
func getMemoContentSnippet(content string) (string, error) {
nodes, err := parser.Parse(tokenizer.Tokenize(content))
if err != nil {
return "", errors.Wrap(err, "failed to parse content")
}
plainText := renderer.NewStringRenderer().Render(nodes)
if len(plainText) > 64 {
return substring(plainText, 64) + "...", nil
}
return plainText, nil
}
func substring(s string, length int) string {
if length <= 0 {
return ""
}
runeCount := 0
byteIndex := 0
for byteIndex < len(s) {
_, size := utf8.DecodeRuneInString(s[byteIndex:])
byteIndex += size
runeCount++
if runeCount == length {
break
}
}
return s[:byteIndex]
}
// parseMemoOrderBy parses the order_by field and sets the appropriate ordering in memoFind.
func (*APIV1Service) parseMemoOrderBy(orderBy string, memoFind *store.FindMemo) error {
// Parse order_by field like "display_time desc" or "create_time asc"
parts := strings.Fields(strings.TrimSpace(orderBy))
if len(parts) == 0 {
return errors.New("empty order_by")
}
field := parts[0]
direction := "desc" // default
if len(parts) > 1 {
direction = strings.ToLower(parts[1])
if direction != "asc" && direction != "desc" {
return errors.Errorf("invalid order direction: %s, must be 'asc' or 'desc'", parts[1])
}
}
switch field {
case "display_time":
memoFind.OrderByTimeAsc = direction == "asc"
case "create_time":
memoFind.OrderByTimeAsc = direction == "asc"
case "update_time":
memoFind.OrderByUpdatedTs = true
memoFind.OrderByTimeAsc = direction == "asc"
case "name":
// For ordering by memo name/id - not commonly used but supported
memoFind.OrderByTimeAsc = direction == "asc"
default:
return errors.Errorf("unsupported order field: %s, supported fields are: display_time, create_time, update_time, name", field)
}
return nil
}

View File

@@ -0,0 +1,149 @@
package v1
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/gomark/parser"
"github.com/usememos/gomark/parser/tokenizer"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoMessage := &v1pb.Memo{
Name: name,
State: convertStateFromStore(memo.RowStatus),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, memo.CreatorID),
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
Content: memo.Content,
Visibility: convertVisibilityFromStore(memo.Visibility),
Pinned: memo.Pinned,
}
if memo.Payload != nil {
memoMessage.Tags = memo.Payload.Tags
memoMessage.Property = convertMemoPropertyFromStore(memo.Payload.Property)
memoMessage.Location = convertLocationFromStore(memo.Payload.Location)
}
if memo.ParentID != nil {
parent, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: memo.ParentID,
ExcludeContent: true,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get parent memo")
}
parentName := fmt.Sprintf("%s%s", MemoNamePrefix, parent.UID)
memoMessage.Parent = &parentName
}
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo relations")
}
memoMessage.Relations = listMemoRelationsResponse.Relations
listMemoAttachmentsResponse, err := s.ListMemoAttachments(ctx, &v1pb.ListMemoAttachmentsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo attachments")
}
memoMessage.Attachments = listMemoAttachmentsResponse.Attachments
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo reactions")
}
memoMessage.Reactions = listMemoReactionsResponse.Reactions
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, errors.Wrap(err, "failed to parse content")
}
memoMessage.Nodes = convertFromASTNodes(nodes)
snippet, err := getMemoContentSnippet(memo.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
memoMessage.Snippet = snippet
return memoMessage, nil
}
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
if property == nil {
return nil
}
return &v1pb.Memo_Property{
HasLink: property.HasLink,
HasTaskList: property.HasTaskList,
HasCode: property.HasCode,
HasIncompleteTasks: property.HasIncompleteTasks,
}
}
func convertLocationFromStore(location *storepb.MemoPayload_Location) *v1pb.Location {
if location == nil {
return nil
}
return &v1pb.Location{
Placeholder: location.Placeholder,
Latitude: location.Latitude,
Longitude: location.Longitude,
}
}
func convertLocationToStore(location *v1pb.Location) *storepb.MemoPayload_Location {
if location == nil {
return nil
}
return &storepb.MemoPayload_Location{
Placeholder: location.Placeholder,
Latitude: location.Latitude,
Longitude: location.Longitude,
}
}
func convertVisibilityFromStore(visibility store.Visibility) v1pb.Visibility {
switch visibility {
case store.Private:
return v1pb.Visibility_PRIVATE
case store.Protected:
return v1pb.Visibility_PROTECTED
case store.Public:
return v1pb.Visibility_PUBLIC
default:
return v1pb.Visibility_VISIBILITY_UNSPECIFIED
}
}
func convertVisibilityToStore(visibility v1pb.Visibility) store.Visibility {
switch visibility {
case v1pb.Visibility_PRIVATE:
return store.Private
case v1pb.Visibility_PROTECTED:
return store.Protected
case v1pb.Visibility_PUBLIC:
return store.Public
default:
return store.Private
}
}

View File

@@ -0,0 +1,168 @@
package v1
import (
"context"
"github.com/google/cel-go/cel"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store.FindMemo, filter string) error {
if find.PayloadFind == nil {
find.PayloadFind = &store.FindMemoPayload{}
}
if filter != "" {
filterExpr, err := parseMemoFilter(filter)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
if len(filterExpr.ContentSearch) > 0 {
find.ContentSearch = filterExpr.ContentSearch
}
if filterExpr.TagSearch != nil {
if find.PayloadFind == nil {
find.PayloadFind = &store.FindMemoPayload{}
}
find.PayloadFind.TagSearch = filterExpr.TagSearch
}
if filterExpr.DisplayTimeAfter != nil {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
find.UpdatedTsAfter = filterExpr.DisplayTimeAfter
} else {
find.CreatedTsAfter = filterExpr.DisplayTimeAfter
}
}
if filterExpr.DisplayTimeBefore != nil {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
find.UpdatedTsBefore = filterExpr.DisplayTimeBefore
} else {
find.CreatedTsBefore = filterExpr.DisplayTimeBefore
}
}
if filterExpr.Pinned {
pinned := true
find.Pinned = &pinned
}
if filterExpr.HasLink {
find.PayloadFind.HasLink = true
}
if filterExpr.HasTaskList {
find.PayloadFind.HasTaskList = true
}
if filterExpr.HasCode {
find.PayloadFind.HasCode = true
}
if filterExpr.HasIncompleteTasks {
find.PayloadFind.HasIncompleteTasks = true
}
}
return nil
}
// MemoFilterCELAttributes are the CEL attributes.
var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("content_search", cel.ListType(cel.StringType)),
cel.Variable("tag_search", cel.ListType(cel.StringType)),
cel.Variable("display_time_before", cel.IntType),
cel.Variable("display_time_after", cel.IntType),
cel.Variable("pinned", cel.BoolType),
cel.Variable("has_link", cel.BoolType),
cel.Variable("has_task_list", cel.BoolType),
cel.Variable("has_code", cel.BoolType),
cel.Variable("has_incomplete_tasks", cel.BoolType),
}
type MemoFilter struct {
ContentSearch []string
TagSearch []string
DisplayTimeBefore *int64
DisplayTimeAfter *int64
Pinned bool
HasLink bool
HasTaskList bool
HasCode bool
HasIncompleteTasks bool
}
func parseMemoFilter(expression string) (*MemoFilter, error) {
e, err := cel.NewEnv(MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
ast, issues := e.Compile(expression)
if issues != nil {
return nil, errors.Errorf("found issue %v", issues)
}
filter := &MemoFilter{}
parsedExpr, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, err
}
callExpr := parsedExpr.GetExpr().GetCallExpr()
findMemoField(callExpr, filter)
return filter, nil
}
func findMemoField(callExpr *exprv1.Expr_Call, filter *MemoFilter) {
if len(callExpr.Args) == 2 {
idExpr := callExpr.Args[0].GetIdentExpr()
if idExpr != nil {
if idExpr.Name == "content_search" {
contentSearch := []string{}
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
value := expr.GetConstExpr().GetStringValue()
contentSearch = append(contentSearch, value)
}
filter.ContentSearch = contentSearch
} else if idExpr.Name == "tag_search" {
tagSearch := []string{}
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
value := expr.GetConstExpr().GetStringValue()
tagSearch = append(tagSearch, value)
}
filter.TagSearch = tagSearch
} else if idExpr.Name == "display_time_before" {
displayTimeBefore := callExpr.Args[1].GetConstExpr().GetInt64Value()
filter.DisplayTimeBefore = &displayTimeBefore
} else if idExpr.Name == "display_time_after" {
displayTimeAfter := callExpr.Args[1].GetConstExpr().GetInt64Value()
filter.DisplayTimeAfter = &displayTimeAfter
} else if idExpr.Name == "pinned" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.Pinned = value
} else if idExpr.Name == "has_link" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.HasLink = value
} else if idExpr.Name == "has_task_list" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.HasTaskList = value
} else if idExpr.Name == "has_code" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.HasCode = value
} else if idExpr.Name == "has_incomplete_tasks" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.HasIncompleteTasks = value
}
return
}
}
for _, arg := range callExpr.Args {
callExpr := arg.GetCallExpr()
if callExpr != nil {
findMemoField(callExpr, filter)
}
}
}

View File

@@ -0,0 +1,90 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
response := &v1pb.ListMemoReactionsResponse{
Reactions: []*v1pb.Reaction{},
}
for _, reaction := range reactions {
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
}
response.Reactions = append(response.Reactions, reactionMessage)
}
return response, nil
}
func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.UpsertMemoReactionRequest) (*v1pb.Reaction, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: request.Reaction.ContentId,
ReactionType: request.Reaction.ReactionType,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
}
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
}
return reactionMessage, nil
}
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
reactionID, err := ExtractReactionIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
}
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
ID: reactionID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) convertReactionFromStore(ctx context.Context, reaction *store.Reaction) (*v1pb.Reaction, error) {
creator, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &reaction.CreatorID,
})
if err != nil {
return nil, err
}
reactionUID := fmt.Sprintf("%d", reaction.ID)
return &v1pb.Reaction{
Name: fmt.Sprintf("%s%s", ReactionNamePrefix, reactionUID),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID),
ContentId: reaction.ContentID,
ReactionType: reaction.ReactionType,
CreateTime: timestamppb.New(time.Unix(reaction.CreatedTs, 0)),
}, nil
}

View File

@@ -0,0 +1,162 @@
package v1
import (
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
const (
WorkspaceSettingNamePrefix = "workspace/settings/"
UserNamePrefix = "users/"
MemoNamePrefix = "memos/"
AttachmentNamePrefix = "attachments/"
ReactionNamePrefix = "reactions/"
InboxNamePrefix = "inboxes/"
IdentityProviderNamePrefix = "identityProviders/"
ActivityNamePrefix = "activities/"
WebhookNamePrefix = "webhooks/"
)
// GetNameParentTokens returns the tokens from a resource name.
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
parts := strings.Split(name, "/")
if len(parts) != 2*len(tokenPrefixes) {
return nil, errors.Errorf("invalid request %q", name)
}
var tokens []string
for i, tokenPrefix := range tokenPrefixes {
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
}
if parts[2*i+1] == "" {
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
}
tokens = append(tokens, parts[2*i+1])
}
return tokens, nil
}
func ExtractWorkspaceSettingKeyFromName(name string) (string, error) {
const prefix = "workspace/settings/"
if !strings.HasPrefix(name, prefix) {
return "", errors.Errorf("invalid workspace setting name: expected prefix %q, got %q", prefix, name)
}
settingKey := strings.TrimPrefix(name, prefix)
if settingKey == "" {
return "", errors.Errorf("invalid workspace setting name: empty setting key in %q", name)
}
// Ensure there are no additional path segments
if strings.Contains(settingKey, "/") {
return "", errors.Errorf("invalid workspace setting name: setting key cannot contain '/' in %q", name)
}
return settingKey, nil
}
// ExtractUserIDFromName returns the uid from a resource name.
func ExtractUserIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid user ID %q", tokens[0])
}
return id, nil
}
// ExtractMemoUIDFromName returns the memo UID from a resource name.
// e.g., "memos/uuid" -> "uuid".
func ExtractMemoUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, MemoNamePrefix)
if err != nil {
return "", err
}
id := tokens[0]
return id, nil
}
// ExtractAttachmentUIDFromName returns the attachment UID from a resource name.
func ExtractAttachmentUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, AttachmentNamePrefix)
if err != nil {
return "", err
}
id := tokens[0]
return id, nil
}
// ExtractReactionIDFromName returns the reaction ID from a resource name.
// e.g., "reactions/123" -> 123.
func ExtractReactionIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, ReactionNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid reaction ID %q", tokens[0])
}
return id, nil
}
// ExtractInboxIDFromName returns the inbox ID from a resource name.
func ExtractInboxIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
}
return id, nil
}
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
}
return id, nil
}
func ExtractActivityIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, ActivityNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid activity ID %q", tokens[0])
}
return id, nil
}
// ExtractWebhookIDFromName returns the webhook ID from a resource name.
func ExtractWebhookIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, UserNamePrefix, WebhookNamePrefix)
if err != nil {
return "", err
}
if len(tokens) != 2 {
return "", errors.Errorf("invalid webhook name format: %q", name)
}
webhookID := tokens[1]
if webhookID == "" {
return "", errors.Errorf("invalid webhook ID %q", webhookID)
}
return webhookID, nil
}

View File

@@ -0,0 +1,337 @@
package v1
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/filter"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// Helper function to extract user ID and shortcut ID from shortcut resource name.
// Format: users/{user}/shortcuts/{shortcut}.
func extractUserAndShortcutIDFromName(name string) (int32, string, error) {
parts := strings.Split(name, "/")
if len(parts) != 4 || parts[0] != "users" || parts[2] != "shortcuts" {
return 0, "", errors.Errorf("invalid shortcut name format: %s", name)
}
userID, err := util.ConvertStringToInt32(parts[1])
if err != nil {
return 0, "", errors.Errorf("invalid user ID %q", parts[1])
}
shortcutID := parts[3]
if shortcutID == "" {
return 0, "", errors.Errorf("empty shortcut ID in name: %s", name)
}
return userID, shortcutID, nil
}
// Helper function to construct shortcut resource name.
func constructShortcutName(userID int32, shortcutID string) string {
return fmt.Sprintf("users/%d/shortcuts/%s", userID, shortcutID)
}
func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShortcutsRequest) (*v1pb.ListShortcutsResponse, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return &v1pb.ListShortcutsResponse{
Shortcuts: []*v1pb.Shortcut{},
}, nil
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := []*v1pb.Shortcut{}
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
shortcuts = append(shortcuts, &v1pb.Shortcut{
Name: constructShortcutName(userID, shortcut.GetId()),
Title: shortcut.GetTitle(),
Filter: shortcut.GetFilter(),
})
}
return &v1pb.ListShortcutsResponse{
Shortcuts: shortcuts,
}, nil
}
func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcutRequest) (*v1pb.Shortcut, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
if shortcut.GetId() == shortcutID {
return &v1pb.Shortcut{
Name: constructShortcutName(userID, shortcut.GetId()),
Title: shortcut.GetTitle(),
Filter: shortcut.GetFilter(),
}, nil
}
}
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateShortcutRequest) (*v1pb.Shortcut, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
newShortcut := &storepb.ShortcutsUserSetting_Shortcut{
Id: util.GenUUID(),
Title: request.Shortcut.GetTitle(),
Filter: request.Shortcut.GetFilter(),
}
if newShortcut.Title == "" {
return nil, status.Errorf(codes.InvalidArgument, "title is required")
}
if err := s.validateFilter(ctx, newShortcut.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
if request.ValidateOnly {
return &v1pb.Shortcut{
Name: constructShortcutName(userID, newShortcut.GetId()),
Title: newShortcut.GetTitle(),
Filter: newShortcut.GetFilter(),
}, nil
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
userSetting = &storepb.UserSetting{
UserId: userID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{
Shortcuts: &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
},
},
}
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
shortcuts = append(shortcuts, newShortcut)
shortcutsUserSetting.Shortcuts = shortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, err
}
return &v1pb.Shortcut{
Name: constructShortcutName(userID, newShortcut.GetId()),
Title: newShortcut.GetTitle(),
Filter: newShortcut.GetFilter(),
}, nil
}
func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateShortcutRequest) (*v1pb.Shortcut, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Shortcut.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
var foundShortcut *storepb.ShortcutsUserSetting_Shortcut
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
for _, shortcut := range shortcuts {
if shortcut.GetId() == shortcutID {
foundShortcut = shortcut
for _, field := range request.UpdateMask.Paths {
if field == "title" {
if request.Shortcut.GetTitle() == "" {
return nil, status.Errorf(codes.InvalidArgument, "title is required")
}
shortcut.Title = request.Shortcut.GetTitle()
} else if field == "filter" {
if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
shortcut.Filter = request.Shortcut.GetFilter()
}
}
}
newShortcuts = append(newShortcuts, shortcut)
}
if foundShortcut == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting.Shortcuts = newShortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, err
}
return &v1pb.Shortcut{
Name: constructShortcutName(userID, foundShortcut.GetId()),
Title: foundShortcut.GetTitle(),
Filter: foundShortcut.GetFilter(),
}, nil
}
func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteShortcutRequest) (*emptypb.Empty, error) {
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil || currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_SHORTCUTS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting := userSetting.GetShortcuts()
shortcuts := shortcutsUserSetting.GetShortcuts()
newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts))
found := false
for _, shortcut := range shortcuts {
if shortcut.GetId() != shortcutID {
newShortcuts = append(newShortcuts, shortcut)
} else {
found = true
}
}
if !found {
return nil, status.Errorf(codes.NotFound, "shortcut not found")
}
shortcutsUserSetting.Shortcuts = newShortcuts
userSetting.Value = &storepb.UserSetting_Shortcuts{
Shortcuts: shortcutsUserSetting,
}
_, err = s.Store.UpsertUserSetting(ctx, userSetting)
if err != nil {
return nil, err
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
// Validate the filter.
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
if err != nil {
return errors.Wrap(err, "failed to parse filter")
}
convertCtx := filter.NewConvertContext()
err = s.Store.GetDriver().ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
if err != nil {
return errors.Wrap(err, "failed to convert filter to SQL")
}
return nil
}

View File

@@ -0,0 +1,519 @@
package v1
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestCreateIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("CreateIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, hostUser.ID)
// Create OAuth2 identity provider
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test OAuth2 Provider",
IdentifierFilter: "",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "https://example.com/oauth/authorize",
TokenUrl: "https://example.com/oauth/token",
UserInfoUrl: "https://example.com/oauth/userinfo",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
DisplayName: "name",
Email: "email",
AvatarUrl: "avatar_url",
},
},
},
},
},
}
resp, err := ts.Service.CreateIdentityProvider(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "Test OAuth2 Provider", resp.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
require.Contains(t, resp.Name, "identityProviders/")
require.NotNil(t, resp.Config.GetOauth2Config())
require.Equal(t, "test-client-id", resp.Config.GetOauth2Config().ClientId)
})
t.Run("CreateIdentityProvider permission denied for non-host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
// Set user context
ctx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err = ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("CreateIdentityProvider unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
}
func TestListIdentityProviders(t *testing.T) {
ctx := context.Background()
t.Run("ListIdentityProviders empty", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.ListIdentityProvidersRequest{}
resp, err := ts.Service.ListIdentityProviders(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.IdentityProviders)
})
t.Run("ListIdentityProviders with providers", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a couple of identity providers
createReq1 := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider 1",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client1",
AuthUrl: "https://example1.com/auth",
TokenUrl: "https://example1.com/token",
UserInfoUrl: "https://example1.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
createReq2 := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider 2",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client2",
AuthUrl: "https://example2.com/auth",
TokenUrl: "https://example2.com/token",
UserInfoUrl: "https://example2.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq1)
require.NoError(t, err)
_, err = ts.Service.CreateIdentityProvider(userCtx, createReq2)
require.NoError(t, err)
// List providers
listReq := &v1pb.ListIdentityProvidersRequest{}
resp, err := ts.Service.ListIdentityProviders(ctx, listReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.IdentityProviders, 2)
// Verify response contains expected providers
titles := []string{resp.IdentityProviders[0].Title, resp.IdentityProviders[1].Title}
require.Contains(t, titles, "Provider 1")
require.Contains(t, titles, "Provider 2")
})
}
func TestGetIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("GetIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "test-client",
ClientSecret: "test-secret",
AuthUrl: "https://example.com/auth",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/user",
Scopes: []string{"openid", "profile"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
DisplayName: "name",
Email: "email",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Get identity provider
getReq := &v1pb.GetIdentityProviderRequest{
Name: created.Name,
}
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, created.Name, resp.Name)
require.Equal(t, "Test Provider", resp.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
require.NotNil(t, resp.Config.GetOauth2Config())
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret)
})
t.Run("GetIdentityProvider not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.GetIdentityProviderRequest{
Name: "identityProviders/999",
}
_, err := ts.Service.GetIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("GetIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.GetIdentityProviderRequest{
Name: "invalid-name",
}
_, err := ts.Service.GetIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
}
func TestUpdateIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("UpdateIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Original Provider",
IdentifierFilter: "",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "original-client",
AuthUrl: "https://original.com/auth",
TokenUrl: "https://original.com/token",
UserInfoUrl: "https://original.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Update identity provider
updateReq := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: created.Name,
Title: "Updated Provider",
IdentifierFilter: "test@example.com",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "updated-client",
ClientSecret: "updated-secret",
AuthUrl: "https://updated.com/auth",
TokenUrl: "https://updated.com/token",
UserInfoUrl: "https://updated.com/user",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &v1pb.FieldMapping{
Identifier: "sub",
DisplayName: "given_name",
Email: "email",
AvatarUrl: "picture",
},
},
},
},
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "identifier_filter", "config"},
},
}
updated, err := ts.Service.UpdateIdentityProvider(userCtx, updateReq)
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "Updated Provider", updated.Title)
require.Equal(t, "test@example.com", updated.IdentifierFilter)
require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId)
})
t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "identityProviders/1",
Title: "Updated Provider",
},
}
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update_mask is required")
})
t.Run("UpdateIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "invalid-name",
Title: "Updated Provider",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title"},
},
}
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
}
func TestDeleteIdentityProvider(t *testing.T) {
ctx := context.Background()
t.Run("DeleteIdentityProvider success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create identity provider
createReq := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Provider to Delete",
Type: v1pb.IdentityProvider_OAUTH2,
Config: &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: "client-to-delete",
AuthUrl: "https://example.com/auth",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/user",
FieldMapping: &v1pb.FieldMapping{
Identifier: "id",
},
},
},
},
},
}
created, err := ts.Service.CreateIdentityProvider(userCtx, createReq)
require.NoError(t, err)
// Delete identity provider
deleteReq := &v1pb.DeleteIdentityProviderRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, deleteReq)
require.NoError(t, err)
// Verify deletion
getReq := &v1pb.GetIdentityProviderRequest{
Name: created.Name,
}
_, err = ts.Service.GetIdentityProvider(ctx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("DeleteIdentityProvider invalid name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.DeleteIdentityProviderRequest{
Name: "invalid-name",
}
_, err := ts.Service.DeleteIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
t.Run("DeleteIdentityProvider not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.DeleteIdentityProviderRequest{
Name: "identityProviders/999",
}
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
require.Error(t, err)
// Note: Delete might succeed even if item doesn't exist, depending on store implementation
})
}
func TestIdentityProviderPermissions(t *testing.T) {
ctx := context.Background()
t.Run("Only host users can create identity providers", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err = ts.Service.CreateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("Authentication required", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.CreateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Title: "Test Provider",
Type: v1pb.IdentityProvider_OAUTH2,
},
}
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
}

View File

@@ -0,0 +1,559 @@
package v1
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestListInboxes(t *testing.T) {
ctx := context.Background()
t.Run("ListInboxes success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// List inboxes (should be empty initially)
req := &v1pb.ListInboxesRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
resp, err := ts.Service.ListInboxes(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.Inboxes)
require.Equal(t, int32(0), resp.TotalSize)
})
t.Run("ListInboxes with pagination", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create some inbox entries
const systemBotID int32 = 0
for i := 0; i < 3; i++ {
_, err := ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
}
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// List inboxes with page size limit
req := &v1pb.ListInboxesRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
PageSize: 2,
}
resp, err := ts.Service.ListInboxes(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, 2, len(resp.Inboxes))
require.NotEmpty(t, resp.NextPageToken)
})
t.Run("ListInboxes permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Set user1 context but try to list user2's inboxes
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.ListInboxesRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
}
_, err = ts.Service.ListInboxes(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot access inboxes")
})
t.Run("ListInboxes host can access other users' inboxes", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user and a regular user
hostUser, err := ts.CreateHostUser(ctx, "hostuser")
require.NoError(t, err)
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
require.NoError(t, err)
// Create an inbox for the regular user
const systemBotID int32 = 0
_, err = ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: regularUser.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
// Set host user context and try to list regular user's inboxes
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
req := &v1pb.ListInboxesRequest{
Parent: fmt.Sprintf("users/%d", regularUser.ID),
}
resp, err := ts.Service.ListInboxes(hostCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, 1, len(resp.Inboxes))
})
t.Run("ListInboxes invalid parent format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.ListInboxesRequest{
Parent: "invalid-parent-format",
}
_, err = ts.Service.ListInboxes(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid parent name")
})
t.Run("ListInboxes unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.ListInboxesRequest{
Parent: "users/1",
}
_, err := ts.Service.ListInboxes(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "user not authenticated")
})
}
func TestUpdateInbox(t *testing.T) {
ctx := context.Background()
t.Run("UpdateInbox success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create an inbox entry
const systemBotID int32 = 0
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Update inbox status
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
Status: v1pb.Inbox_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"status"},
},
}
resp, err := ts.Service.UpdateInbox(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, v1pb.Inbox_ARCHIVED, resp.Status)
})
t.Run("UpdateInbox permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create an inbox entry for user2
const systemBotID int32 = 0
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: user2.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
// Set user1 context but try to update user2's inbox
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
Status: v1pb.Inbox_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"status"},
},
}
_, err = ts.Service.UpdateInbox(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot update inbox")
})
t.Run("UpdateInbox missing update mask", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
Name: "inboxes/1",
Status: v1pb.Inbox_ARCHIVED,
},
}
_, err = ts.Service.UpdateInbox(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update mask is required")
})
t.Run("UpdateInbox invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
Name: "invalid-inbox-name",
Status: v1pb.Inbox_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"status"},
},
}
_, err = ts.Service.UpdateInbox(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid inbox name")
})
t.Run("UpdateInbox not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
Name: "inboxes/99999", // Non-existent inbox
Status: v1pb.Inbox_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"status"},
},
}
_, err = ts.Service.UpdateInbox(userCtx, req)
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, codes.NotFound, st.Code())
})
t.Run("UpdateInbox unsupported field", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create an inbox entry
const systemBotID int32 = 0
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
Status: v1pb.Inbox_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"unsupported_field"},
},
}
_, err = ts.Service.UpdateInbox(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "unsupported field")
})
}
func TestDeleteInbox(t *testing.T) {
ctx := context.Background()
t.Run("DeleteInbox success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create an inbox entry
const systemBotID int32 = 0
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Delete inbox
req := &v1pb.DeleteInboxRequest{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
}
_, err = ts.Service.DeleteInbox(userCtx, req)
require.NoError(t, err)
// Verify inbox is deleted
inboxes, err := ts.Store.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
})
require.NoError(t, err)
require.Equal(t, 0, len(inboxes))
})
t.Run("DeleteInbox permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create an inbox entry for user2
const systemBotID int32 = 0
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: user2.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
// Set user1 context but try to delete user2's inbox
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.DeleteInboxRequest{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
}
_, err = ts.Service.DeleteInbox(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot delete inbox")
})
t.Run("DeleteInbox invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteInboxRequest{
Name: "invalid-inbox-name",
}
_, err = ts.Service.DeleteInbox(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid inbox name")
})
t.Run("DeleteInbox not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteInboxRequest{
Name: "inboxes/99999", // Non-existent inbox
}
_, err = ts.Service.DeleteInbox(userCtx, req)
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, codes.NotFound, st.Code())
})
}
func TestInboxCRUDComplete(t *testing.T) {
ctx := context.Background()
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Create an inbox entry directly in store
const systemBotID int32 = 0
inbox, err := ts.Store.CreateInbox(ctx, &store.Inbox{
SenderID: systemBotID,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
},
})
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// 1. List inboxes - should have 1
listReq := &v1pb.ListInboxesRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListInboxes(userCtx, listReq)
require.NoError(t, err)
require.Equal(t, 1, len(listResp.Inboxes))
require.Equal(t, v1pb.Inbox_UNREAD, listResp.Inboxes[0].Status)
// 2. Update inbox status to ARCHIVED
updateReq := &v1pb.UpdateInboxRequest{
Inbox: &v1pb.Inbox{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
Status: v1pb.Inbox_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"status"},
},
}
updateResp, err := ts.Service.UpdateInbox(userCtx, updateReq)
require.NoError(t, err)
require.Equal(t, v1pb.Inbox_ARCHIVED, updateResp.Status)
// 3. List inboxes again - should still have 1 but ARCHIVED
listResp, err = ts.Service.ListInboxes(userCtx, listReq)
require.NoError(t, err)
require.Equal(t, 1, len(listResp.Inboxes))
require.Equal(t, v1pb.Inbox_ARCHIVED, listResp.Inboxes[0].Status)
// 4. Delete inbox
deleteReq := &v1pb.DeleteInboxRequest{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
}
_, err = ts.Service.DeleteInbox(userCtx, deleteReq)
require.NoError(t, err)
// 5. List inboxes - should be empty
listResp, err = ts.Service.ListInboxes(userCtx, listReq)
require.NoError(t, err)
require.Equal(t, 0, len(listResp.Inboxes))
require.Equal(t, int32(0), listResp.TotalSize)
})
}

View File

@@ -0,0 +1,819 @@
package v1
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestListShortcuts(t *testing.T) {
ctx := context.Background()
t.Run("ListShortcuts success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// List shortcuts (should be empty initially)
req := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
resp, err := ts.Service.ListShortcuts(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.Shortcuts)
})
t.Run("ListShortcuts permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Set user1 context but try to list user2's shortcuts
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
}
_, err = ts.Service.ListShortcuts(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("ListShortcuts invalid parent format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.ListShortcutsRequest{
Parent: "invalid-parent-format",
}
_, err = ts.Service.ListShortcuts(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name")
})
t.Run("ListShortcuts unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.ListShortcutsRequest{
Parent: "users/1",
}
_, err := ts.Service.ListShortcuts(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
}
func TestGetShortcut(t *testing.T) {
ctx := context.Background()
t.Run("GetShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// First create a shortcut
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Now get the shortcut
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
resp, err := ts.Service.GetShortcut(userCtx, getReq)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, created.Name, resp.Name)
require.Equal(t, "Test Shortcut", resp.Title)
require.Equal(t, "tag in [\"test\"]", resp.Filter)
})
t.Run("GetShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to get shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.GetShortcut(user2Ctx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("GetShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: "invalid-shortcut-name",
}
_, err = ts.Service.GetShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("GetShortcut not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.GetShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
}
_, err = ts.Service.GetShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}
func TestCreateShortcut(t *testing.T) {
ctx := context.Background()
t.Run("CreateShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "My Shortcut",
Filter: "tag in [\"important\"]",
},
}
resp, err := ts.Service.CreateShortcut(userCtx, req)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "My Shortcut", resp.Title)
require.Equal(t, "tag in [\"important\"]", resp.Filter)
require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID))
// Verify the shortcut was created by listing
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, listResp.Shortcuts, 1)
require.Equal(t, "My Shortcut", listResp.Shortcuts[0].Title)
})
t.Run("CreateShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Set user1 context but try to create shortcut for user2
userCtx := ts.CreateUserContext(ctx, user1.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user2.ID),
Shortcut: &v1pb.Shortcut{
Title: "Forbidden Shortcut",
Filter: "tag in [\"forbidden\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("CreateShortcut invalid parent format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: "invalid-parent",
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name")
})
t.Run("CreateShortcut invalid filter", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Invalid Filter Shortcut",
Filter: "invalid||filter))syntax",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid filter")
})
t.Run("CreateShortcut missing title", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Filter: "tag in [\"test\"]",
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "title is required")
})
}
func TestUpdateShortcut(t *testing.T) {
ctx := context.Background()
t.Run("UpdateShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Original Title",
Filter: "tag in [\"original\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Update the shortcut
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Title: "Updated Title",
Filter: "tag in [\"updated\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "Updated Title", updated.Title)
require.Equal(t, "tag in [\"updated\"]", updated.Filter)
require.Equal(t, created.Name, updated.Name)
})
t.Run("UpdateShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to update shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Title: "Hacked Title",
Filter: "tag in [\"hacked\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
_, err = ts.Service.UpdateShortcut(user2Ctx, updateReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("UpdateShortcut missing update mask", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user and context for authentication
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID),
Title: "Updated Title",
},
}
_, err = ts.Service.UpdateShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update mask is required")
})
t.Run("UpdateShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: "invalid-shortcut-name",
Title: "Updated Title",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title"},
},
}
_, err := ts.Service.UpdateShortcut(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("UpdateShortcut invalid filter", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Test Shortcut",
Filter: "tag in [\"test\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Try to update with invalid filter
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created.Name,
Filter: "invalid||filter))syntax",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"filter"},
},
}
_, err = ts.Service.UpdateShortcut(userCtx, updateReq)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid filter")
})
}
func TestDeleteShortcut(t *testing.T) {
ctx := context.Background()
t.Run("DeleteShortcut success", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create a user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a shortcut first
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Shortcut to Delete",
Filter: "tag in [\"delete\"]",
},
}
created, err := ts.Service.CreateShortcut(userCtx, createReq)
require.NoError(t, err)
// Delete the shortcut
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
require.NoError(t, err)
// Verify deletion by listing shortcuts
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Empty(t, listResp.Shortcuts)
// Also verify by trying to get the deleted shortcut
getReq := &v1pb.GetShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.GetShortcut(userCtx, getReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("DeleteShortcut permission denied for different user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create two users
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
// Create shortcut as user1
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
createReq := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user1.ID),
Shortcut: &v1pb.Shortcut{
Title: "User1 Shortcut",
Filter: "tag in [\"user1\"]",
},
}
created, err := ts.Service.CreateShortcut(user1Ctx, createReq)
require.NoError(t, err)
// Try to delete shortcut as user2
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created.Name,
}
_, err = ts.Service.DeleteShortcut(user2Ctx, deleteReq)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("DeleteShortcut invalid name format", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
req := &v1pb.DeleteShortcutRequest{
Name: "invalid-shortcut-name",
}
_, err := ts.Service.DeleteShortcut(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid shortcut name")
})
t.Run("DeleteShortcut not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
req := &v1pb.DeleteShortcutRequest{
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
}
_, err = ts.Service.DeleteShortcut(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}
func TestShortcutFiltering(t *testing.T) {
ctx := context.Background()
t.Run("CreateShortcut with valid filters", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various valid filter formats
validFilters := []string{
"tag in [\"work\"]",
"content.contains(\"meeting\")",
"tag in [\"work\"] && content.contains(\"meeting\")",
"tag in [\"work\"] || tag in [\"personal\"]",
"creator_id == 1",
"visibility == \"PUBLIC\"",
"has_task_list == true",
"has_task_list == false",
}
for i, filter := range validFilters {
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Valid Filter " + string(rune(i)),
Filter: filter,
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.NoError(t, err, "Filter should be valid: %s", filter)
}
})
t.Run("CreateShortcut with invalid filters", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// Test various invalid filter formats
invalidFilters := []string{
"tag in ", // incomplete expression
"invalid_field @in [\"value\"]", // unknown field
"tag in [\"work\"] &&", // incomplete expression
"tag in [\"work\"] || || tag in [\"test\"]", // double operator
"((tag in [\"work\"]", // unmatched parentheses
"tag in [\"work\"] && )", // mismatched parentheses
"tag == \"work\"", // wrong operator (== not supported for tags)
"tag in work", // missing brackets
}
for _, filter := range invalidFilters {
req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Invalid Filter Test",
Filter: filter,
},
}
_, err = ts.Service.CreateShortcut(userCtx, req)
require.Error(t, err, "Filter should be invalid: %s", filter)
require.Contains(t, err.Error(), "invalid filter", "Error should mention invalid filter for: %s", filter)
}
})
}
func TestShortcutCRUDComplete(t *testing.T) {
ctx := context.Background()
t.Run("Complete CRUD lifecycle", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "testuser")
require.NoError(t, err)
// Set user context
userCtx := ts.CreateUserContext(ctx, user.ID)
// 1. Create multiple shortcuts
shortcut1Req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Work Notes",
Filter: "tag in [\"work\"]",
},
}
shortcut2Req := &v1pb.CreateShortcutRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
Shortcut: &v1pb.Shortcut{
Title: "Personal Notes",
Filter: "tag in [\"personal\"]",
},
}
created1, err := ts.Service.CreateShortcut(userCtx, shortcut1Req)
require.NoError(t, err)
require.Equal(t, "Work Notes", created1.Title)
created2, err := ts.Service.CreateShortcut(userCtx, shortcut2Req)
require.NoError(t, err)
require.Equal(t, "Personal Notes", created2.Title)
// 2. List shortcuts and verify both exist
listReq := &v1pb.ListShortcutsRequest{
Parent: fmt.Sprintf("users/%d", user.ID),
}
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, listResp.Shortcuts, 2)
// 3. Get individual shortcuts
getReq1 := &v1pb.GetShortcutRequest{Name: created1.Name}
getResp1, err := ts.Service.GetShortcut(userCtx, getReq1)
require.NoError(t, err)
require.Equal(t, created1.Name, getResp1.Name)
require.Equal(t, "Work Notes", getResp1.Title)
getReq2 := &v1pb.GetShortcutRequest{Name: created2.Name}
getResp2, err := ts.Service.GetShortcut(userCtx, getReq2)
require.NoError(t, err)
require.Equal(t, created2.Name, getResp2.Name)
require.Equal(t, "Personal Notes", getResp2.Title)
// 4. Update one shortcut
updateReq := &v1pb.UpdateShortcutRequest{
Shortcut: &v1pb.Shortcut{
Name: created1.Name,
Title: "Work & Meeting Notes",
Filter: "tag in [\"work\"] || tag in [\"meeting\"]",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"title", "filter"},
},
}
updated, err := ts.Service.UpdateShortcut(userCtx, updateReq)
require.NoError(t, err)
require.Equal(t, "Work & Meeting Notes", updated.Title)
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", updated.Filter)
// 5. Verify update by getting it again
getUpdatedReq := &v1pb.GetShortcutRequest{Name: created1.Name}
getUpdatedResp, err := ts.Service.GetShortcut(userCtx, getUpdatedReq)
require.NoError(t, err)
require.Equal(t, "Work & Meeting Notes", getUpdatedResp.Title)
require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", getUpdatedResp.Filter)
// 6. Delete one shortcut
deleteReq := &v1pb.DeleteShortcutRequest{
Name: created2.Name,
}
_, err = ts.Service.DeleteShortcut(userCtx, deleteReq)
require.NoError(t, err)
// 7. Verify deletion by listing (should only have 1 left)
finalListResp, err := ts.Service.ListShortcuts(userCtx, listReq)
require.NoError(t, err)
require.Len(t, finalListResp.Shortcuts, 1)
require.Equal(t, "Work & Meeting Notes", finalListResp.Shortcuts[0].Title)
// 8. Verify deleted shortcut can't be accessed
getDeletedReq := &v1pb.GetShortcutRequest{Name: created2.Name}
_, err = ts.Service.GetShortcut(userCtx, getDeletedReq)
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,81 @@
package v1
import (
"context"
"testing"
"github.com/usememos/memos/internal/profile"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
teststore "github.com/usememos/memos/store/test"
)
// TestService holds the test service setup for API v1 services.
type TestService struct {
Service *apiv1.APIV1Service
Store *store.Store
Profile *profile.Profile
Secret string
}
// NewTestService creates a new test service with SQLite database.
func NewTestService(t *testing.T) *TestService {
ctx := context.Background()
// Create a test store with SQLite
testStore := teststore.NewTestingStore(ctx, t)
// Create a test profile
testProfile := &profile.Profile{
Mode: "dev",
Version: "test-1.0.0",
InstanceURL: "http://localhost:8080",
Driver: "sqlite",
DSN: ":memory:",
}
// Create APIV1Service with nil grpcServer since we're testing direct calls
secret := "test-secret"
service := &apiv1.APIV1Service{
Secret: secret,
Profile: testProfile,
Store: testStore,
}
return &TestService{
Service: service,
Store: testStore,
Profile: testProfile,
Secret: secret,
}
}
// Cleanup clears caches and closes resources after test.
func (ts *TestService) Cleanup() {
ts.Store.Close()
// Note: Owner cache is package-level in parent package, cannot clear from test package
}
// CreateHostUser creates a host user for testing.
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
return ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleHost,
Email: username + "@example.com",
})
}
// CreateRegularUser creates a regular user for testing.
func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (*store.User, error) {
return ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleUser,
Email: username + "@example.com",
})
}
// CreateUserContext creates a context with the given user's ID for authentication.
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
// Use the real context key from the parent package
return apiv1.CreateTestUserContext(ctx, userID)
}

View File

@@ -0,0 +1,105 @@
package v1
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestGetUserStats_TagCount(t *testing.T) {
ctx := context.Background()
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test host user
user, err := ts.CreateHostUser(ctx, "test_user")
require.NoError(t, err)
// Create user context for authentication
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create a memo with a single tag
memo, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-1",
CreatorID: user.ID,
Content: "This is a test memo with #test tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"test"},
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Test GetUserStats
userName := fmt.Sprintf("users/%d", user.ID)
response, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response)
// Check that the tag count is exactly 1, not 2
require.Contains(t, response.TagCount, "test")
require.Equal(t, int32(1), response.TagCount["test"], "Tag count should be 1 for a single occurrence")
// Create another memo with the same tag
memo2, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-2",
CreatorID: user.ID,
Content: "Another memo with #test tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"test"},
},
})
require.NoError(t, err)
require.NotNil(t, memo2)
// Test GetUserStats again
response2, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response2)
// Check that the tag count is exactly 2, not 3
require.Contains(t, response2.TagCount, "test")
require.Equal(t, int32(2), response2.TagCount["test"], "Tag count should be 2 for two occurrences")
// Test with a new unique tag
memo3, err := ts.Store.CreateMemo(ctx, &store.Memo{
UID: "test-memo-3",
CreatorID: user.ID,
Content: "Memo with #unique tag",
Visibility: store.Public,
Payload: &storepb.MemoPayload{
Tags: []string{"unique"},
},
})
require.NoError(t, err)
require.NotNil(t, memo3)
// Test GetUserStats for the new tag
response3, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
Name: userName,
})
require.NoError(t, err)
require.NotNil(t, response3)
// Check that the unique tag count is exactly 1
require.Contains(t, response3.TagCount, "unique")
require.Equal(t, int32(1), response3.TagCount["unique"], "New tag count should be 1 for first occurrence")
// The original test tag should still be 2
require.Contains(t, response3.TagCount, "test")
require.Equal(t, int32(2), response3.TagCount["test"], "Original tag count should remain 2")
}

View File

@@ -0,0 +1,406 @@
package v1
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestCreateWebhook(t *testing.T) {
ctx := context.Background()
t.Run("CreateWebhook with host user", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create and authenticate as host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
req := &v1pb.CreateWebhookRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
Webhook: &v1pb.Webhook{
DisplayName: "Test Webhook",
Url: "https://example.com/webhook",
},
}
resp, err := ts.Service.CreateWebhook(userCtx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "Test Webhook", resp.DisplayName)
require.Equal(t, "https://example.com/webhook", resp.Url)
require.Contains(t, resp.Name, "webhooks/")
require.Contains(t, resp.Name, fmt.Sprintf("users/%d", hostUser.ID))
})
t.Run("CreateWebhook fails without authentication", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Try to create webhook without authentication
req := &v1pb.CreateWebhookRequest{
Parent: "users/1", // Dummy parent since we don't have a real user
Webhook: &v1pb.Webhook{
DisplayName: "Test Webhook",
Url: "https://example.com/webhook",
},
}
_, err := ts.Service.CreateWebhook(ctx, req)
// Should fail with permission denied or unauthenticated
require.Error(t, err)
})
t.Run("CreateWebhook fails with regular user", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create and authenticate as regular user
regularUser, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Try to create webhook as regular user
req := &v1pb.CreateWebhookRequest{
Parent: fmt.Sprintf("users/%d", regularUser.ID),
Webhook: &v1pb.Webhook{
DisplayName: "Test Webhook",
Url: "https://example.com/webhook",
},
}
_, err = ts.Service.CreateWebhook(userCtx, req)
// Should fail with permission denied
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("CreateWebhook validates required fields", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create and authenticate as host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to create webhook with missing URL
req := &v1pb.CreateWebhookRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
Webhook: &v1pb.Webhook{
DisplayName: "Test Webhook",
// URL missing
},
}
_, err = ts.Service.CreateWebhook(userCtx, req)
// Should fail with validation error
require.Error(t, err)
})
}
func TestListWebhooks(t *testing.T) {
ctx := context.Background()
t.Run("ListWebhooks returns empty list initially", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user for authentication
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// List webhooks
req := &v1pb.ListWebhooksRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
}
resp, err := ts.Service.ListWebhooks(userCtx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Empty(t, resp.Webhooks)
})
t.Run("ListWebhooks returns created webhooks", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
Webhook: &v1pb.Webhook{
DisplayName: "Test Webhook",
Url: "https://example.com/webhook",
},
}
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
require.NoError(t, err)
// List webhooks
listReq := &v1pb.ListWebhooksRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
}
resp, err := ts.Service.ListWebhooks(userCtx, listReq)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Webhooks, 1)
require.Equal(t, createdWebhook.Name, resp.Webhooks[0].Name)
require.Equal(t, createdWebhook.Url, resp.Webhooks[0].Url)
})
t.Run("ListWebhooks fails without authentication", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Try to list webhooks without authentication
req := &v1pb.ListWebhooksRequest{
Parent: "users/1", // Dummy parent since we don't have a real user
}
_, err := ts.Service.ListWebhooks(ctx, req)
// Should fail with permission denied or unauthenticated
require.Error(t, err)
})
}
func TestGetWebhook(t *testing.T) {
ctx := context.Background()
t.Run("GetWebhook returns webhook by name", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
Webhook: &v1pb.Webhook{
DisplayName: "Test Webhook",
Url: "https://example.com/webhook",
},
}
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
require.NoError(t, err)
// Get the webhook
getReq := &v1pb.GetWebhookRequest{
Name: createdWebhook.Name,
}
resp, err := ts.Service.GetWebhook(userCtx, getReq)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, createdWebhook.Name, resp.Name)
require.Equal(t, createdWebhook.Url, resp.Url)
})
t.Run("GetWebhook fails with invalid name", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to get webhook with invalid name
req := &v1pb.GetWebhookRequest{
Name: "invalid/webhook/name",
}
_, err = ts.Service.GetWebhook(userCtx, req)
// Should return an error
require.Error(t, err)
})
t.Run("GetWebhook fails with non-existent webhook", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to get non-existent webhook
req := &v1pb.GetWebhookRequest{
Name: fmt.Sprintf("users/%d/webhooks/999", hostUser.ID),
}
_, err = ts.Service.GetWebhook(userCtx, req)
// Should return not found error
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}
func TestUpdateWebhook(t *testing.T) {
ctx := context.Background()
t.Run("UpdateWebhook updates webhook properties", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
Webhook: &v1pb.Webhook{
DisplayName: "Original Webhook",
Url: "https://example.com/webhook",
},
}
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
require.NoError(t, err)
// Update the webhook
updateReq := &v1pb.UpdateWebhookRequest{
Webhook: &v1pb.Webhook{
Name: createdWebhook.Name,
Url: "https://updated.example.com/webhook",
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{"url"},
},
}
resp, err := ts.Service.UpdateWebhook(userCtx, updateReq)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, createdWebhook.Name, resp.Name)
require.Equal(t, "https://updated.example.com/webhook", resp.Url)
})
t.Run("UpdateWebhook fails without authentication", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Try to update webhook without authentication
req := &v1pb.UpdateWebhookRequest{
Webhook: &v1pb.Webhook{
Name: "users/1/webhooks/1",
Url: "https://updated.example.com/webhook",
},
}
_, err := ts.Service.UpdateWebhook(ctx, req)
// Should fail with permission denied or unauthenticated
require.Error(t, err)
})
}
func TestDeleteWebhook(t *testing.T) {
ctx := context.Background()
t.Run("DeleteWebhook removes webhook", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create a webhook
createReq := &v1pb.CreateWebhookRequest{
Parent: fmt.Sprintf("users/%d", hostUser.ID),
Webhook: &v1pb.Webhook{
DisplayName: "Test Webhook",
Url: "https://example.com/webhook",
},
}
createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq)
require.NoError(t, err)
// Delete the webhook
deleteReq := &v1pb.DeleteWebhookRequest{
Name: createdWebhook.Name,
}
_, err = ts.Service.DeleteWebhook(userCtx, deleteReq)
// Verify deletion
require.NoError(t, err)
// Try to get the deleted webhook
getReq := &v1pb.GetWebhookRequest{
Name: createdWebhook.Name,
}
_, err = ts.Service.GetWebhook(userCtx, getReq)
// Should return not found error
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
t.Run("DeleteWebhook fails without authentication", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Try to delete webhook without authentication
req := &v1pb.DeleteWebhookRequest{
Name: "users/1/webhooks/1",
}
_, err := ts.Service.DeleteWebhook(ctx, req)
// Should fail with permission denied or unauthenticated
require.Error(t, err)
})
t.Run("DeleteWebhook fails with non-existent webhook", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user and authenticate
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Try to delete non-existent webhook
req := &v1pb.DeleteWebhookRequest{
Name: fmt.Sprintf("users/%d/webhooks/999", hostUser.ID),
}
_, err = ts.Service.DeleteWebhook(userCtx, req)
// Should return not found error
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

View File

@@ -0,0 +1,206 @@
package v1
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestGetWorkspaceProfile(t *testing.T) {
ctx := context.Background()
t.Run("GetWorkspaceProfile returns workspace profile", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetWorkspaceProfile directly
req := &v1pb.GetWorkspaceProfileRequest{}
resp, err := ts.Service.GetWorkspaceProfile(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
// Verify the response contains expected data
require.Equal(t, "test-1.0.0", resp.Version)
require.Equal(t, "dev", resp.Mode)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// Owner should be empty since no users are created
require.Empty(t, resp.Owner)
})
t.Run("GetWorkspaceProfile with owner", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user in the store
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
require.NotNil(t, hostUser)
// Call GetWorkspaceProfile directly
req := &v1pb.GetWorkspaceProfileRequest{}
resp, err := ts.Service.GetWorkspaceProfile(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
// Verify the response contains expected data including owner
require.Equal(t, "test-1.0.0", resp.Version)
require.Equal(t, "dev", resp.Mode)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// User name should be "users/{id}" format where id is the user's ID
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
require.Equal(t, expectedOwnerName, resp.Owner)
})
}
func TestGetWorkspaceProfile_Concurrency(t *testing.T) {
ctx := context.Background()
t.Run("Concurrent access to service", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
// Make concurrent requests
numGoroutines := 10
results := make(chan *v1pb.WorkspaceProfile, numGoroutines)
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
req := &v1pb.GetWorkspaceProfileRequest{}
resp, err := ts.Service.GetWorkspaceProfile(ctx, req)
if err != nil {
errors <- err
return
}
results <- resp
}()
}
// Collect all results
for i := 0; i < numGoroutines; i++ {
select {
case err := <-errors:
t.Fatalf("Goroutine returned error: %v", err)
case resp := <-results:
require.NotNil(t, resp)
require.Equal(t, "test-1.0.0", resp.Version)
require.Equal(t, "dev", resp.Mode)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
require.Equal(t, expectedOwnerName, resp.Owner)
}
}
})
}
func TestGetWorkspaceSetting(t *testing.T) {
ctx := context.Background()
t.Run("GetWorkspaceSetting - general setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetWorkspaceSetting for general setting
req := &v1pb.GetWorkspaceSettingRequest{
Name: "workspace/settings/GENERAL",
}
resp, err := ts.Service.GetWorkspaceSetting(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "workspace/settings/GENERAL", resp.Name)
// The general setting should have a general_setting field
generalSetting := resp.GetGeneralSetting()
require.NotNil(t, generalSetting)
// General setting should have default values
require.False(t, generalSetting.DisallowUserRegistration)
require.False(t, generalSetting.DisallowPasswordAuth)
require.Empty(t, generalSetting.AdditionalScript)
})
t.Run("GetWorkspaceSetting - storage setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Create a host user for storage setting access
hostUser, err := ts.CreateHostUser(ctx, "testhost")
require.NoError(t, err)
// Add user to context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Call GetWorkspaceSetting for storage setting
req := &v1pb.GetWorkspaceSettingRequest{
Name: "workspace/settings/STORAGE",
}
resp, err := ts.Service.GetWorkspaceSetting(userCtx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "workspace/settings/STORAGE", resp.Name)
// The storage setting should have a storage_setting field
storageSetting := resp.GetStorageSetting()
require.NotNil(t, storageSetting)
})
t.Run("GetWorkspaceSetting - memo related setting", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetWorkspaceSetting for memo related setting
req := &v1pb.GetWorkspaceSettingRequest{
Name: "workspace/settings/MEMO_RELATED",
}
resp, err := ts.Service.GetWorkspaceSetting(ctx, req)
// Verify response
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "workspace/settings/MEMO_RELATED", resp.Name)
// The memo related setting should have a memo_related_setting field
memoRelatedSetting := resp.GetMemoRelatedSetting()
require.NotNil(t, memoRelatedSetting)
})
t.Run("GetWorkspaceSetting - invalid setting name", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
// Call GetWorkspaceSetting with invalid name
req := &v1pb.GetWorkspaceSettingRequest{
Name: "invalid/setting/name",
}
_, err := ts.Service.GetWorkspaceSetting(ctx, req)
// Should return an error
require.Error(t, err)
require.Contains(t, err.Error(), "invalid workspace setting name")
})
}

View File

@@ -0,0 +1,19 @@
package v1
import (
"context"
"github.com/usememos/memos/store"
)
// CreateTestUserContext creates a context with user's ID for testing purposes.
// This function is only intended for use in tests.
func CreateTestUserContext(ctx context.Context, userID int32) context.Context {
return context.WithValue(ctx, userIDContextKey, userID)
}
// CreateTestUserContextWithUser creates a context and ensures the user exists for testing.
// This function is only intended for use in tests.
func CreateTestUserContextWithUser(ctx context.Context, _ *APIV1Service, user *store.User) context.Context {
return context.WithValue(ctx, userIDContextKey, user.ID)
}

View File

@@ -0,0 +1,831 @@
package v1
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"regexp"
"slices"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/base"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListUsers(ctx context.Context, _ *v1pb.ListUsersRequest) (*v1pb.ListUsersResponse, error) {
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
users, err := s.Store.ListUsers(ctx, &store.FindUser{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
}
// TODO: Implement proper filtering, ordering, and pagination
// For now, return all users with basic structure
response := &v1pb.ListUsersResponse{
Users: []*v1pb.User{},
TotalSize: int32(len(users)),
}
for _, user := range users {
response.Users = append(response.Users, convertUserFromStore(user))
}
return response, nil
}
func (s *APIV1Service) GetUser(ctx context.Context, request *v1pb.GetUserRequest) (*v1pb.User, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
userPb := convertUserFromStore(user)
// TODO: Implement read_mask field filtering
// For now, return all fields
return userPb, nil
}
func (s *APIV1Service) SearchUsers(ctx context.Context, request *v1pb.SearchUsersRequest) (*v1pb.SearchUsersResponse, error) {
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Search users by username, email, or display name
users, err := s.Store.ListUsers(ctx, &store.FindUser{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
}
var filteredUsers []*store.User
query := strings.ToLower(request.Query)
for _, user := range users {
if strings.Contains(strings.ToLower(user.Username), query) ||
strings.Contains(strings.ToLower(user.Email), query) ||
strings.Contains(strings.ToLower(user.Nickname), query) {
filteredUsers = append(filteredUsers, user)
}
}
response := &v1pb.SearchUsersResponse{
Users: []*v1pb.User{},
TotalSize: int32(len(filteredUsers)),
}
for _, user := range filteredUsers {
response.Users = append(response.Users, convertUserFromStore(user))
}
return response, nil
}
func (s *APIV1Service) GetUserAvatar(ctx context.Context, request *v1pb.GetUserAvatarRequest) (*httpbody.HttpBody, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if user.AvatarURL == "" {
return nil, status.Errorf(codes.NotFound, "avatar not found")
}
imageType, base64Data, err := extractImageInfo(user.AvatarURL)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to extract image info: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to decode string: %v", err)
}
httpBody := &httpbody.HttpBody{
ContentType: imageType,
Data: imageData,
}
return httpBody, nil
}
func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserRequest) (*v1pb.User, error) {
// Check if there are any existing host users (for first-time setup detection)
hostUserType := store.RoleHost
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list host users: %v", err)
}
// Determine the role to assign and check permissions
var roleToAssign store.Role
if len(existedHostUsers) == 0 {
// First-time setup: create the first user as HOST (no authentication required)
roleToAssign = store.RoleHost
} else {
// Regular user creation: allow unauthenticated creation of normal users
// But if authenticated, check if user has HOST permission for any role
currentUser, err := s.GetCurrentUser(ctx)
if err == nil && currentUser != nil && currentUser.Role == store.RoleHost {
// Authenticated HOST user can create users with any role specified in request
if request.User.Role != v1pb.User_ROLE_UNSPECIFIED {
roleToAssign = convertUserRoleToStore(request.User.Role)
} else {
roleToAssign = store.RoleUser
}
} else {
// Unauthenticated or non-HOST users can only create normal users
roleToAssign = store.RoleUser
}
}
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
}
// If validate_only is true, just validate without creating
if request.ValidateOnly {
// Perform validation checks without actually creating the user
return &v1pb.User{
Username: request.User.Username,
Email: request.User.Email,
DisplayName: request.User.DisplayName,
Role: convertUserRoleFromStore(roleToAssign),
}, nil
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
user, err := s.Store.CreateUser(ctx, &store.User{
Username: request.User.Username,
Role: roleToAssign,
Email: request.User.Email,
Nickname: request.User.DisplayName,
PasswordHash: string(passwordHash),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserRequest) (*v1pb.User, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
userID, err := ExtractUserIDFromName(request.User.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
// Check permission.
// Only allow admin or self to update user.
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
// Handle allow_missing field
if request.AllowMissing {
// Could create user if missing, but for now return not found
return nil, status.Errorf(codes.NotFound, "user not found")
}
return nil, status.Errorf(codes.NotFound, "user not found")
}
currentTs := time.Now().Unix()
update := &store.UpdateUser{
ID: user.ID,
UpdatedTs: &currentTs,
}
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting: %v", err)
}
for _, field := range request.UpdateMask.Paths {
switch field {
case "username":
if workspaceGeneralSetting.DisallowChangeUsername {
return nil, status.Errorf(codes.PermissionDenied, "permission denied: disallow change username")
}
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
}
update.Username = &request.User.Username
case "display_name":
if workspaceGeneralSetting.DisallowChangeNickname {
return nil, status.Errorf(codes.PermissionDenied, "permission denied: disallow change nickname")
}
update.Nickname = &request.User.DisplayName
case "email":
update.Email = &request.User.Email
case "avatar_url":
update.AvatarURL = &request.User.AvatarUrl
case "description":
update.Description = &request.User.Description
case "role":
// Only allow admin to update role.
if currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
role := convertUserRoleToStore(request.User.Role)
update.Role = &role
case "password":
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
passwordHashStr := string(passwordHash)
update.PasswordHash = &passwordHashStr
case "state":
rowStatus := convertStateToStore(request.User.State)
update.RowStatus = &rowStatus
default:
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
updatedUser, err := s.Store.UpdateUser(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
}
return convertUserFromStore(updatedUser), nil
}
func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserRequest) (*emptypb.Empty, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err)
}
return &emptypb.Empty{}, nil
}
func getDefaultUserSetting() *v1pb.UserSetting {
return &v1pb.UserSetting{
Name: "", // Will be set by caller
Locale: "en",
Appearance: "system",
MemoVisibility: "PRIVATE",
Theme: "",
}
}
func (s *APIV1Service) GetUserSetting(ctx context.Context, request *v1pb.GetUserSettingRequest) (*v1pb.UserSetting, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
// Only allow user to get their own settings
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSettings, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err)
}
userSettingMessage := getDefaultUserSetting()
userSettingMessage.Name = fmt.Sprintf("users/%d", userID)
for _, setting := range userSettings {
if setting.Key == storepb.UserSetting_GENERAL {
general := setting.GetGeneral()
if general != nil {
userSettingMessage.Locale = general.Locale
userSettingMessage.Appearance = general.Appearance
userSettingMessage.MemoVisibility = general.MemoVisibility
userSettingMessage.Theme = general.Theme
}
}
}
// Backfill theme if empty: use workspace theme or default to "default"
if userSettingMessage.Theme == "" {
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting: %v", err)
}
workspaceTheme := workspaceGeneralSetting.Theme
if workspaceTheme == "" {
workspaceTheme = "default"
}
userSettingMessage.Theme = workspaceTheme
}
return userSettingMessage, nil
}
func (s *APIV1Service) UpdateUserSetting(ctx context.Context, request *v1pb.UpdateUserSettingRequest) (*v1pb.UserSetting, error) {
// Extract user ID from the setting resource name
userID, err := ExtractUserIDFromName(request.Setting.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
// Only allow user to update their own settings
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
// Get the current general setting
existingGeneralSetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &userID,
Key: storepb.UserSetting_GENERAL,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get existing general setting: %v", err)
}
// Create or update the general setting
generalSetting := &storepb.GeneralUserSetting{
Locale: "en",
Appearance: "system",
MemoVisibility: "PRIVATE",
Theme: "",
}
// If there's an existing setting, use its values as defaults
if existingGeneralSetting != nil && existingGeneralSetting.GetGeneral() != nil {
existing := existingGeneralSetting.GetGeneral()
generalSetting.Locale = existing.Locale
generalSetting.Appearance = existing.Appearance
generalSetting.MemoVisibility = existing.MemoVisibility
generalSetting.Theme = existing.Theme
}
// Apply updates based on the update mask
for _, field := range request.UpdateMask.Paths {
switch field {
case "locale":
generalSetting.Locale = request.Setting.Locale
case "appearance":
generalSetting.Appearance = request.Setting.Appearance
case "memo_visibility":
generalSetting.MemoVisibility = request.Setting.MemoVisibility
case "theme":
generalSetting.Theme = request.Setting.Theme
default:
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
// Upsert the general setting
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: userID,
Key: storepb.UserSetting_GENERAL,
Value: &storepb.UserSetting_General{
General: generalSetting,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return s.GetUserSetting(ctx, &v1pb.GetUserSettingRequest{Name: request.Setting.Name})
}
func (s *APIV1Service) ListUserAccessTokens(ctx context.Context, request *v1pb.ListUserAccessTokensRequest) (*v1pb.ListUserAccessTokensResponse, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
accessTokens := []*v1pb.UserAccessToken{}
for _, userAccessToken := range userAccessTokens {
claims := &ClaimsMessage{}
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
// If the access token is invalid or expired, just ignore it.
continue
}
accessTokenResponse := &v1pb.UserAccessToken{
Name: fmt.Sprintf("users/%d/accessTokens/%s", userID, userAccessToken.AccessToken),
AccessToken: userAccessToken.AccessToken,
Description: userAccessToken.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
accessTokenResponse.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
accessTokens = append(accessTokens, accessTokenResponse)
}
// Sort by issued time in descending order.
slices.SortFunc(accessTokens, func(i, j *v1pb.UserAccessToken) int {
return int(i.IssuedAt.Seconds - j.IssuedAt.Seconds)
})
response := &v1pb.ListUserAccessTokensResponse{
AccessTokens: accessTokens,
}
return response, nil
}
func (s *APIV1Service) CreateUserAccessToken(ctx context.Context, request *v1pb.CreateUserAccessTokenRequest) (*v1pb.UserAccessToken, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
expiresAt := time.Time{}
if request.AccessToken.ExpiresAt != nil {
expiresAt = request.AccessToken.ExpiresAt.AsTime()
}
accessToken, err := GenerateAccessToken(currentUser.Username, currentUser.ID, expiresAt, []byte(s.Secret))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
claims := &ClaimsMessage{}
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
}
// Upsert the access token to user setting store.
if err := s.UpsertAccessTokenToStore(ctx, currentUser, accessToken, request.AccessToken.Description); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
}
userAccessToken := &v1pb.UserAccessToken{
Name: fmt.Sprintf("users/%d/accessTokens/%s", userID, accessToken),
AccessToken: accessToken,
Description: request.AccessToken.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
return userAccessToken, nil
}
func (s *APIV1Service) DeleteUserAccessToken(ctx context.Context, request *v1pb.DeleteUserAccessTokenRequest) (*emptypb.Empty, error) {
// Extract user ID from the access token resource name
// Format: users/{user}/accessTokens/{access_token}
parts := strings.Split(request.Name, "/")
if len(parts) != 4 || parts[0] != "users" || parts[2] != "accessTokens" {
return nil, status.Errorf(codes.InvalidArgument, "invalid access token name format: %s", request.Name)
}
userID, err := ExtractUserIDFromName(fmt.Sprintf("users/%s", parts[1]))
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
accessTokenToDelete := parts[3]
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, currentUser.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
for _, userAccessToken := range userAccessTokens {
if userAccessToken.AccessToken == accessTokenToDelete {
continue
}
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
}
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: currentUser.ID,
Key: storepb.UserSetting_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: updatedUserAccessTokens,
},
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListUserSessions(ctx context.Context, request *v1pb.ListUserSessionsRequest) (*v1pb.ListUserSessionsResponse, error) {
userID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userSessions, err := s.Store.GetUserSessions(ctx, userID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list sessions: %v", err)
}
sessions := []*v1pb.UserSession{}
for _, userSession := range userSessions {
sessionResponse := &v1pb.UserSession{
Name: fmt.Sprintf("users/%d/sessions/%s", userID, userSession.SessionId),
SessionId: userSession.SessionId,
CreateTime: userSession.CreateTime,
LastAccessedTime: userSession.LastAccessedTime,
}
if userSession.ClientInfo != nil {
sessionResponse.ClientInfo = &v1pb.UserSession_ClientInfo{
UserAgent: userSession.ClientInfo.UserAgent,
IpAddress: userSession.ClientInfo.IpAddress,
DeviceType: userSession.ClientInfo.DeviceType,
Os: userSession.ClientInfo.Os,
Browser: userSession.ClientInfo.Browser,
}
}
sessions = append(sessions, sessionResponse)
}
// Sort by last accessed time in descending order.
slices.SortFunc(sessions, func(i, j *v1pb.UserSession) int {
return int(j.LastAccessedTime.Seconds - i.LastAccessedTime.Seconds)
})
response := &v1pb.ListUserSessionsResponse{
Sessions: sessions,
}
return response, nil
}
func (s *APIV1Service) RevokeUserSession(ctx context.Context, request *v1pb.RevokeUserSessionRequest) (*emptypb.Empty, error) {
// Extract user ID and session ID from the session resource name
// Format: users/{user}/sessions/{session}
parts := strings.Split(request.Name, "/")
if len(parts) != 4 || parts[0] != "users" || parts[2] != "sessions" {
return nil, status.Errorf(codes.InvalidArgument, "invalid session name format: %s", request.Name)
}
userID, err := ExtractUserIDFromName(fmt.Sprintf("users/%s", parts[1]))
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
sessionIDToRevoke := parts[3]
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if err := s.Store.RemoveUserSession(ctx, userID, sessionIDToRevoke); err != nil {
return nil, status.Errorf(codes.Internal, "failed to revoke session: %v", err)
}
return &emptypb.Empty{}, nil
}
// UpsertUserSession adds or updates a user session.
func (s *APIV1Service) UpsertUserSession(ctx context.Context, userID int32, sessionID string, clientInfo *storepb.SessionsUserSetting_ClientInfo) error {
session := &storepb.SessionsUserSetting_Session{
SessionId: sessionID,
CreateTime: timestamppb.Now(),
LastAccessedTime: timestamppb.Now(),
ClientInfo: clientInfo,
}
return s.Store.AddUserSession(ctx, userID, session)
}
func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: description,
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return errors.Wrap(err, "failed to upsert user setting")
}
return nil
}
func convertUserFromStore(user *store.User) *v1pb.User {
userpb := &v1pb.User{
Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
State: convertStateFromStore(user.RowStatus),
CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
Role: convertUserRoleFromStore(user.Role),
Username: user.Username,
Email: user.Email,
DisplayName: user.Nickname,
AvatarUrl: user.AvatarURL,
Description: user.Description,
}
// Use the avatar URL instead of raw base64 image data to reduce the response size.
if user.AvatarURL != "" {
// Check if avatar url is base64 format.
_, _, err := extractImageInfo(user.AvatarURL)
if err == nil {
userpb.AvatarUrl = fmt.Sprintf("/api/v1/%s/avatar", userpb.Name)
} else {
userpb.AvatarUrl = user.AvatarURL
}
}
return userpb
}
func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
switch role {
case store.RoleHost:
return v1pb.User_HOST
case store.RoleAdmin:
return v1pb.User_ADMIN
case store.RoleUser:
return v1pb.User_USER
default:
return v1pb.User_ROLE_UNSPECIFIED
}
}
func convertUserRoleToStore(role v1pb.User_Role) store.Role {
switch role {
case v1pb.User_HOST:
return store.RoleHost
case v1pb.User_ADMIN:
return store.RoleAdmin
case v1pb.User_USER:
return store.RoleUser
default:
return store.RoleUser
}
}
func extractImageInfo(dataURI string) (string, string, error) {
dataURIRegex := regexp.MustCompile(`^data:(?P<type>.+);base64,(?P<base64>.+)`)
matches := dataURIRegex.FindStringSubmatch(dataURI)
if len(matches) != 3 {
return "", "", errors.New("Invalid data URI format")
}
imageType := matches[1]
base64Data := matches[2]
return imageType, base64Data, nil
}

View File

@@ -0,0 +1,168 @@
package v1
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUserStatsRequest) (*v1pb.ListAllUserStatsResponse, error) {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
}
normalStatus := store.Normal
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &normalStatus,
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else {
if memoFind.CreatorID == nil {
internalFilter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
if memoFind.Filter != nil {
filter := fmt.Sprintf("(%s) && (%s)", *memoFind.Filter, internalFilter)
memoFind.Filter = &filter
} else {
memoFind.Filter = &internalFilter
}
} else if *memoFind.CreatorID != currentUser.ID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
userMemoStatMap := make(map[int32]*v1pb.UserStats)
for _, memo := range memos {
displayTs := memo.CreatedTs
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{
Name: fmt.Sprintf("users/%d/stats", memo.CreatorID),
}
userMemoStatMap[memo.CreatorID].MemoDisplayTimestamps = append(userMemoStatMap[memo.CreatorID].MemoDisplayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
}
userMemoStats := []*v1pb.UserStats{}
for _, userMemoStat := range userMemoStatMap {
userMemoStats = append(userMemoStats, userMemoStat)
}
response := &v1pb.ListAllUserStatsResponse{
UserStats: userMemoStats,
}
return response, nil
}
func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserStatsRequest) (*v1pb.UserStats, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
normalStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &userID,
// Exclude comments by default.
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &normalStatus,
}
if currentUser == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
} else if currentUser.ID != userID {
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
}
displayTimestamps := []*timestamppb.Timestamp{}
tagCount := make(map[string]int32)
linkCount := int32(0)
codeCount := int32(0)
todoCount := int32(0)
undoCount := int32(0)
pinnedMemos := []string{}
for _, memo := range memos {
displayTs := memo.CreatedTs
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
displayTimestamps = append(displayTimestamps, timestamppb.New(time.Unix(displayTs, 0)))
// Count different memo types based on content.
if memo.Payload != nil {
for _, tag := range memo.Payload.Tags {
tagCount[tag]++
}
if memo.Payload.Property != nil {
if memo.Payload.Property.HasLink {
linkCount++
}
if memo.Payload.Property.HasCode {
codeCount++
}
if memo.Payload.Property.HasTaskList {
todoCount++
}
if memo.Payload.Property.HasIncompleteTasks {
undoCount++
}
}
}
if memo.Pinned {
pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID))
}
}
userStats := &v1pb.UserStats{
Name: fmt.Sprintf("users/%d/stats", userID),
MemoDisplayTimestamps: displayTimestamps,
TagCount: tagCount,
PinnedMemos: pinnedMemos,
TotalMemoCount: int32(len(memos)),
MemoTypeStats: &v1pb.UserStats_MemoTypeStats{
LinkCount: linkCount,
CodeCount: codeCount,
TodoCount: todoCount,
UndoCount: undoCount,
},
}
return userStats, nil
}

137
server/router/api/v1/v1.go Normal file
View File

@@ -0,0 +1,137 @@
package v1
import (
"context"
"fmt"
"math"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
"github.com/usememos/memos/internal/profile"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
type APIV1Service struct {
grpc_health_v1.UnimplementedHealthServer
v1pb.UnimplementedWorkspaceServiceServer
v1pb.UnimplementedAuthServiceServer
v1pb.UnimplementedUserServiceServer
v1pb.UnimplementedMemoServiceServer
v1pb.UnimplementedAttachmentServiceServer
v1pb.UnimplementedShortcutServiceServer
v1pb.UnimplementedInboxServiceServer
v1pb.UnimplementedActivityServiceServer
v1pb.UnimplementedWebhookServiceServer
v1pb.UnimplementedMarkdownServiceServer
v1pb.UnimplementedIdentityProviderServiceServer
Secret string
Profile *profile.Profile
Store *store.Store
grpcServer *grpc.Server
}
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store, grpcServer *grpc.Server) *APIV1Service {
grpc.EnableTracing = true
apiv1Service := &APIV1Service{
Secret: secret,
Profile: profile,
Store: store,
grpcServer: grpcServer,
}
grpc_health_v1.RegisterHealthServer(grpcServer, apiv1Service)
v1pb.RegisterWorkspaceServiceServer(grpcServer, apiv1Service)
v1pb.RegisterAuthServiceServer(grpcServer, apiv1Service)
v1pb.RegisterUserServiceServer(grpcServer, apiv1Service)
v1pb.RegisterMemoServiceServer(grpcServer, apiv1Service)
v1pb.RegisterAttachmentServiceServer(grpcServer, apiv1Service)
v1pb.RegisterShortcutServiceServer(grpcServer, apiv1Service)
v1pb.RegisterInboxServiceServer(grpcServer, apiv1Service)
v1pb.RegisterActivityServiceServer(grpcServer, apiv1Service)
v1pb.RegisterWebhookServiceServer(grpcServer, apiv1Service)
v1pb.RegisterMarkdownServiceServer(grpcServer, apiv1Service)
v1pb.RegisterIdentityProviderServiceServer(grpcServer, apiv1Service)
reflection.Register(grpcServer)
return apiv1Service
}
// RegisterGateway registers the gRPC-Gateway with the given Echo instance.
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
var target string
if len(s.Profile.UNIXSock) == 0 {
target = fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port)
} else {
target = fmt.Sprintf("unix:%s", s.Profile.UNIXSock)
}
conn, err := grpc.NewClient(
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
)
if err != nil {
return err
}
gwMux := runtime.NewServeMux()
if err := v1pb.RegisterWorkspaceServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterAuthServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterUserServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterMemoServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterAttachmentServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterShortcutServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterInboxServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterActivityServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterWebhookServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterMarkdownServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterIdentityProviderServiceHandler(ctx, gwMux, conn); err != nil {
return err
}
gwGroup := echoServer.Group("")
gwGroup.Use(middleware.CORS())
handler := echo.WrapHandler(gwMux)
gwGroup.Any("/api/v1/*", handler)
gwGroup.Any("/file/*", handler)
// GRPC web proxy.
options := []grpcweb.Option{
grpcweb.WithCorsForRegisteredEndpointsOnly(false),
grpcweb.WithOriginFunc(func(_ string) bool {
return true
}),
}
wrappedGrpc := grpcweb.WrapServer(s.grpcServer, options...)
echoServer.Any("/memos.api.v1.*", echo.WrapHandler(wrappedGrpc))
return nil
}

View File

@@ -0,0 +1,317 @@
package v1
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/util"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
func (s *APIV1Service) CreateWebhook(ctx context.Context, request *v1pb.CreateWebhookRequest) (*v1pb.Webhook, error) {
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Extract user ID from parent (format: users/{user})
parentUserID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
}
// Users can only create webhooks for themselves
if parentUserID != currentUser.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Only host users can create webhooks
if !isSuperUser(currentUser) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Validate required fields
if request.Webhook == nil {
return nil, status.Errorf(codes.InvalidArgument, "webhook is required")
}
if strings.TrimSpace(request.Webhook.Url) == "" {
return nil, status.Errorf(codes.InvalidArgument, "webhook URL is required")
}
// Handle validate_only field
if request.ValidateOnly {
// Perform validation checks without actually creating the webhook
return &v1pb.Webhook{
Name: fmt.Sprintf("users/%d/webhooks/validate", currentUser.ID),
DisplayName: request.Webhook.DisplayName,
Url: request.Webhook.Url,
}, nil
}
err = s.Store.AddUserWebhook(ctx, currentUser.ID, &storepb.WebhooksUserSetting_Webhook{
Id: generateWebhookID(),
Title: request.Webhook.DisplayName,
Url: strings.TrimSpace(request.Webhook.Url),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create webhook, error: %+v", err)
}
// Return the newly created webhook
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user webhooks, error: %+v", err)
}
// Find the webhook we just created
for _, webhook := range webhooks {
if webhook.Title == request.Webhook.DisplayName && webhook.Url == strings.TrimSpace(request.Webhook.Url) {
return convertWebhookFromUserSetting(webhook, currentUser.ID), nil
}
}
return nil, status.Errorf(codes.Internal, "failed to find created webhook")
}
func (s *APIV1Service) ListWebhooks(ctx context.Context, request *v1pb.ListWebhooksRequest) (*v1pb.ListWebhooksResponse, error) {
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Extract user ID from parent (format: users/{user})
parentUserID, err := ExtractUserIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
}
// Users can only list their own webhooks
if parentUserID != currentUser.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list webhooks, error: %+v", err)
}
response := &v1pb.ListWebhooksResponse{
Webhooks: []*v1pb.Webhook{},
}
for _, webhook := range webhooks {
response.Webhooks = append(response.Webhooks, convertWebhookFromUserSetting(webhook, currentUser.ID))
}
return response, nil
}
func (s *APIV1Service) GetWebhook(ctx context.Context, request *v1pb.GetWebhookRequest) (*v1pb.Webhook, error) {
// Extract user ID and webhook ID from name (format: users/{user}/webhooks/{webhook})
tokens, err := GetNameParentTokens(request.Name, UserNamePrefix, WebhookNamePrefix)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
}
if len(tokens) != 2 {
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name format")
}
userIDStr := tokens[0]
webhookID := tokens[1]
requestedUserID, err := util.ConvertStringToInt32(userIDStr)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID in webhook name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Users can only access their own webhooks
if requestedUserID != currentUser.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get webhooks, error: %+v", err)
}
// Find webhook by ID
for _, webhook := range webhooks {
if webhook.Id == webhookID {
return convertWebhookFromUserSetting(webhook, currentUser.ID), nil
}
}
return nil, status.Errorf(codes.NotFound, "webhook not found")
}
func (s *APIV1Service) UpdateWebhook(ctx context.Context, request *v1pb.UpdateWebhookRequest) (*v1pb.Webhook, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
// Extract user ID and webhook ID from name (format: users/{user}/webhooks/{webhook})
tokens, err := GetNameParentTokens(request.Webhook.Name, UserNamePrefix, WebhookNamePrefix)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
}
if len(tokens) != 2 {
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name format")
}
userIDStr := tokens[0]
webhookID := tokens[1]
requestedUserID, err := util.ConvertStringToInt32(userIDStr)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID in webhook name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Users can only update their own webhooks
if requestedUserID != currentUser.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Get existing webhooks from user settings
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get webhooks: %v", err)
}
// Find the webhook to update
var existingWebhook *storepb.WebhooksUserSetting_Webhook
for _, webhook := range webhooks {
if webhook.Id == webhookID {
existingWebhook = webhook
break
}
}
if existingWebhook == nil {
return nil, status.Errorf(codes.NotFound, "webhook not found")
}
// Create updated webhook
updatedWebhook := &storepb.WebhooksUserSetting_Webhook{
Id: existingWebhook.Id,
Title: existingWebhook.Title,
Url: existingWebhook.Url,
}
// Apply updates based on update mask
for _, field := range request.UpdateMask.Paths {
switch field {
case "display_name":
updatedWebhook.Title = request.Webhook.DisplayName
case "url":
updatedWebhook.Url = request.Webhook.Url
default:
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
// Update the webhook in user settings
err = s.Store.UpdateUserWebhook(ctx, currentUser.ID, updatedWebhook)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update webhook: %v", err)
}
return convertWebhookFromUserSetting(updatedWebhook, currentUser.ID), nil
}
func (s *APIV1Service) DeleteWebhook(ctx context.Context, request *v1pb.DeleteWebhookRequest) (*emptypb.Empty, error) {
// Extract user ID and webhook ID from name (format: users/{user}/webhooks/{webhook})
tokens, err := GetNameParentTokens(request.Name, UserNamePrefix, WebhookNamePrefix)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
}
if len(tokens) != 2 {
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name format")
}
userIDStr := tokens[0]
webhookID := tokens[1]
requestedUserID, err := util.ConvertStringToInt32(userIDStr)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID in webhook name: %v", err)
}
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Users can only delete their own webhooks
if requestedUserID != currentUser.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Get existing webhooks from user settings to verify it exists
webhooks, err := s.Store.GetUserWebhooks(ctx, currentUser.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get webhooks: %v", err)
}
// Check if webhook exists
webhookExists := false
for _, webhook := range webhooks {
if webhook.Id == webhookID {
webhookExists = true
break
}
}
if !webhookExists {
return nil, status.Errorf(codes.NotFound, "webhook not found")
}
err = s.Store.RemoveUserWebhook(ctx, currentUser.ID, webhookID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete webhook: %v", err)
}
return &emptypb.Empty{}, nil
}
func convertWebhookFromUserSetting(webhook *storepb.WebhooksUserSetting_Webhook, userID int32) *v1pb.Webhook {
return &v1pb.Webhook{
Name: fmt.Sprintf("users/%d/webhooks/%s", userID, webhook.Id),
DisplayName: webhook.Title,
Url: webhook.Url,
}
}
func generateWebhookID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}

View File

@@ -0,0 +1,306 @@
package v1
import (
"context"
"fmt"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// GetWorkspaceProfile returns the workspace profile.
func (s *APIV1Service) GetWorkspaceProfile(ctx context.Context, _ *v1pb.GetWorkspaceProfileRequest) (*v1pb.WorkspaceProfile, error) {
workspaceProfile := &v1pb.WorkspaceProfile{
Version: s.Profile.Version,
Mode: s.Profile.Mode,
InstanceUrl: s.Profile.InstanceURL,
}
owner, err := s.GetInstanceOwner(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance owner: %v", err)
}
if owner != nil {
workspaceProfile.Owner = owner.Name
}
return workspaceProfile, nil
}
func (s *APIV1Service) GetWorkspaceSetting(ctx context.Context, request *v1pb.GetWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
workspaceSettingKeyString, err := ExtractWorkspaceSettingKeyFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid workspace setting name: %v", err)
}
workspaceSettingKey := storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[workspaceSettingKeyString])
// Get workspace setting from store with default value.
switch workspaceSettingKey {
case storepb.WorkspaceSettingKey_BASIC:
_, err = s.Store.GetWorkspaceBasicSetting(ctx)
case storepb.WorkspaceSettingKey_GENERAL:
_, err = s.Store.GetWorkspaceGeneralSetting(ctx)
case storepb.WorkspaceSettingKey_MEMO_RELATED:
_, err = s.Store.GetWorkspaceMemoRelatedSetting(ctx)
case storepb.WorkspaceSettingKey_STORAGE:
_, err = s.Store.GetWorkspaceStorageSetting(ctx)
default:
return nil, status.Errorf(codes.InvalidArgument, "unsupported workspace setting key: %v", workspaceSettingKey)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
}
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
Name: workspaceSettingKey.String(),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
}
if workspaceSetting == nil {
return nil, status.Errorf(codes.NotFound, "workspace setting not found")
}
// For storage setting, only host can get it.
if workspaceSetting.Key == storepb.WorkspaceSettingKey_STORAGE {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil || user.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
return convertWorkspaceSettingFromStore(workspaceSetting), nil
}
func (s *APIV1Service) UpdateWorkspaceSetting(ctx context.Context, request *v1pb.UpdateWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// TODO: Apply update_mask if specified
_ = request.UpdateMask
updateSetting := convertWorkspaceSettingToStore(request.Setting)
workspaceSetting, err := s.Store.UpsertWorkspaceSetting(ctx, updateSetting)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert workspace setting: %v", err)
}
return convertWorkspaceSettingFromStore(workspaceSetting), nil
}
func convertWorkspaceSettingFromStore(setting *storepb.WorkspaceSetting) *v1pb.WorkspaceSetting {
workspaceSetting := &v1pb.WorkspaceSetting{
Name: fmt.Sprintf("workspace/settings/%s", setting.Key.String()),
}
switch setting.Value.(type) {
case *storepb.WorkspaceSetting_GeneralSetting:
workspaceSetting.Value = &v1pb.WorkspaceSetting_GeneralSetting{
GeneralSetting: convertWorkspaceGeneralSettingFromStore(setting.GetGeneralSetting()),
}
case *storepb.WorkspaceSetting_StorageSetting:
workspaceSetting.Value = &v1pb.WorkspaceSetting_StorageSetting{
StorageSetting: convertWorkspaceStorageSettingFromStore(setting.GetStorageSetting()),
}
case *storepb.WorkspaceSetting_MemoRelatedSetting:
workspaceSetting.Value = &v1pb.WorkspaceSetting_MemoRelatedSetting{
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingFromStore(setting.GetMemoRelatedSetting()),
}
}
return workspaceSetting
}
func convertWorkspaceSettingToStore(setting *v1pb.WorkspaceSetting) *storepb.WorkspaceSetting {
settingKeyString, _ := ExtractWorkspaceSettingKeyFromName(setting.Name)
workspaceSetting := &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString]),
Value: &storepb.WorkspaceSetting_GeneralSetting{
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
},
}
switch workspaceSetting.Key {
case storepb.WorkspaceSettingKey_GENERAL:
workspaceSetting.Value = &storepb.WorkspaceSetting_GeneralSetting{
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
}
case storepb.WorkspaceSettingKey_STORAGE:
workspaceSetting.Value = &storepb.WorkspaceSetting_StorageSetting{
StorageSetting: convertWorkspaceStorageSettingToStore(setting.GetStorageSetting()),
}
case storepb.WorkspaceSettingKey_MEMO_RELATED:
workspaceSetting.Value = &storepb.WorkspaceSetting_MemoRelatedSetting{
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingToStore(setting.GetMemoRelatedSetting()),
}
}
return workspaceSetting
}
func convertWorkspaceGeneralSettingFromStore(setting *storepb.WorkspaceGeneralSetting) *v1pb.WorkspaceGeneralSetting {
if setting == nil {
return nil
}
// Backfill theme if empty
theme := setting.Theme
if theme == "" {
theme = "default"
}
generalSetting := &v1pb.WorkspaceGeneralSetting{
Theme: theme,
DisallowUserRegistration: setting.DisallowUserRegistration,
DisallowPasswordAuth: setting.DisallowPasswordAuth,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
WeekStartDayOffset: setting.WeekStartDayOffset,
DisallowChangeUsername: setting.DisallowChangeUsername,
DisallowChangeNickname: setting.DisallowChangeNickname,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &v1pb.WorkspaceCustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
Locale: setting.CustomProfile.Locale,
Appearance: setting.CustomProfile.Appearance,
}
}
return generalSetting
}
func convertWorkspaceGeneralSettingToStore(setting *v1pb.WorkspaceGeneralSetting) *storepb.WorkspaceGeneralSetting {
if setting == nil {
return nil
}
generalSetting := &storepb.WorkspaceGeneralSetting{
Theme: setting.Theme,
DisallowUserRegistration: setting.DisallowUserRegistration,
DisallowPasswordAuth: setting.DisallowPasswordAuth,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
WeekStartDayOffset: setting.WeekStartDayOffset,
DisallowChangeUsername: setting.DisallowChangeUsername,
DisallowChangeNickname: setting.DisallowChangeNickname,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &storepb.WorkspaceCustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
Locale: setting.CustomProfile.Locale,
Appearance: setting.CustomProfile.Appearance,
}
}
return generalSetting
}
func convertWorkspaceStorageSettingFromStore(settingpb *storepb.WorkspaceStorageSetting) *v1pb.WorkspaceStorageSetting {
if settingpb == nil {
return nil
}
setting := &v1pb.WorkspaceStorageSetting{
StorageType: v1pb.WorkspaceStorageSetting_StorageType(settingpb.StorageType),
FilepathTemplate: settingpb.FilepathTemplate,
UploadSizeLimitMb: settingpb.UploadSizeLimitMb,
}
if settingpb.S3Config != nil {
setting.S3Config = &v1pb.WorkspaceStorageSetting_S3Config{
AccessKeyId: settingpb.S3Config.AccessKeyId,
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
Endpoint: settingpb.S3Config.Endpoint,
Region: settingpb.S3Config.Region,
Bucket: settingpb.S3Config.Bucket,
UsePathStyle: settingpb.S3Config.UsePathStyle,
}
}
return setting
}
func convertWorkspaceStorageSettingToStore(setting *v1pb.WorkspaceStorageSetting) *storepb.WorkspaceStorageSetting {
if setting == nil {
return nil
}
settingpb := &storepb.WorkspaceStorageSetting{
StorageType: storepb.WorkspaceStorageSetting_StorageType(setting.StorageType),
FilepathTemplate: setting.FilepathTemplate,
UploadSizeLimitMb: setting.UploadSizeLimitMb,
}
if setting.S3Config != nil {
settingpb.S3Config = &storepb.StorageS3Config{
AccessKeyId: setting.S3Config.AccessKeyId,
AccessKeySecret: setting.S3Config.AccessKeySecret,
Endpoint: setting.S3Config.Endpoint,
Region: setting.S3Config.Region,
Bucket: setting.S3Config.Bucket,
UsePathStyle: setting.S3Config.UsePathStyle,
}
}
return settingpb
}
func convertWorkspaceMemoRelatedSettingFromStore(setting *storepb.WorkspaceMemoRelatedSetting) *v1pb.WorkspaceMemoRelatedSetting {
if setting == nil {
return nil
}
return &v1pb.WorkspaceMemoRelatedSetting{
DisallowPublicVisibility: setting.DisallowPublicVisibility,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
ContentLengthLimit: setting.ContentLengthLimit,
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
EnableLinkPreview: setting.EnableLinkPreview,
EnableComment: setting.EnableComment,
Reactions: setting.Reactions,
DisableMarkdownShortcuts: setting.DisableMarkdownShortcuts,
EnableBlurNsfwContent: setting.EnableBlurNsfwContent,
NsfwTags: setting.NsfwTags,
}
}
func convertWorkspaceMemoRelatedSettingToStore(setting *v1pb.WorkspaceMemoRelatedSetting) *storepb.WorkspaceMemoRelatedSetting {
if setting == nil {
return nil
}
return &storepb.WorkspaceMemoRelatedSetting{
DisallowPublicVisibility: setting.DisallowPublicVisibility,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
ContentLengthLimit: setting.ContentLengthLimit,
EnableDoubleClickEdit: setting.EnableDoubleClickEdit,
EnableLinkPreview: setting.EnableLinkPreview,
EnableComment: setting.EnableComment,
Reactions: setting.Reactions,
DisableMarkdownShortcuts: setting.DisableMarkdownShortcuts,
EnableBlurNsfwContent: setting.EnableBlurNsfwContent,
NsfwTags: setting.NsfwTags,
}
}
var ownerCache *v1pb.User
func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) {
if ownerCache != nil {
return ownerCache, nil
}
hostUserType := store.RoleHost
user, err := s.Store.GetUser(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return nil, errors.Wrapf(err, "failed to find owner")
}
if user == nil {
return nil, nil
}
ownerCache = convertUserFromStore(user)
return ownerCache, nil
}

View File

@@ -0,0 +1,61 @@
package frontend
import (
"context"
"embed"
"io/fs"
"net/http"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
//go:embed dist/*
var embeddedFiles embed.FS
type FrontendService struct {
Profile *profile.Profile
Store *store.Store
}
func NewFrontendService(profile *profile.Profile, store *store.Store) *FrontendService {
return &FrontendService{
Profile: profile,
Store: store,
}
}
func (*FrontendService) Serve(_ context.Context, e *echo.Echo) {
skipper := func(c echo.Context) bool {
// Skip API routes.
if util.HasPrefixes(c.Path(), "/api", "/memos.api.v1") {
return true
}
// Skip setting cache headers for index.html
if c.Path() == "/" || c.Path() == "/index.html" {
return false
}
// Set Cache-Control header to allow public caching with a max-age of 7 days.
c.Response().Header().Set(echo.HeaderCacheControl, "public, max-age=604800") // 7 days
return false
}
// Route to serve the main app with HTML5 fallback for SPA behavior.
e.Use(middleware.StaticWithConfig(middleware.StaticConfig{
Filesystem: getFileSystem("dist"),
HTML5: true, // Enable fallback to index.html
Skipper: skipper,
}))
}
func getFileSystem(path string) http.FileSystem {
fs, err := fs.Sub(embeddedFiles, path)
if err != nil {
panic(err)
}
return http.FS(fs)
}

179
server/router/rss/rss.go Normal file
View File

@@ -0,0 +1,179 @@
package rss
import (
"context"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gorilla/feeds"
"github.com/labstack/echo/v4"
"github.com/usememos/gomark"
"github.com/usememos/gomark/renderer"
"github.com/usememos/memos/internal/profile"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
maxRSSItemCount = 100
)
type RSSService struct {
Profile *profile.Profile
Store *store.Store
}
type RSSHeading struct {
Title string
Description string
}
func NewRSSService(profile *profile.Profile, store *store.Store) *RSSService {
return &RSSService{
Profile: profile,
Store: store,
}
}
func (s *RSSService) RegisterRoutes(g *echo.Group) {
g.GET("/explore/rss.xml", s.GetExploreRSS)
g.GET("/u/:username/rss.xml", s.GetUserRSS)
}
func (s *RSSService) GetExploreRSS(c echo.Context) error {
ctx := c.Request().Context()
normalStatus := store.Normal
memoFind := store.FindMemo{
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
}
memoList, err := s.Store.ListMemos(ctx, &memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
}
baseURL := c.Scheme() + "://" + c.Request().Host
rss, err := s.generateRSSFromMemoList(ctx, memoList, baseURL)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").SetInternal(err)
}
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationXMLCharsetUTF8)
return c.String(http.StatusOK, rss)
}
func (s *RSSService) GetUserRSS(c echo.Context) error {
ctx := c.Request().Context()
username := c.Param("username")
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusNotFound, "User not found")
}
normalStatus := store.Normal
memoFind := store.FindMemo{
CreatorID: &user.ID,
RowStatus: &normalStatus,
VisibilityList: []store.Visibility{store.Public},
}
memoList, err := s.Store.ListMemos(ctx, &memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
}
baseURL := c.Scheme() + "://" + c.Request().Host
rss, err := s.generateRSSFromMemoList(ctx, memoList, baseURL)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate rss").SetInternal(err)
}
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationXMLCharsetUTF8)
return c.String(http.StatusOK, rss)
}
func (s *RSSService) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string) (string, error) {
rssHeading, err := getRSSHeading(ctx, s.Store)
if err != nil {
return "", err
}
feed := &feeds.Feed{
Title: rssHeading.Title,
Link: &feeds.Link{Href: baseURL},
Description: rssHeading.Description,
Created: time.Now(),
}
var itemCountLimit = min(len(memoList), maxRSSItemCount)
feed.Items = make([]*feeds.Item, itemCountLimit)
for i := 0; i < itemCountLimit; i++ {
memo := memoList[i]
description, err := getRSSItemDescription(memo.Content)
if err != nil {
return "", err
}
link := &feeds.Link{Href: baseURL + "/memos/" + memo.UID}
feed.Items[i] = &feeds.Item{
Link: link,
Description: description,
Created: time.Unix(memo.CreatedTs, 0),
Id: link.Href,
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
if err != nil {
return "", err
}
if len(attachments) > 0 {
attachment := attachments[0]
enclosure := feeds.Enclosure{}
if attachment.StorageType == storepb.AttachmentStorageType_EXTERNAL || attachment.StorageType == storepb.AttachmentStorageType_S3 {
enclosure.Url = attachment.Reference
} else {
enclosure.Url = fmt.Sprintf("%s/file/attachments/%s/%s", baseURL, attachment.UID, attachment.Filename)
}
enclosure.Length = strconv.Itoa(int(attachment.Size))
enclosure.Type = attachment.Type
feed.Items[i].Enclosure = &enclosure
}
}
rss, err := feed.ToRss()
if err != nil {
return "", err
}
return rss, nil
}
func getRSSItemDescription(content string) (string, error) {
nodes, err := gomark.Parse(content)
if err != nil {
return "", err
}
result := renderer.NewHTMLRenderer().Render(nodes)
return result, nil
}
func getRSSHeading(ctx context.Context, stores *store.Store) (RSSHeading, error) {
settings, err := stores.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return RSSHeading{}, err
}
if settings == nil || settings.CustomProfile == nil {
return RSSHeading{
Title: "Memos",
Description: "An open source, lightweight note-taking service. Easily capture and share your great thoughts.",
}, nil
}
customProfile := settings.CustomProfile
return RSSHeading{
Title: customProfile.Title,
Description: customProfile.Description,
}, nil
}

View File

@@ -0,0 +1,134 @@
package memopayload
import (
"context"
"log/slog"
"slices"
"github.com/pkg/errors"
"github.com/usememos/gomark/ast"
"github.com/usememos/gomark/parser"
"github.com/usememos/gomark/parser/tokenizer"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
type Runner struct {
Store *store.Store
}
func NewRunner(store *store.Store) *Runner {
return &Runner{
Store: store,
}
}
// RunOnce rebuilds the payload of all memos.
func (r *Runner) RunOnce(ctx context.Context) {
// Process memos in batches to avoid loading all memos into memory at once
const batchSize = 100
offset := 0
processed := 0
for {
limit := batchSize
memos, err := r.Store.ListMemos(ctx, &store.FindMemo{
Limit: &limit,
Offset: &offset,
})
if err != nil {
slog.Error("failed to list memos", "err", err)
return
}
// Break if no more memos
if len(memos) == 0 {
break
}
// Process batch
batchSuccessCount := 0
for _, memo := range memos {
if err := RebuildMemoPayload(memo); err != nil {
slog.Error("failed to rebuild memo payload", "err", err, "memoID", memo.ID)
continue
}
if err := r.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Payload: memo.Payload,
}); err != nil {
slog.Error("failed to update memo", "err", err, "memoID", memo.ID)
continue
}
batchSuccessCount++
}
processed += len(memos)
slog.Info("Processed memo batch", "batchSize", len(memos), "successCount", batchSuccessCount, "totalProcessed", processed)
// Move to next batch
offset += len(memos)
}
}
func RebuildMemoPayload(memo *store.Memo) error {
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return errors.Wrap(err, "failed to parse content")
}
if memo.Payload == nil {
memo.Payload = &storepb.MemoPayload{}
}
tags := []string{}
property := &storepb.MemoPayload_Property{}
TraverseASTNodes(nodes, func(node ast.Node) {
switch n := node.(type) {
case *ast.Tag:
tag := n.Content
if !slices.Contains(tags, tag) {
tags = append(tags, tag)
}
case *ast.Link, *ast.AutoLink:
property.HasLink = true
case *ast.TaskListItem:
property.HasTaskList = true
if !n.Complete {
property.HasIncompleteTasks = true
}
case *ast.CodeBlock:
property.HasCode = true
case *ast.EmbeddedContent:
// TODO: validate references.
property.References = append(property.References, n.ResourceName)
}
})
memo.Payload.Tags = tags
memo.Payload.Property = property
return nil
}
func TraverseASTNodes(nodes []ast.Node, fn func(ast.Node)) {
for _, node := range nodes {
fn(node)
switch n := node.(type) {
case *ast.Paragraph:
TraverseASTNodes(n.Children, fn)
case *ast.Heading:
TraverseASTNodes(n.Children, fn)
case *ast.Blockquote:
TraverseASTNodes(n.Children, fn)
case *ast.List:
TraverseASTNodes(n.Children, fn)
case *ast.OrderedListItem:
TraverseASTNodes(n.Children, fn)
case *ast.UnorderedListItem:
TraverseASTNodes(n.Children, fn)
case *ast.TaskListItem:
TraverseASTNodes(n.Children, fn)
case *ast.Bold:
TraverseASTNodes(n.Children, fn)
}
}
}

View File

@@ -0,0 +1,134 @@
package s3presign
import (
"context"
"log/slog"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/plugin/storage/s3"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
type Runner struct {
Store *store.Store
}
func NewRunner(store *store.Store) *Runner {
return &Runner{
Store: store,
}
}
// Schedule runner every 12 hours.
const runnerInterval = time.Hour * 12
func (r *Runner) Run(ctx context.Context) {
ticker := time.NewTicker(runnerInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
r.RunOnce(ctx)
case <-ctx.Done():
return
}
}
}
func (r *Runner) RunOnce(ctx context.Context) {
r.CheckAndPresign(ctx)
}
func (r *Runner) CheckAndPresign(ctx context.Context) {
workspaceStorageSetting, err := r.Store.GetWorkspaceStorageSetting(ctx)
if err != nil {
return
}
s3StorageType := storepb.AttachmentStorageType_S3
// Limit attachments to a reasonable batch size
const batchSize = 100
offset := 0
for {
limit := batchSize
attachments, err := r.Store.ListAttachments(ctx, &store.FindAttachment{
GetBlob: false,
StorageType: &s3StorageType,
Limit: &limit,
Offset: &offset,
})
if err != nil {
slog.Error("Failed to list attachments for presigning", "error", err)
return
}
// Break if no more attachments
if len(attachments) == 0 {
break
}
// Process batch of attachments
presignCount := 0
for _, attachment := range attachments {
s3ObjectPayload := attachment.Payload.GetS3Object()
if s3ObjectPayload == nil {
continue
}
if s3ObjectPayload.LastPresignedTime != nil {
// Skip if the presigned URL is still valid for the next 4 days.
// The expiration time is set to 5 days.
if time.Now().Before(s3ObjectPayload.LastPresignedTime.AsTime().Add(4 * 24 * time.Hour)) {
continue
}
}
s3Config := workspaceStorageSetting.GetS3Config()
if s3ObjectPayload.S3Config != nil {
s3Config = s3ObjectPayload.S3Config
}
if s3Config == nil {
slog.Error("S3 config is not found")
continue
}
s3Client, err := s3.NewClient(ctx, s3Config)
if err != nil {
slog.Error("Failed to create S3 client", "error", err)
continue
}
presignURL, err := s3Client.PresignGetObject(ctx, s3ObjectPayload.Key)
if err != nil {
slog.Error("Failed to presign URL", "error", err, "attachmentID", attachment.ID)
continue
}
s3ObjectPayload.S3Config = s3Config
s3ObjectPayload.LastPresignedTime = timestamppb.New(time.Now())
if err := r.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
ID: attachment.ID,
Reference: &presignURL,
Payload: &storepb.AttachmentPayload{
Payload: &storepb.AttachmentPayload_S3Object_{
S3Object: s3ObjectPayload,
},
},
}); err != nil {
slog.Error("Failed to update attachment", "error", err, "attachmentID", attachment.ID)
continue
}
presignCount++
}
slog.Info("Presigned batch of S3 attachments", "batchSize", len(attachments), "presigned", presignCount)
// Move to next batch
offset += len(attachments)
}
}

227
server/server.go Normal file
View File

@@ -0,0 +1,227 @@
package server
import (
"context"
"fmt"
"log/slog"
"math"
"net"
"net/http"
"runtime"
"time"
"github.com/google/uuid"
grpcrecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/pkg/errors"
"github.com/soheilhy/cmux"
"google.golang.org/grpc"
"github.com/usememos/memos/internal/profile"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/profiler"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/server/router/frontend"
"github.com/usememos/memos/server/router/rss"
"github.com/usememos/memos/server/runner/s3presign"
"github.com/usememos/memos/store"
)
type Server struct {
Secret string
Profile *profile.Profile
Store *store.Store
echoServer *echo.Echo
grpcServer *grpc.Server
profiler *profiler.Profiler
runnerCancelFuncs []context.CancelFunc
}
func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) {
s := &Server{
Store: store,
Profile: profile,
}
echoServer := echo.New()
echoServer.Debug = true
echoServer.HideBanner = true
echoServer.HidePort = true
echoServer.Use(middleware.Recover())
s.echoServer = echoServer
// Initialize profiler
s.profiler = profiler.NewProfiler()
s.profiler.RegisterRoutes(echoServer)
s.profiler.StartMemoryMonitor(ctx)
workspaceBasicSetting, err := s.getOrUpsertWorkspaceBasicSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace basic setting")
}
secret := "usememos"
if profile.Mode == "prod" {
secret = workspaceBasicSetting.SecretKey
}
s.Secret = secret
// Register healthz endpoint.
echoServer.GET("/healthz", func(c echo.Context) error {
return c.String(http.StatusOK, "Service ready.")
})
// Serve frontend static files.
frontend.NewFrontendService(profile, store).Serve(ctx, echoServer)
rootGroup := echoServer.Group("")
// Create and register RSS routes.
rss.NewRSSService(s.Profile, s.Store).RegisterRoutes(rootGroup)
grpcServer := grpc.NewServer(
// Override the maximum receiving message size to math.MaxInt32 for uploading large attachments.
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.ChainUnaryInterceptor(
apiv1.NewLoggerInterceptor().LoggerInterceptor,
grpcrecovery.UnaryServerInterceptor(),
apiv1.NewGRPCAuthInterceptor(store, secret).AuthenticationInterceptor,
))
s.grpcServer = grpcServer
apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store, grpcServer)
// Register gRPC gateway as api v1.
if err := apiV1Service.RegisterGateway(ctx, echoServer); err != nil {
return nil, errors.Wrap(err, "failed to register gRPC gateway")
}
return s, nil
}
func (s *Server) Start(ctx context.Context) error {
var address, network string
if len(s.Profile.UNIXSock) == 0 {
address = fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port)
network = "tcp"
} else {
address = s.Profile.UNIXSock
network = "unix"
}
listener, err := net.Listen(network, address)
if err != nil {
return errors.Wrap(err, "failed to listen")
}
muxServer := cmux.New(listener)
go func() {
grpcListener := muxServer.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"))
if err := s.grpcServer.Serve(grpcListener); err != nil {
slog.Error("failed to serve gRPC", "error", err)
}
}()
go func() {
httpListener := muxServer.Match(cmux.HTTP1Fast(http.MethodPatch))
s.echoServer.Listener = httpListener
if err := s.echoServer.Start(address); err != nil {
slog.Error("failed to start echo server", "error", err)
}
}()
go func() {
if err := muxServer.Serve(); err != nil {
slog.Error("mux server listen error", "error", err)
}
}()
s.StartBackgroundRunners(ctx)
return nil
}
func (s *Server) Shutdown(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
slog.Info("server shutting down")
// Cancel all background runners
for _, cancelFunc := range s.runnerCancelFuncs {
if cancelFunc != nil {
cancelFunc()
}
}
// Shutdown echo server.
if err := s.echoServer.Shutdown(ctx); err != nil {
slog.Error("failed to shutdown server", slog.String("error", err.Error()))
}
// Shutdown gRPC server.
s.grpcServer.GracefulStop()
// Stop the profiler
if s.profiler != nil {
slog.Info("stopping profiler")
// Log final memory stats
var m runtime.MemStats
runtime.ReadMemStats(&m)
slog.Info("final memory stats before exit",
"heapAlloc", m.Alloc,
"heapSys", m.Sys,
"heapObjects", m.HeapObjects,
"numGoroutine", runtime.NumGoroutine(),
)
}
// Close database connection.
if err := s.Store.Close(); err != nil {
slog.Error("failed to close database", slog.String("error", err.Error()))
}
slog.Info("memos stopped properly")
}
func (s *Server) StartBackgroundRunners(ctx context.Context) {
// Create a separate context for each background runner
// This allows us to control cancellation for each runner independently
s3Context, s3Cancel := context.WithCancel(ctx)
// Store the cancel function so we can properly shut down runners
s.runnerCancelFuncs = append(s.runnerCancelFuncs, s3Cancel)
// Create and start S3 presign runner
s3presignRunner := s3presign.NewRunner(s.Store)
s3presignRunner.RunOnce(ctx)
// Start continuous S3 presign runner
go func() {
s3presignRunner.Run(s3Context)
slog.Info("s3presign runner stopped")
}()
// Log the number of goroutines running
slog.Info("background runners started", "goroutines", runtime.NumGoroutine())
}
func (s *Server) getOrUpsertWorkspaceBasicSetting(ctx context.Context) (*storepb.WorkspaceBasicSetting, error) {
workspaceBasicSetting, err := s.Store.GetWorkspaceBasicSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace basic setting")
}
modified := false
if workspaceBasicSetting.SecretKey == "" {
workspaceBasicSetting.SecretKey = uuid.NewString()
modified = true
}
if modified {
workspaceSetting, err := s.Store.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_BASIC,
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
})
if err != nil {
return nil, errors.Wrap(err, "failed to upsert workspace setting")
}
workspaceBasicSetting = workspaceSetting.GetBasicSetting()
}
return workspaceBasicSetting, nil
}