Commit 778a5eb1 authored by Johnny's avatar Johnny

refactor: memo filter

parent 52a5ca2e
This diff is collapsed.
package filter
import (
"fmt"
"strings"
)
// SQLDialect defines database-specific SQL generation methods
type SQLDialect interface {
// Basic field access
GetTablePrefix() string
GetParameterPlaceholder(index int) string
// JSON operations
GetJSONExtract(path string) string
GetJSONArrayLength(path string) string
GetJSONContains(path, element string) string
GetJSONLike(path, pattern string) string
// Boolean operations
GetBooleanValue(value bool) interface{}
GetBooleanComparison(path string, value bool) string
GetBooleanCheck(path string) string
// Timestamp operations
GetTimestampComparison(field string) string
GetCurrentTimestamp() string
}
// DatabaseType represents the type of database
type DatabaseType string
const (
SQLite DatabaseType = "sqlite"
MySQL DatabaseType = "mysql"
PostgreSQL DatabaseType = "postgres"
)
// GetDialect returns the appropriate dialect for the database type
func GetDialect(dbType DatabaseType) SQLDialect {
switch dbType {
case SQLite:
return &SQLiteDialect{}
case MySQL:
return &MySQLDialect{}
case PostgreSQL:
return &PostgreSQLDialect{}
default:
return &SQLiteDialect{} // default fallback
}
}
// SQLiteDialect implements SQLDialect for SQLite
type SQLiteDialect struct{}
func (d *SQLiteDialect) GetTablePrefix() string {
return "`memo`"
}
func (d *SQLiteDialect) GetParameterPlaceholder(index int) string {
return "?"
}
func (d *SQLiteDialect) GetJSONExtract(path string) string {
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
}
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetJSONContains(path, element string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetJSONLike(path, pattern string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetBooleanValue(value bool) interface{} {
if value {
return 1
}
return 0
}
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value))
}
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field)
}
func (d *SQLiteDialect) GetCurrentTimestamp() string {
return "strftime('%s', 'now')"
}
// MySQLDialect implements SQLDialect for MySQL
type MySQLDialect struct{}
func (d *MySQLDialect) GetTablePrefix() string {
return "`memo`"
}
func (d *MySQLDialect) GetParameterPlaceholder(index int) string {
return "?"
}
func (d *MySQLDialect) GetJSONExtract(path string) string {
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
}
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetJSONContains(path, element string) string {
return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetJSONLike(path, pattern string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetBooleanValue(value bool) interface{} {
return value
}
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
boolStr := "false"
if value {
boolStr = "true"
}
return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr)
}
func (d *MySQLDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field)
}
func (d *MySQLDialect) GetCurrentTimestamp() string {
return "UNIX_TIMESTAMP()"
}
// PostgreSQLDialect implements SQLDialect for PostgreSQL
type PostgreSQLDialect struct{}
func (d *PostgreSQLDialect) GetTablePrefix() string {
return "memo"
}
func (d *PostgreSQLDialect) GetParameterPlaceholder(index int) string {
return fmt.Sprintf("$%d", index)
}
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
// Convert $.property.hasTaskList to payload->'property'->>'hasTaskList'
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
for i, part := range parts {
if i == len(parts)-1 {
result += fmt.Sprintf("->>'%s'", part)
} else {
result += fmt.Sprintf("->'%s'", part)
}
}
return result
}
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath)
}
func (d *PostgreSQLDialect) GetJSONContains(path, element string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
}
func (d *PostgreSQLDialect) GetJSONLike(path, pattern string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
}
func (d *PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
return value
}
func (d *PostgreSQLDialect) GetBooleanComparison(path string, value bool) string {
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
}
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
}
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("EXTRACT(EPOCH FROM %s.%s)", d.GetTablePrefix(), field)
}
func (d *PostgreSQLDialect) GetCurrentTimestamp() string {
return "EXTRACT(EPOCH FROM NOW())"
}
......@@ -18,6 +18,7 @@ var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("updated_ts", cel.IntType),
cel.Variable("pinned", cel.BoolType),
cel.Variable("tag", cel.StringType),
cel.Variable("tags", cel.ListType(cel.StringType)),
cel.Variable("visibility", cel.StringType),
cel.Variable("has_task_list", cel.BoolType),
// Current timestamp function.
......
package filter
import (
"fmt"
)
// SQLTemplate holds database-specific SQL fragments
type SQLTemplate struct {
SQLite string
MySQL string
PostgreSQL string
}
// TemplateDBType represents the database type for templates
type TemplateDBType string
const (
SQLiteTemplate TemplateDBType = "sqlite"
MySQLTemplate TemplateDBType = "mysql"
PostgreSQLTemplate TemplateDBType = "postgres"
)
// SQLTemplates contains common SQL patterns for different databases
var SQLTemplates = map[string]SQLTemplate{
"json_extract": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')",
PostgreSQL: "memo.payload%s",
},
"json_array_length": {
SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))",
},
"json_contains_element": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
},
"json_contains_tag": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
},
"boolean_true": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true",
},
"boolean_false": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false",
},
"boolean_not_true": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true",
},
"boolean_not_false": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false",
},
"boolean_compare": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?",
},
"boolean_check": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
},
"table_prefix": {
SQLite: "`memo`",
MySQL: "`memo`",
PostgreSQL: "memo",
},
"timestamp_field": {
SQLite: "`memo`.`%s`",
MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)",
PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)",
},
"content_like": {
SQLite: "`memo`.`content` LIKE ?",
MySQL: "`memo`.`content` LIKE ?",
PostgreSQL: "memo.content ILIKE ?",
},
"visibility_in": {
SQLite: "`memo`.`visibility` IN (%s)",
MySQL: "`memo`.`visibility` IN (%s)",
PostgreSQL: "memo.visibility IN (%s)",
},
}
// GetSQL returns the appropriate SQL for the given template and database type
func GetSQL(templateName string, dbType TemplateDBType) string {
template, exists := SQLTemplates[templateName]
if !exists {
return ""
}
switch dbType {
case SQLiteTemplate:
return template.SQLite
case MySQLTemplate:
return template.MySQL
case PostgreSQLTemplate:
return template.PostgreSQL
default:
return template.SQLite
}
}
// GetParameterPlaceholder returns the appropriate parameter placeholder for the database
func GetParameterPlaceholder(dbType TemplateDBType, index int) string {
switch dbType {
case PostgreSQLTemplate:
return fmt.Sprintf("$%d", index)
default:
return "?"
}
}
// GetParameterValue returns the appropriate parameter value for the database
func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} {
switch templateName {
case "json_contains_element", "json_contains_tag":
if dbType == SQLiteTemplate {
return fmt.Sprintf(`%%"%s"%%`, value)
}
return value
default:
return value
}
}
// FormatPlaceholders formats a list of placeholders for the given database type
func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string {
placeholders := make([]string, count)
for i := 0; i < count; i++ {
placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i)
}
return placeholders
}
This diff is collapsed.
......@@ -95,6 +95,26 @@ func TestConvertExprToSQL(t *testing.T) {
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?",
args: []any{time.Now().Unix() - 60*60*24},
},
{
filter: `size(tags) == 0`,
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(0)},
},
{
filter: `size(tags) > 0`,
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
args: []any{int64(0)},
},
{
filter: `"work" in tags`,
want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
args: []any{"work"},
},
{
filter: `size(tags) == 2`,
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(2)},
},
}
for _, tt := range tests {
......
This diff is collapsed.
......@@ -95,6 +95,26 @@ func TestRestoreExprToSQL(t *testing.T) {
want: "EXTRACT(EPOCH FROM memo.created_ts) > $1",
args: []any{time.Now().Unix() - 60*60*24},
},
{
filter: `size(tags) == 0`,
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
args: []any{int64(0)},
},
{
filter: `size(tags) > 0`,
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) > $1",
args: []any{int64(0)},
},
{
filter: `"work" in tags`,
want: "memo.payload->'tags' @> jsonb_build_array($1)",
args: []any{"work"},
},
{
filter: `size(tags) == 2`,
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
args: []any{int64(2)},
},
}
for _, tt := range tests {
......
This diff is collapsed.
......@@ -110,14 +110,40 @@ func TestConvertExprToSQL(t *testing.T) {
want: "`memo`.`created_ts` > ?",
args: []any{time.Now().Unix() - 60*60*24},
},
{
filter: `size(tags) == 0`,
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(0)},
},
{
filter: `size(tags) > 0`,
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
args: []any{int64(0)},
},
{
filter: `"work" in tags`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
args: []any{`%"work"%`},
},
{
filter: `size(tags) == 2`,
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(2)},
},
}
for _, tt := range tests {
db := &DB{}
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
if err != nil {
t.Logf("Failed to parse filter: %s, error: %v", tt.filter, err)
}
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
if err != nil {
t.Logf("Failed to convert filter: %s, error: %v", tt.filter, err)
}
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment