diff --git a/chainable_api.go b/chainable_api.go index cac7495d..432026cf 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -31,8 +31,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { } if len(whereConds) > 0 { - tx.Statement.AddClause(clause.Where{ - AndConditions: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), + tx.Statement.AddClause(&clause.Where{ + tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), }) } return @@ -59,8 +59,8 @@ func (db *DB) Omit(columns ...string) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{ - AndConditions: tx.Statement.BuildCondtion(query, args...), + tx.Statement.AddClause(&clause.Where{ + tx.Statement.BuildCondtion(query, args...), }) return } @@ -68,10 +68,8 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { // Not add NOT condition func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{ - AndConditions: []clause.Expression{ - clause.NotConditions(tx.Statement.BuildCondtion(query, args...)), - }, + tx.Statement.AddClause(&clause.Where{ + []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}, }) return } @@ -79,10 +77,8 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{ - OrConditions: []clause.OrConditions{ - tx.Statement.BuildCondtion(query, args...), - }, + tx.Statement.AddClause(&clause.Where{ + []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)}, }) return } @@ -113,13 +109,13 @@ func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() switch v := value.(type) { - case clause.OrderBy: - db.Statement.AddClause(clause.OrderByClause{ - Columns: []clause.OrderBy{v}, + case clause.OrderByColumn: + db.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{v}, }) default: - db.Statement.AddClause(clause.OrderByClause{ - Columns: []clause.OrderBy{{ + db.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, }}, }) diff --git a/clause/clause.go b/clause/clause.go index 6d4698e9..df8e3a57 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -1,5 +1,26 @@ package clause +// Interface clause interface +type Interface interface { + Name() string + Build(Builder) + MergeClause(*Clause) +} + +// ClauseBuilder clause builder, allows to custmize how to build clause +type ClauseBuilder interface { + Build(Clause, Builder) +} + +// Builder builder interface +type Builder interface { + WriteByte(byte) error + Write(sql ...string) error + WriteQuoted(field interface{}) error + AddVar(vars ...interface{}) string + Quote(field interface{}) string +} + // Clause type Clause struct { Name string // WHERE @@ -18,7 +39,7 @@ func (c Clause) Build(builder Builder) { } else { builders := c.BeforeExpressions if c.Name != "" { - builders = append(builders, Expr{c.Name}) + builders = append(builders, Expr{SQL: c.Name}) } builders = append(builders, c.AfterNameExpressions...) @@ -35,28 +56,27 @@ func (c Clause) Build(builder Builder) { } } -// Interface clause interface -type Interface interface { - Name() string - Build(Builder) - MergeExpression(Expression) +const ( + PrimaryKey string = "@@@priamry_key@@@" + CurrentTable string = "@@@table@@@" +) + +var ( + currentTable = Table{Name: CurrentTable} + PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} +) + +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool } -// OverrideNameInterface override name interface -type OverrideNameInterface interface { - OverrideName() string -} - -// ClauseBuilder clause builder, allows to custmize how to build clause -type ClauseBuilder interface { - Build(Clause, Builder) -} - -// Builder builder interface -type Builder interface { - WriteByte(byte) error - Write(sql ...string) error - WriteQuoted(field interface{}) error - AddVar(vars ...interface{}) string - Quote(field interface{}) string +// Table quote with name +type Table struct { + Name string + Alias string + Raw bool } diff --git a/clause/clause_test.go b/clause/clause_test.go index 37f07686..30ea9343 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -1,8 +1,8 @@ package clause_test import ( - "fmt" "reflect" + "strings" "sync" "testing" @@ -12,45 +12,32 @@ import ( "github.com/jinzhu/gorm/tests" ) -func TestClauses(t *testing.T) { +var db, _ = gorm.Open(tests.DummyDialector{}, nil) + +func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) { var ( - db, _ = gorm.Open(tests.DummyDialector{}, nil) - results = []struct { - Clauses []clause.Interface - Result string - Vars []interface{} - }{ - { - []clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}}, - "SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"}, - }, - } + buildNames []string + buildNamesMap = map[string]bool{} + user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} ) - for idx, result := range results { - t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { - var ( - user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) - stmt = gorm.Statement{ - DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}, - } - buildNames []string - ) + for _, c := range clauses { + if _, ok := buildNamesMap[c.Name()]; !ok { + buildNames = append(buildNames, c.Name()) + buildNamesMap[c.Name()] = true + } - for _, c := range result.Clauses { - buildNames = append(buildNames, c.Name()) - stmt.AddClause(c) - } + stmt.AddClause(c) + } - stmt.Build(buildNames...) + stmt.Build(buildNames...) - if stmt.SQL.String() != result.Result { - t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String()) - } + if strings.TrimSpace(stmt.SQL.String()) != result { + t.Errorf("SQL expects %v got %v", result, stmt.SQL.String()) + } - if reflect.DeepEqual(stmt.Vars, result.Vars) { - t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars) - } - }) + if !reflect.DeepEqual(stmt.Vars, vars) { + t.Errorf("Vars expects %+v got %v", stmt.Vars, vars) } } diff --git a/clause/delete.go b/clause/delete.go new file mode 100644 index 00000000..2a622b45 --- /dev/null +++ b/clause/delete.go @@ -0,0 +1,23 @@ +package clause + +type Delete struct { + Modifier string +} + +func (d Delete) Name() string { + return "DELETE" +} + +func (d Delete) Build(builder Builder) { + builder.Write("DELETE") + + if d.Modifier != "" { + builder.WriteByte(' ') + builder.Write(d.Modifier) + } +} + +func (d Delete) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = d +} diff --git a/clause/delete_test.go b/clause/delete_test.go new file mode 100644 index 00000000..2faf8364 --- /dev/null +++ b/clause/delete_test.go @@ -0,0 +1,31 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestDelete(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Delete{}, clause.From{}}, + "DELETE FROM `users`", nil, + }, + { + []clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}}, + "DELETE LOW_PRIORITY FROM `users`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/expression.go b/clause/expression.go index 3ddc146d..048b0980 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,14 +1,6 @@ package clause -const ( - PrimaryKey string = "@@@priamry_key@@@" - CurrentTable string = "@@@table@@@" -) - -var PrimaryColumn = Column{ - Table: CurrentTable, - Name: PrimaryKey, -} +import "strings" // Expression expression interface type Expression interface { @@ -20,27 +12,155 @@ type NegationExpressionBuilder interface { NegationBuild(builder Builder) } -// Column quote with name -type Column struct { - Table string - Name string - Alias string - Raw bool -} - -// Table quote with name -type Table struct { - Table string - Alias string - Raw bool -} - // Expr raw expression type Expr struct { - Value string + SQL string + Vars []interface{} } // Build build raw expression func (expr Expr) Build(builder Builder) { - builder.Write(expr.Value) + sql := expr.SQL + for _, v := range expr.Vars { + sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) + } + builder.Write(sql) +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder Builder) { + builder.WriteQuoted(in.Column) + + switch len(in.Values) { + case 0: + builder.Write(" IN (NULL)") + case 1: + builder.Write(" = ", builder.AddVar(in.Values...)) + default: + builder.Write(" IN (", builder.AddVar(in.Values...), ")") + } +} + +func (in IN) NegationBuild(builder Builder) { + switch len(in.Values) { + case 0: + case 1: + builder.Write(" <> ", builder.AddVar(in.Values...)) + default: + builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder Builder) { + builder.WriteQuoted(eq.Column) + + if eq.Value == nil { + builder.Write(" IS NULL") + } else { + builder.Write(" = ", builder.AddVar(eq.Value)) + } +} + +func (eq Eq) NegationBuild(builder Builder) { + Neq{eq.Column, eq.Value}.Build(builder) +} + +// Neq not equal to for where +type Neq Eq + +func (neq Neq) Build(builder Builder) { + builder.WriteQuoted(neq.Column) + + if neq.Value == nil { + builder.Write(" IS NOT NULL") + } else { + builder.Write(" <> ", builder.AddVar(neq.Value)) + } +} + +func (neq Neq) NegationBuild(builder Builder) { + Eq{neq.Column, neq.Value}.Build(builder) +} + +// Gt greater than for where +type Gt Eq + +func (gt Gt) Build(builder Builder) { + builder.WriteQuoted(gt.Column) + builder.Write(" > ", builder.AddVar(gt.Value)) +} + +func (gt Gt) NegationBuild(builder Builder) { + Lte{gt.Column, gt.Value}.Build(builder) +} + +// Gte greater than or equal to for where +type Gte Eq + +func (gte Gte) Build(builder Builder) { + builder.WriteQuoted(gte.Column) + builder.Write(" >= ", builder.AddVar(gte.Value)) +} + +func (gte Gte) NegationBuild(builder Builder) { + Lt{gte.Column, gte.Value}.Build(builder) +} + +// Lt less than for where +type Lt Eq + +func (lt Lt) Build(builder Builder) { + builder.WriteQuoted(lt.Column) + builder.Write(" < ", builder.AddVar(lt.Value)) +} + +func (lt Lt) NegationBuild(builder Builder) { + Gte{lt.Column, lt.Value}.Build(builder) +} + +// Lte less than or equal to for where +type Lte Eq + +func (lte Lte) Build(builder Builder) { + builder.WriteQuoted(lte.Column) + builder.Write(" <= ", builder.AddVar(lte.Value)) +} + +func (lte Lte) NegationBuild(builder Builder) { + Gt{lte.Column, lte.Value}.Build(builder) +} + +// Like whether string matches regular expression +type Like Eq + +func (like Like) Build(builder Builder) { + builder.WriteQuoted(like.Column) + builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (like Like) NegationBuild(builder Builder) { + builder.WriteQuoted(like.Column) + builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) +} + +// Map +type Map map[interface{}]interface{} + +func (m Map) Build(builder Builder) { + // TODO +} + +func (m Map) NegationBuild(builder Builder) { + // TODO } diff --git a/clause/from.go b/clause/from.go index b7665bc3..f01065b5 100644 --- a/clause/from.go +++ b/clause/from.go @@ -3,15 +3,31 @@ package clause // From from clause type From struct { Tables []Table + Joins []Join +} + +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin = "INNER" + LeftJoin = "LEFT" + RightJoin = "RIGHT" +) + +// Join join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string } // Name from clause name -func (From) Name() string { +func (from From) Name() string { return "FROM" } -var currentTable = Table{Table: CurrentTable} - // Build build from clause func (from From) Build(builder Builder) { if len(from.Tables) > 0 { @@ -25,11 +41,42 @@ func (from From) Build(builder Builder) { } else { builder.WriteQuoted(currentTable) } -} -// MergeExpression merge order by clauses -func (from From) MergeExpression(expr Expression) { - if v, ok := expr.(From); ok { - from.Tables = append(v.Tables, from.Tables...) + for _, join := range from.Joins { + builder.WriteByte(' ') + join.Build(builder) } } + +func (join Join) Build(builder Builder) { + if join.Type != "" { + builder.Write(string(join.Type)) + builder.WriteByte(' ') + } + + builder.Write("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.Write(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.Write(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } +} + +// MergeClause merge from clause +func (from From) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(From); ok { + from.Tables = append(v.Tables, from.Tables...) + from.Joins = append(v.Joins, from.Joins...) + } + clause.Expression = from +} diff --git a/clause/from_test.go b/clause/from_test.go new file mode 100644 index 00000000..4b7b0e18 --- /dev/null +++ b/clause/from_test.go @@ -0,0 +1,75 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestFrom(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, { + Type: clause.LeftJoin, + Table: clause.Table{Name: "companies"}, + Using: []string{"company_name"}, + }, + }, + }, clause.From{ + Joins: []clause.Join{ + { + Type: clause.RightJoin, + Table: clause.Table{Name: "profiles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, + }, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`) RIGHT JOIN `profiles` ON `profiles`.`email` = `users`.`email`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/group_by.go b/clause/group_by.go index bce94109..8d164731 100644 --- a/clause/group_by.go +++ b/clause/group_by.go @@ -2,5 +2,36 @@ package clause // GroupBy group by clause type GroupBy struct { - Having Where + Columns []Column + Having Where +} + +// Name from clause name +func (groupBy GroupBy) Name() string { + return "GROUP BY" +} + +// Build build group by clause +func (groupBy GroupBy) Build(builder Builder) { + for idx, column := range groupBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + + if len(groupBy.Having.Exprs) > 0 { + builder.Write(" HAVING ") + groupBy.Having.Build(builder) + } +} + +// MergeClause merge group by clause +func (groupBy GroupBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(GroupBy); ok { + groupBy.Columns = append(v.Columns, groupBy.Columns...) + groupBy.Having.Exprs = append(v.Having.Exprs, groupBy.Having.Exprs...) + } + clause.Expression = groupBy } diff --git a/clause/group_by_test.go b/clause/group_by_test.go new file mode 100644 index 00000000..35be84a4 --- /dev/null +++ b/clause/group_by_test.go @@ -0,0 +1,40 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestGroupBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}, + }}, + "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}}, + }, clause.GroupBy{ + Columns: []clause.Column{{Name: "gender"}}, + Having: clause.Where{[]clause.Expression{clause.Neq{"gender", "U"}}}, + }}, + "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/insert.go b/clause/insert.go index e056b35e..3f86c98f 100644 --- a/clause/insert.go +++ b/clause/insert.go @@ -2,7 +2,7 @@ package clause type Insert struct { Table Table - Priority string + Modifier string } // Name insert clause name @@ -12,23 +12,28 @@ func (insert Insert) Name() string { // Build build insert clause func (insert Insert) Build(builder Builder) { - if insert.Priority != "" { - builder.Write(insert.Priority) + if insert.Modifier != "" { + builder.Write(insert.Modifier) builder.WriteByte(' ') } builder.Write("INTO ") - builder.WriteQuoted(insert.Table) + if insert.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(insert.Table) + } } -// MergeExpression merge insert clauses -func (insert Insert) MergeExpression(expr Expression) { - if v, ok := expr.(Insert); ok { - if insert.Priority == "" { - insert.Priority = v.Priority +// MergeClause merge insert clause +func (insert Insert) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Insert); ok { + if insert.Modifier == "" { + insert.Modifier = v.Modifier } - if insert.Table.Table == "" { + if insert.Table.Name == "" { insert.Table = v.Table } } + clause.Expression = insert } diff --git a/clause/insert_test.go b/clause/insert_test.go new file mode 100644 index 00000000..b1a57803 --- /dev/null +++ b/clause/insert_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestInsert(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Insert{}}, + "INSERT INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/join.go b/clause/join.go deleted file mode 100644 index 6b0e8f97..00000000 --- a/clause/join.go +++ /dev/null @@ -1,23 +0,0 @@ -package clause - -// Join join clause -type Join struct { - Table From // From - Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN - Using []Column - ON Where -} - -// TODO multiple joins - -func (join Join) Build(builder Builder) { - // TODO -} - -func (join Join) MergeExpression(expr Expression) { - // if j, ok := expr.(Join); ok { - // join.builders = append(join.builders, j.builders...) - // } else { - // join.builders = append(join.builders, expr) - // } -} diff --git a/clause/limit.go b/clause/limit.go index 8fbc0055..7b16f339 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -1,6 +1,44 @@ package clause +import "strconv" + // Limit limit clause type Limit struct { - Offset uint + Limit int + Offset int +} + +// Name where clause name +func (limit Limit) Name() string { + return "LIMIT" +} + +// Build build where clause +func (limit Limit) Build(builder Builder) { + if limit.Limit > 0 { + builder.Write("LIMIT ") + builder.Write(strconv.Itoa(limit.Limit)) + + if limit.Offset > 0 { + builder.Write(" OFFSET ") + builder.Write(strconv.Itoa(limit.Offset)) + } + } +} + +// MergeClause merge order by clauses +func (limit Limit) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(Limit); ok { + if limit.Limit == 0 && v.Limit > 0 { + limit.Limit = v.Limit + } + + if limit.Offset == 0 && v.Offset > 0 { + limit.Offset = v.Offset + } + } + + clause.Expression = limit } diff --git a/clause/limit_test.go b/clause/limit_test.go new file mode 100644 index 00000000..7b76aaf4 --- /dev/null +++ b/clause/limit_test.go @@ -0,0 +1,46 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestLimit(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ + Limit: 10, + Offset: 20, + }}, + "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + "SELECT * FROM `users` LIMIT 10", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, + "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/locking.go b/clause/locking.go new file mode 100644 index 00000000..48b84b34 --- /dev/null +++ b/clause/locking.go @@ -0,0 +1,48 @@ +package clause + +type For struct { + Lockings []Locking +} + +type Locking struct { + Strength string + Table Table + Options string +} + +// Name where clause name +func (f For) Name() string { + return "FOR" +} + +// Build build where clause +func (f For) Build(builder Builder) { + for idx, locking := range f.Lockings { + if idx > 0 { + builder.WriteByte(' ') + } + + builder.Write("FOR ") + builder.Write(locking.Strength) + if locking.Table.Name != "" { + builder.Write(" OF ") + builder.WriteQuoted(locking.Table) + } + + if locking.Options != "" { + builder.WriteByte(' ') + builder.Write(locking.Options) + } + } +} + +// MergeClause merge order by clauses +func (f For) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(For); ok { + f.Lockings = append(v.Lockings, f.Lockings...) + } + + clause.Expression = f +} diff --git a/clause/locking_test.go b/clause/locking_test.go new file mode 100644 index 00000000..6b054404 --- /dev/null +++ b/clause/locking_test.go @@ -0,0 +1,43 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestFor(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE"}}, + }}, + "SELECT * FROM `users` FOR UPDATE", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + }}, + "SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users`", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + }, clause.For{ + Lockings: []clause.Locking{{Strength: "UPDATE", Options: "NOWAIT"}}, + }}, + "SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users` FOR UPDATE NOWAIT", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/on_conflict.go b/clause/on_conflict.go deleted file mode 100644 index 5cbe3dd7..00000000 --- a/clause/on_conflict.go +++ /dev/null @@ -1,6 +0,0 @@ -package clause - -type OnConflict struct { - ON string // duplicate key - Values *Values // update c=c+1 -} diff --git a/clause/order_by.go b/clause/order_by.go index 6025e1ba..2734f2bc 100644 --- a/clause/order_by.go +++ b/clause/order_by.go @@ -1,38 +1,47 @@ package clause -type OrderBy struct { +type OrderByColumn struct { Column Column Desc bool Reorder bool } -type OrderByClause struct { - Columns []OrderBy +type OrderBy struct { + Columns []OrderByColumn } // Name where clause name -func (orderBy OrderByClause) Name() string { +func (orderBy OrderBy) Name() string { return "ORDER BY" } // Build build where clause -func (orderBy OrderByClause) Build(builder Builder) { - for i := len(orderBy.Columns) - 1; i >= 0; i-- { - builder.WriteQuoted(orderBy.Columns[i].Column) +func (orderBy OrderBy) Build(builder Builder) { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } - if orderBy.Columns[i].Desc { + builder.WriteQuoted(column.Column) + if column.Desc { builder.Write(" DESC") } - - if orderBy.Columns[i].Reorder { - break - } } } -// MergeExpression merge order by clauses -func (orderBy OrderByClause) MergeExpression(expr Expression) { - if v, ok := expr.(OrderByClause); ok { +// MergeClause merge order by clauses +func (orderBy OrderBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(OrderBy); ok { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + if orderBy.Columns[i].Reorder { + orderBy.Columns = orderBy.Columns[i:] + clause.Expression = orderBy + return + } + } + orderBy.Columns = append(v.Columns, orderBy.Columns...) } + + clause.Expression = orderBy } diff --git a/clause/order_by_test.go b/clause/order_by_test.go new file mode 100644 index 00000000..2c74a322 --- /dev/null +++ b/clause/order_by_test.go @@ -0,0 +1,49 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestOrderBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }}, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}}, + }, + }, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}}, + }, + }, + "SELECT * FROM `users` ORDER BY `name`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/query.go b/clause/query.go deleted file mode 100644 index ce609014..00000000 --- a/clause/query.go +++ /dev/null @@ -1,258 +0,0 @@ -package clause - -import "strings" - -//////////////////////////////////////////////////////////////////////////////// -// Query Expressions -//////////////////////////////////////////////////////////////////////////////// - -func Add(exprs ...Expression) AddConditions { - return AddConditions(exprs) -} - -func Or(exprs ...Expression) OrConditions { - return OrConditions(exprs) -} - -type AddConditions []Expression - -func (cs AddConditions) Build(builder Builder) { - for idx, c := range cs { - if idx > 0 { - builder.Write(" AND ") - } - c.Build(builder) - } -} - -type OrConditions []Expression - -func (cs OrConditions) Build(builder Builder) { - for idx, c := range cs { - if idx > 0 { - builder.Write(" OR ") - } - c.Build(builder) - } -} - -type NotConditions []Expression - -func (cs NotConditions) Build(builder Builder) { - for idx, c := range cs { - if idx > 0 { - builder.Write(" AND ") - } - - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.Write(" NOT ") - c.Build(builder) - } - } -} - -// String raw sql for where -type String struct { - SQL string - Values []interface{} -} - -func (str String) Build(builder Builder) { - sql := str.SQL - for _, v := range str.Values { - sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) - } - builder.Write(sql) -} - -// IN Whether a value is within a set of values -type IN struct { - Column interface{} - Values []interface{} -} - -func (in IN) Build(builder Builder) { - builder.WriteQuoted(in.Column) - - switch len(in.Values) { - case 0: - builder.Write(" IN (NULL)") - case 1: - builder.Write(" = ", builder.AddVar(in.Values...)) - default: - builder.Write(" IN (", builder.AddVar(in.Values...), ")") - } -} - -func (in IN) NegationBuild(builder Builder) { - switch len(in.Values) { - case 0: - case 1: - builder.Write(" <> ", builder.AddVar(in.Values...)) - default: - builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") - } -} - -// Eq equal to for where -type Eq struct { - Column interface{} - Value interface{} -} - -func (eq Eq) Build(builder Builder) { - builder.WriteQuoted(eq.Column) - - if eq.Value == nil { - builder.Write(" IS NULL") - } else { - builder.Write(" = ", builder.AddVar(eq.Value)) - } -} - -func (eq Eq) NegationBuild(builder Builder) { - Neq{eq.Column, eq.Value}.Build(builder) -} - -// Neq not equal to for where -type Neq struct { - Column interface{} - Value interface{} -} - -func (neq Neq) Build(builder Builder) { - builder.WriteQuoted(neq.Column) - - if neq.Value == nil { - builder.Write(" IS NOT NULL") - } else { - builder.Write(" <> ", builder.AddVar(neq.Value)) - } -} - -func (neq Neq) NegationBuild(builder Builder) { - Eq{neq.Column, neq.Value}.Build(builder) -} - -// Gt greater than for where -type Gt struct { - Column interface{} - Value interface{} -} - -func (gt Gt) Build(builder Builder) { - builder.WriteQuoted(gt.Column) - builder.Write(" > ", builder.AddVar(gt.Value)) -} - -func (gt Gt) NegationBuild(builder Builder) { - Lte{gt.Column, gt.Value}.Build(builder) -} - -// Gte greater than or equal to for where -type Gte struct { - Column interface{} - Value interface{} -} - -func (gte Gte) Build(builder Builder) { - builder.WriteQuoted(gte.Column) - builder.Write(" >= ", builder.AddVar(gte.Value)) -} - -func (gte Gte) NegationBuild(builder Builder) { - Lt{gte.Column, gte.Value}.Build(builder) -} - -// Lt less than for where -type Lt struct { - Column interface{} - Value interface{} -} - -func (lt Lt) Build(builder Builder) { - builder.WriteQuoted(lt.Column) - builder.Write(" < ", builder.AddVar(lt.Value)) -} - -func (lt Lt) NegationBuild(builder Builder) { - Gte{lt.Column, lt.Value}.Build(builder) -} - -// Lte less than or equal to for where -type Lte struct { - Column interface{} - Value interface{} -} - -func (lte Lte) Build(builder Builder) { - builder.WriteQuoted(lte.Column) - builder.Write(" <= ", builder.AddVar(lte.Value)) -} - -func (lte Lte) NegationBuild(builder Builder) { - Gt{lte.Column, lte.Value}.Build(builder) -} - -// Like whether string matches regular expression -type Like struct { - Column interface{} - Value interface{} -} - -func (like Like) Build(builder Builder) { - builder.WriteQuoted(like.Column) - builder.Write(" LIKE ", builder.AddVar(like.Value)) -} - -func (like Like) NegationBuild(builder Builder) { - builder.WriteQuoted(like.Column) - builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) -} - -// Map -type Map map[interface{}]interface{} - -func (m Map) Build(builder Builder) { - // TODO -} - -func (m Map) NegationBuild(builder Builder) { - // TODO -} - -// Attrs -type Attrs struct { - Value interface{} - Select []string - Omit []string -} - -func (attrs Attrs) Build(builder Builder) { - // TODO - // builder.WriteQuoted(like.Column) - // builder.Write(" LIKE ", builder.AddVar(like.Value)) -} - -func (attrs Attrs) NegationBuild(builder Builder) { - // TODO -} - -// ID -type ID struct { - Value []interface{} -} - -func (id ID) Build(builder Builder) { - if len(id.Value) == 1 { - } - // TODO - // builder.WriteQuoted(like.Column) - // builder.Write(" LIKE ", builder.AddVar(like.Value)) -} - -func (id ID) NegationBuild(builder Builder) { - // TODO -} diff --git a/clause/returning.go b/clause/returning.go new file mode 100644 index 00000000..04bc96da --- /dev/null +++ b/clause/returning.go @@ -0,0 +1,30 @@ +package clause + +type Returning struct { + Columns []Column +} + +// Name where clause name +func (returning Returning) Name() string { + return "RETURNING" +} + +// Build build where clause +func (returning Returning) Build(builder Builder) { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } +} + +// MergeClause merge order by clauses +func (returning Returning) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Returning); ok { + returning.Columns = append(v.Columns, returning.Columns...) + } + + clause.Expression = returning +} diff --git a/clause/returning_test.go b/clause/returning_test.go new file mode 100644 index 00000000..e9fed1cb --- /dev/null +++ b/clause/returning_test.go @@ -0,0 +1,36 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestReturning(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`", nil, + }, { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }, clause.Returning{ + []clause.Column{{Name: "name"}, {Name: "age"}}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/select.go b/clause/select.go index 7f0e4438..4bb1af8d 100644 --- a/clause/select.go +++ b/clause/select.go @@ -1,32 +1,18 @@ package clause -// SelectInterface select clause interface -type SelectInterface interface { - Selects() []Column - Omits() []Column -} - // Select select attrs when querying, updating, creating type Select struct { - SelectColumns []Column - OmitColumns []Column + Columns []Column + Omits []Column } func (s Select) Name() string { return "SELECT" } -func (s Select) Selects() []Column { - return s.SelectColumns -} - -func (s Select) Omits() []Column { - return s.OmitColumns -} - func (s Select) Build(builder Builder) { - if len(s.SelectColumns) > 0 { - for idx, column := range s.SelectColumns { + if len(s.Columns) > 0 { + for idx, column := range s.Columns { if idx > 0 { builder.WriteByte(',') } @@ -37,13 +23,10 @@ func (s Select) Build(builder Builder) { } } -func (s Select) MergeExpression(expr Expression) { - if v, ok := expr.(SelectInterface); ok { - if len(s.SelectColumns) == 0 { - s.SelectColumns = v.Selects() - } - if len(s.OmitColumns) == 0 { - s.OmitColumns = v.Omits() - } +func (s Select) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Select); ok { + s.Columns = append(v.Columns, s.Columns...) + s.Omits = append(v.Omits, s.Omits...) } + clause.Expression = s } diff --git a/clause/select_test.go b/clause/select_test.go new file mode 100644 index 00000000..8255e51b --- /dev/null +++ b/clause/select_test.go @@ -0,0 +1,41 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestSelect(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.From{}}, + "SELECT `users`.`id` FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.Select{ + Columns: []clause.Column{{Name: "name"}}, + }, clause.From{}}, + "SELECT `users`.`id`,`name` FROM `users`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/set.go b/clause/set.go new file mode 100644 index 00000000..3b7e972d --- /dev/null +++ b/clause/set.go @@ -0,0 +1,37 @@ +package clause + +type Set []Assignment + +type Assignment struct { + Column Column + Value interface{} +} + +func (set Set) Name() string { + return "SET" +} + +func (set Set) Build(builder Builder) { + if len(set) > 0 { + for idx, assignment := range set { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(assignment.Column) + builder.WriteByte('=') + builder.Write(builder.AddVar(assignment.Value)) + } + } else { + builder.WriteQuoted(PrimaryColumn) + builder.WriteByte('=') + builder.WriteQuoted(PrimaryColumn) + } +} + +// MergeClause merge assignments clauses +func (set Set) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Set); ok { + set = append(v, set...) + } + clause.Expression = set +} diff --git a/clause/set_test.go b/clause/set_test.go new file mode 100644 index 00000000..85754737 --- /dev/null +++ b/clause/set_test.go @@ -0,0 +1,38 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestSet(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + }, + "UPDATE `users` SET `users`.`id`=?", []interface{}{1}, + }, + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), + }, + "UPDATE `users` SET `users`.`id`=?,`name`=?", []interface{}{1, "jinzhu"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/update.go b/clause/update.go new file mode 100644 index 00000000..c375b373 --- /dev/null +++ b/clause/update.go @@ -0,0 +1,38 @@ +package clause + +type Update struct { + Modifier string + Table Table +} + +// Name update clause name +func (update Update) Name() string { + return "UPDATE" +} + +// Build build update clause +func (update Update) Build(builder Builder) { + if update.Modifier != "" { + builder.Write(update.Modifier) + builder.WriteByte(' ') + } + + if update.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(update.Table) + } +} + +// MergeClause merge update clause +func (update Update) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Update); ok { + if update.Modifier == "" { + update.Modifier = v.Modifier + } + if update.Table.Name == "" { + update.Table = v.Table + } + } + clause.Expression = update +} diff --git a/clause/update_test.go b/clause/update_test.go new file mode 100644 index 00000000..adc48f03 --- /dev/null +++ b/clause/update_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestUpdate(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Update{}}, + "UPDATE `users`", nil, + }, + { + []clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `users`", nil, + }, + { + []clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/value.go b/clause/values.go similarity index 76% rename from clause/value.go rename to clause/values.go index 4de0d91e..594b92e2 100644 --- a/clause/value.go +++ b/clause/values.go @@ -25,11 +25,11 @@ func (values Values) Build(builder Builder) { builder.Write(" VALUES ") for idx, value := range values.Values { - builder.WriteByte('(') if idx > 0 { builder.WriteByte(',') } + builder.WriteByte('(') builder.Write(builder.AddVar(value...)) builder.WriteByte(')') } @@ -37,3 +37,11 @@ func (values Values) Build(builder Builder) { builder.Write("DEFAULT VALUES") } } + +// MergeClause merge values clauses +func (values Values) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Values); ok { + values.Values = append(v.Values, values.Values...) + } + clause.Expression = values +} diff --git a/clause/values_test.go b/clause/values_test.go new file mode 100644 index 00000000..ced4f1e6 --- /dev/null +++ b/clause/values_test.go @@ -0,0 +1,33 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestValues(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Insert{}, + clause.Values{ + Columns: []clause.Column{{Name: "name"}, {Name: "age"}}, + Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, + }, + }, + "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/where.go b/clause/where.go index de82662c..d0f57ed1 100644 --- a/clause/where.go +++ b/clause/where.go @@ -2,9 +2,7 @@ package clause // Where where clause type Where struct { - AndConditions AddConditions - OrConditions []OrConditions - builders []Expression + Exprs []Expression } // Name where clause name @@ -14,64 +12,122 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { - var withConditions bool - - if len(where.AndConditions) > 0 { - withConditions = true - where.AndConditions.Build(builder) - } - - if len(where.builders) > 0 { - for _, b := range where.builders { - if withConditions { - builder.Write(" AND ") + // Switch position if the first query expression is a single Or condition + for idx, expr := range where.Exprs { + if v, ok := expr.(OrConditions); (!ok && expr != nil) || len(v.Exprs) > 1 { + if idx != 0 { + where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] } - withConditions = true - b.Build(builder) + break } } - var singleOrConditions []OrConditions - for _, or := range where.OrConditions { - if len(or) == 1 { - if withConditions { - builder.Write(" OR ") - or.Build(builder) - } else { - singleOrConditions = append(singleOrConditions, or) + for idx, expr := range where.Exprs { + if expr != nil { + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.Write(" OR ") + } else { + builder.Write(" AND ") + } } - } else { - withConditions = true - builder.Write(" AND (") - or.Build(builder) - builder.WriteByte(')') - } - } - for _, or := range singleOrConditions { - if withConditions { - builder.Write(" AND ") - or.Build(builder) - } else { - withConditions = true - or.Build(builder) + expr.Build(builder) } } - if !withConditions { - builder.Write(" FALSE") - } - return } -// MergeExpression merge where clauses -func (where Where) MergeExpression(expr Expression) { - if w, ok := expr.(Where); ok { - where.AndConditions = append(where.AndConditions, w.AndConditions...) - where.OrConditions = append(where.OrConditions, w.OrConditions...) - where.builders = append(where.builders, w.builders...) - } else { - where.builders = append(where.builders, expr) +// MergeClause merge where clauses +func (where Where) MergeClause(clause *Clause) { + if w, ok := clause.Expression.(Where); ok { + where.Exprs = append(w.Exprs, where.Exprs...) + } + + clause.Expression = where +} + +func And(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return AndConditions{Exprs: exprs} +} + +type AndConditions struct { + Exprs []Expression +} + +func (and AndConditions) Build(builder Builder) { + if len(and.Exprs) > 1 { + builder.Write("(") + } + for idx, c := range and.Exprs { + if idx > 0 { + builder.Write(" AND ") + } + c.Build(builder) + } + if len(and.Exprs) > 1 { + builder.Write(")") + } +} + +func Or(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return OrConditions{Exprs: exprs} +} + +type OrConditions struct { + Exprs []Expression +} + +func (or OrConditions) Build(builder Builder) { + if len(or.Exprs) > 1 { + builder.Write("(") + } + for idx, c := range or.Exprs { + if idx > 0 { + builder.Write(" OR ") + } + c.Build(builder) + } + if len(or.Exprs) > 1 { + builder.Write(")") + } +} + +func Not(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return NotConditions{Exprs: exprs} +} + +type NotConditions struct { + Exprs []Expression +} + +func (not NotConditions) Build(builder Builder) { + if len(not.Exprs) > 1 { + builder.Write("(") + } + for idx, c := range not.Exprs { + if idx > 0 { + builder.Write(" AND ") + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.Write(" NOT ") + c.Build(builder) + } + } + if len(not.Exprs) > 1 { + builder.Write(")") } } diff --git a/clause/where_test.go b/clause/where_test.go new file mode 100644 index 00000000..450a0c89 --- /dev/null +++ b/clause/where_test.go @@ -0,0 +1,63 @@ +package clause_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm/clause" +) + +func TestWhere(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/finisher_api.go b/finisher_api.go index 06809651..5389ed6a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -22,7 +22,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance().Limit(1).Order(clause.OrderBy{ + tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) diff --git a/statement.go b/statement.go index bc07b6e4..5dd49623 100644 --- a/statement.go +++ b/statement.go @@ -5,7 +5,6 @@ import ( "database/sql" "database/sql/driver" "fmt" - "log" "strconv" "strings" "sync" @@ -26,7 +25,7 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { if len(clauses) > 0 { instance.Statement.Build(clauses...) } - return instance.Statement.SQL.String(), instance.Statement.Vars + return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars } // AddError add error to instance @@ -85,10 +84,10 @@ func (stmt Statement) Quote(field interface{}) string { switch v := field.(type) { case clause.Table: - if v.Table == clause.CurrentTable { + if v.Name == clause.CurrentTable { str.WriteString(stmt.Table) } else { - str.WriteString(v.Table) + str.WriteString(v.Name) } if v.Alias != "" { @@ -126,7 +125,7 @@ func (stmt Statement) Quote(field interface{}) string { str.WriteByte(stmt.DB.quoteChars[1]) } default: - fmt.Sprint(field) + str.WriteString(fmt.Sprint(field)) } str.WriteByte(stmt.DB.quoteChars[1]) @@ -141,19 +140,28 @@ func (stmt *Statement) AddVar(vars ...interface{}) string { placeholders.WriteByte(',') } - if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { - stmt.NamedVars = append(stmt.NamedVars, namedArg) - placeholders.WriteByte('@') - placeholders.WriteString(namedArg.Name) - } else if arrs, ok := v.([]interface{}); ok { + switch v := v.(type) { + case sql.NamedArg: + if len(v.Name) > 0 { + stmt.NamedVars = append(stmt.NamedVars, v) + placeholders.WriteByte('@') + placeholders.WriteString(v.Name) + } else { + stmt.Vars = append(stmt.Vars, v.Value) + placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) + } + case clause.Column: + placeholders.WriteString(stmt.Quote(v)) + case []interface{}: placeholders.WriteByte('(') - if len(arrs) > 0 { - placeholders.WriteString(stmt.AddVar(arrs...)) + if len(v) > 0 { + placeholders.WriteString(stmt.AddVar(v...)) } else { placeholders.WriteString("NULL") } placeholders.WriteByte(')') - } else { + default: + stmt.Vars = append(stmt.Vars, v) placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) } } @@ -166,42 +174,18 @@ func (stmt *Statement) AddClause(v clause.Interface) { optimizer.OptimizeStatement(stmt) } - c, _ := stmt.Clauses[v.Name()] - if namer, ok := v.(clause.OverrideNameInterface); ok { - c.Name = namer.OverrideName() - } else { + c, ok := stmt.Clauses[v.Name()] + if !ok { c.Name = v.Name() } - - if c.Expression != nil { - v.MergeExpression(c.Expression) - } - - c.Expression = v + v.MergeClause(&c) stmt.Clauses[v.Name()] = c } // AddClauseIfNotExists add clause if not exists func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { - if optimizer, ok := v.(StatementOptimizer); ok { - optimizer.OptimizeStatement(stmt) - } - - log.Println(v.Name()) - if c, ok := stmt.Clauses[v.Name()]; !ok { - if namer, ok := v.(clause.OverrideNameInterface); ok { - c.Name = namer.OverrideName() - } else { - c.Name = v.Name() - } - - if c.Expression != nil { - v.MergeExpression(c.Expression) - } - - c.Expression = v - stmt.Clauses[v.Name()] = c - log.Println(stmt.Clauses[v.Name()]) + if _, ok := stmt.Clauses[v.Name()]; !ok { + stmt.AddClause(v) } } @@ -211,7 +195,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con if i, err := strconv.Atoi(sql); err != nil { query = i } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { - return []clause.Expression{clause.String{SQL: sql, Values: args}} + return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} } } @@ -255,7 +239,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con } if len(conditions) == 0 { - conditions = append(conditions, clause.ID{Value: args}) + conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args}) } return conditions