init commit
This commit is contained in:
64
store/activity.go
Normal file
64
store/activity.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
type ActivityType string
|
||||
|
||||
const (
|
||||
ActivityTypeMemoComment ActivityType = "MEMO_COMMENT"
|
||||
)
|
||||
|
||||
func (t ActivityType) String() string {
|
||||
return string(t)
|
||||
}
|
||||
|
||||
type ActivityLevel string
|
||||
|
||||
const (
|
||||
ActivityLevelInfo ActivityLevel = "INFO"
|
||||
)
|
||||
|
||||
func (l ActivityLevel) String() string {
|
||||
return string(l)
|
||||
}
|
||||
|
||||
type Activity struct {
|
||||
ID int32
|
||||
|
||||
// Standard fields
|
||||
CreatorID int32
|
||||
CreatedTs int64
|
||||
|
||||
// Domain specific fields
|
||||
Type ActivityType
|
||||
Level ActivityLevel
|
||||
Payload *storepb.ActivityPayload
|
||||
}
|
||||
|
||||
type FindActivity struct {
|
||||
ID *int32
|
||||
Type *ActivityType
|
||||
}
|
||||
|
||||
func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) {
|
||||
return s.driver.CreateActivity(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error) {
|
||||
return s.driver.ListActivities(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) GetActivity(ctx context.Context, find *FindActivity) (*Activity, error) {
|
||||
list, err := s.ListActivities(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
166
store/attachment.go
Normal file
166
store/attachment.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/base"
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
type Attachment struct {
|
||||
// ID is the system generated unique identifier for the attachment.
|
||||
ID int32
|
||||
// UID is the user defined unique identifier for the attachment.
|
||||
UID string
|
||||
|
||||
// Standard fields
|
||||
CreatorID int32
|
||||
CreatedTs int64
|
||||
UpdatedTs int64
|
||||
|
||||
// Domain specific fields
|
||||
Filename string
|
||||
Blob []byte
|
||||
Type string
|
||||
Size int64
|
||||
StorageType storepb.AttachmentStorageType
|
||||
Reference string
|
||||
Payload *storepb.AttachmentPayload
|
||||
|
||||
// The related memo ID.
|
||||
MemoID *int32
|
||||
}
|
||||
|
||||
type FindAttachment struct {
|
||||
GetBlob bool
|
||||
ID *int32
|
||||
UID *string
|
||||
CreatorID *int32
|
||||
Filename *string
|
||||
FilenameSearch *string
|
||||
MemoID *int32
|
||||
HasRelatedMemo bool
|
||||
StorageType *storepb.AttachmentStorageType
|
||||
Limit *int
|
||||
Offset *int
|
||||
}
|
||||
|
||||
type UpdateAttachment struct {
|
||||
ID int32
|
||||
UID *string
|
||||
UpdatedTs *int64
|
||||
Filename *string
|
||||
MemoID *int32
|
||||
Reference *string
|
||||
Payload *storepb.AttachmentPayload
|
||||
}
|
||||
|
||||
type DeleteAttachment struct {
|
||||
ID int32
|
||||
MemoID *int32
|
||||
}
|
||||
|
||||
func (s *Store) CreateAttachment(ctx context.Context, create *Attachment) (*Attachment, error) {
|
||||
if !base.UIDMatcher.MatchString(create.UID) {
|
||||
return nil, errors.New("invalid uid")
|
||||
}
|
||||
return s.driver.CreateAttachment(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) ListAttachments(ctx context.Context, find *FindAttachment) ([]*Attachment, error) {
|
||||
// Set default limits to prevent loading too many attachments at once
|
||||
if find.Limit == nil && find.GetBlob {
|
||||
// When fetching blobs, we should be especially careful with limits
|
||||
defaultLimit := 10
|
||||
find.Limit = &defaultLimit
|
||||
} else if find.Limit == nil {
|
||||
// Even without blobs, let's default to a reasonable limit
|
||||
defaultLimit := 100
|
||||
find.Limit = &defaultLimit
|
||||
}
|
||||
|
||||
return s.driver.ListAttachments(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) GetAttachment(ctx context.Context, find *FindAttachment) (*Attachment, error) {
|
||||
attachments, err := s.ListAttachments(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(attachments) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return attachments[0], nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAttachment(ctx context.Context, update *UpdateAttachment) error {
|
||||
if update.UID != nil && !base.UIDMatcher.MatchString(*update.UID) {
|
||||
return errors.New("invalid uid")
|
||||
}
|
||||
return s.driver.UpdateAttachment(ctx, update)
|
||||
}
|
||||
|
||||
func (s *Store) DeleteAttachment(ctx context.Context, delete *DeleteAttachment) error {
|
||||
attachment, err := s.GetAttachment(ctx, &FindAttachment{ID: &delete.ID})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get attachment")
|
||||
}
|
||||
if attachment == nil {
|
||||
return errors.New("attachment not found")
|
||||
}
|
||||
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
|
||||
if err := func() error {
|
||||
p := filepath.FromSlash(attachment.Reference)
|
||||
if !filepath.IsAbs(p) {
|
||||
p = filepath.Join(s.profile.Data, p)
|
||||
}
|
||||
err := os.Remove(p)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete local file")
|
||||
}
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return errors.Wrap(err, "failed to delete local file")
|
||||
}
|
||||
} else if attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
if err := func() error {
|
||||
s3ObjectPayload := attachment.Payload.GetS3Object()
|
||||
if s3ObjectPayload == nil {
|
||||
return errors.Errorf("No s3 object found")
|
||||
}
|
||||
workspaceStorageSetting, err := s.GetWorkspaceStorageSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get workspace storage setting")
|
||||
}
|
||||
s3Config := s3ObjectPayload.S3Config
|
||||
if s3Config == nil {
|
||||
if workspaceStorageSetting.S3Config == nil {
|
||||
return errors.Errorf("S3 config is not found")
|
||||
}
|
||||
s3Config = workspaceStorageSetting.S3Config
|
||||
}
|
||||
|
||||
s3Client, err := s3.NewClient(ctx, s3Config)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to create s3 client")
|
||||
}
|
||||
if err := s3Client.DeleteObject(ctx, s3ObjectPayload.Key); err != nil {
|
||||
return errors.Wrap(err, "Failed to delete s3 object")
|
||||
}
|
||||
return nil
|
||||
}(); err != nil {
|
||||
slog.Warn("Failed to delete s3 object", slog.Any("err", err))
|
||||
}
|
||||
}
|
||||
|
||||
return s.driver.DeleteAttachment(ctx, delete)
|
||||
}
|
||||
9
store/cache.go
Normal file
9
store/cache.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func getUserSettingCacheKey(userID int32, key string) string {
|
||||
return fmt.Sprintf("%d-%s", userID, key)
|
||||
}
|
||||
327
store/cache/cache.go
vendored
Normal file
327
store/cache/cache.go
vendored
Normal file
@@ -0,0 +1,327 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Interface defines the operations a cache must support.
|
||||
type Interface interface {
|
||||
// Set adds a value to the cache with the default TTL.
|
||||
Set(ctx context.Context, key string, value any)
|
||||
|
||||
// SetWithTTL adds a value to the cache with a custom TTL.
|
||||
SetWithTTL(ctx context.Context, key string, value any, ttl time.Duration)
|
||||
|
||||
// Get retrieves a value from the cache.
|
||||
Get(ctx context.Context, key string) (any, bool)
|
||||
|
||||
// Delete removes a value from the cache.
|
||||
Delete(ctx context.Context, key string)
|
||||
|
||||
// Clear removes all values from the cache.
|
||||
Clear(ctx context.Context)
|
||||
|
||||
// Size returns the number of items in the cache.
|
||||
Size() int64
|
||||
|
||||
// Close stops all background tasks and releases resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// item represents a cached value with metadata.
|
||||
type item struct {
|
||||
value any
|
||||
expiration time.Time
|
||||
size int // Approximate size in bytes
|
||||
}
|
||||
|
||||
// Config contains options for configuring a cache.
|
||||
type Config struct {
|
||||
// DefaultTTL is the default time-to-live for cache entries.
|
||||
DefaultTTL time.Duration
|
||||
|
||||
// CleanupInterval is how often the cache runs cleanup.
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// MaxItems is the maximum number of items allowed in the cache.
|
||||
MaxItems int
|
||||
|
||||
// OnEviction is called when an item is evicted from the cache.
|
||||
OnEviction func(key string, value any)
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for the cache.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
DefaultTTL: 10 * time.Minute,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
MaxItems: 1000,
|
||||
OnEviction: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// Cache is a thread-safe in-memory cache with TTL and memory management.
|
||||
type Cache struct {
|
||||
data sync.Map
|
||||
config Config
|
||||
itemCount int64 // Use atomic operations to track item count
|
||||
stopChan chan struct{}
|
||||
closedChan chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new memory cache with the given configuration.
|
||||
func New(config Config) *Cache {
|
||||
c := &Cache{
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
closedChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go c.cleanupLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
// NewDefault creates a new memory cache with default configuration.
|
||||
func NewDefault() *Cache {
|
||||
return New(DefaultConfig())
|
||||
}
|
||||
|
||||
// Set adds a value to the cache with the default TTL.
|
||||
func (c *Cache) Set(ctx context.Context, key string, value any) {
|
||||
c.SetWithTTL(ctx, key, value, c.config.DefaultTTL)
|
||||
}
|
||||
|
||||
// SetWithTTL adds a value to the cache with a custom TTL.
|
||||
func (c *Cache) SetWithTTL(_ context.Context, key string, value any, ttl time.Duration) {
|
||||
// Estimate size of the item (very rough approximation).
|
||||
size := estimateSize(value)
|
||||
|
||||
// Check if item already exists to avoid double counting.
|
||||
if _, exists := c.data.Load(key); exists {
|
||||
c.data.Delete(key)
|
||||
} else {
|
||||
// Only increment if this is a new key.
|
||||
atomic.AddInt64(&c.itemCount, 1)
|
||||
}
|
||||
|
||||
c.data.Store(key, item{
|
||||
value: value,
|
||||
expiration: time.Now().Add(ttl),
|
||||
size: size,
|
||||
})
|
||||
|
||||
// If we're over the max items, clean up old items.
|
||||
if c.config.MaxItems > 0 && atomic.LoadInt64(&c.itemCount) > int64(c.config.MaxItems) {
|
||||
c.cleanupOldest()
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache.
|
||||
func (c *Cache) Get(_ context.Context, key string) (any, bool) {
|
||||
value, ok := c.data.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
itm, ok := value.(item)
|
||||
if !ok {
|
||||
// If the value is not of type item, it means it was corrupted or not set correctly.
|
||||
c.data.Delete(key)
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(itm.expiration) {
|
||||
c.data.Delete(key)
|
||||
atomic.AddInt64(&c.itemCount, -1)
|
||||
|
||||
if c.config.OnEviction != nil {
|
||||
c.config.OnEviction(key, itm.value)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return itm.value, true
|
||||
}
|
||||
|
||||
// Delete removes a value from the cache.
|
||||
func (c *Cache) Delete(_ context.Context, key string) {
|
||||
if value, loaded := c.data.LoadAndDelete(key); loaded {
|
||||
atomic.AddInt64(&c.itemCount, -1)
|
||||
|
||||
if c.config.OnEviction != nil {
|
||||
if itm, ok := value.(item); ok {
|
||||
c.config.OnEviction(key, itm.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear removes all values from the cache.
|
||||
func (c *Cache) Clear(_ context.Context) {
|
||||
if c.config.OnEviction != nil {
|
||||
c.data.Range(func(key, value any) bool {
|
||||
itm, ok := value.(item)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if keyStr, ok := key.(string); ok {
|
||||
c.config.OnEviction(keyStr, itm.value)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
c.data = sync.Map{}
|
||||
atomic.StoreInt64(&c.itemCount, 0)
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache.
|
||||
func (c *Cache) Size() int64 {
|
||||
return atomic.LoadInt64(&c.itemCount)
|
||||
}
|
||||
|
||||
// Close stops the cache cleanup goroutine.
|
||||
func (c *Cache) Close() error {
|
||||
select {
|
||||
case <-c.stopChan:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
close(c.stopChan)
|
||||
<-c.closedChan // Wait for cleanup goroutine to exit
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up expired items.
|
||||
func (c *Cache) cleanupLoop() {
|
||||
ticker := time.NewTicker(c.config.CleanupInterval)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
close(c.closedChan)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.cleanup()
|
||||
case <-c.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes expired items.
|
||||
func (c *Cache) cleanup() {
|
||||
evicted := make(map[string]any)
|
||||
count := 0
|
||||
|
||||
c.data.Range(func(key, value any) bool {
|
||||
itm, ok := value.(item)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if time.Now().After(itm.expiration) {
|
||||
c.data.Delete(key)
|
||||
count++
|
||||
|
||||
if c.config.OnEviction != nil {
|
||||
if keyStr, ok := key.(string); ok {
|
||||
evicted[keyStr] = itm.value
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if count > 0 {
|
||||
atomic.AddInt64(&c.itemCount, -int64(count))
|
||||
|
||||
// Call eviction callbacks outside the loop to avoid blocking the range
|
||||
if c.config.OnEviction != nil {
|
||||
for k, v := range evicted {
|
||||
c.config.OnEviction(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupOldest removes the oldest items if we're over the max items.
|
||||
func (c *Cache) cleanupOldest() {
|
||||
// Remove 20% of max items at once
|
||||
threshold := max(c.config.MaxItems/5, 1)
|
||||
|
||||
currentCount := atomic.LoadInt64(&c.itemCount)
|
||||
|
||||
// If we're not over the threshold, don't do anything
|
||||
if currentCount <= int64(c.config.MaxItems) {
|
||||
return
|
||||
}
|
||||
|
||||
// Find the oldest items
|
||||
type keyExpPair struct {
|
||||
key string
|
||||
value any
|
||||
expiration time.Time
|
||||
}
|
||||
candidates := make([]keyExpPair, 0, threshold)
|
||||
|
||||
c.data.Range(func(key, value any) bool {
|
||||
itm, ok := value.(item)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if keyStr, ok := key.(string); ok && len(candidates) < threshold {
|
||||
candidates = append(candidates, keyExpPair{keyStr, itm.value, itm.expiration})
|
||||
return true
|
||||
}
|
||||
|
||||
// Find the newest item in candidates
|
||||
newestIdx := 0
|
||||
for i := 1; i < len(candidates); i++ {
|
||||
if candidates[i].expiration.After(candidates[newestIdx].expiration) {
|
||||
newestIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
// Replace it if this item is older
|
||||
if itm.expiration.Before(candidates[newestIdx].expiration) {
|
||||
candidates[newestIdx] = keyExpPair{key.(string), itm.value, itm.expiration}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete the oldest items
|
||||
deletedCount := 0
|
||||
for _, candidate := range candidates {
|
||||
c.data.Delete(candidate.key)
|
||||
deletedCount++
|
||||
|
||||
if c.config.OnEviction != nil {
|
||||
c.config.OnEviction(candidate.key, candidate.value)
|
||||
}
|
||||
}
|
||||
|
||||
// Update count
|
||||
if deletedCount > 0 {
|
||||
atomic.AddInt64(&c.itemCount, -int64(deletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
// estimateSize attempts to estimate the memory footprint of a value.
|
||||
func estimateSize(value any) int {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return len(v) + 24 // base size + string overhead
|
||||
case []byte:
|
||||
return len(v) + 24 // base size + slice overhead
|
||||
case map[string]any:
|
||||
return len(v) * 64 // rough estimate
|
||||
default:
|
||||
return 64 // default conservative estimate
|
||||
}
|
||||
}
|
||||
209
store/cache/cache_test.go
vendored
Normal file
209
store/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,209 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCacheBasicOperations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
config := DefaultConfig()
|
||||
config.DefaultTTL = 100 * time.Millisecond
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
cache := New(config)
|
||||
defer cache.Close()
|
||||
|
||||
// Test Set and Get
|
||||
cache.Set(ctx, "key1", "value1")
|
||||
if val, ok := cache.Get(ctx, "key1"); !ok || val != "value1" {
|
||||
t.Errorf("Expected 'value1', got %v, exists: %v", val, ok)
|
||||
}
|
||||
|
||||
// Test SetWithTTL
|
||||
cache.SetWithTTL(ctx, "key2", "value2", 200*time.Millisecond)
|
||||
if val, ok := cache.Get(ctx, "key2"); !ok || val != "value2" {
|
||||
t.Errorf("Expected 'value2', got %v, exists: %v", val, ok)
|
||||
}
|
||||
|
||||
// Test Delete
|
||||
cache.Delete(ctx, "key1")
|
||||
if _, ok := cache.Get(ctx, "key1"); ok {
|
||||
t.Errorf("Key 'key1' should have been deleted")
|
||||
}
|
||||
|
||||
// Test automatic expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
if _, ok := cache.Get(ctx, "key1"); ok {
|
||||
t.Errorf("Key 'key1' should have expired")
|
||||
}
|
||||
// key2 should still be valid (200ms TTL)
|
||||
if _, ok := cache.Get(ctx, "key2"); !ok {
|
||||
t.Errorf("Key 'key2' should still be valid")
|
||||
}
|
||||
|
||||
// Wait for key2 to expire
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, ok := cache.Get(ctx, "key2"); ok {
|
||||
t.Errorf("Key 'key2' should have expired")
|
||||
}
|
||||
|
||||
// Test Clear
|
||||
cache.Set(ctx, "key3", "value3")
|
||||
cache.Clear(ctx)
|
||||
if _, ok := cache.Get(ctx, "key3"); ok {
|
||||
t.Errorf("Cache should be empty after Clear()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheEviction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
config := DefaultConfig()
|
||||
config.MaxItems = 5
|
||||
cache := New(config)
|
||||
defer cache.Close()
|
||||
|
||||
// Add 5 items (max capacity)
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
cache.Set(ctx, key, i)
|
||||
}
|
||||
|
||||
// Verify all 5 items are in the cache
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
if _, ok := cache.Get(ctx, key); !ok {
|
||||
t.Errorf("Key '%s' should be in the cache", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Add 2 more items to trigger eviction
|
||||
cache.Set(ctx, "keyA", "valueA")
|
||||
cache.Set(ctx, "keyB", "valueB")
|
||||
|
||||
// Verify size is still within limits
|
||||
if cache.Size() > int64(config.MaxItems) {
|
||||
t.Errorf("Cache size %d exceeds limit %d", cache.Size(), config.MaxItems)
|
||||
}
|
||||
|
||||
// Some of the original keys should have been evicted
|
||||
evictedCount := 0
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
if _, ok := cache.Get(ctx, key); !ok {
|
||||
evictedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if evictedCount == 0 {
|
||||
t.Errorf("No keys were evicted despite exceeding max items")
|
||||
}
|
||||
|
||||
// The newer keys should still be present
|
||||
if _, ok := cache.Get(ctx, "keyA"); !ok {
|
||||
t.Errorf("Key 'keyA' should be in the cache")
|
||||
}
|
||||
if _, ok := cache.Get(ctx, "keyB"); !ok {
|
||||
t.Errorf("Key 'keyB' should be in the cache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheConcurrency(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cache := NewDefault()
|
||||
defer cache.Close()
|
||||
|
||||
const goroutines = 10
|
||||
const operationsPerGoroutine = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
baseKey := fmt.Sprintf("worker%d-", id)
|
||||
|
||||
// Set operations
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
key := fmt.Sprintf("%skey%d", baseKey, j)
|
||||
value := fmt.Sprintf("value%d-%d", id, j)
|
||||
cache.Set(ctx, key, value)
|
||||
}
|
||||
|
||||
// Get operations
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
key := fmt.Sprintf("%skey%d", baseKey, j)
|
||||
val, ok := cache.Get(ctx, key)
|
||||
if !ok {
|
||||
t.Errorf("Key '%s' should exist in cache", key)
|
||||
continue
|
||||
}
|
||||
expected := fmt.Sprintf("value%d-%d", id, j)
|
||||
if val != expected {
|
||||
t.Errorf("For key '%s', expected '%s', got '%s'", key, expected, val)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete half the keys
|
||||
for j := 0; j < operationsPerGoroutine/2; j++ {
|
||||
key := fmt.Sprintf("%skey%d", baseKey, j)
|
||||
cache.Delete(ctx, key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify size and deletion
|
||||
var totalKeysExpected int64 = goroutines * operationsPerGoroutine / 2
|
||||
if cache.Size() != totalKeysExpected {
|
||||
t.Errorf("Expected cache size to be %d, got %d", totalKeysExpected, cache.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvictionCallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
evicted := make(map[string]interface{})
|
||||
evictedMu := sync.Mutex{}
|
||||
|
||||
config := DefaultConfig()
|
||||
config.DefaultTTL = 50 * time.Millisecond
|
||||
config.CleanupInterval = 25 * time.Millisecond
|
||||
config.OnEviction = func(key string, value interface{}) {
|
||||
evictedMu.Lock()
|
||||
evicted[key] = value
|
||||
evictedMu.Unlock()
|
||||
}
|
||||
|
||||
cache := New(config)
|
||||
defer cache.Close()
|
||||
|
||||
// Add items
|
||||
cache.Set(ctx, "key1", "value1")
|
||||
cache.Set(ctx, "key2", "value2")
|
||||
|
||||
// Manually delete
|
||||
cache.Delete(ctx, "key1")
|
||||
|
||||
// Verify manual deletion triggered callback
|
||||
time.Sleep(10 * time.Millisecond) // Small delay to ensure callback processed
|
||||
evictedMu.Lock()
|
||||
if evicted["key1"] != "value1" {
|
||||
t.Errorf("Eviction callback not triggered for manual deletion")
|
||||
}
|
||||
evictedMu.Unlock()
|
||||
|
||||
// Wait for automatic expiration
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Verify TTL expiration triggered callback
|
||||
evictedMu.Lock()
|
||||
if evicted["key2"] != "value2" {
|
||||
t.Errorf("Eviction callback not triggered for TTL expiration")
|
||||
}
|
||||
evictedMu.Unlock()
|
||||
}
|
||||
24
store/common.go
Normal file
24
store/common.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package store
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
AllowPartial: true,
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
|
||||
// RowStatus is the status for a row.
|
||||
type RowStatus string
|
||||
|
||||
const (
|
||||
// Normal is the status for a normal row.
|
||||
Normal RowStatus = "NORMAL"
|
||||
// Archived is the status for an archived row.
|
||||
Archived RowStatus = "ARCHIVED"
|
||||
)
|
||||
|
||||
func (r RowStatus) String() string {
|
||||
return string(r)
|
||||
}
|
||||
32
store/db/db.go
Normal file
32
store/db/db.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
"github.com/usememos/memos/store/db/mysql"
|
||||
"github.com/usememos/memos/store/db/postgres"
|
||||
"github.com/usememos/memos/store/db/sqlite"
|
||||
)
|
||||
|
||||
// NewDBDriver creates new db driver based on profile.
|
||||
func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
|
||||
var driver store.Driver
|
||||
var err error
|
||||
|
||||
switch profile.Driver {
|
||||
case "sqlite":
|
||||
driver, err = sqlite.NewDB(profile)
|
||||
case "mysql":
|
||||
driver, err = mysql.NewDB(profile)
|
||||
case "postgres":
|
||||
driver, err = postgres.NewDB(profile)
|
||||
default:
|
||||
return nil, errors.New("unknown db driver")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create db driver")
|
||||
}
|
||||
return driver, nil
|
||||
}
|
||||
93
store/db/mysql/activity.go
Normal file
93
store/db/mysql/activity.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
fields := []string{"`creator_id`", "`type`", "`level`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
|
||||
stmt := "INSERT INTO `activity` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute statement")
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get last insert id")
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
|
||||
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id32})
|
||||
if err != nil || len(list) == 0 {
|
||||
return nil, errors.Wrap(err, "failed to find activity")
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `creator_id`, `type`, `level`, `payload`, UNIX_TIMESTAMP(`created_ts`) FROM `activity` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
202
store/db/mysql/attachment.go
Normal file
202
store/db/mysql/attachment.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
|
||||
fields := []string{"`uid`", "`filename`", "`blob`", "`type`", "`size`", "`creator_id`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?", "?", "?", "?", "?"}
|
||||
storageType := ""
|
||||
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
|
||||
storageType = create.StorageType.String()
|
||||
}
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
return d.GetAttachment(ctx, &store.FindAttachment{ID: &id32})
|
||||
}
|
||||
|
||||
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "`filename` LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "`memo_id` IS NOT NULL")
|
||||
}
|
||||
if find.StorageType != nil {
|
||||
where, args = append(where, "`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
}
|
||||
|
||||
fields := []string{"`id`", "`uid`", "`filename`", "`type`", "`size`", "`creator_id`", "UNIX_TIMESTAMP(`created_ts`)", "UNIX_TIMESTAMP(`updated_ts`)", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "`blob`")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT %s FROM `resource` WHERE %s ORDER BY `updated_ts` DESC", strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Attachment, 0)
|
||||
for rows.Next() {
|
||||
attachment := store.Attachment{}
|
||||
var memoID sql.NullInt32
|
||||
var storageType string
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&attachment.ID,
|
||||
&attachment.UID,
|
||||
&attachment.Filename,
|
||||
&attachment.Type,
|
||||
&attachment.Size,
|
||||
&attachment.CreatorID,
|
||||
&attachment.CreatedTs,
|
||||
&attachment.UpdatedTs,
|
||||
&memoID,
|
||||
&storageType,
|
||||
&attachment.Reference,
|
||||
&payloadBytes,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &attachment.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
attachment.MemoID = &memoID.Int32
|
||||
}
|
||||
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
|
||||
payload := &storepb.AttachmentPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachment.Payload = payload
|
||||
list = append(list, &attachment)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetAttachment(ctx context.Context, find *store.FindAttachment) (*store.Attachment, error) {
|
||||
list, err := d.ListAttachments(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Reference; v != nil {
|
||||
set, args = append(set, "`reference` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
bytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(bytes))
|
||||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := "DELETE FROM `resource` WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
10
store/db/mysql/common.go
Normal file
10
store/db/mysql/common.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package mysql
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
AllowPartial: true,
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
126
store/db/mysql/idp.go
Normal file
126
store/db/mysql/idp.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
placeholders := []string{"?", "?", "?", "?"}
|
||||
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
|
||||
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create.ID = int32(id)
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
|
||||
list, err := d.ListIdentityProviders(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
identityProvider := list[0]
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "`name` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "`identifier_filter` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
set, args = append(set, "`config` = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `idp` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider, err := d.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &update.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, errors.Errorf("idp %d not found", update.ID)
|
||||
}
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `idp` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
141
store/db/mysql/inbox.go
Normal file
141
store/db/mysql/inbox.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"`sender_id`", "`receiver_id`", "`status`", "`message`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
|
||||
stmt := "INSERT INTO `inbox` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
inbox, err := d.GetInbox(ctx, &store.FindInbox{ID: &id32})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
where, args = append(where, "`sender_id` = ?"), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
where, args = append(where, "`receiver_id` = ?"), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
where, args = append(where, "`status` = ?"), append(args, *find.Status)
|
||||
}
|
||||
|
||||
query := "SELECT `id`, UNIX_TIMESTAMP(`created_ts`), `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
|
||||
list, err := d.ListInboxes(ctx, find)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get inbox")
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected inbox count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
set, args := []string{"`status` = ?"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE `inbox` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update inbox")
|
||||
}
|
||||
inbox, err := d.GetInbox(ctx, &store.FindInbox{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `inbox` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete inbox")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
287
store/db/mysql/memo.go
Normal file
287
store/db/mysql/memo.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
fields := []string{"`uid`", "`creator_id`", "`content`", "`visibility`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?"}
|
||||
payload := "{}"
|
||||
if create.Payload != nil {
|
||||
payloadBytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = string(payloadBytes)
|
||||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := int32(rawID)
|
||||
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, errors.Errorf("failed to create memo")
|
||||
}
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, having, args := []string{"1 = 1"}, []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`memo`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`memo`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "`memo`.`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsBefore; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`created_ts`) < ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsAfter; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsBefore; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`updated_ts`) < ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsAfter; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`updated_ts`) > ?"), append(args, *v)
|
||||
}
|
||||
if v := find.ContentSearch; len(v) != 0 {
|
||||
for _, s := range v {
|
||||
where, args = append(where, "`memo`.`content` LIKE ?"), append(args, "%"+s+"%")
|
||||
}
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
placeholder := []string{}
|
||||
for _, visibility := range v {
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("`memo`.`visibility` in (%s)", strings.Join(placeholder, ",")))
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
where, args = append(where, "`memo`.`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.PayloadFind; v != nil {
|
||||
if v.Raw != nil {
|
||||
where, args = append(where, "`memo`.`payload` = ?"), append(args, *v.Raw)
|
||||
}
|
||||
if len(v.TagSearch) != 0 {
|
||||
for _, tag := range v.TagSearch {
|
||||
where, args = append(where, "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?))"), append(args, fmt.Sprintf(`"%s"`, tag), fmt.Sprintf(`"%s/"`, tag))
|
||||
}
|
||||
}
|
||||
if v.HasLink {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE")
|
||||
}
|
||||
if v.HasTaskList {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE")
|
||||
}
|
||||
if v.HasCode {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE")
|
||||
}
|
||||
if v.HasIncompleteTasks {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
|
||||
}
|
||||
}
|
||||
if v := find.Filter; v != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
having = append(having, "`parent_id` IS NULL")
|
||||
}
|
||||
|
||||
order := "DESC"
|
||||
if find.OrderByTimeAsc {
|
||||
order = "ASC"
|
||||
}
|
||||
orderBy := []string{}
|
||||
if find.OrderByPinned {
|
||||
orderBy = append(orderBy, "`pinned` DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
orderBy = append(orderBy, "`updated_ts` "+order)
|
||||
} else {
|
||||
orderBy = append(orderBy, "`created_ts` "+order)
|
||||
}
|
||||
fields := []string{
|
||||
"`memo`.`id` AS `id`",
|
||||
"`memo`.`uid` AS `uid`",
|
||||
"`memo`.`creator_id` AS `creator_id`",
|
||||
"UNIX_TIMESTAMP(`memo`.`created_ts`) AS `created_ts`",
|
||||
"UNIX_TIMESTAMP(`memo`.`updated_ts`) AS `updated_ts`",
|
||||
"`memo`.`row_status` AS `row_status`",
|
||||
"`memo`.`visibility` AS `visibility`",
|
||||
"`memo`.`pinned` AS `pinned`",
|
||||
"`memo`.`payload` AS `payload`",
|
||||
"`memo_relation`.`related_memo_id` AS `parent_id`",
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, "`memo`.`content` AS `content`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `memo`" + " " +
|
||||
"LEFT JOIN `memo_relation` ON `memo`.`id` = `memo_relation`.`memo_id` AND `memo_relation`.`type` = 'COMMENT'" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"HAVING " + strings.Join(having, " AND ") + " " +
|
||||
"ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.UID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&payloadBytes,
|
||||
&memo.ParentID,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &storepb.MemoPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal payload")
|
||||
}
|
||||
memo.Payload = payload
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
|
||||
list, err := d.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memo := list[0]
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "`created_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "`content` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "`visibility` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Pinned; v != nil {
|
||||
set, args = append(set, "`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
payloadBytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(payloadBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `memo` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `memo` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
304
store/db/mysql/memo_filter.go
Normal file
304
store/db/mysql/memo_filter.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
return d.convertWithTemplates(ctx, expr)
|
||||
}
|
||||
|
||||
func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
const dbType = filter.MySQLTemplate
|
||||
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
// Handle size(tags) comparison
|
||||
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?",
|
||||
filter.GetSQL("json_array_length", dbType), operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "visibility" || identifier == "content" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
var sqlTemplate string
|
||||
if identifier == "visibility" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`"
|
||||
} else if identifier == "content" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
} else if identifier == "creator_id" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`"
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "has_task_list" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
// Use template for boolean comparison
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_false", dbType)
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_false", dbType)
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||
// This is "element in collection" - the second argument is the collection
|
||||
if !slices.Contains([]string{"tags"}, identifier) {
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||
}
|
||||
|
||||
if identifier == "tags" {
|
||||
// Handle "element" in tags
|
||||
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if identifier == "tag" {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
for _, v := range values {
|
||||
subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType))
|
||||
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||
}
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
} else if identifier == "visibility" {
|
||||
placeholders := filter.FormatPlaceholders(dbType, len(values), 1)
|
||||
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
identifier := v.IdentExpr.GetName()
|
||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
if identifier == "pinned" {
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
// Handle has_task_list as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*DB) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
130
store/db/mysql/memo_filter_test.go
Normal file
130
store/db/mysql/memo_filter_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?))",
|
||||
args: []any{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: "NOT ((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)))",
|
||||
args: []any{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
filter: `content.contains("memos")`,
|
||||
want: "`memo`.`content` LIKE ?",
|
||||
args: []any{"%memos%"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "`memo`.`visibility` IN (?)",
|
||||
args: []any{"PUBLIC"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "`memo`.`visibility` IN (?,?)",
|
||||
args: []any{"PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR `memo`.`content` LIKE ?)",
|
||||
args: []any{"tag1", "%hello%"},
|
||||
},
|
||||
{
|
||||
filter: `1`,
|
||||
want: "",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `pinned`,
|
||||
want: "`memo`.`pinned` IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == true`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list != false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `!has_task_list`,
|
||||
want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON))",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && pinned`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`pinned` IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && content.contains("todo")`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`content` LIKE ?)",
|
||||
args: []any{"%todo%"},
|
||||
},
|
||||
{
|
||||
filter: `created_ts > now() - 60 * 60 * 24`,
|
||||
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?",
|
||||
args: []any{time.Now().Unix() - 60*60*24},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 0`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) > 0`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `"work" in tags`,
|
||||
want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
args: []any{"work"},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 2`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(2)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
db := &DB{}
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
111
store/db/mysql/memo_relation.go
Normal file
111
store/db/mysql/memo_relation.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?)"
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoRelation := store.MemoRelation{
|
||||
MemoID: create.MemoID,
|
||||
RelatedMemoID: create.RelatedMemoID,
|
||||
Type: create.Type,
|
||||
}
|
||||
|
||||
return &memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "`related_memo_id` = ?"), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type)
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
args = append(args, append(convertCtx.Args, convertCtx.Args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `memo_id`, `related_memo_id`, `type` FROM `memo_relation` WHERE "+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "`related_memo_id` = ?"), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, delete.Type)
|
||||
}
|
||||
stmt := "DELETE FROM `memo_relation` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
53
store/db/mysql/migration_history.go
Normal file
53
store/db/mysql/migration_history.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
|
||||
query := "SELECT `version`, UNIX_TIMESTAMP(`created_ts`) FROM `migration_history` ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
|
||||
stmt := "INSERT INTO `migration_history` (`version`) VALUES (?) ON DUPLICATE KEY UPDATE `version` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, upsert.Version, upsert.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var migrationHistory store.MigrationHistory
|
||||
stmt = "SELECT `version`, UNIX_TIMESTAMP(`created_ts`) FROM `migration_history` WHERE `version` = ?"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &migrationHistory, nil
|
||||
}
|
||||
68
store/db/mysql/mysql.go
Normal file
68
store/db/mysql/mysql.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
config *mysql.Config
|
||||
}
|
||||
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
// Open MySQL connection with parameter.
|
||||
// multiStatements=true is required for migration.
|
||||
// See more in: https://github.com/go-sql-driver/mysql#multistatements
|
||||
dsn, err := mergeDSN(profile.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driver := DB{profile: profile}
|
||||
driver.config, err = mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, errors.New("Parse DSN eroor")
|
||||
}
|
||||
|
||||
driver.db, err = sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to open db: %s", profile.DSN)
|
||||
}
|
||||
|
||||
return &driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE TABLE_NAME = 'memo' AND TABLE_TYPE = 'BASE TABLE')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func mergeDSN(baseDSN string) (string, error) {
|
||||
config, err := mysql.ParseDSN(baseDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to parse DSN: %s", baseDSN)
|
||||
}
|
||||
|
||||
config.MultiStatements = true
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
104
store/db/mysql/reaction.go
Normal file
104
store/db/mysql/reaction.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
|
||||
fields := []string{"`creator_id`", "`content_id`", "`reaction_type`"}
|
||||
placeholder := []string{"?", "?", "?"}
|
||||
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
|
||||
stmt := "INSERT INTO `reaction` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := int32(rawID)
|
||||
reaction, err := d.GetReaction(ctx, &store.FindReaction{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reaction == nil {
|
||||
return nil, errors.Errorf("failed to create reaction")
|
||||
}
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "`creator_id` = ?"), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "`content_id` = ?"), append(args, *find.ContentID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
UNIX_TIMESTAMP(created_ts) AS created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Reaction{}
|
||||
for rows.Next() {
|
||||
reaction := &store.Reaction{}
|
||||
if err := rows.Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, reaction)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
|
||||
list, err := d.ListReactions(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reaction := list[0]
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID)
|
||||
return err
|
||||
}
|
||||
162
store/db/mysql/user.go
Normal file
162
store/db/mysql/user.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`", "`avatar_url`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
|
||||
stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
list, err := d.ListUsers(ctx, &store.FindUser{ID: &id32})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected user count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "`username` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "`email` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "`nickname` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "`avatar_url` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "`password_hash` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Description; v != nil {
|
||||
set, args = append(set, "`description` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "`role` = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := "UPDATE `user` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "`username` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "`role` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "`email` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "`nickname` = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
orderBy := []string{"`created_ts` DESC", "`row_status` DESC"}
|
||||
query := "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, `description`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE " + strings.Join(where, " AND ") + " ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if v := find.Limit; v != nil {
|
||||
query += fmt.Sprintf(" LIMIT %d", *v)
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
|
||||
list, err := d.ListUsers(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected user count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `user` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
56
store/db/mysql/user_setting.go
Normal file
56
store/db/mysql/user_setting.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := "INSERT INTO `user_setting` (`user_id`, `key`, `value`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value, upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "`key` = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := "SELECT `user_id`, `key`, `value` FROM `user_setting` WHERE " + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &store.UserSetting{}
|
||||
var keyString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&keyString,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
65
store/db/mysql/workspace_setting.go
Normal file
65
store/db/mysql/workspace_setting.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *store.WorkspaceSetting) (*store.WorkspaceSetting, error) {
|
||||
stmt := "INSERT INTO `system_setting` (`name`, `value`, `description`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?, `description` = ?"
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
upsert.Name,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*store.WorkspaceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "`name` = ?"), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := "SELECT `name`, `value`, `description` FROM `system_setting` WHERE " + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.WorkspaceSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.WorkspaceSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteWorkspaceSetting(ctx context.Context, delete *store.DeleteWorkspaceSetting) error {
|
||||
stmt := "DELETE FROM `system_setting` WHERE `name` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
|
||||
return err
|
||||
}
|
||||
81
store/db/postgres/activity.go
Normal file
81
store/db/postgres/activity.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"creator_id", "type", "level", "payload"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT id, creator_id, type, level, payload, created_ts FROM activity WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
186
store/db/postgres/attachment.go
Normal file
186
store/db/postgres/attachment.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
|
||||
fields := []string{"uid", "filename", "blob", "type", "size", "creator_id", "memo_id", "storage_type", "reference", "payload"}
|
||||
storageType := ""
|
||||
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
|
||||
storageType = create.StorageType.String()
|
||||
}
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "memo_id IS NOT NULL")
|
||||
}
|
||||
if v := find.StorageType; v != nil {
|
||||
where, args = append(where, "storage_type = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
}
|
||||
|
||||
fields := []string{"id", "uid", "filename", "type", "size", "creator_id", "created_ts", "updated_ts", "memo_id", "storage_type", "reference", "payload"}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "blob")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
%s
|
||||
FROM resource
|
||||
WHERE %s
|
||||
ORDER BY updated_ts DESC
|
||||
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Attachment, 0)
|
||||
for rows.Next() {
|
||||
attachment := store.Attachment{}
|
||||
var memoID sql.NullInt32
|
||||
var storageType string
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&attachment.ID,
|
||||
&attachment.UID,
|
||||
&attachment.Filename,
|
||||
&attachment.Type,
|
||||
&attachment.Size,
|
||||
&attachment.CreatorID,
|
||||
&attachment.CreatedTs,
|
||||
&attachment.UpdatedTs,
|
||||
&memoID,
|
||||
&storageType,
|
||||
&attachment.Reference,
|
||||
&payloadBytes,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &attachment.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
attachment.MemoID = &memoID.Int32
|
||||
}
|
||||
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
|
||||
payload := &storepb.AttachmentPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachment.Payload = payload
|
||||
list = append(list, &attachment)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Reference; v != nil {
|
||||
set, args = append(set, "reference = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
bytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(bytes))
|
||||
}
|
||||
|
||||
stmt := `UPDATE resource SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
|
||||
args = append(args, update.ID)
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := `DELETE FROM resource WHERE id = $1`
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
26
store/db/postgres/common.go
Normal file
26
store/db/postgres/common.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
|
||||
func placeholder(n int) string {
|
||||
return "$" + fmt.Sprint(n)
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
list := []string{}
|
||||
for i := 0; i < n; i++ {
|
||||
list = append(list, placeholder(i+1))
|
||||
}
|
||||
return strings.Join(list, ", ")
|
||||
}
|
||||
117
store/db/postgres/idp.go
Normal file
117
store/db/postgres/idp.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
fields := []string{"name", "type", "identifier_filter", "config"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider := create
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
FROM idp
|
||||
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
stmt := `
|
||||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
return &identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"id = $1"}, []any{delete.ID}
|
||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
141
store/db/postgres/inbox.go
Normal file
141
store/db/postgres/inbox.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"sender_id", "receiver_id", "status", "message"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
where, args = append(where, "sender_id = "+placeholder(len(args)+1)), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
where, args = append(where, "receiver_id = "+placeholder(len(args)+1)), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
where, args = append(where, "status = "+placeholder(len(args)+1)), append(args, *find.Status)
|
||||
}
|
||||
|
||||
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
|
||||
list, err := d.ListInboxes(ctx, find)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get inbox")
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected inbox count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
set, args := []string{"status = $1"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE inbox SET " + strings.Join(set, ", ") + " WHERE id = $2 RETURNING id, created_ts, sender_id, receiver_id, status, message"
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = $1", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
279
store/db/postgres/memo.go
Normal file
279
store/db/postgres/memo.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
fields := []string{"uid", "creator_id", "content", "visibility", "payload"}
|
||||
payload := "{}"
|
||||
if create.Payload != nil {
|
||||
payloadBytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = string(payloadBytes)
|
||||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "memo.uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "memo.row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsBefore; v != nil {
|
||||
where, args = append(where, "memo.created_ts < "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsAfter; v != nil {
|
||||
where, args = append(where, "memo.created_ts > "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsBefore; v != nil {
|
||||
where, args = append(where, "memo.updated_ts < "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsAfter; v != nil {
|
||||
where, args = append(where, "memo.updated_ts > "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.ContentSearch; len(v) != 0 {
|
||||
for _, s := range v {
|
||||
where, args = append(where, "memo.content ILIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", s))
|
||||
}
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
holders := []string{}
|
||||
for _, visibility := range v {
|
||||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(holders, ", ")))
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
where, args = append(where, "memo.pinned = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.PayloadFind; v != nil {
|
||||
if v.Raw != nil {
|
||||
where, args = append(where, "memo.payload = "+placeholder(len(args)+1)), append(args, *v.Raw)
|
||||
}
|
||||
if len(v.TagSearch) != 0 {
|
||||
for _, tag := range v.TagSearch {
|
||||
where, args = append(where, "EXISTS (SELECT 1 FROM jsonb_array_elements(memo.payload->'tags') AS tag WHERE tag::text = "+placeholder(len(args)+1)+" OR tag::text LIKE "+placeholder(len(args)+2)+")"), append(args, fmt.Sprintf(`"%s"`, tag), fmt.Sprintf(`"%s/%%"`, tag))
|
||||
}
|
||||
}
|
||||
if v.HasLink {
|
||||
where = append(where, "(memo.payload->'property'->>'hasLink')::BOOLEAN IS TRUE")
|
||||
}
|
||||
if v.HasTaskList {
|
||||
where = append(where, "(memo.payload->'property'->>'hasTaskList')::BOOLEAN IS TRUE")
|
||||
}
|
||||
if v.HasCode {
|
||||
where = append(where, "(memo.payload->'property'->>'hasCode')::BOOLEAN IS TRUE")
|
||||
}
|
||||
if v.HasIncompleteTasks {
|
||||
where = append(where, "(memo.payload->'property'->>'hasIncompleteTasks')::BOOLEAN IS TRUE")
|
||||
}
|
||||
}
|
||||
if v := find.Filter; v != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
convertCtx.ArgsOffset = len(args)
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
where = append(where, "memo_relation.related_memo_id IS NULL")
|
||||
}
|
||||
|
||||
order := "DESC"
|
||||
if find.OrderByTimeAsc {
|
||||
order = "ASC"
|
||||
}
|
||||
orderBy := []string{}
|
||||
if find.OrderByPinned {
|
||||
orderBy = append(orderBy, "pinned DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
orderBy = append(orderBy, "updated_ts "+order)
|
||||
} else {
|
||||
orderBy = append(orderBy, "created_ts "+order)
|
||||
}
|
||||
fields := []string{
|
||||
`memo.id AS id`,
|
||||
`memo.uid AS uid`,
|
||||
`memo.creator_id AS creator_id`,
|
||||
`memo.created_ts AS created_ts`,
|
||||
`memo.updated_ts AS updated_ts`,
|
||||
`memo.row_status AS row_status`,
|
||||
`memo.visibility AS visibility`,
|
||||
`memo.pinned AS pinned`,
|
||||
`memo.payload AS payload`,
|
||||
`memo_relation.related_memo_id AS parent_id`,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, `memo.content AS content`)
|
||||
}
|
||||
|
||||
query := `SELECT ` + strings.Join(fields, ", ") + `
|
||||
FROM memo
|
||||
LEFT JOIN memo_relation ON memo.id = memo_relation.memo_id AND memo_relation.type = 'COMMENT'
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY ` + strings.Join(orderBy, ", ")
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.UID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&payloadBytes,
|
||||
&memo.ParentID,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &storepb.MemoPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal payload")
|
||||
}
|
||||
memo.Payload = payload
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
|
||||
list, err := d.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memo := list[0]
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "created_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Pinned; v != nil {
|
||||
set, args = append(set, "pinned = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
payloadBytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(payloadBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := `UPDATE memo SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
|
||||
args = append(args, update.ID)
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"id = " + placeholder(1)}, []any{delete.ID}
|
||||
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete memo")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
326
store/db/postgres/memo_filter.go
Normal file
326
store/db/postgres/memo_filter.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
const dbType = filter.PostgreSQLTemplate
|
||||
// Fix: Use ctx.ArgsOffset instead of len(ctx.Args) to properly handle parameter indexing
|
||||
_, err := d.convertWithParameterIndex(ctx, expr, dbType, ctx.ArgsOffset+1)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.Expr, dbType filter.TemplateDBType, paramIndex int) (int, error) {
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
newParamIndex, err = d.convertWithParameterIndex(ctx, v.CallExpr.Args[1], dbType, newParamIndex)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
return newParamIndex, nil
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
return newParamIndex, nil
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
// Handle size(tags) comparison
|
||||
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||
return paramIndex, errors.New("size function requires exactly one argument")
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if identifier != "tags" {
|
||||
return paramIndex, errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("size comparison value must be an integer")
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
|
||||
filter.GetSQL("json_array_length", dbType), operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampSQL, operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return paramIndex + 1, nil
|
||||
} else if identifier == "visibility" || identifier == "content" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid string value")
|
||||
}
|
||||
|
||||
var sqlTemplate string
|
||||
if identifier == "visibility" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".visibility"
|
||||
} else if identifier == "content" {
|
||||
sqlTemplate = filter.GetSQL("content_like", dbType)
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", valueStr))
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
return paramIndex + 1, nil
|
||||
} else if identifier == "creator_id" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid int value")
|
||||
}
|
||||
|
||||
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".creator_id"
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return paramIndex + 1, nil
|
||||
} else if identifier == "has_task_list" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
// Use parameterized template for boolean comparison (PostgreSQL only)
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||
sqlTemplate := fmt.Sprintf(filter.GetSQL("boolean_compare", dbType), operator)
|
||||
sqlTemplate = strings.Replace(sqlTemplate, "?", placeholder, 1)
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueBool)
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||
// This is "element in collection" - the second argument is the collection
|
||||
if !slices.Contains([]string{"tags"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||
}
|
||||
|
||||
if identifier == "tags" {
|
||||
// Handle "element" in tags
|
||||
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||
sql := strings.Replace(filter.GetSQL("json_contains_element", dbType), "?", placeholder, 1)
|
||||
if _, err := ctx.Buffer.WriteString(sql); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
return paramIndex, nil
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if identifier == "tag" {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
currentParamIndex := paramIndex
|
||||
for _, v := range values {
|
||||
// Use parameter index for each placeholder
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, currentParamIndex)
|
||||
subcondition := strings.Replace(filter.GetSQL("json_contains_tag", dbType), "?", placeholder, 1)
|
||||
subconditions = append(subconditions, subcondition)
|
||||
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||
currentParamIndex++
|
||||
}
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
}
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
return paramIndex + len(args), nil
|
||||
} else if identifier == "visibility" {
|
||||
placeholders := filter.FormatPlaceholders(dbType, len(values), paramIndex)
|
||||
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return paramIndex + len(values), nil
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if identifier != "content" {
|
||||
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||
sql := strings.Replace(filter.GetSQL("content_like", dbType), "?", placeholder, 1)
|
||||
if _, err := ctx.Buffer.WriteString(sql); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
identifier := v.IdentExpr.GetName()
|
||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
if identifier == "pinned" {
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".pinned IS TRUE"); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
// Handle has_task_list as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return paramIndex, nil
|
||||
}
|
||||
|
||||
func (*DB) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
130
store/db/postgres/memo_filter_test.go
Normal file
130
store/db/postgres/memo_filter_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestRestoreExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "(memo.payload->'tags' @> jsonb_build_array($1) OR memo.payload->'tags' @> jsonb_build_array($2))",
|
||||
args: []any{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: `NOT ((memo.payload->'tags' @> jsonb_build_array($1) OR memo.payload->'tags' @> jsonb_build_array($2)))`,
|
||||
args: []any{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
filter: `content.contains("memos")`,
|
||||
want: "memo.content ILIKE $1",
|
||||
args: []any{"%memos%"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "memo.visibility IN ($1)",
|
||||
args: []any{"PUBLIC"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "memo.visibility IN ($1,$2)",
|
||||
args: []any{"PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||
want: "(memo.payload->'tags' @> jsonb_build_array($1) OR memo.content ILIKE $2)",
|
||||
args: []any{"tag1", "%hello%"},
|
||||
},
|
||||
{
|
||||
filter: `1`,
|
||||
want: "",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `pinned`,
|
||||
want: "memo.pinned IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == true`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1",
|
||||
args: []any{true},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list != false`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean != $1",
|
||||
args: []any{false},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == false`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1",
|
||||
args: []any{false},
|
||||
},
|
||||
{
|
||||
filter: `!has_task_list`,
|
||||
want: "NOT ((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && pinned`,
|
||||
want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.pinned IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && content.contains("todo")`,
|
||||
want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.content ILIKE $1)",
|
||||
args: []any{"%todo%"},
|
||||
},
|
||||
{
|
||||
filter: `created_ts > now() - 60 * 60 * 24`,
|
||||
want: "EXTRACT(EPOCH FROM memo.created_ts) > $1",
|
||||
args: []any{time.Now().Unix() - 60*60*24},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 0`,
|
||||
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) > 0`,
|
||||
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) > $1",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `"work" in tags`,
|
||||
want: "memo.payload->'tags' @> jsonb_build_array($1)",
|
||||
args: []any{"work"},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 2`,
|
||||
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
|
||||
args: []any{int64(2)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
db := &DB{}
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
124
store/db/postgres/memo_relation.go
Normal file
124
store/db/postgres/memo_relation.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := `
|
||||
INSERT INTO memo_relation (
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
)
|
||||
VALUES (` + placeholders(3) + `)
|
||||
RETURNING memo_id, related_memo_id, type
|
||||
`
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := d.db.QueryRowContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
).Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
convertCtx.ArgsOffset = len(args)
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
FROM memo_relation
|
||||
WHERE `+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type)
|
||||
}
|
||||
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
store/db/postgres/migration_history.go
Normal file
57
store/db/postgres/migration_history.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
|
||||
query := "SELECT version, created_ts FROM migration_history ORDER BY created_ts DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
|
||||
stmt := `
|
||||
INSERT INTO migration_history (
|
||||
version
|
||||
)
|
||||
VALUES ($1)
|
||||
ON CONFLICT(version) DO UPDATE
|
||||
SET
|
||||
version=EXCLUDED.version
|
||||
RETURNING version, created_ts
|
||||
`
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &migrationHistory, nil
|
||||
}
|
||||
57
store/db/postgres/postgres.go
Normal file
57
store/db/postgres/postgres.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
|
||||
// Import the PostgreSQL driver.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
}
|
||||
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
if profile == nil {
|
||||
return nil, errors.New("profile is nil")
|
||||
}
|
||||
|
||||
// Open the PostgreSQL connection
|
||||
db, err := sql.Open("postgres", profile.DSN)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open database: %s", err)
|
||||
return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN)
|
||||
}
|
||||
|
||||
var driver store.Driver = &DB{
|
||||
db: db,
|
||||
profile: profile,
|
||||
}
|
||||
|
||||
// Return the DB struct
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'memo' AND table_type = 'BASE TABLE')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
79
store/db/postgres/reaction.go
Normal file
79
store/db/postgres/reaction.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
|
||||
fields := []string{"creator_id", "content_id", "reaction_type"}
|
||||
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
|
||||
stmt := "INSERT INTO reaction (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&upsert.ID,
|
||||
&upsert.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reaction := upsert
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "content_id = "+placeholder(len(args)+1)), append(args, *find.ContentID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Reaction{}
|
||||
for rows.Next() {
|
||||
reaction := &store.Reaction{}
|
||||
if err := rows.Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, reaction)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM reaction WHERE id = $1", delete.ID)
|
||||
return err
|
||||
}
|
||||
166
store/db/postgres/user.go
Normal file
166
store/db/postgres/user.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
fields := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
stmt := "INSERT INTO \"user\" (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, description, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.Description,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "username = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "avatar_url = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Description; v != nil {
|
||||
set, args = append(set, "description = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "role = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE "user"
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, username, role, email, nickname, password_hash, avatar_url, description, created_ts, updated_ts, row_status
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
user := &store.User{}
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
orderBy := []string{"created_ts DESC", "row_status DESC"}
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url,
|
||||
description,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status
|
||||
FROM "user"
|
||||
WHERE ` + strings.Join(where, " AND ") + ` ORDER BY ` + strings.Join(orderBy, ", ")
|
||||
if v := find.Limit; v != nil {
|
||||
query += fmt.Sprintf(" LIMIT %d", *v)
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, `DELETE FROM "user" WHERE id = $1`, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
69
store/db/postgres/user_setting.go
Normal file
69
store/db/postgres/user_setting.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (
|
||||
user_id, key, value
|
||||
)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &store.UserSetting{}
|
||||
var keyString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&keyString,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
72
store/db/postgres/workspace_setting.go
Normal file
72
store/db/postgres/workspace_setting.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *store.WorkspaceSetting) (*store.WorkspaceSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO system_setting (
|
||||
name, value, description
|
||||
)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT(name) DO UPDATE
|
||||
SET
|
||||
value = EXCLUDED.value,
|
||||
description = EXCLUDED.description
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*store.WorkspaceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
name,
|
||||
value,
|
||||
description
|
||||
FROM system_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.WorkspaceSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.WorkspaceSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteWorkspaceSetting(ctx context.Context, delete *store.DeleteWorkspaceSetting) error {
|
||||
stmt := `DELETE FROM system_setting WHERE name = $1`
|
||||
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
|
||||
return err
|
||||
}
|
||||
83
store/db/sqlite/activity.go
Normal file
83
store/db/sqlite/activity.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"`creator_id`", "`type`", "`level`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
|
||||
stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `creator_id`, `type`, `level`, `payload`, `created_ts` FROM `activity` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
182
store/db/sqlite/attachment.go
Normal file
182
store/db/sqlite/attachment.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
|
||||
fields := []string{"`uid`", "`filename`", "`blob`", "`type`", "`size`", "`creator_id`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?", "?", "?", "?", "?"}
|
||||
storageType := ""
|
||||
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
|
||||
storageType = create.StorageType.String()
|
||||
}
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "`memo_id` IS NOT NULL")
|
||||
}
|
||||
if find.StorageType != nil {
|
||||
where, args = append(where, "`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
}
|
||||
|
||||
fields := []string{"`id`", "`uid`", "`filename`", "`type`", "`size`", "`creator_id`", "`created_ts`", "`updated_ts`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "`blob`")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT %s FROM `resource` WHERE %s ORDER BY `updated_ts` DESC", strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Attachment, 0)
|
||||
for rows.Next() {
|
||||
attachment := store.Attachment{}
|
||||
var memoID sql.NullInt32
|
||||
var storageType string
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&attachment.ID,
|
||||
&attachment.UID,
|
||||
&attachment.Filename,
|
||||
&attachment.Type,
|
||||
&attachment.Size,
|
||||
&attachment.CreatorID,
|
||||
&attachment.CreatedTs,
|
||||
&attachment.UpdatedTs,
|
||||
&memoID,
|
||||
&storageType,
|
||||
&attachment.Reference,
|
||||
&payloadBytes,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &attachment.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
attachment.MemoID = &memoID.Int32
|
||||
}
|
||||
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
|
||||
payload := &storepb.AttachmentPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachment.Payload = payload
|
||||
list = append(list, &attachment)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Reference; v != nil {
|
||||
set, args = append(set, "`reference` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
bytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(bytes))
|
||||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to update attachment")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := "DELETE FROM `resource` WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
9
store/db/sqlite/common.go
Normal file
9
store/db/sqlite/common.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package sqlite
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
117
store/db/sqlite/idp.go
Normal file
117
store/db/sqlite/idp.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
placeholders := []string{"?", "?", "?", "?"}
|
||||
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
|
||||
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider := create
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
FROM idp
|
||||
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "identifier_filter = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
set, args = append(set, "config = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
`
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
return &identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"id = ?"}, []any{delete.ID}
|
||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
132
store/db/sqlite/inbox.go
Normal file
132
store/db/sqlite/inbox.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"`sender_id`", "`receiver_id`", "`status`", "`message`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
|
||||
stmt := "INSERT INTO `inbox` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
where, args = append(where, "`sender_id` = ?"), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
where, args = append(where, "`receiver_id` = ?"), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
where, args = append(where, "`status` = ?"), append(args, *find.Status)
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
set, args := []string{"`status` = ?"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE `inbox` SET " + strings.Join(set, ", ") + " WHERE `id` = ? RETURNING `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message`"
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `inbox` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
265
store/db/sqlite/memo.go
Normal file
265
store/db/sqlite/memo.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
fields := []string{"`uid`", "`creator_id`", "`content`", "`visibility`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?"}
|
||||
payload := "{}"
|
||||
if create.Payload != nil {
|
||||
payloadBytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = string(payloadBytes)
|
||||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`, `row_status`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`memo`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`memo`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "`memo`.`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsBefore; v != nil {
|
||||
where, args = append(where, "`memo`.`created_ts` < ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsAfter; v != nil {
|
||||
where, args = append(where, "`memo`.`created_ts` > ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsBefore; v != nil {
|
||||
where, args = append(where, "`memo`.`updated_ts` < ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsAfter; v != nil {
|
||||
where, args = append(where, "`memo`.`updated_ts` > ?"), append(args, *v)
|
||||
}
|
||||
if v := find.ContentSearch; len(v) != 0 {
|
||||
for _, s := range v {
|
||||
where, args = append(where, "`memo`.`content` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", s))
|
||||
}
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
placeholder := []string{}
|
||||
for _, visibility := range v {
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ",")))
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
where, args = append(where, "`memo`.`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.PayloadFind; v != nil {
|
||||
if v.Raw != nil {
|
||||
where, args = append(where, "`memo`.`payload` = ?"), append(args, *v.Raw)
|
||||
}
|
||||
if len(v.TagSearch) != 0 {
|
||||
for _, tag := range v.TagSearch {
|
||||
where, args = append(where, "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)"), append(args, fmt.Sprintf(`%%"%s"%%`, tag), fmt.Sprintf(`%%"%s/%%`, tag))
|
||||
}
|
||||
}
|
||||
if v.HasLink {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE")
|
||||
}
|
||||
if v.HasTaskList {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE")
|
||||
}
|
||||
if v.HasCode {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE")
|
||||
}
|
||||
if v.HasIncompleteTasks {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
|
||||
}
|
||||
}
|
||||
if v := find.Filter; v != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
where = append(where, "`parent_id` IS NULL")
|
||||
}
|
||||
|
||||
order := "DESC"
|
||||
if find.OrderByTimeAsc {
|
||||
order = "ASC"
|
||||
}
|
||||
orderBy := []string{}
|
||||
if find.OrderByPinned {
|
||||
orderBy = append(orderBy, "`pinned` DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
orderBy = append(orderBy, "`updated_ts` "+order)
|
||||
} else {
|
||||
orderBy = append(orderBy, "`created_ts` "+order)
|
||||
}
|
||||
fields := []string{
|
||||
"`memo`.`id` AS `id`",
|
||||
"`memo`.`uid` AS `uid`",
|
||||
"`memo`.`creator_id` AS `creator_id`",
|
||||
"`memo`.`created_ts` AS `created_ts`",
|
||||
"`memo`.`updated_ts` AS `updated_ts`",
|
||||
"`memo`.`row_status` AS `row_status`",
|
||||
"`memo`.`visibility` AS `visibility`",
|
||||
"`memo`.`pinned` AS `pinned`",
|
||||
"`memo`.`payload` AS `payload`",
|
||||
"`memo_relation`.`related_memo_id` AS `parent_id`",
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, "`memo`.`content` AS `content`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + "FROM `memo` " +
|
||||
"LEFT JOIN `memo_relation` ON `memo`.`id` = `memo_relation`.`memo_id` AND `memo_relation`.`type` = \"COMMENT\" " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.UID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&payloadBytes,
|
||||
&memo.ParentID,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &storepb.MemoPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal payload")
|
||||
}
|
||||
memo.Payload = payload
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "`created_ts` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "`content` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "`visibility` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Pinned; v != nil {
|
||||
set, args = append(set, "`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
payloadBytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(payloadBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `memo` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `memo` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
304
store/db/sqlite/memo_filter.go
Normal file
304
store/db/sqlite/memo_filter.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
return d.convertWithTemplates(ctx, expr)
|
||||
}
|
||||
|
||||
func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
const dbType = filter.SQLiteTemplate
|
||||
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
// Handle size(tags) comparison
|
||||
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?",
|
||||
filter.GetSQL("json_array_length", dbType), operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "visibility" || identifier == "content" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
var sqlTemplate string
|
||||
if identifier == "visibility" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`"
|
||||
} else if identifier == "content" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
} else if identifier == "creator_id" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`"
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "has_task_list" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
// Use template for boolean comparison
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_false", dbType)
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_false", dbType)
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||
// This is "element in collection" - the second argument is the collection
|
||||
if !slices.Contains([]string{"tags"}, identifier) {
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||
}
|
||||
|
||||
if identifier == "tags" {
|
||||
// Handle "element" in tags
|
||||
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if identifier == "tag" {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
for _, v := range values {
|
||||
subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType))
|
||||
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||
}
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
} else if identifier == "visibility" {
|
||||
placeholders := filter.FormatPlaceholders(dbType, len(values), 1)
|
||||
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
identifier := v.IdentExpr.GetName()
|
||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
if identifier == "pinned" {
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
// Handle has_task_list as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*DB) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
151
store/db/sqlite/memo_filter_test.go
Normal file
151
store/db/sqlite/memo_filter_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)",
|
||||
args: []any{`%"tag1"%`, `%"tag2"%`},
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
|
||||
args: []any{`%"tag1"%`, `%"tag2"%`},
|
||||
},
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`,
|
||||
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
|
||||
args: []any{`%"tag1"%`, `%"tag2"%`, `%"tag3"%`, `%"tag4"%`},
|
||||
},
|
||||
{
|
||||
filter: `content.contains("memos")`,
|
||||
want: "`memo`.`content` LIKE ?",
|
||||
args: []any{"%memos%"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "`memo`.`visibility` IN (?)",
|
||||
args: []any{"PUBLIC"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "`memo`.`visibility` IN (?,?)",
|
||||
args: []any{"PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR `memo`.`content` LIKE ?)",
|
||||
args: []any{`%"tag1"%`, "%hello%"},
|
||||
},
|
||||
{
|
||||
filter: `1`,
|
||||
want: "",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `pinned`,
|
||||
want: "`memo`.`pinned` IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `!pinned`,
|
||||
want: "NOT (`memo`.`pinned` IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `creator_id == 101 || visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "(`memo`.`creator_id` = ? OR `memo`.`visibility` IN (?,?))",
|
||||
args: []any{int64(101), "PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == true`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list != false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `!has_task_list`,
|
||||
want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && pinned`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`pinned` IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && content.contains("todo")`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`content` LIKE ?)",
|
||||
args: []any{"%todo%"},
|
||||
},
|
||||
{
|
||||
filter: `created_ts > now() - 60 * 60 * 24`,
|
||||
want: "`memo`.`created_ts` > ?",
|
||||
args: []any{time.Now().Unix() - 60*60*24},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 0`,
|
||||
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) > 0`,
|
||||
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `"work" in tags`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
args: []any{`%"work"%`},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 2`,
|
||||
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(2)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
db := &DB{}
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
t.Logf("Failed to parse filter: %s, error: %v", tt.filter, err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
if err != nil {
|
||||
t.Logf("Failed to convert filter: %s, error: %v", tt.filter, err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
125
store/db/sqlite/memo_relation.go
Normal file
125
store/db/sqlite/memo_relation.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := `
|
||||
INSERT INTO memo_relation (
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING memo_id, related_memo_id, type
|
||||
`
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := d.db.QueryRowContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
).Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = ?"), append(args, find.Type)
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
args = append(args, append(convertCtx.Args, convertCtx.Args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
FROM memo_relation
|
||||
WHERE `+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "type = ?"), append(args, delete.Type)
|
||||
}
|
||||
stmt := `
|
||||
DELETE FROM memo_relation
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
store/db/sqlite/migration_history.go
Normal file
57
store/db/sqlite/migration_history.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
|
||||
query := "SELECT `version`, `created_ts` FROM `migration_history` ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
|
||||
stmt := `
|
||||
INSERT INTO migration_history (
|
||||
version
|
||||
)
|
||||
VALUES (?)
|
||||
ON CONFLICT(version) DO UPDATE
|
||||
SET
|
||||
version=EXCLUDED.version
|
||||
RETURNING version, created_ts
|
||||
`
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &migrationHistory, nil
|
||||
}
|
||||
80
store/db/sqlite/reaction.go
Normal file
80
store/db/sqlite/reaction.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
|
||||
fields := []string{"`creator_id`", "`content_id`", "`reaction_type`"}
|
||||
placeholder := []string{"?", "?", "?"}
|
||||
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
|
||||
stmt := "INSERT INTO `reaction` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&upsert.ID,
|
||||
&upsert.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reaction := upsert
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "content_id = ?"), append(args, *find.ContentID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Reaction{}
|
||||
for rows.Next() {
|
||||
reaction := &store.Reaction{}
|
||||
if err := rows.Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, reaction)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID)
|
||||
return err
|
||||
}
|
||||
70
store/db/sqlite/sqlite.go
Normal file
70
store/db/sqlite/sqlite.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
// Import the SQLite driver.
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
}
|
||||
|
||||
// NewDB opens a database specified by its database driver name and a
|
||||
// driver-specific data source name, usually consisting of at least a
|
||||
// database name and connection information.
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
// Ensure a DSN is set before attempting to open the database.
|
||||
if profile.DSN == "" {
|
||||
return nil, errors.New("dsn required")
|
||||
}
|
||||
|
||||
// Connect to the database with some sane settings:
|
||||
// - No shared-cache: it's obsolete; WAL journal mode is a better solution.
|
||||
// - No foreign key constraints: it's currently disabled by default, but it's a
|
||||
// good practice to be explicit and prevent future surprises on SQLite upgrades.
|
||||
// - Journal mode set to WAL: it's the recommended journal mode for most applications
|
||||
// as it prevents locking issues.
|
||||
//
|
||||
// Notes:
|
||||
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
|
||||
//
|
||||
// References:
|
||||
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
|
||||
// - https://www.sqlite.org/sharedcache.html
|
||||
// - https://www.sqlite.org/pragma.html
|
||||
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)")
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN)
|
||||
}
|
||||
|
||||
driver := DB{db: sqliteDB, profile: profile}
|
||||
|
||||
return &driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
// Check if the database is initialized by checking if the memo table exists.
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='memo')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
170
store/db/sqlite/user.go
Normal file
170
store/db/sqlite/user.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`, `avatar_url`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING id, description, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.Description,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "updated_ts = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "avatar_url = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Description; v != nil {
|
||||
set, args = append(set, "description = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "role = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := `
|
||||
UPDATE user
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, username, role, email, nickname, password_hash, avatar_url, description, created_ts, updated_ts, row_status
|
||||
`
|
||||
user := &store.User{}
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
orderBy := []string{"created_ts DESC", "row_status DESC"}
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url,
|
||||
description,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status
|
||||
FROM user
|
||||
WHERE ` + strings.Join(where, " AND ") + ` ORDER BY ` + strings.Join(orderBy, ", ")
|
||||
if v := find.Limit; v != nil {
|
||||
query += fmt.Sprintf(" LIMIT %d", *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, `
|
||||
DELETE FROM user WHERE id = ?
|
||||
`, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
68
store/db/sqlite/user_setting.go
Normal file
68
store/db/sqlite/user_setting.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (
|
||||
user_id, key, value
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &store.UserSetting{}
|
||||
var keyString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&keyString,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
72
store/db/sqlite/workspace_setting.go
Normal file
72
store/db/sqlite/workspace_setting.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *store.WorkspaceSetting) (*store.WorkspaceSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO system_setting (
|
||||
name, value, description
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(name) DO UPDATE
|
||||
SET
|
||||
value = EXCLUDED.value,
|
||||
description = EXCLUDED.description
|
||||
`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*store.WorkspaceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "name = ?"), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
name,
|
||||
value,
|
||||
description
|
||||
FROM system_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.WorkspaceSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.WorkspaceSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteWorkspaceSetting(ctx context.Context, delete *store.DeleteWorkspaceSetting) error {
|
||||
stmt := "DELETE FROM system_setting WHERE name = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
|
||||
return err
|
||||
}
|
||||
79
store/driver.go
Normal file
79
store/driver.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
// Driver is an interface for store driver.
|
||||
// It contains all methods that store database driver should implement.
|
||||
type Driver interface {
|
||||
GetDB() *sql.DB
|
||||
Close() error
|
||||
|
||||
IsInitialized(ctx context.Context) (bool, error)
|
||||
|
||||
// MigrationHistory model related methods.
|
||||
FindMigrationHistoryList(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error)
|
||||
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)
|
||||
|
||||
// Activity model related methods.
|
||||
CreateActivity(ctx context.Context, create *Activity) (*Activity, error)
|
||||
ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error)
|
||||
|
||||
// Attachment model related methods.
|
||||
CreateAttachment(ctx context.Context, create *Attachment) (*Attachment, error)
|
||||
ListAttachments(ctx context.Context, find *FindAttachment) ([]*Attachment, error)
|
||||
UpdateAttachment(ctx context.Context, update *UpdateAttachment) error
|
||||
DeleteAttachment(ctx context.Context, delete *DeleteAttachment) error
|
||||
|
||||
// Memo model related methods.
|
||||
CreateMemo(ctx context.Context, create *Memo) (*Memo, error)
|
||||
ListMemos(ctx context.Context, find *FindMemo) ([]*Memo, error)
|
||||
UpdateMemo(ctx context.Context, update *UpdateMemo) error
|
||||
DeleteMemo(ctx context.Context, delete *DeleteMemo) error
|
||||
|
||||
// MemoRelation model related methods.
|
||||
UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error)
|
||||
ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error)
|
||||
DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error
|
||||
|
||||
// WorkspaceSetting model related methods.
|
||||
UpsertWorkspaceSetting(ctx context.Context, upsert *WorkspaceSetting) (*WorkspaceSetting, error)
|
||||
ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*WorkspaceSetting, error)
|
||||
DeleteWorkspaceSetting(ctx context.Context, delete *DeleteWorkspaceSetting) error
|
||||
|
||||
// User model related methods.
|
||||
CreateUser(ctx context.Context, create *User) (*User, error)
|
||||
UpdateUser(ctx context.Context, update *UpdateUser) (*User, error)
|
||||
ListUsers(ctx context.Context, find *FindUser) ([]*User, error)
|
||||
DeleteUser(ctx context.Context, delete *DeleteUser) error
|
||||
|
||||
// UserSetting model related methods.
|
||||
UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error)
|
||||
ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error)
|
||||
|
||||
// IdentityProvider model related methods.
|
||||
CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error)
|
||||
ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error)
|
||||
UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error)
|
||||
DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error
|
||||
|
||||
// Inbox model related methods.
|
||||
CreateInbox(ctx context.Context, create *Inbox) (*Inbox, error)
|
||||
ListInboxes(ctx context.Context, find *FindInbox) ([]*Inbox, error)
|
||||
UpdateInbox(ctx context.Context, update *UpdateInbox) (*Inbox, error)
|
||||
DeleteInbox(ctx context.Context, delete *DeleteInbox) error
|
||||
|
||||
// Reaction model related methods.
|
||||
UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error)
|
||||
ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error)
|
||||
DeleteReaction(ctx context.Context, delete *DeleteReaction) error
|
||||
|
||||
// Shortcut related methods.
|
||||
ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error
|
||||
}
|
||||
182
store/idp.go
Normal file
182
store/idp.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
type IdentityProvider struct {
|
||||
ID int32
|
||||
Name string
|
||||
Type storepb.IdentityProvider_Type
|
||||
IdentifierFilter string
|
||||
Config string
|
||||
}
|
||||
|
||||
type FindIdentityProvider struct {
|
||||
ID *int32
|
||||
}
|
||||
|
||||
type UpdateIdentityProvider struct {
|
||||
ID int32
|
||||
Name *string
|
||||
IdentifierFilter *string
|
||||
Config *string
|
||||
}
|
||||
|
||||
type DeleteIdentityProvider struct {
|
||||
ID int32
|
||||
}
|
||||
|
||||
func (s *Store) CreateIdentityProvider(ctx context.Context, create *storepb.IdentityProvider) (*storepb.IdentityProvider, error) {
|
||||
raw, err := convertIdentityProviderToRaw(create)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProviderRaw, err := s.driver.CreateIdentityProvider(ctx, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*storepb.IdentityProvider, error) {
|
||||
list, err := s.driver.ListIdentityProviders(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProviders := []*storepb.IdentityProvider{}
|
||||
for _, raw := range list {
|
||||
identityProvider, err := convertIdentityProviderFromRaw(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProviders = append(identityProviders, identityProvider)
|
||||
}
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) {
|
||||
list, err := s.ListIdentityProviders(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if len(list) > 1 {
|
||||
return nil, errors.Errorf("Found multiple identity providers with ID %d", *find.ID)
|
||||
}
|
||||
|
||||
identityProvider := list[0]
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
type UpdateIdentityProviderV1 struct {
|
||||
ID int32
|
||||
Type storepb.IdentityProvider_Type
|
||||
Name *string
|
||||
IdentifierFilter *string
|
||||
Config *storepb.IdentityProviderConfig
|
||||
}
|
||||
|
||||
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderV1) (*storepb.IdentityProvider, error) {
|
||||
updateRaw := &UpdateIdentityProvider{
|
||||
ID: update.ID,
|
||||
}
|
||||
if update.Name != nil {
|
||||
updateRaw.Name = update.Name
|
||||
}
|
||||
if update.IdentifierFilter != nil {
|
||||
updateRaw.IdentifierFilter = update.IdentifierFilter
|
||||
}
|
||||
if update.Config != nil {
|
||||
configRaw, err := convertIdentityProviderConfigToRaw(update.Type, update.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updateRaw.Config = &configRaw
|
||||
}
|
||||
identityProviderRaw, err := s.driver.UpdateIdentityProvider(ctx, updateRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
|
||||
err := s.driver.DeleteIdentityProvider(ctx, delete)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
|
||||
identityProvider := &storepb.IdentityProvider{
|
||||
Id: raw.ID,
|
||||
Name: raw.Name,
|
||||
Type: raw.Type,
|
||||
IdentifierFilter: raw.IdentifierFilter,
|
||||
}
|
||||
config, err := convertIdentityProviderConfigFromRaw(identityProvider.Type, raw.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Config = config
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
|
||||
raw := &IdentityProvider{
|
||||
ID: identityProvider.Id,
|
||||
Name: identityProvider.Name,
|
||||
Type: identityProvider.Type,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
}
|
||||
configRaw, err := convertIdentityProviderConfigToRaw(identityProvider.Type, identityProvider.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw.Config = configRaw
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func convertIdentityProviderConfigFromRaw(identityProviderType storepb.IdentityProvider_Type, raw string) (*storepb.IdentityProviderConfig, error) {
|
||||
config := &storepb.IdentityProviderConfig{}
|
||||
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
|
||||
oauth2Config := &storepb.OAuth2Config{}
|
||||
if err := protojsonUnmarshaler.Unmarshal([]byte(raw), oauth2Config); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to unmarshal OAuth2Config")
|
||||
}
|
||||
config.Config = &storepb.IdentityProviderConfig_Oauth2Config{Oauth2Config: oauth2Config}
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func convertIdentityProviderConfigToRaw(identityProviderType storepb.IdentityProvider_Type, config *storepb.IdentityProviderConfig) (string, error) {
|
||||
raw := ""
|
||||
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
|
||||
bytes, err := protojson.Marshal(config.GetOauth2Config())
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "Failed to marshal OAuth2Config")
|
||||
}
|
||||
raw = string(bytes)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
64
store/inbox.go
Normal file
64
store/inbox.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
// InboxStatus is the status for an inbox.
|
||||
type InboxStatus string
|
||||
|
||||
const (
|
||||
UNREAD InboxStatus = "UNREAD"
|
||||
ARCHIVED InboxStatus = "ARCHIVED"
|
||||
)
|
||||
|
||||
func (s InboxStatus) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type Inbox struct {
|
||||
ID int32
|
||||
CreatedTs int64
|
||||
SenderID int32
|
||||
ReceiverID int32
|
||||
Status InboxStatus
|
||||
Message *storepb.InboxMessage
|
||||
}
|
||||
|
||||
type UpdateInbox struct {
|
||||
ID int32
|
||||
Status InboxStatus
|
||||
}
|
||||
|
||||
type FindInbox struct {
|
||||
ID *int32
|
||||
SenderID *int32
|
||||
ReceiverID *int32
|
||||
Status *InboxStatus
|
||||
|
||||
// Pagination
|
||||
Limit *int
|
||||
Offset *int
|
||||
}
|
||||
|
||||
type DeleteInbox struct {
|
||||
ID int32
|
||||
}
|
||||
|
||||
func (s *Store) CreateInbox(ctx context.Context, create *Inbox) (*Inbox, error) {
|
||||
return s.driver.CreateInbox(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) ListInboxes(ctx context.Context, find *FindInbox) ([]*Inbox, error) {
|
||||
return s.driver.ListInboxes(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateInbox(ctx context.Context, update *UpdateInbox) (*Inbox, error) {
|
||||
return s.driver.UpdateInbox(ctx, update)
|
||||
}
|
||||
|
||||
func (s *Store) DeleteInbox(ctx context.Context, delete *DeleteInbox) error {
|
||||
return s.driver.DeleteInbox(ctx, delete)
|
||||
}
|
||||
147
store/memo.go
Normal file
147
store/memo.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/usememos/memos/internal/base"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
// Visibility is the type of a visibility.
|
||||
type Visibility string
|
||||
|
||||
const (
|
||||
// Public is the PUBLIC visibility.
|
||||
Public Visibility = "PUBLIC"
|
||||
// Protected is the PROTECTED visibility.
|
||||
Protected Visibility = "PROTECTED"
|
||||
// Private is the PRIVATE visibility.
|
||||
Private Visibility = "PRIVATE"
|
||||
)
|
||||
|
||||
func (v Visibility) String() string {
|
||||
switch v {
|
||||
case Public:
|
||||
return "PUBLIC"
|
||||
case Protected:
|
||||
return "PROTECTED"
|
||||
case Private:
|
||||
return "PRIVATE"
|
||||
}
|
||||
return "PRIVATE"
|
||||
}
|
||||
|
||||
type Memo struct {
|
||||
// ID is the system generated unique identifier for the memo.
|
||||
ID int32
|
||||
// UID is the user defined unique identifier for the memo.
|
||||
UID string
|
||||
|
||||
// Standard fields
|
||||
RowStatus RowStatus
|
||||
CreatorID int32
|
||||
CreatedTs int64
|
||||
UpdatedTs int64
|
||||
|
||||
// Domain specific fields
|
||||
Content string
|
||||
Visibility Visibility
|
||||
Pinned bool
|
||||
Payload *storepb.MemoPayload
|
||||
|
||||
// Composed fields
|
||||
ParentID *int32
|
||||
}
|
||||
|
||||
type FindMemo struct {
|
||||
ID *int32
|
||||
UID *string
|
||||
|
||||
// Standard fields
|
||||
RowStatus *RowStatus
|
||||
CreatorID *int32
|
||||
CreatedTsAfter *int64
|
||||
CreatedTsBefore *int64
|
||||
UpdatedTsAfter *int64
|
||||
UpdatedTsBefore *int64
|
||||
|
||||
// Domain specific fields
|
||||
ContentSearch []string
|
||||
VisibilityList []Visibility
|
||||
Pinned *bool
|
||||
PayloadFind *FindMemoPayload
|
||||
ExcludeContent bool
|
||||
ExcludeComments bool
|
||||
Filter *string
|
||||
|
||||
// Pagination
|
||||
Limit *int
|
||||
Offset *int
|
||||
|
||||
// Ordering
|
||||
OrderByUpdatedTs bool
|
||||
OrderByPinned bool
|
||||
OrderByTimeAsc bool
|
||||
}
|
||||
|
||||
type FindMemoPayload struct {
|
||||
Raw *string
|
||||
TagSearch []string
|
||||
HasLink bool
|
||||
HasTaskList bool
|
||||
HasCode bool
|
||||
HasIncompleteTasks bool
|
||||
}
|
||||
|
||||
type UpdateMemo struct {
|
||||
ID int32
|
||||
UID *string
|
||||
CreatedTs *int64
|
||||
UpdatedTs *int64
|
||||
RowStatus *RowStatus
|
||||
Content *string
|
||||
Visibility *Visibility
|
||||
Pinned *bool
|
||||
Payload *storepb.MemoPayload
|
||||
}
|
||||
|
||||
type DeleteMemo struct {
|
||||
ID int32
|
||||
}
|
||||
|
||||
func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) {
|
||||
if !base.UIDMatcher.MatchString(create.UID) {
|
||||
return nil, errors.New("invalid uid")
|
||||
}
|
||||
return s.driver.CreateMemo(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*Memo, error) {
|
||||
return s.driver.ListMemos(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) {
|
||||
list, err := s.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memo := list[0]
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error {
|
||||
if update.UID != nil && !base.UIDMatcher.MatchString(*update.UID) {
|
||||
return errors.New("invalid uid")
|
||||
}
|
||||
return s.driver.UpdateMemo(ctx, update)
|
||||
}
|
||||
|
||||
func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error {
|
||||
return s.driver.DeleteMemo(ctx, delete)
|
||||
}
|
||||
45
store/memo_relation.go
Normal file
45
store/memo_relation.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type MemoRelationType string
|
||||
|
||||
const (
|
||||
// MemoRelationReference is the type for a reference memo relation.
|
||||
MemoRelationReference MemoRelationType = "REFERENCE"
|
||||
// MemoRelationComment is the type for a comment memo relation.
|
||||
MemoRelationComment MemoRelationType = "COMMENT"
|
||||
)
|
||||
|
||||
type MemoRelation struct {
|
||||
MemoID int32
|
||||
RelatedMemoID int32
|
||||
Type MemoRelationType
|
||||
}
|
||||
|
||||
type FindMemoRelation struct {
|
||||
MemoID *int32
|
||||
RelatedMemoID *int32
|
||||
Type *MemoRelationType
|
||||
MemoFilter *string
|
||||
}
|
||||
|
||||
type DeleteMemoRelation struct {
|
||||
MemoID *int32
|
||||
RelatedMemoID *int32
|
||||
Type *MemoRelationType
|
||||
}
|
||||
|
||||
func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) {
|
||||
return s.driver.UpsertMemoRelation(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) {
|
||||
return s.driver.ListMemoRelations(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error {
|
||||
return s.driver.DeleteMemoRelation(ctx, delete)
|
||||
}
|
||||
9
store/migration/mysql/0.17/00__inbox.sql
Normal file
9
store/migration/mysql/0.17/00__inbox.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
-- inbox
|
||||
CREATE TABLE `inbox` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`sender_id` INT NOT NULL,
|
||||
`receiver_id` INT NOT NULL,
|
||||
`status` TEXT NOT NULL,
|
||||
`message` TEXT NOT NULL
|
||||
);
|
||||
1
store/migration/mysql/0.17/01__delete_activity.sql
Normal file
1
store/migration/mysql/0.17/01__delete_activity.sql
Normal file
@@ -0,0 +1 @@
|
||||
DELETE FROM `activity`;
|
||||
3
store/migration/mysql/0.18/00__extend_text.sql
Normal file
3
store/migration/mysql/0.18/00__extend_text.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE `system_setting` MODIFY `value` LONGTEXT NOT NULL;
|
||||
ALTER TABLE `user_setting` MODIFY `value` LONGTEXT NOT NULL;
|
||||
ALTER TABLE `user` MODIFY `avatar_url` LONGTEXT NOT NULL;
|
||||
10
store/migration/mysql/0.18/01__webhook.sql
Normal file
10
store/migration/mysql/0.18/01__webhook.sql
Normal file
@@ -0,0 +1,10 @@
|
||||
-- webhook
|
||||
CREATE TABLE `webhook` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`row_status` VARCHAR(256) NOT NULL DEFAULT 'NORMAL',
|
||||
`creator_id` INT NOT NULL,
|
||||
`name` TEXT NOT NULL,
|
||||
`url` TEXT NOT NULL
|
||||
);
|
||||
4
store/migration/mysql/0.18/02__user_setting.sql
Normal file
4
store/migration/mysql/0.18/02__user_setting.sql
Normal file
@@ -0,0 +1,4 @@
|
||||
UPDATE `user_setting` SET `key` = 'USER_SETTING_LOCALE', `value` = REPLACE(`value`, '"', '') WHERE `key` = 'locale';
|
||||
UPDATE `user_setting` SET `key` = 'USER_SETTING_APPEARANCE', `value` = REPLACE(`value`, '"', '') WHERE `key` = 'appearance';
|
||||
UPDATE `user_setting` SET `key` = 'USER_SETTING_MEMO_VISIBILITY', `value` = REPLACE(`value`, '"', '') WHERE `key` = 'memo-visibility';
|
||||
UPDATE `user_setting` SET `key` = 'USER_SETTING_TELEGRAM_USER_ID', `value` = REPLACE(`value`, '"', '') WHERE `key` = 'telegram-user-id';
|
||||
15
store/migration/mysql/0.19/00__add_resource_name.sql
Normal file
15
store/migration/mysql/0.19/00__add_resource_name.sql
Normal file
@@ -0,0 +1,15 @@
|
||||
ALTER TABLE `memo` ADD COLUMN `resource_name` VARCHAR(256) AFTER `id`;
|
||||
|
||||
UPDATE `memo` SET `resource_name` = uuid();
|
||||
|
||||
ALTER TABLE `memo` MODIFY COLUMN `resource_name` VARCHAR(256) NOT NULL;
|
||||
|
||||
CREATE UNIQUE INDEX idx_memo_resource_name ON `memo` (`resource_name`);
|
||||
|
||||
ALTER TABLE `resource` ADD COLUMN `resource_name` VARCHAR(256) AFTER `id`;
|
||||
|
||||
UPDATE `resource` SET `resource_name` = uuid();
|
||||
|
||||
ALTER TABLE `resource` MODIFY COLUMN `resource_name` VARCHAR(256) NOT NULL;
|
||||
|
||||
CREATE UNIQUE INDEX idx_resource_resource_name ON `resource` (`resource_name`);
|
||||
9
store/migration/mysql/0.20/00__reaction.sql
Normal file
9
store/migration/mysql/0.20/00__reaction.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
-- reaction
|
||||
CREATE TABLE `reaction` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`creator_id` INT NOT NULL,
|
||||
`content_id` VARCHAR(256) NOT NULL,
|
||||
`reaction_type` VARCHAR(256) NOT NULL,
|
||||
UNIQUE(`creator_id`,`content_id`,`reaction_type`)
|
||||
);
|
||||
1
store/migration/mysql/0.21/00__user_description.sql
Normal file
1
store/migration/mysql/0.21/00__user_description.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE `user` ADD COLUMN `description` VARCHAR(256) NOT NULL DEFAULT '';
|
||||
3
store/migration/mysql/0.21/01__rename_uid.sql
Normal file
3
store/migration/mysql/0.21/01__rename_uid.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE `memo` RENAME COLUMN `resource_name` TO `uid`;
|
||||
|
||||
ALTER TABLE `resource` RENAME COLUMN `resource_name` TO `uid`;
|
||||
11
store/migration/mysql/0.22/00__resource_storage_type.sql
Normal file
11
store/migration/mysql/0.22/00__resource_storage_type.sql
Normal file
@@ -0,0 +1,11 @@
|
||||
ALTER TABLE `resource` ADD COLUMN `storage_type` VARCHAR(256) NOT NULL DEFAULT '';
|
||||
ALTER TABLE `resource` ADD COLUMN `reference` VARCHAR(256) NOT NULL DEFAULT '';
|
||||
ALTER TABLE `resource` ADD COLUMN `payload` TEXT NOT NULL;
|
||||
|
||||
UPDATE `resource` SET `payload` = '{}';
|
||||
|
||||
UPDATE `resource` SET `storage_type` = 'LOCAL', `reference` = `internal_path` WHERE `internal_path` IS NOT NULL AND `internal_path` != '';
|
||||
UPDATE `resource` SET `storage_type` = 'EXTERNAL', `reference` = `external_link` WHERE `external_link` IS NOT NULL AND `external_link` != '';
|
||||
|
||||
ALTER TABLE `resource` DROP COLUMN `internal_path`;
|
||||
ALTER TABLE `resource` DROP COLUMN `external_link`;
|
||||
3
store/migration/mysql/0.22/01__memo_tags.sql
Normal file
3
store/migration/mysql/0.22/01__memo_tags.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE `memo` ADD COLUMN `tags_temp` JSON;
|
||||
UPDATE `memo` SET `tags_temp` = '[]';
|
||||
ALTER TABLE `memo` CHANGE COLUMN `tags_temp` `tags` JSON NOT NULL;
|
||||
3
store/migration/mysql/0.22/02__memo_payload.sql
Normal file
3
store/migration/mysql/0.22/02__memo_payload.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE `memo` ADD COLUMN `payload_temp` JSON;
|
||||
UPDATE `memo` SET `payload_temp` = '{}';
|
||||
ALTER TABLE `memo` CHANGE COLUMN `payload_temp` `payload` JSON NOT NULL;
|
||||
1
store/migration/mysql/0.22/03__drop_tag.sql
Normal file
1
store/migration/mysql/0.22/03__drop_tag.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS `tag`;
|
||||
12
store/migration/mysql/0.23/00__reactions.sql
Normal file
12
store/migration/mysql/0.23/00__reactions.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
UPDATE `reaction` SET `reaction_type` = '👍' WHERE `reaction_type` = 'THUMBS_UP';
|
||||
UPDATE `reaction` SET `reaction_type` = '👎' WHERE `reaction_type` = 'THUMBS_DOWN';
|
||||
UPDATE `reaction` SET `reaction_type` = '💛' WHERE `reaction_type` = 'HEART';
|
||||
UPDATE `reaction` SET `reaction_type` = '🔥' WHERE `reaction_type` = 'FIRE';
|
||||
UPDATE `reaction` SET `reaction_type` = '👏' WHERE `reaction_type` = 'CLAPPING_HANDS';
|
||||
UPDATE `reaction` SET `reaction_type` = '😂' WHERE `reaction_type` = 'LAUGH';
|
||||
UPDATE `reaction` SET `reaction_type` = '👌' WHERE `reaction_type` = 'OK_HAND';
|
||||
UPDATE `reaction` SET `reaction_type` = '🚀' WHERE `reaction_type` = 'ROCKET';
|
||||
UPDATE `reaction` SET `reaction_type` = '👀' WHERE `reaction_type` = 'EYES';
|
||||
UPDATE `reaction` SET `reaction_type` = '🤔' WHERE `reaction_type` = 'THINKING_FACE';
|
||||
UPDATE `reaction` SET `reaction_type` = '🤡' WHERE `reaction_type` = 'CLOWN_FACE';
|
||||
UPDATE `reaction` SET `reaction_type` = '❓' WHERE `reaction_type` = 'QUESTION_MARK';
|
||||
2
store/migration/mysql/0.24/00__memo.sql
Normal file
2
store/migration/mysql/0.24/00__memo.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
-- Drop deprecated tags column.
|
||||
ALTER TABLE `memo` DROP COLUMN `tags`;
|
||||
8
store/migration/mysql/0.24/01__memo_pinned.sql
Normal file
8
store/migration/mysql/0.24/01__memo_pinned.sql
Normal file
@@ -0,0 +1,8 @@
|
||||
-- Add pinned column.
|
||||
ALTER TABLE `memo` ADD COLUMN `pinned` BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Update pinned column from memo_organizer.
|
||||
UPDATE memo
|
||||
JOIN memo_organizer ON memo.id = memo_organizer.memo_id
|
||||
SET memo.pinned = TRUE
|
||||
WHERE memo_organizer.pinned = 1;
|
||||
2
store/migration/mysql/0.24/02__s3_reference_length.sql
Normal file
2
store/migration/mysql/0.24/02__s3_reference_length.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
-- https://github.com/usememos/memos/issues/4322
|
||||
ALTER TABLE `resource` MODIFY `reference` TEXT NOT NULL DEFAULT ('');
|
||||
1
store/migration/mysql/0.25/00__remove_webhook.sql
Normal file
1
store/migration/mysql/0.25/00__remove_webhook.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS webhook;
|
||||
121
store/migration/mysql/LATEST.sql
Normal file
121
store/migration/mysql/LATEST.sql
Normal file
@@ -0,0 +1,121 @@
|
||||
-- migration_history
|
||||
CREATE TABLE `migration_history` (
|
||||
`version` VARCHAR(256) NOT NULL PRIMARY KEY,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- system_setting
|
||||
CREATE TABLE `system_setting` (
|
||||
`name` VARCHAR(256) NOT NULL PRIMARY KEY,
|
||||
`value` LONGTEXT NOT NULL,
|
||||
`description` TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- user
|
||||
CREATE TABLE `user` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`row_status` VARCHAR(256) NOT NULL DEFAULT 'NORMAL',
|
||||
`username` VARCHAR(256) NOT NULL UNIQUE,
|
||||
`role` VARCHAR(256) NOT NULL DEFAULT 'USER',
|
||||
`email` VARCHAR(256) NOT NULL DEFAULT '',
|
||||
`nickname` VARCHAR(256) NOT NULL DEFAULT '',
|
||||
`password_hash` VARCHAR(256) NOT NULL,
|
||||
`avatar_url` LONGTEXT NOT NULL,
|
||||
`description` VARCHAR(256) NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE `user_setting` (
|
||||
`user_id` INT NOT NULL,
|
||||
`key` VARCHAR(256) NOT NULL,
|
||||
`value` LONGTEXT NOT NULL,
|
||||
UNIQUE(`user_id`,`key`)
|
||||
);
|
||||
|
||||
-- memo
|
||||
CREATE TABLE `memo` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`uid` VARCHAR(256) NOT NULL UNIQUE,
|
||||
`creator_id` INT NOT NULL,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`row_status` VARCHAR(256) NOT NULL DEFAULT 'NORMAL',
|
||||
`content` TEXT NOT NULL,
|
||||
`visibility` VARCHAR(256) NOT NULL DEFAULT 'PRIVATE',
|
||||
`pinned` BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
`payload` JSON NOT NULL
|
||||
);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE `memo_organizer` (
|
||||
`memo_id` INT NOT NULL,
|
||||
`user_id` INT NOT NULL,
|
||||
`pinned` INT NOT NULL DEFAULT '0',
|
||||
UNIQUE(`memo_id`,`user_id`)
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE `memo_relation` (
|
||||
`memo_id` INT NOT NULL,
|
||||
`related_memo_id` INT NOT NULL,
|
||||
`type` VARCHAR(256) NOT NULL,
|
||||
UNIQUE(`memo_id`,`related_memo_id`,`type`)
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE `resource` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`uid` VARCHAR(256) NOT NULL UNIQUE,
|
||||
`creator_id` INT NOT NULL,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`filename` TEXT NOT NULL,
|
||||
`blob` MEDIUMBLOB,
|
||||
`type` VARCHAR(256) NOT NULL DEFAULT '',
|
||||
`size` INT NOT NULL DEFAULT '0',
|
||||
`memo_id` INT DEFAULT NULL,
|
||||
`storage_type` VARCHAR(256) NOT NULL DEFAULT '',
|
||||
`reference` TEXT NOT NULL DEFAULT (''),
|
||||
`payload` TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- activity
|
||||
CREATE TABLE `activity` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`creator_id` INT NOT NULL,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`type` VARCHAR(256) NOT NULL DEFAULT '',
|
||||
`level` VARCHAR(256) NOT NULL DEFAULT 'INFO',
|
||||
`payload` TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- idp
|
||||
CREATE TABLE `idp` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`name` TEXT NOT NULL,
|
||||
`type` TEXT NOT NULL,
|
||||
`identifier_filter` VARCHAR(256) NOT NULL DEFAULT '',
|
||||
`config` TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- inbox
|
||||
CREATE TABLE `inbox` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`sender_id` INT NOT NULL,
|
||||
`receiver_id` INT NOT NULL,
|
||||
`status` TEXT NOT NULL,
|
||||
`message` TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- reaction
|
||||
CREATE TABLE `reaction` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`creator_id` INT NOT NULL,
|
||||
`content_id` VARCHAR(256) NOT NULL,
|
||||
`reaction_type` VARCHAR(256) NOT NULL,
|
||||
UNIQUE(`creator_id`,`content_id`,`reaction_type`)
|
||||
);
|
||||
15
store/migration/postgres/0.19/00__add_resource_name.sql
Normal file
15
store/migration/postgres/0.19/00__add_resource_name.sql
Normal file
@@ -0,0 +1,15 @@
|
||||
ALTER TABLE memo ADD COLUMN resource_name TEXT;
|
||||
|
||||
UPDATE memo SET resource_name = uuid_in(md5(random()::text || random()::text)::cstring);
|
||||
|
||||
ALTER TABLE memo ALTER COLUMN resource_name SET NOT NULL;
|
||||
|
||||
CREATE UNIQUE INDEX idx_memo_resource_name ON memo (resource_name);
|
||||
|
||||
ALTER TABLE resource ADD COLUMN resource_name TEXT;
|
||||
|
||||
UPDATE resource SET resource_name = uuid_in(md5(random()::text || random()::text)::cstring);
|
||||
|
||||
ALTER TABLE resource ALTER COLUMN resource_name SET NOT NULL;
|
||||
|
||||
CREATE UNIQUE INDEX idx_resource_resource_name ON resource (resource_name);
|
||||
9
store/migration/postgres/0.20/00__reaction.sql
Normal file
9
store/migration/postgres/0.20/00__reaction.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
-- reaction
|
||||
CREATE TABLE reaction (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
creator_id INTEGER NOT NULL,
|
||||
content_id TEXT NOT NULL,
|
||||
reaction_type TEXT NOT NULL,
|
||||
UNIQUE(creator_id, content_id, reaction_type)
|
||||
);
|
||||
1
store/migration/postgres/0.21/00__user_description.sql
Normal file
1
store/migration/postgres/0.21/00__user_description.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE "user" ADD COLUMN description TEXT NOT NULL DEFAULT '';
|
||||
3
store/migration/postgres/0.21/01__rename_uid.sql
Normal file
3
store/migration/postgres/0.21/01__rename_uid.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE memo RENAME COLUMN resource_name TO uid;
|
||||
|
||||
ALTER TABLE resource RENAME COLUMN resource_name TO uid;
|
||||
11
store/migration/postgres/0.22/00__resource_storage_type.sql
Normal file
11
store/migration/postgres/0.22/00__resource_storage_type.sql
Normal file
@@ -0,0 +1,11 @@
|
||||
ALTER TABLE resource ADD COLUMN storage_type TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE resource ADD COLUMN reference TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE resource ADD COLUMN payload TEXT NOT NULL DEFAULT '{}';
|
||||
|
||||
UPDATE resource SET storage_type = 'LOCAL', reference = internal_path WHERE internal_path IS NOT NULL AND internal_path != '';
|
||||
|
||||
UPDATE resource SET storage_type = 'EXTERNAL', reference = external_link WHERE external_link IS NOT NULL AND external_link != '';
|
||||
|
||||
ALTER TABLE resource DROP COLUMN internal_path;
|
||||
|
||||
ALTER TABLE resource DROP COLUMN external_link;
|
||||
1
store/migration/postgres/0.22/01__memo_tags.sql
Normal file
1
store/migration/postgres/0.22/01__memo_tags.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE memo ADD COLUMN tags JSONB NOT NULL DEFAULT '[]';
|
||||
1
store/migration/postgres/0.22/02__memo_payload.sql
Normal file
1
store/migration/postgres/0.22/02__memo_payload.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE memo ADD COLUMN payload JSONB NOT NULL DEFAULT '{}';
|
||||
1
store/migration/postgres/0.22/03__drop_tag.sql
Normal file
1
store/migration/postgres/0.22/03__drop_tag.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS tag;
|
||||
12
store/migration/postgres/0.23/00__reactions.sql
Normal file
12
store/migration/postgres/0.23/00__reactions.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
UPDATE "reaction" SET "reaction_type" = '👍' WHERE "reaction_type" = 'THUMBS_UP';
|
||||
UPDATE "reaction" SET "reaction_type" = '👎' WHERE "reaction_type" = 'THUMBS_DOWN';
|
||||
UPDATE "reaction" SET "reaction_type" = '💛' WHERE "reaction_type" = 'HEART';
|
||||
UPDATE "reaction" SET "reaction_type" = '🔥' WHERE "reaction_type" = 'FIRE';
|
||||
UPDATE "reaction" SET "reaction_type" = '👏' WHERE "reaction_type" = 'CLAPPING_HANDS';
|
||||
UPDATE "reaction" SET "reaction_type" = '😂' WHERE "reaction_type" = 'LAUGH';
|
||||
UPDATE "reaction" SET "reaction_type" = '👌' WHERE "reaction_type" = 'OK_HAND';
|
||||
UPDATE "reaction" SET "reaction_type" = '🚀' WHERE "reaction_type" = 'ROCKET';
|
||||
UPDATE "reaction" SET "reaction_type" = '👀' WHERE "reaction_type" = 'EYES';
|
||||
UPDATE "reaction" SET "reaction_type" = '🤔' WHERE "reaction_type" = 'THINKING_FACE';
|
||||
UPDATE "reaction" SET "reaction_type" = '🤡' WHERE "reaction_type" = 'CLOWN_FACE';
|
||||
UPDATE "reaction" SET "reaction_type" = '❓' WHERE "reaction_type" = 'QUESTION_MARK';
|
||||
2
store/migration/postgres/0.24/00__memo.sql
Normal file
2
store/migration/postgres/0.24/00__memo.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
-- Drop deprecated tags column.
|
||||
ALTER TABLE memo DROP COLUMN IF EXISTS tags;
|
||||
8
store/migration/postgres/0.24/01__memo_pinned.sql
Normal file
8
store/migration/postgres/0.24/01__memo_pinned.sql
Normal file
@@ -0,0 +1,8 @@
|
||||
-- Add pinned column.
|
||||
ALTER TABLE memo ADD COLUMN pinned BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Update pinned column from memo_organizer.
|
||||
UPDATE memo
|
||||
SET pinned = TRUE
|
||||
FROM memo_organizer
|
||||
WHERE memo.id = memo_organizer.memo_id AND memo_organizer.pinned = 1;
|
||||
1
store/migration/postgres/0.25/00__remove_webhook.sql
Normal file
1
store/migration/postgres/0.25/00__remove_webhook.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS webhook;
|
||||
121
store/migration/postgres/LATEST.sql
Normal file
121
store/migration/postgres/LATEST.sql
Normal file
@@ -0,0 +1,121 @@
|
||||
-- migration_history
|
||||
CREATE TABLE migration_history (
|
||||
version TEXT NOT NULL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW())
|
||||
);
|
||||
|
||||
-- system_setting
|
||||
CREATE TABLE system_setting (
|
||||
name TEXT NOT NULL PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
description TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- user
|
||||
CREATE TABLE "user" (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
row_status TEXT NOT NULL DEFAULT 'NORMAL',
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
role TEXT NOT NULL DEFAULT 'USER',
|
||||
email TEXT NOT NULL DEFAULT '',
|
||||
nickname TEXT NOT NULL DEFAULT '',
|
||||
password_hash TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE user_setting (
|
||||
user_id INTEGER NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
UNIQUE(user_id, key)
|
||||
);
|
||||
|
||||
-- memo
|
||||
CREATE TABLE memo (
|
||||
id SERIAL PRIMARY KEY,
|
||||
uid TEXT NOT NULL UNIQUE,
|
||||
creator_id INTEGER NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
row_status TEXT NOT NULL DEFAULT 'NORMAL',
|
||||
content TEXT NOT NULL,
|
||||
visibility TEXT NOT NULL DEFAULT 'PRIVATE',
|
||||
pinned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
payload JSONB NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE memo_organizer (
|
||||
memo_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE(memo_id, user_id)
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE memo_relation (
|
||||
memo_id INTEGER NOT NULL,
|
||||
related_memo_id INTEGER NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
UNIQUE(memo_id, related_memo_id, type)
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE resource (
|
||||
id SERIAL PRIMARY KEY,
|
||||
uid TEXT NOT NULL UNIQUE,
|
||||
creator_id INTEGER NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
filename TEXT NOT NULL,
|
||||
blob BYTEA,
|
||||
type TEXT NOT NULL DEFAULT '',
|
||||
size INTEGER NOT NULL DEFAULT 0,
|
||||
memo_id INTEGER DEFAULT NULL,
|
||||
storage_type TEXT NOT NULL DEFAULT '',
|
||||
reference TEXT NOT NULL DEFAULT '',
|
||||
payload TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
-- activity
|
||||
CREATE TABLE activity (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
type TEXT NOT NULL DEFAULT '',
|
||||
level TEXT NOT NULL DEFAULT 'INFO',
|
||||
payload JSONB NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
-- idp
|
||||
CREATE TABLE idp (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
identifier_filter TEXT NOT NULL DEFAULT '',
|
||||
config JSONB NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
-- inbox
|
||||
CREATE TABLE inbox (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
sender_id INTEGER NOT NULL,
|
||||
receiver_id INTEGER NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
message TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- reaction
|
||||
CREATE TABLE reaction (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
creator_id INTEGER NOT NULL,
|
||||
content_id TEXT NOT NULL,
|
||||
reaction_type TEXT NOT NULL,
|
||||
UNIQUE(creator_id, content_id, reaction_type)
|
||||
);
|
||||
9
store/migration/sqlite/0.10/00__activity.sql
Normal file
9
store/migration/sqlite/0.10/00__activity.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
-- activity
|
||||
CREATE TABLE activity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
creator_id INTEGER NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
type TEXT NOT NULL DEFAULT '',
|
||||
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
|
||||
payload TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
4
store/migration/sqlite/0.11/00__user_avatar.sql
Normal file
4
store/migration/sqlite/0.11/00__user_avatar.sql
Normal file
@@ -0,0 +1,4 @@
|
||||
ALTER TABLE
|
||||
user
|
||||
ADD
|
||||
COLUMN avatar_url TEXT NOT NULL DEFAULT '';
|
||||
8
store/migration/sqlite/0.11/01__idp.sql
Normal file
8
store/migration/sqlite/0.11/01__idp.sql
Normal file
@@ -0,0 +1,8 @@
|
||||
-- idp
|
||||
CREATE TABLE idp (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
identifier_filter TEXT NOT NULL DEFAULT '',
|
||||
config TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
7
store/migration/sqlite/0.11/02__storage.sql
Normal file
7
store/migration/sqlite/0.11/02__storage.sql
Normal file
@@ -0,0 +1,7 @@
|
||||
-- storage
|
||||
CREATE TABLE storage (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
config TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
6
store/migration/sqlite/0.12/00__user_setting.sql
Normal file
6
store/migration/sqlite/0.12/00__user_setting.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
UPDATE
|
||||
user_setting
|
||||
SET
|
||||
key = 'memo-visibility'
|
||||
WHERE
|
||||
key = 'memoVisibility';
|
||||
69
store/migration/sqlite/0.12/01__system_setting.sql
Normal file
69
store/migration/sqlite/0.12/01__system_setting.sql
Normal file
@@ -0,0 +1,69 @@
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'server-id'
|
||||
WHERE
|
||||
name = 'serverId';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'secret-session'
|
||||
WHERE
|
||||
name = 'secretSessionName';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'allow-signup'
|
||||
WHERE
|
||||
name = 'allowSignUp';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'disable-public-memos'
|
||||
WHERE
|
||||
name = 'disablePublicMemos';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'additional-style'
|
||||
WHERE
|
||||
name = 'additionalStyle';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'additional-script'
|
||||
WHERE
|
||||
name = 'additionalScript';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'customized-profile'
|
||||
WHERE
|
||||
name = 'customizedProfile';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'storage-service-id'
|
||||
WHERE
|
||||
name = 'storageServiceId';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'local-storage-path'
|
||||
WHERE
|
||||
name = 'localStoragePath';
|
||||
|
||||
UPDATE
|
||||
system_setting
|
||||
SET
|
||||
name = 'openai-config'
|
||||
WHERE
|
||||
name = 'openAIConfig';
|
||||
@@ -0,0 +1,4 @@
|
||||
ALTER TABLE
|
||||
resource
|
||||
ADD
|
||||
COLUMN internal_path TEXT NOT NULL DEFAULT '';
|
||||
18
store/migration/sqlite/0.12/04__resource_public_id.sql
Normal file
18
store/migration/sqlite/0.12/04__resource_public_id.sql
Normal file
@@ -0,0 +1,18 @@
|
||||
ALTER TABLE
|
||||
resource
|
||||
ADD
|
||||
COLUMN public_id TEXT NOT NULL DEFAULT '';
|
||||
|
||||
CREATE UNIQUE INDEX resource_id_public_id_unique_index ON resource (id, public_id);
|
||||
|
||||
UPDATE
|
||||
resource
|
||||
SET
|
||||
public_id = printf (
|
||||
'%s-%s-%s-%s-%s',
|
||||
lower(hex(randomblob(4))),
|
||||
lower(hex(randomblob(2))),
|
||||
lower(hex(randomblob(2))),
|
||||
lower(hex(randomblob(2))),
|
||||
lower(hex(randomblob(6)))
|
||||
);
|
||||
7
store/migration/sqlite/0.13/00__memo_relation.sql
Normal file
7
store/migration/sqlite/0.13/00__memo_relation.sql
Normal file
@@ -0,0 +1,7 @@
|
||||
-- memo_relation
|
||||
CREATE TABLE memo_relation (
|
||||
memo_id INTEGER NOT NULL,
|
||||
related_memo_id INTEGER NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
UNIQUE(memo_id, related_memo_id, type)
|
||||
);
|
||||
22
store/migration/sqlite/0.13/01__remove_memo_organizer_id.sql
Normal file
22
store/migration/sqlite/0.13/01__remove_memo_organizer_id.sql
Normal file
@@ -0,0 +1,22 @@
|
||||
DROP TABLE IF EXISTS memo_organizer_temp;
|
||||
|
||||
CREATE TABLE memo_organizer_temp (
|
||||
memo_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
pinned INTEGER NOT NULL CHECK (pinned IN (0, 1)) DEFAULT 0,
|
||||
UNIQUE(memo_id, user_id)
|
||||
);
|
||||
|
||||
INSERT INTO
|
||||
memo_organizer_temp (memo_id, user_id, pinned)
|
||||
SELECT
|
||||
memo_id,
|
||||
user_id,
|
||||
pinned
|
||||
FROM
|
||||
memo_organizer;
|
||||
|
||||
DROP TABLE memo_organizer;
|
||||
|
||||
ALTER TABLE
|
||||
memo_organizer_temp RENAME TO memo_organizer;
|
||||
25
store/migration/sqlite/0.14/00__drop_resource_public_id.sql
Normal file
25
store/migration/sqlite/0.14/00__drop_resource_public_id.sql
Normal file
@@ -0,0 +1,25 @@
|
||||
DROP TABLE IF EXISTS resource_temp;
|
||||
|
||||
CREATE TABLE resource_temp (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
creator_id INTEGER NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
filename TEXT NOT NULL DEFAULT '',
|
||||
blob BLOB DEFAULT NULL,
|
||||
external_link TEXT NOT NULL DEFAULT '',
|
||||
type TEXT NOT NULL DEFAULT '',
|
||||
size INTEGER NOT NULL DEFAULT 0,
|
||||
internal_path TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
INSERT INTO
|
||||
resource_temp (id, creator_id, created_ts, updated_ts, filename, blob, external_link, type, size, internal_path)
|
||||
SELECT
|
||||
id, creator_id, created_ts, updated_ts, filename, blob, external_link, type, size, internal_path
|
||||
FROM
|
||||
resource;
|
||||
|
||||
DROP TABLE resource;
|
||||
|
||||
ALTER TABLE resource_temp RENAME TO resource;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user