Support Expression for OrderBy clause

This commit is contained in:
Jinzhu 2020-11-03 10:30:05 +08:00
parent 57b033e2dd
commit c915471169
4 changed files with 35 additions and 11 deletions

View File

@ -21,6 +21,7 @@ type NegationExpressionBuilder interface {
type Expr struct { type Expr struct {
SQL string SQL string
Vars []interface{} Vars []interface{}
WithoutParentheses bool
} }
// Build build raw expression // Build build raw expression
@ -32,7 +33,7 @@ func (expr Expr) Build(builder Builder) {
for _, v := range []byte(expr.SQL) { for _, v := range []byte(expr.SQL) {
if v == '?' && len(expr.Vars) > idx { if v == '?' && len(expr.Vars) > idx {
if afterParenthesis { if afterParenthesis || expr.WithoutParentheses {
if _, ok := expr.Vars[idx].(driver.Valuer); ok { if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx]) builder.AddVar(builder, expr.Vars[idx])
} else { } else {

View File

@ -8,6 +8,7 @@ type OrderByColumn struct {
type OrderBy struct { type OrderBy struct {
Columns []OrderByColumn Columns []OrderByColumn
Expression Expression
} }
// Name where clause name // Name where clause name
@ -17,6 +18,9 @@ func (orderBy OrderBy) Name() string {
// Build build where clause // Build build where clause
func (orderBy OrderBy) Build(builder Builder) { func (orderBy OrderBy) Build(builder Builder) {
if orderBy.Expression != nil {
orderBy.Expression.Build(builder)
} else {
for idx, column := range orderBy.Columns { for idx, column := range orderBy.Columns {
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')
@ -28,6 +32,7 @@ func (orderBy OrderBy) Build(builder Builder) {
} }
} }
} }
}
// MergeClause merge order by clauses // MergeClause merge order by clauses
func (orderBy OrderBy) MergeClause(clause *Clause) { func (orderBy OrderBy) MergeClause(clause *Clause) {

View File

@ -39,6 +39,14 @@ func TestOrderBy(t *testing.T) {
}, },
"SELECT * FROM `users` ORDER BY `name`", nil, "SELECT * FROM `users` ORDER BY `name`", nil,
}, },
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
},
},
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -659,6 +660,15 @@ func TestOrder(t *testing.T) {
if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) { if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String())
} }
stmt := dryDB.Clauses(clause.OrderBy{
Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
}).Find(&User{}).Statement
explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) {
t.Fatalf("Build Order condition, but got %v", explainedSQL)
}
} }
func TestLimit(t *testing.T) { func TestLimit(t *testing.T) {