From c915471169b7e6696edfa9bfc2c8e7b816e70ad6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 3 Nov 2020 10:30:05 +0800 Subject: [PATCH] Support Expression for OrderBy clause --- clause/expression.go | 7 ++++--- clause/order_by.go | 21 +++++++++++++-------- clause/order_by_test.go | 8 ++++++++ tests/query_test.go | 10 ++++++++++ 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 5822a314..725a4909 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -19,8 +19,9 @@ type NegationExpressionBuilder interface { // Expr raw expression type Expr struct { - SQL string - Vars []interface{} + SQL string + Vars []interface{} + WithoutParentheses bool } // Build build raw expression @@ -32,7 +33,7 @@ func (expr Expr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '?' && len(expr.Vars) > idx { - if afterParenthesis { + if afterParenthesis || expr.WithoutParentheses { if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) } else { diff --git a/clause/order_by.go b/clause/order_by.go index a8a9539a..41218025 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -7,7 +7,8 @@ type OrderByColumn struct { } type OrderBy struct { - Columns []OrderByColumn + Columns []OrderByColumn + Expression Expression } // Name where clause name @@ -17,14 +18,18 @@ func (orderBy OrderBy) Name() string { // Build build where clause func (orderBy OrderBy) Build(builder Builder) { - for idx, column := range orderBy.Columns { - if idx > 0 { - builder.WriteByte(',') - } + if orderBy.Expression != nil { + orderBy.Expression.Build(builder) + } else { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } - builder.WriteQuoted(column.Column) - if column.Desc { - builder.WriteString(" DESC") + builder.WriteQuoted(column.Column) + if column.Desc { + builder.WriteString(" DESC") + } } } } diff --git a/clause/order_by_test.go b/clause/order_by_test.go index 2ea2d192..8fd1e2a8 100644 --- a/clause/order_by_test.go +++ b/clause/order_by_test.go @@ -39,6 +39,14 @@ func TestOrderBy(t *testing.T) { }, "SELECT * FROM `users` ORDER BY `name`", nil, }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }, + }, + "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3}, + }, } for idx, result := range results { diff --git a/tests/query_test.go b/tests/query_test.go index bb9aa26d..dc2907e6 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -12,6 +12,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -659,6 +660,15 @@ func TestOrder(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } } func TestLimit(t *testing.T) {