init commit
This commit is contained in:
93
store/db/mysql/activity.go
Normal file
93
store/db/mysql/activity.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal activity payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
fields := []string{"`creator_id`", "`type`", "`level`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
|
||||
stmt := "INSERT INTO `activity` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute statement")
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get last insert id")
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
|
||||
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id32})
|
||||
if err != nil || len(list) == 0 {
|
||||
return nil, errors.Wrap(err, "failed to find activity")
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `creator_id`, `type`, `level`, `payload`, UNIX_TIMESTAMP(`created_ts`) FROM `activity` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
202
store/db/mysql/attachment.go
Normal file
202
store/db/mysql/attachment.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*store.Attachment, error) {
|
||||
fields := []string{"`uid`", "`filename`", "`blob`", "`type`", "`size`", "`creator_id`", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?", "?", "?", "?", "?"}
|
||||
storageType := ""
|
||||
if create.StorageType != storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED {
|
||||
storageType = create.StorageType.String()
|
||||
}
|
||||
payloadString := "{}"
|
||||
if create.Payload != nil {
|
||||
bytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
return d.GetAttachment(ctx, &store.FindAttachment{ID: &id32})
|
||||
}
|
||||
|
||||
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "`filename` LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "`memo_id` IS NOT NULL")
|
||||
}
|
||||
if find.StorageType != nil {
|
||||
where, args = append(where, "`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
}
|
||||
|
||||
fields := []string{"`id`", "`uid`", "`filename`", "`type`", "`size`", "`creator_id`", "UNIX_TIMESTAMP(`created_ts`)", "UNIX_TIMESTAMP(`updated_ts`)", "`memo_id`", "`storage_type`", "`reference`", "`payload`"}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "`blob`")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("SELECT %s FROM `resource` WHERE %s ORDER BY `updated_ts` DESC", strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Attachment, 0)
|
||||
for rows.Next() {
|
||||
attachment := store.Attachment{}
|
||||
var memoID sql.NullInt32
|
||||
var storageType string
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&attachment.ID,
|
||||
&attachment.UID,
|
||||
&attachment.Filename,
|
||||
&attachment.Type,
|
||||
&attachment.Size,
|
||||
&attachment.CreatorID,
|
||||
&attachment.CreatedTs,
|
||||
&attachment.UpdatedTs,
|
||||
&memoID,
|
||||
&storageType,
|
||||
&attachment.Reference,
|
||||
&payloadBytes,
|
||||
}
|
||||
if find.GetBlob {
|
||||
dests = append(dests, &attachment.Blob)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
attachment.MemoID = &memoID.Int32
|
||||
}
|
||||
attachment.StorageType = storepb.AttachmentStorageType(storepb.AttachmentStorageType_value[storageType])
|
||||
payload := &storepb.AttachmentPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachment.Payload = payload
|
||||
list = append(list, &attachment)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetAttachment(ctx context.Context, find *store.FindAttachment) (*store.Attachment, error) {
|
||||
list, err := d.ListAttachments(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachment) error {
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
set, args = append(set, "`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
set, args = append(set, "`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Reference; v != nil {
|
||||
set, args = append(set, "`reference` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
bytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal attachment payload")
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(bytes))
|
||||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := "DELETE FROM `resource` WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
10
store/db/mysql/common.go
Normal file
10
store/db/mysql/common.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package mysql
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
AllowPartial: true,
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
126
store/db/mysql/idp.go
Normal file
126
store/db/mysql/idp.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
placeholders := []string{"?", "?", "?", "?"}
|
||||
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
|
||||
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create.ID = int32(id)
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviders []*store.IdentityProvider
|
||||
for rows.Next() {
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProvider.Config,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
|
||||
identityProviders = append(identityProviders, &identityProvider)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identityProviders, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
|
||||
list, err := d.ListIdentityProviders(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
identityProvider := list[0]
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "`name` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "`identifier_filter` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
set, args = append(set, "`config` = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `idp` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identityProvider, err := d.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &update.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, errors.Errorf("idp %d not found", update.ID)
|
||||
}
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `idp` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
141
store/db/mysql/inbox.go
Normal file
141
store/db/mysql/inbox.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) {
|
||||
messageString := "{}"
|
||||
if create.Message != nil {
|
||||
bytes, err := protojson.Marshal(create.Message)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal inbox message")
|
||||
}
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
fields := []string{"`sender_id`", "`receiver_id`", "`status`", "`message`"}
|
||||
placeholder := []string{"?", "?", "?", "?"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
|
||||
stmt := "INSERT INTO `inbox` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
inbox, err := d.GetInbox(ctx, &store.FindInbox{ID: &id32})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
where, args = append(where, "`sender_id` = ?"), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
where, args = append(where, "`receiver_id` = ?"), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
where, args = append(where, "`status` = ?"), append(args, *find.Status)
|
||||
}
|
||||
|
||||
query := "SELECT `id`, UNIX_TIMESTAMP(`created_ts`), `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbox.Message = message
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
|
||||
list, err := d.ListInboxes(ctx, find)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get inbox")
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected inbox count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
set, args := []string{"`status` = ?"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE `inbox` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update inbox")
|
||||
}
|
||||
inbox, err := d.GetInbox(ctx, &store.FindInbox{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `inbox` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete inbox")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
287
store/db/mysql/memo.go
Normal file
287
store/db/mysql/memo.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
fields := []string{"`uid`", "`creator_id`", "`content`", "`visibility`", "`payload`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?"}
|
||||
payload := "{}"
|
||||
if create.Payload != nil {
|
||||
payloadBytes, err := protojson.Marshal(create.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = string(payloadBytes)
|
||||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := int32(rawID)
|
||||
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, errors.Errorf("failed to create memo")
|
||||
}
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
where, having, args := []string{"1 = 1"}, []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`memo`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`memo`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "`memo`.`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsBefore; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`created_ts`) < ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsAfter; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsBefore; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`updated_ts`) < ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UpdatedTsAfter; v != nil {
|
||||
where, args = append(where, "UNIX_TIMESTAMP(`memo`.`updated_ts`) > ?"), append(args, *v)
|
||||
}
|
||||
if v := find.ContentSearch; len(v) != 0 {
|
||||
for _, s := range v {
|
||||
where, args = append(where, "`memo`.`content` LIKE ?"), append(args, "%"+s+"%")
|
||||
}
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
placeholder := []string{}
|
||||
for _, visibility := range v {
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("`memo`.`visibility` in (%s)", strings.Join(placeholder, ",")))
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
where, args = append(where, "`memo`.`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.PayloadFind; v != nil {
|
||||
if v.Raw != nil {
|
||||
where, args = append(where, "`memo`.`payload` = ?"), append(args, *v.Raw)
|
||||
}
|
||||
if len(v.TagSearch) != 0 {
|
||||
for _, tag := range v.TagSearch {
|
||||
where, args = append(where, "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?))"), append(args, fmt.Sprintf(`"%s"`, tag), fmt.Sprintf(`"%s/"`, tag))
|
||||
}
|
||||
}
|
||||
if v.HasLink {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE")
|
||||
}
|
||||
if v.HasTaskList {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE")
|
||||
}
|
||||
if v.HasCode {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE")
|
||||
}
|
||||
if v.HasIncompleteTasks {
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
|
||||
}
|
||||
}
|
||||
if v := find.Filter; v != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
||||
args = append(args, convertCtx.Args...)
|
||||
}
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
having = append(having, "`parent_id` IS NULL")
|
||||
}
|
||||
|
||||
order := "DESC"
|
||||
if find.OrderByTimeAsc {
|
||||
order = "ASC"
|
||||
}
|
||||
orderBy := []string{}
|
||||
if find.OrderByPinned {
|
||||
orderBy = append(orderBy, "`pinned` DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
orderBy = append(orderBy, "`updated_ts` "+order)
|
||||
} else {
|
||||
orderBy = append(orderBy, "`created_ts` "+order)
|
||||
}
|
||||
fields := []string{
|
||||
"`memo`.`id` AS `id`",
|
||||
"`memo`.`uid` AS `uid`",
|
||||
"`memo`.`creator_id` AS `creator_id`",
|
||||
"UNIX_TIMESTAMP(`memo`.`created_ts`) AS `created_ts`",
|
||||
"UNIX_TIMESTAMP(`memo`.`updated_ts`) AS `updated_ts`",
|
||||
"`memo`.`row_status` AS `row_status`",
|
||||
"`memo`.`visibility` AS `visibility`",
|
||||
"`memo`.`pinned` AS `pinned`",
|
||||
"`memo`.`payload` AS `payload`",
|
||||
"`memo_relation`.`related_memo_id` AS `parent_id`",
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, "`memo`.`content` AS `content`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `memo`" + " " +
|
||||
"LEFT JOIN `memo_relation` ON `memo`.`id` = `memo_relation`.`memo_id` AND `memo_relation`.`type` = 'COMMENT'" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"HAVING " + strings.Join(having, " AND ") + " " +
|
||||
"ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
var payloadBytes []byte
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.UID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
&payloadBytes,
|
||||
&memo.ParentID,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := &storepb.MemoPayload{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal payload")
|
||||
}
|
||||
memo.Payload = payload
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) {
|
||||
list, err := d.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memo := list[0]
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UID; v != nil {
|
||||
set, args = append(set, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
set, args = append(set, "`created_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
set, args = append(set, "`content` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
set, args = append(set, "`visibility` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Pinned; v != nil {
|
||||
set, args = append(set, "`pinned` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Payload; v != nil {
|
||||
payloadBytes, err := protojson.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set, args = append(set, "`payload` = ?"), append(args, string(payloadBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := "UPDATE `memo` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
where, args := []string{"`id` = ?"}, []any{delete.ID}
|
||||
stmt := "DELETE FROM `memo` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
304
store/db/mysql/memo_filter.go
Normal file
304
store/db/mysql/memo_filter.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
return d.convertWithTemplates(ctx, expr)
|
||||
}
|
||||
|
||||
func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
const dbType = filter.MySQLTemplate
|
||||
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
// Handle size(tags) comparison
|
||||
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?",
|
||||
filter.GetSQL("json_array_length", dbType), operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "visibility" || identifier == "content" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
var sqlTemplate string
|
||||
if identifier == "visibility" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`"
|
||||
} else if identifier == "content" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
} else if identifier == "creator_id" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`"
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "has_task_list" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
// Use template for boolean comparison
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_false", dbType)
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_false", dbType)
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||
// This is "element in collection" - the second argument is the collection
|
||||
if !slices.Contains([]string{"tags"}, identifier) {
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||
}
|
||||
|
||||
if identifier == "tags" {
|
||||
// Handle "element" in tags
|
||||
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if identifier == "tag" {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
for _, v := range values {
|
||||
subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType))
|
||||
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||
}
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
} else if identifier == "visibility" {
|
||||
placeholders := filter.FormatPlaceholders(dbType, len(values), 1)
|
||||
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
identifier := v.IdentExpr.GetName()
|
||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
if identifier == "pinned" {
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
// Handle has_task_list as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*DB) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
130
store/db/mysql/memo_filter_test.go
Normal file
130
store/db/mysql/memo_filter_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?))",
|
||||
args: []any{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: "NOT ((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)))",
|
||||
args: []any{"tag1", "tag2"},
|
||||
},
|
||||
{
|
||||
filter: `content.contains("memos")`,
|
||||
want: "`memo`.`content` LIKE ?",
|
||||
args: []any{"%memos%"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "`memo`.`visibility` IN (?)",
|
||||
args: []any{"PUBLIC"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "`memo`.`visibility` IN (?,?)",
|
||||
args: []any{"PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR `memo`.`content` LIKE ?)",
|
||||
args: []any{"tag1", "%hello%"},
|
||||
},
|
||||
{
|
||||
filter: `1`,
|
||||
want: "",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `pinned`,
|
||||
want: "`memo`.`pinned` IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == true`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list != false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `!has_task_list`,
|
||||
want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON))",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && pinned`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`pinned` IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && content.contains("todo")`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`content` LIKE ?)",
|
||||
args: []any{"%todo%"},
|
||||
},
|
||||
{
|
||||
filter: `created_ts > now() - 60 * 60 * 24`,
|
||||
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?",
|
||||
args: []any{time.Now().Unix() - 60*60*24},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 0`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) > 0`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `"work" in tags`,
|
||||
want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
args: []any{"work"},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 2`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(2)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
db := &DB{}
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
111
store/db/mysql/memo_relation.go
Normal file
111
store/db/mysql/memo_relation.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?)"
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoRelation := store.MemoRelation{
|
||||
MemoID: create.MemoID,
|
||||
RelatedMemoID: create.RelatedMemoID,
|
||||
Type: create.Type,
|
||||
}
|
||||
|
||||
return &memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
where, args = append(where, "`related_memo_id` = ?"), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, find.Type)
|
||||
}
|
||||
if find.MemoFilter != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
if condition != "" {
|
||||
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition))
|
||||
args = append(args, append(convertCtx.Args, convertCtx.Args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `memo_id`, `related_memo_id`, `type` FROM `memo_relation` WHERE "+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
where, args := []string{"TRUE"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
where, args = append(where, "`memo_id` = ?"), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
where, args = append(where, "`related_memo_id` = ?"), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
where, args = append(where, "`type` = ?"), append(args, delete.Type)
|
||||
}
|
||||
stmt := "DELETE FROM `memo_relation` WHERE " + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
53
store/db/mysql/migration_history.go
Normal file
53
store/db/mysql/migration_history.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
|
||||
query := "SELECT `version`, UNIX_TIMESTAMP(`created_ts`) FROM `migration_history` ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
|
||||
stmt := "INSERT INTO `migration_history` (`version`) VALUES (?) ON DUPLICATE KEY UPDATE `version` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, upsert.Version, upsert.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var migrationHistory store.MigrationHistory
|
||||
stmt = "SELECT `version`, UNIX_TIMESTAMP(`created_ts`) FROM `migration_history` WHERE `version` = ?"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &migrationHistory, nil
|
||||
}
|
||||
68
store/db/mysql/mysql.go
Normal file
68
store/db/mysql/mysql.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
config *mysql.Config
|
||||
}
|
||||
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
// Open MySQL connection with parameter.
|
||||
// multiStatements=true is required for migration.
|
||||
// See more in: https://github.com/go-sql-driver/mysql#multistatements
|
||||
dsn, err := mergeDSN(profile.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driver := DB{profile: profile}
|
||||
driver.config, err = mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, errors.New("Parse DSN eroor")
|
||||
}
|
||||
|
||||
driver.db, err = sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to open db: %s", profile.DSN)
|
||||
}
|
||||
|
||||
return &driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE TABLE_NAME = 'memo' AND TABLE_TYPE = 'BASE TABLE')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func mergeDSN(baseDSN string) (string, error) {
|
||||
config, err := mysql.ParseDSN(baseDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to parse DSN: %s", baseDSN)
|
||||
}
|
||||
|
||||
config.MultiStatements = true
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
104
store/db/mysql/reaction.go
Normal file
104
store/db/mysql/reaction.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store.Reaction, error) {
|
||||
fields := []string{"`creator_id`", "`content_id`", "`reaction_type`"}
|
||||
placeholder := []string{"?", "?", "?"}
|
||||
args := []interface{}{upsert.CreatorID, upsert.ContentID, upsert.ReactionType}
|
||||
stmt := "INSERT INTO `reaction` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := int32(rawID)
|
||||
reaction, err := d.GetReaction(ctx, &store.FindReaction{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reaction == nil {
|
||||
return nil, errors.Errorf("failed to create reaction")
|
||||
}
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
if find.CreatorID != nil {
|
||||
where, args = append(where, "`creator_id` = ?"), append(args, *find.CreatorID)
|
||||
}
|
||||
if find.ContentID != nil {
|
||||
where, args = append(where, "`content_id` = ?"), append(args, *find.ContentID)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
UNIX_TIMESTAMP(created_ts) AS created_ts,
|
||||
creator_id,
|
||||
content_id,
|
||||
reaction_type
|
||||
FROM reaction
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Reaction{}
|
||||
for rows.Next() {
|
||||
reaction := &store.Reaction{}
|
||||
if err := rows.Scan(
|
||||
&reaction.ID,
|
||||
&reaction.CreatedTs,
|
||||
&reaction.CreatorID,
|
||||
&reaction.ContentID,
|
||||
&reaction.ReactionType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, reaction)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
|
||||
list, err := d.ListReactions(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reaction := list[0]
|
||||
return reaction, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID)
|
||||
return err
|
||||
}
|
||||
162
store/db/mysql/user.go
Normal file
162
store/db/mysql/user.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
fields := []string{"`username`", "`role`", "`email`", "`nickname`", "`password_hash`", "`avatar_url`"}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
|
||||
stmt := "INSERT INTO user (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id32 := int32(id)
|
||||
list, err := d.ListUsers(ctx, &store.FindUser{ID: &id32})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected user count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
set, args = append(set, "`updated_ts` = FROM_UNIXTIME(?)"), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "`row_status` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
set, args = append(set, "`username` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "`email` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "`nickname` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
set, args = append(set, "`avatar_url` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "`password_hash` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Description; v != nil {
|
||||
set, args = append(set, "`description` = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "`role` = ?"), append(args, *v)
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := "UPDATE `user` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "`username` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "`role` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "`email` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "`nickname` = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
orderBy := []string{"`created_ts` DESC", "`row_status` DESC"}
|
||||
query := "SELECT `id`, `username`, `role`, `email`, `nickname`, `password_hash`, `avatar_url`, `description`, UNIX_TIMESTAMP(`created_ts`), UNIX_TIMESTAMP(`updated_ts`), `row_status` FROM `user` WHERE " + strings.Join(where, " AND ") + " ORDER BY " + strings.Join(orderBy, ", ")
|
||||
if v := find.Limit; v != nil {
|
||||
query += fmt.Sprintf(" LIMIT %d", *v)
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
var user store.User
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.Description,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
|
||||
list, err := d.ListUsers(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Errorf("unexpected user count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM `user` WHERE `id` = ?", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
56
store/db/mysql/user_setting.go
Normal file
56
store/db/mysql/user_setting.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) {
|
||||
stmt := "INSERT INTO `user_setting` (`user_id`, `key`, `value`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key.String(), upsert.Value, upsert.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "`key` = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := "SELECT `user_id`, `key`, `value` FROM `user_setting` WHERE " + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*store.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &store.UserSetting{}
|
||||
var keyString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserID,
|
||||
&keyString,
|
||||
&userSetting.Value,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSetting_Key(storepb.UserSetting_Key_value[keyString])
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
65
store/db/mysql/workspace_setting.go
Normal file
65
store/db/mysql/workspace_setting.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *store.WorkspaceSetting) (*store.WorkspaceSetting, error) {
|
||||
stmt := "INSERT INTO `system_setting` (`name`, `value`, `description`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `value` = ?, `description` = ?"
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
upsert.Name,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
upsert.Value,
|
||||
upsert.Description,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*store.WorkspaceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Name != "" {
|
||||
where, args = append(where, "`name` = ?"), append(args, find.Name)
|
||||
}
|
||||
|
||||
query := "SELECT `name`, `value`, `description` FROM `system_setting` WHERE " + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.WorkspaceSetting{}
|
||||
for rows.Next() {
|
||||
systemSettingMessage := &store.WorkspaceSetting{}
|
||||
if err := rows.Scan(
|
||||
&systemSettingMessage.Name,
|
||||
&systemSettingMessage.Value,
|
||||
&systemSettingMessage.Description,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, systemSettingMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteWorkspaceSetting(ctx context.Context, delete *store.DeleteWorkspaceSetting) error {
|
||||
stmt := "DELETE FROM `system_setting` WHERE `name` = ?"
|
||||
_, err := d.db.ExecContext(ctx, stmt, delete.Name)
|
||||
return err
|
||||
}
|
||||
Reference in New Issue
Block a user