init commit
This commit is contained in:
448
plugin/filter/common_converter.go
Normal file
448
plugin/filter/common_converter.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// CommonSQLConverter handles the common CEL to SQL conversion logic.
|
||||
type CommonSQLConverter struct {
|
||||
dialect SQLDialect
|
||||
paramIndex int
|
||||
}
|
||||
|
||||
// NewCommonSQLConverter creates a new converter with the specified dialect.
|
||||
func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
|
||||
return &CommonSQLConverter{
|
||||
dialect: dialect,
|
||||
paramIndex: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect.
|
||||
func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error {
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
return c.handleLogicalOperator(ctx, v.CallExpr)
|
||||
case "!_":
|
||||
return c.handleNotOperator(ctx, v.CallExpr)
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
return c.handleComparisonOperator(ctx, v.CallExpr)
|
||||
case "@in":
|
||||
return c.handleInOperator(ctx, v.CallExpr)
|
||||
case "contains":
|
||||
return c.handleContainsOperator(ctx, v.CallExpr)
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
return c.handleIdentifier(ctx, v.IdentExpr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator := "AND"
|
||||
if callExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr)
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
value, err := GetExprValue(callExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator := c.getComparisonOperator(callExpr.Function)
|
||||
|
||||
switch identifier {
|
||||
case "created_ts", "updated_ts":
|
||||
return c.handleTimestampComparison(ctx, identifier, operator, value)
|
||||
case "visibility", "content":
|
||||
return c.handleStringComparison(ctx, identifier, operator, value)
|
||||
case "creator_id":
|
||||
return c.handleIntComparison(ctx, identifier, operator, value)
|
||||
case "has_task_list":
|
||||
return c.handleBooleanComparison(ctx, identifier, operator, value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error {
|
||||
if len(sizeCall.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(sizeCall.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
|
||||
value, err := GetExprValue(callExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
|
||||
operator := c.getComparisonOperator(callExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
|
||||
c.dialect.GetJSONArrayLength("$.tags"),
|
||||
operator,
|
||||
c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil {
|
||||
if identifier == "tags" {
|
||||
return c.handleElementInTags(ctx, callExpr.Args[0])
|
||||
}
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier)
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range callExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if identifier == "tag" {
|
||||
return c.handleTagInList(ctx, values)
|
||||
} else if identifier == "visibility" {
|
||||
return c.handleVisibilityInList(ctx, values)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error {
|
||||
element, err := GetConstValue(elementExpr)
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
|
||||
// Use dialect-specific JSON contains logic
|
||||
sqlExpr := c.dialect.GetJSONContains("$.tags", "element")
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For SQLite, we need a different approach since it uses LIKE
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element))
|
||||
} else {
|
||||
ctx.Args = append(ctx.Args, element)
|
||||
}
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
|
||||
for _, v := range values {
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern"))
|
||||
args = append(args, fmt.Sprintf(`%%"%s"%%`, v))
|
||||
} else {
|
||||
subconditions = append(subconditions, c.dialect.GetJSONContains("$.tags", "element"))
|
||||
args = append(args, v)
|
||||
}
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error {
|
||||
placeholders := []string{}
|
||||
for range values {
|
||||
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(callExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
arg, err := GetConstValue(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
|
||||
identifier := identExpr.GetName()
|
||||
|
||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
|
||||
if identifier == "pinned" {
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampField := c.dialect.GetTimestampComparison(field)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
fieldName := field
|
||||
if field == "visibility" {
|
||||
fieldName = "`visibility`"
|
||||
} else if field == "content" {
|
||||
fieldName = "`content`"
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
|
||||
sqlExpr := c.dialect.GetBooleanComparison("$.property.hasTaskList", valueBool)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For dialects that need parameters (PostgreSQL)
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
ctx.Args = append(ctx.Args, valueBool)
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*CommonSQLConverter) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
20
plugin/filter/converter.go
Normal file
20
plugin/filter/converter.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ConvertContext struct {
|
||||
Buffer strings.Builder
|
||||
Args []any
|
||||
// The offset of the next argument in the condition string.
|
||||
// Mainly using for PostgreSQL.
|
||||
ArgsOffset int
|
||||
}
|
||||
|
||||
func NewConvertContext() *ConvertContext {
|
||||
return &ConvertContext{
|
||||
Buffer: strings.Builder{},
|
||||
Args: []any{},
|
||||
}
|
||||
}
|
||||
212
plugin/filter/dialect.go
Normal file
212
plugin/filter/dialect.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SQLDialect defines database-specific SQL generation methods.
|
||||
type SQLDialect interface {
|
||||
// Basic field access
|
||||
GetTablePrefix() string
|
||||
GetParameterPlaceholder(index int) string
|
||||
|
||||
// JSON operations
|
||||
GetJSONExtract(path string) string
|
||||
GetJSONArrayLength(path string) string
|
||||
GetJSONContains(path, element string) string
|
||||
GetJSONLike(path, pattern string) string
|
||||
|
||||
// Boolean operations
|
||||
GetBooleanValue(value bool) interface{}
|
||||
GetBooleanComparison(path string, value bool) string
|
||||
GetBooleanCheck(path string) string
|
||||
|
||||
// Timestamp operations
|
||||
GetTimestampComparison(field string) string
|
||||
GetCurrentTimestamp() string
|
||||
}
|
||||
|
||||
// DatabaseType represents the type of database.
|
||||
type DatabaseType string
|
||||
|
||||
const (
|
||||
SQLite DatabaseType = "sqlite"
|
||||
MySQL DatabaseType = "mysql"
|
||||
PostgreSQL DatabaseType = "postgres"
|
||||
)
|
||||
|
||||
// GetDialect returns the appropriate dialect for the database type.
|
||||
func GetDialect(dbType DatabaseType) SQLDialect {
|
||||
switch dbType {
|
||||
case SQLite:
|
||||
return &SQLiteDialect{}
|
||||
case MySQL:
|
||||
return &MySQLDialect{}
|
||||
case PostgreSQL:
|
||||
return &PostgreSQLDialect{}
|
||||
default:
|
||||
return &SQLiteDialect{} // default fallback
|
||||
}
|
||||
}
|
||||
|
||||
// SQLiteDialect implements SQLDialect for SQLite.
|
||||
type SQLiteDialect struct{}
|
||||
|
||||
func (*SQLiteDialect) GetTablePrefix() string {
|
||||
return "`memo`"
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
|
||||
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONContains(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONLike(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetBooleanValue(value bool) interface{} {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
|
||||
return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field)
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetCurrentTimestamp() string {
|
||||
return "strftime('%s', 'now')"
|
||||
}
|
||||
|
||||
// MySQLDialect implements SQLDialect for MySQL.
|
||||
type MySQLDialect struct{}
|
||||
|
||||
func (*MySQLDialect) GetTablePrefix() string {
|
||||
return "`memo`"
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
|
||||
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONContains(path, _ string) string {
|
||||
return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONLike(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetBooleanValue(value bool) interface{} {
|
||||
return value
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
|
||||
boolStr := "false"
|
||||
if value {
|
||||
boolStr = "true"
|
||||
}
|
||||
return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr)
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field)
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetCurrentTimestamp() string {
|
||||
return "UNIX_TIMESTAMP()"
|
||||
}
|
||||
|
||||
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
|
||||
type PostgreSQLDialect struct{}
|
||||
|
||||
func (*PostgreSQLDialect) GetTablePrefix() string {
|
||||
return "memo"
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
||||
return fmt.Sprintf("$%d", index)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||
// Convert $.property.hasTaskList to payload->'property'->>'hasTaskList'
|
||||
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
result += fmt.Sprintf("->>'%s'", part)
|
||||
} else {
|
||||
result += fmt.Sprintf("->'%s'", part)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
||||
return value
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string {
|
||||
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("EXTRACT(EPOCH FROM %s.%s)", d.GetTablePrefix(), field)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
|
||||
return "EXTRACT(EPOCH FROM NOW())"
|
||||
}
|
||||
127
plugin/filter/expr.go
Normal file
127
plugin/filter/expr.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// GetConstValue returns the constant value of the expression.
|
||||
func GetConstValue(expr *exprv1.Expr) (any, error) {
|
||||
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid constant expression")
|
||||
}
|
||||
|
||||
switch v.ConstExpr.ConstantKind.(type) {
|
||||
case *exprv1.Constant_StringValue:
|
||||
return v.ConstExpr.GetStringValue(), nil
|
||||
case *exprv1.Constant_Int64Value:
|
||||
return v.ConstExpr.GetInt64Value(), nil
|
||||
case *exprv1.Constant_Uint64Value:
|
||||
return v.ConstExpr.GetUint64Value(), nil
|
||||
case *exprv1.Constant_DoubleValue:
|
||||
return v.ConstExpr.GetDoubleValue(), nil
|
||||
case *exprv1.Constant_BoolValue:
|
||||
return v.ConstExpr.GetBoolValue(), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected constant type")
|
||||
}
|
||||
}
|
||||
|
||||
// GetIdentExprName returns the name of the identifier expression.
|
||||
func GetIdentExprName(expr *exprv1.Expr) (string, error) {
|
||||
_, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr)
|
||||
if !ok {
|
||||
return "", errors.New("invalid identifier expression")
|
||||
}
|
||||
return expr.GetIdentExpr().GetName(), nil
|
||||
}
|
||||
|
||||
// GetFunctionValue evaluates CEL function calls and returns their value.
|
||||
// This is specifically for time functions like now().
|
||||
func GetFunctionValue(expr *exprv1.Expr) (any, error) {
|
||||
callExpr, ok := expr.ExprKind.(*exprv1.Expr_CallExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid function call expression")
|
||||
}
|
||||
|
||||
switch callExpr.CallExpr.Function {
|
||||
case "now":
|
||||
if len(callExpr.CallExpr.Args) != 0 {
|
||||
return nil, errors.New("now() function takes no arguments")
|
||||
}
|
||||
return time.Now().Unix(), nil
|
||||
case "_-_":
|
||||
// Handle subtraction for expressions like "now() - 60 * 60 * 24"
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("subtraction requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("subtraction operands must be integers")
|
||||
}
|
||||
return leftInt - rightInt, nil
|
||||
case "_*_":
|
||||
// Handle multiplication for expressions like "60 * 60 * 24"
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("multiplication requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("multiplication operands must be integers")
|
||||
}
|
||||
return leftInt * rightInt, nil
|
||||
case "_+_":
|
||||
// Handle addition
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("addition requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("addition operands must be integers")
|
||||
}
|
||||
return leftInt + rightInt, nil
|
||||
default:
|
||||
return nil, errors.New("unsupported function: " + callExpr.CallExpr.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExprValue attempts to get a value from an expression, trying constants first, then functions.
|
||||
func GetExprValue(expr *exprv1.Expr) (any, error) {
|
||||
// Try to get constant value first
|
||||
if constValue, err := GetConstValue(expr); err == nil {
|
||||
return constValue, nil
|
||||
}
|
||||
|
||||
// If not a constant, try to evaluate as a function
|
||||
return GetFunctionValue(expr)
|
||||
}
|
||||
48
plugin/filter/filter.go
Normal file
48
plugin/filter/filter.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// MemoFilterCELAttributes are the CEL attributes for memo.
|
||||
var MemoFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("creator_id", cel.IntType),
|
||||
cel.Variable("created_ts", cel.IntType),
|
||||
cel.Variable("updated_ts", cel.IntType),
|
||||
cel.Variable("pinned", cel.BoolType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
cel.Variable("tags", cel.ListType(cel.StringType)),
|
||||
cel.Variable("visibility", cel.StringType),
|
||||
cel.Variable("has_task_list", cel.BoolType),
|
||||
// Current timestamp function.
|
||||
cel.Function("now",
|
||||
cel.Overload("now",
|
||||
[]*cel.Type{},
|
||||
cel.IntType,
|
||||
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
|
||||
return types.Int(time.Now().Unix())
|
||||
}),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
// Parse parses the filter string and returns the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) {
|
||||
e, err := cel.NewEnv(opts...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||
}
|
||||
ast, issues := e.Compile(filter)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("failed to compile filter: %v", issues)
|
||||
}
|
||||
return cel.AstToParsedExpr(ast)
|
||||
}
|
||||
146
plugin/filter/templates.go
Normal file
146
plugin/filter/templates.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SQLTemplate holds database-specific SQL fragments.
|
||||
type SQLTemplate struct {
|
||||
SQLite string
|
||||
MySQL string
|
||||
PostgreSQL string
|
||||
}
|
||||
|
||||
// TemplateDBType represents the database type for templates.
|
||||
type TemplateDBType string
|
||||
|
||||
const (
|
||||
SQLiteTemplate TemplateDBType = "sqlite"
|
||||
MySQLTemplate TemplateDBType = "mysql"
|
||||
PostgreSQLTemplate TemplateDBType = "postgres"
|
||||
)
|
||||
|
||||
// SQLTemplates contains common SQL patterns for different databases.
|
||||
var SQLTemplates = map[string]SQLTemplate{
|
||||
"json_extract": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||
PostgreSQL: "memo.payload%s",
|
||||
},
|
||||
"json_array_length": {
|
||||
SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||
MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||
PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))",
|
||||
},
|
||||
"json_contains_element": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||
},
|
||||
"json_contains_tag": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||
},
|
||||
"boolean_true": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true",
|
||||
},
|
||||
"boolean_false": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false",
|
||||
},
|
||||
"boolean_not_true": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true",
|
||||
},
|
||||
"boolean_not_false": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false",
|
||||
},
|
||||
"boolean_compare": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?",
|
||||
},
|
||||
"boolean_check": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
|
||||
},
|
||||
"table_prefix": {
|
||||
SQLite: "`memo`",
|
||||
MySQL: "`memo`",
|
||||
PostgreSQL: "memo",
|
||||
},
|
||||
"timestamp_field": {
|
||||
SQLite: "`memo`.`%s`",
|
||||
MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)",
|
||||
PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)",
|
||||
},
|
||||
"content_like": {
|
||||
SQLite: "`memo`.`content` LIKE ?",
|
||||
MySQL: "`memo`.`content` LIKE ?",
|
||||
PostgreSQL: "memo.content ILIKE ?",
|
||||
},
|
||||
"visibility_in": {
|
||||
SQLite: "`memo`.`visibility` IN (%s)",
|
||||
MySQL: "`memo`.`visibility` IN (%s)",
|
||||
PostgreSQL: "memo.visibility IN (%s)",
|
||||
},
|
||||
}
|
||||
|
||||
// GetSQL returns the appropriate SQL for the given template and database type.
|
||||
func GetSQL(templateName string, dbType TemplateDBType) string {
|
||||
template, exists := SQLTemplates[templateName]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case SQLiteTemplate:
|
||||
return template.SQLite
|
||||
case MySQLTemplate:
|
||||
return template.MySQL
|
||||
case PostgreSQLTemplate:
|
||||
return template.PostgreSQL
|
||||
default:
|
||||
return template.SQLite
|
||||
}
|
||||
}
|
||||
|
||||
// GetParameterPlaceholder returns the appropriate parameter placeholder for the database.
|
||||
func GetParameterPlaceholder(dbType TemplateDBType, index int) string {
|
||||
switch dbType {
|
||||
case PostgreSQLTemplate:
|
||||
return fmt.Sprintf("$%d", index)
|
||||
default:
|
||||
return "?"
|
||||
}
|
||||
}
|
||||
|
||||
// GetParameterValue returns the appropriate parameter value for the database.
|
||||
func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} {
|
||||
switch templateName {
|
||||
case "json_contains_element", "json_contains_tag":
|
||||
if dbType == SQLiteTemplate {
|
||||
return fmt.Sprintf(`%%"%s"%%`, value)
|
||||
}
|
||||
return value
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// FormatPlaceholders formats a list of placeholders for the given database type.
|
||||
func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string {
|
||||
placeholders := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i)
|
||||
}
|
||||
return placeholders
|
||||
}
|
||||
Reference in New Issue
Block a user