From 6937d713c31e23eef0c0377e73d494a631f4e9f3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 6 Jun 2020 22:52:08 +0800 Subject: [PATCH] Refactor clauses --- clause/clause.go | 44 +++++++++++++++++++++++------------------- clause/locking_test.go | 2 +- clause/where.go | 18 ++++++++--------- clause/where_test.go | 2 +- finisher_api.go | 7 ++++--- statement.go | 16 +++++++-------- 6 files changed, 46 insertions(+), 43 deletions(-) diff --git a/clause/clause.go b/clause/clause.go index 9a5d1273..b3e96332 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -24,42 +24,46 @@ type Builder interface { // Clause type Clause struct { - Name string // WHERE - Priority float64 - BeforeExpressions []Expression - AfterNameExpressions []Expression - AfterExpressions []Expression - Expression Expression - Builder ClauseBuilder + Name string // WHERE + BeforeExpression Expression + AfterNameExpression Expression + AfterExpression Expression + Expression Expression + Builder ClauseBuilder } // Build build clause func (c Clause) Build(builder Builder) { if c.Builder != nil { c.Builder(c, builder) - } else { - builders := c.BeforeExpressions + } else if c.Expression != nil { + if c.BeforeExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') + } + if c.Name != "" { - builders = append(builders, Expr{SQL: c.Name}) + builder.WriteString(c.Name) + builder.WriteByte(' ') } - builders = append(builders, c.AfterNameExpressions...) - if c.Expression != nil { - builders = append(builders, c.Expression) + if c.AfterNameExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') } - for idx, expr := range append(builders, c.AfterExpressions...) { - if idx != 0 { - builder.WriteByte(' ') - } - expr.Build(builder) + c.Expression.Build(builder) + + if c.AfterExpression != nil { + builder.WriteByte(' ') + c.AfterExpression.Build(builder) } } } const ( - PrimaryKey string = "@@@priamry_key@@@" - CurrentTable string = "@@@table@@@" + PrimaryKey string = "@@@py@@@" // primary key + CurrentTable string = "@@@ct@@@" // current table ) var ( diff --git a/clause/locking_test.go b/clause/locking_test.go index 5ca30ef0..0e607312 100644 --- a/clause/locking_test.go +++ b/clause/locking_test.go @@ -7,7 +7,7 @@ import ( "gorm.io/gorm/clause" ) -func TestFor(t *testing.T) { +func TestLocking(t *testing.T) { results := []struct { Clauses []clause.Interface Result string diff --git a/clause/where.go b/clause/where.go index 08c78b22..015addf8 100644 --- a/clause/where.go +++ b/clause/where.go @@ -14,7 +14,7 @@ func (where Where) Name() string { func (where Where) Build(builder Builder) { // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { - if v, ok := expr.(OrConditions); (!ok && expr != nil) || len(v.Exprs) > 1 { + if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { if idx != 0 { where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] } @@ -23,17 +23,15 @@ func (where Where) Build(builder Builder) { } for idx, expr := range where.Exprs { - if expr != nil { - if idx > 0 { - if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.WriteString(" OR ") - } else { - builder.WriteString(" AND ") - } + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.WriteString(" OR ") + } else { + builder.WriteString(" AND ") } - - expr.Build(builder) } + + expr.Build(builder) } return diff --git a/clause/where_test.go b/clause/where_test.go index 894e11f4..95bba820 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -27,7 +27,7 @@ func TestWhere(t *testing.T) { }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ - Exprs: []clause.Expression{clause.Or(), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, }, diff --git a/finisher_api.go b/finisher_api.go index e94fd095..434f0e22 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -32,13 +32,14 @@ func (db *DB) Save(value interface{}) (tx *DB) { for idx, pf := range tx.Statement.Schema.PrimaryFields { if pv, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) - where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} return + } else { + where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} } } - } - tx.Statement.AddClause(where) + tx.Statement.AddClause(where) + } } if len(tx.Statement.Selects) == 0 { diff --git a/statement.go b/statement.go index 2c814547..ec9e021f 100644 --- a/statement.go +++ b/statement.go @@ -201,19 +201,19 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { func (stmt *Statement) AddClause(v clause.Interface) { if optimizer, ok := v.(StatementModifier); ok { optimizer.ModifyStatement(stmt) + } else { + c, ok := stmt.Clauses[v.Name()] + if !ok { + c.Name = v.Name() + } + v.MergeClause(&c) + stmt.Clauses[v.Name()] = c } - - c, ok := stmt.Clauses[v.Name()] - if !ok { - c.Name = v.Name() - } - v.MergeClause(&c) - stmt.Clauses[v.Name()] = c } // AddClauseIfNotExists add clause if not exists func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { - if c, ok := stmt.Clauses[v.Name()]; !ok && c.Expression == nil { + if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { stmt.AddClause(v) } }