Finish clauses tests

This commit is contained in:
Jinzhu 2020-02-07 23:45:35 +08:00
parent 0160bab7dc
commit 1f38ec4410
35 changed files with 1282 additions and 544 deletions

View File

@ -31,8 +31,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
} }
if len(whereConds) > 0 { if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{ tx.Statement.AddClause(&clause.Where{
AndConditions: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...), tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...),
}) })
} }
return return
@ -59,8 +59,8 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{ tx.Statement.AddClause(&clause.Where{
AndConditions: tx.Statement.BuildCondtion(query, args...), tx.Statement.BuildCondtion(query, args...),
}) })
return return
} }
@ -68,10 +68,8 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
// Not add NOT condition // Not add NOT condition
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{ tx.Statement.AddClause(&clause.Where{
AndConditions: []clause.Expression{ []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)},
clause.NotConditions(tx.Statement.BuildCondtion(query, args...)),
},
}) })
return return
} }
@ -79,10 +77,8 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
// Or add OR conditions // Or add OR conditions
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{ tx.Statement.AddClause(&clause.Where{
OrConditions: []clause.OrConditions{ []clause.Expression{clause.Or(tx.Statement.BuildCondtion(query, args...)...)},
tx.Statement.BuildCondtion(query, args...),
},
}) })
return return
} }
@ -113,13 +109,13 @@ func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
switch v := value.(type) { switch v := value.(type) {
case clause.OrderBy: case clause.OrderByColumn:
db.Statement.AddClause(clause.OrderByClause{ db.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderBy{v}, Columns: []clause.OrderByColumn{v},
}) })
default: default:
db.Statement.AddClause(clause.OrderByClause{ db.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderBy{{ Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
}}, }},
}) })

View File

@ -1,5 +1,26 @@
package clause package clause
// Interface clause interface
type Interface interface {
Name() string
Build(Builder)
MergeClause(*Clause)
}
// ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder interface {
Build(Clause, Builder)
}
// Builder builder interface
type Builder interface {
WriteByte(byte) error
Write(sql ...string) error
WriteQuoted(field interface{}) error
AddVar(vars ...interface{}) string
Quote(field interface{}) string
}
// Clause // Clause
type Clause struct { type Clause struct {
Name string // WHERE Name string // WHERE
@ -18,7 +39,7 @@ func (c Clause) Build(builder Builder) {
} else { } else {
builders := c.BeforeExpressions builders := c.BeforeExpressions
if c.Name != "" { if c.Name != "" {
builders = append(builders, Expr{c.Name}) builders = append(builders, Expr{SQL: c.Name})
} }
builders = append(builders, c.AfterNameExpressions...) builders = append(builders, c.AfterNameExpressions...)
@ -35,28 +56,27 @@ func (c Clause) Build(builder Builder) {
} }
} }
// Interface clause interface const (
type Interface interface { PrimaryKey string = "@@@priamry_key@@@"
Name() string CurrentTable string = "@@@table@@@"
Build(Builder) )
MergeExpression(Expression)
var (
currentTable = Table{Name: CurrentTable}
PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey}
)
// Column quote with name
type Column struct {
Table string
Name string
Alias string
Raw bool
} }
// OverrideNameInterface override name interface // Table quote with name
type OverrideNameInterface interface { type Table struct {
OverrideName() string Name string
} Alias string
Raw bool
// ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder interface {
Build(Clause, Builder)
}
// Builder builder interface
type Builder interface {
WriteByte(byte) error
Write(sql ...string) error
WriteQuoted(field interface{}) error
AddVar(vars ...interface{}) string
Quote(field interface{}) string
} }

View File

@ -1,8 +1,8 @@
package clause_test package clause_test
import ( import (
"fmt"
"reflect" "reflect"
"strings"
"sync" "sync"
"testing" "testing"
@ -12,45 +12,32 @@ import (
"github.com/jinzhu/gorm/tests" "github.com/jinzhu/gorm/tests"
) )
func TestClauses(t *testing.T) { var db, _ = gorm.Open(tests.DummyDialector{}, nil)
var (
db, _ = gorm.Open(tests.DummyDialector{}, nil)
results = []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}},
"SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"},
},
}
)
for idx, result := range results { func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
var ( var (
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{
DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{},
}
buildNames []string buildNames []string
buildNamesMap = map[string]bool{}
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
) )
for _, c := range result.Clauses { for _, c := range clauses {
if _, ok := buildNamesMap[c.Name()]; !ok {
buildNames = append(buildNames, c.Name()) buildNames = append(buildNames, c.Name())
buildNamesMap[c.Name()] = true
}
stmt.AddClause(c) stmt.AddClause(c)
} }
stmt.Build(buildNames...) stmt.Build(buildNames...)
if stmt.SQL.String() != result.Result { if strings.TrimSpace(stmt.SQL.String()) != result {
t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String()) t.Errorf("SQL expects %v got %v", result, stmt.SQL.String())
} }
if reflect.DeepEqual(stmt.Vars, result.Vars) { if !reflect.DeepEqual(stmt.Vars, vars) {
t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars) t.Errorf("Vars expects %+v got %v", stmt.Vars, vars)
}
})
} }
} }

23
clause/delete.go Normal file
View File

@ -0,0 +1,23 @@
package clause
type Delete struct {
Modifier string
}
func (d Delete) Name() string {
return "DELETE"
}
func (d Delete) Build(builder Builder) {
builder.Write("DELETE")
if d.Modifier != "" {
builder.WriteByte(' ')
builder.Write(d.Modifier)
}
}
func (d Delete) MergeClause(clause *Clause) {
clause.Name = ""
clause.Expression = d
}

31
clause/delete_test.go Normal file
View File

@ -0,0 +1,31 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestDelete(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Delete{}, clause.From{}},
"DELETE FROM `users`", nil,
},
{
[]clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}},
"DELETE LOW_PRIORITY FROM `users`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -1,14 +1,6 @@
package clause package clause
const ( import "strings"
PrimaryKey string = "@@@priamry_key@@@"
CurrentTable string = "@@@table@@@"
)
var PrimaryColumn = Column{
Table: CurrentTable,
Name: PrimaryKey,
}
// Expression expression interface // Expression expression interface
type Expression interface { type Expression interface {
@ -20,27 +12,155 @@ type NegationExpressionBuilder interface {
NegationBuild(builder Builder) NegationBuild(builder Builder)
} }
// Column quote with name
type Column struct {
Table string
Name string
Alias string
Raw bool
}
// Table quote with name
type Table struct {
Table string
Alias string
Raw bool
}
// Expr raw expression // Expr raw expression
type Expr struct { type Expr struct {
Value string SQL string
Vars []interface{}
} }
// Build build raw expression // Build build raw expression
func (expr Expr) Build(builder Builder) { func (expr Expr) Build(builder Builder) {
builder.Write(expr.Value) sql := expr.SQL
for _, v := range expr.Vars {
sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1)
}
builder.Write(sql)
}
// IN Whether a value is within a set of values
type IN struct {
Column interface{}
Values []interface{}
}
func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.Write(" IN (NULL)")
case 1:
builder.Write(" = ", builder.AddVar(in.Values...))
default:
builder.Write(" IN (", builder.AddVar(in.Values...), ")")
}
}
func (in IN) NegationBuild(builder Builder) {
switch len(in.Values) {
case 0:
case 1:
builder.Write(" <> ", builder.AddVar(in.Values...))
default:
builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")")
}
}
// Eq equal to for where
type Eq struct {
Column interface{}
Value interface{}
}
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
if eq.Value == nil {
builder.Write(" IS NULL")
} else {
builder.Write(" = ", builder.AddVar(eq.Value))
}
}
func (eq Eq) NegationBuild(builder Builder) {
Neq{eq.Column, eq.Value}.Build(builder)
}
// Neq not equal to for where
type Neq Eq
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
if neq.Value == nil {
builder.Write(" IS NOT NULL")
} else {
builder.Write(" <> ", builder.AddVar(neq.Value))
}
}
func (neq Neq) NegationBuild(builder Builder) {
Eq{neq.Column, neq.Value}.Build(builder)
}
// Gt greater than for where
type Gt Eq
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
builder.Write(" > ", builder.AddVar(gt.Value))
}
func (gt Gt) NegationBuild(builder Builder) {
Lte{gt.Column, gt.Value}.Build(builder)
}
// Gte greater than or equal to for where
type Gte Eq
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
builder.Write(" >= ", builder.AddVar(gte.Value))
}
func (gte Gte) NegationBuild(builder Builder) {
Lt{gte.Column, gte.Value}.Build(builder)
}
// Lt less than for where
type Lt Eq
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
builder.Write(" < ", builder.AddVar(lt.Value))
}
func (lt Lt) NegationBuild(builder Builder) {
Gte{lt.Column, lt.Value}.Build(builder)
}
// Lte less than or equal to for where
type Lte Eq
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
builder.Write(" <= ", builder.AddVar(lte.Value))
}
func (lte Lte) NegationBuild(builder Builder) {
Gt{lte.Column, lte.Value}.Build(builder)
}
// Like whether string matches regular expression
type Like Eq
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
}
// Map
type Map map[interface{}]interface{}
func (m Map) Build(builder Builder) {
// TODO
}
func (m Map) NegationBuild(builder Builder) {
// TODO
} }

View File

@ -3,15 +3,31 @@ package clause
// From from clause // From from clause
type From struct { type From struct {
Tables []Table Tables []Table
Joins []Join
}
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin = "INNER"
LeftJoin = "LEFT"
RightJoin = "RIGHT"
)
// Join join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
} }
// Name from clause name // Name from clause name
func (From) Name() string { func (from From) Name() string {
return "FROM" return "FROM"
} }
var currentTable = Table{Table: CurrentTable}
// Build build from clause // Build build from clause
func (from From) Build(builder Builder) { func (from From) Build(builder Builder) {
if len(from.Tables) > 0 { if len(from.Tables) > 0 {
@ -25,11 +41,42 @@ func (from From) Build(builder Builder) {
} else { } else {
builder.WriteQuoted(currentTable) builder.WriteQuoted(currentTable)
} }
for _, join := range from.Joins {
builder.WriteByte(' ')
join.Build(builder)
}
} }
// MergeExpression merge order by clauses func (join Join) Build(builder Builder) {
func (from From) MergeExpression(expr Expression) { if join.Type != "" {
if v, ok := expr.(From); ok { builder.Write(string(join.Type))
builder.WriteByte(' ')
}
builder.Write("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.Write(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.Write(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
}
// MergeClause merge from clause
func (from From) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(From); ok {
from.Tables = append(v.Tables, from.Tables...) from.Tables = append(v.Tables, from.Tables...)
from.Joins = append(v.Joins, from.Joins...)
} }
clause.Expression = from
} }

75
clause/from_test.go Normal file
View File

@ -0,0 +1,75 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestFrom(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM `users`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{
Tables: []clause.Table{{Name: "users"}},
Joins: []clause.Join{
{
Type: clause.InnerJoin,
Table: clause.Table{Name: "articles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}},
},
},
},
},
},
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{
Tables: []clause.Table{{Name: "users"}},
Joins: []clause.Join{
{
Type: clause.InnerJoin,
Table: clause.Table{Name: "articles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}},
},
}, {
Type: clause.LeftJoin,
Table: clause.Table{Name: "companies"},
Using: []string{"company_name"},
},
},
}, clause.From{
Joins: []clause.Join{
{
Type: clause.RightJoin,
Table: clause.Table{Name: "profiles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}},
},
},
},
},
},
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`) RIGHT JOIN `profiles` ON `profiles`.`email` = `users`.`email`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -2,5 +2,36 @@ package clause
// GroupBy group by clause // GroupBy group by clause
type GroupBy struct { type GroupBy struct {
Columns []Column
Having Where Having Where
} }
// Name from clause name
func (groupBy GroupBy) Name() string {
return "GROUP BY"
}
// Build build group by clause
func (groupBy GroupBy) Build(builder Builder) {
for idx, column := range groupBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
if len(groupBy.Having.Exprs) > 0 {
builder.Write(" HAVING ")
groupBy.Having.Build(builder)
}
}
// MergeClause merge group by clause
func (groupBy GroupBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(GroupBy); ok {
groupBy.Columns = append(v.Columns, groupBy.Columns...)
groupBy.Having.Exprs = append(v.Having.Exprs, groupBy.Having.Exprs...)
}
clause.Expression = groupBy
}

40
clause/group_by_test.go Normal file
View File

@ -0,0 +1,40 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestGroupBy(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
Columns: []clause.Column{{Name: "role"}},
Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}},
}},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
Columns: []clause.Column{{Name: "role"}},
Having: clause.Where{[]clause.Expression{clause.Eq{"role", "admin"}}},
}, clause.GroupBy{
Columns: []clause.Column{{Name: "gender"}},
Having: clause.Where{[]clause.Expression{clause.Neq{"gender", "U"}}},
}},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -2,7 +2,7 @@ package clause
type Insert struct { type Insert struct {
Table Table Table Table
Priority string Modifier string
} }
// Name insert clause name // Name insert clause name
@ -12,23 +12,28 @@ func (insert Insert) Name() string {
// Build build insert clause // Build build insert clause
func (insert Insert) Build(builder Builder) { func (insert Insert) Build(builder Builder) {
if insert.Priority != "" { if insert.Modifier != "" {
builder.Write(insert.Priority) builder.Write(insert.Modifier)
builder.WriteByte(' ') builder.WriteByte(' ')
} }
builder.Write("INTO ") builder.Write("INTO ")
if insert.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(insert.Table) builder.WriteQuoted(insert.Table)
} }
// MergeExpression merge insert clauses
func (insert Insert) MergeExpression(expr Expression) {
if v, ok := expr.(Insert); ok {
if insert.Priority == "" {
insert.Priority = v.Priority
} }
if insert.Table.Table == "" {
// MergeClause merge insert clause
func (insert Insert) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Insert); ok {
if insert.Modifier == "" {
insert.Modifier = v.Modifier
}
if insert.Table.Name == "" {
insert.Table = v.Table insert.Table = v.Table
} }
} }
clause.Expression = insert
} }

35
clause/insert_test.go Normal file
View File

@ -0,0 +1,35 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestInsert(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Insert{}},
"INSERT INTO `users`", nil,
},
{
[]clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}},
"INSERT LOW_PRIORITY INTO `users`", nil,
},
{
[]clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}},
"INSERT LOW_PRIORITY INTO `products`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -1,23 +0,0 @@
package clause
// Join join clause
type Join struct {
Table From // From
Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN
Using []Column
ON Where
}
// TODO multiple joins
func (join Join) Build(builder Builder) {
// TODO
}
func (join Join) MergeExpression(expr Expression) {
// if j, ok := expr.(Join); ok {
// join.builders = append(join.builders, j.builders...)
// } else {
// join.builders = append(join.builders, expr)
// }
}

View File

@ -1,6 +1,44 @@
package clause package clause
import "strconv"
// Limit limit clause // Limit limit clause
type Limit struct { type Limit struct {
Offset uint Limit int
Offset int
}
// Name where clause name
func (limit Limit) Name() string {
return "LIMIT"
}
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 {
builder.Write("LIMIT ")
builder.Write(strconv.Itoa(limit.Limit))
if limit.Offset > 0 {
builder.Write(" OFFSET ")
builder.Write(strconv.Itoa(limit.Offset))
}
}
}
// MergeClause merge order by clauses
func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit > 0 {
limit.Limit = v.Limit
}
if limit.Offset == 0 && v.Offset > 0 {
limit.Offset = v.Offset
}
}
clause.Expression = limit
} }

46
clause/limit_test.go Normal file
View File

@ -0,0 +1,46 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestLimit(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10,
Offset: 20,
}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
"SELECT * FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

48
clause/locking.go Normal file
View File

@ -0,0 +1,48 @@
package clause
type For struct {
Lockings []Locking
}
type Locking struct {
Strength string
Table Table
Options string
}
// Name where clause name
func (f For) Name() string {
return "FOR"
}
// Build build where clause
func (f For) Build(builder Builder) {
for idx, locking := range f.Lockings {
if idx > 0 {
builder.WriteByte(' ')
}
builder.Write("FOR ")
builder.Write(locking.Strength)
if locking.Table.Name != "" {
builder.Write(" OF ")
builder.WriteQuoted(locking.Table)
}
if locking.Options != "" {
builder.WriteByte(' ')
builder.Write(locking.Options)
}
}
}
// MergeClause merge order by clauses
func (f For) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(For); ok {
f.Lockings = append(v.Lockings, f.Lockings...)
}
clause.Expression = f
}

43
clause/locking_test.go Normal file
View File

@ -0,0 +1,43 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestFor(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE"}},
}},
"SELECT * FROM `users` FOR UPDATE", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
}},
"SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE", Options: "NOWAIT"}},
}},
"SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users` FOR UPDATE NOWAIT", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -1,6 +0,0 @@
package clause
type OnConflict struct {
ON string // duplicate key
Values *Values // update c=c+1
}

View File

@ -1,38 +1,47 @@
package clause package clause
type OrderBy struct { type OrderByColumn struct {
Column Column Column Column
Desc bool Desc bool
Reorder bool Reorder bool
} }
type OrderByClause struct { type OrderBy struct {
Columns []OrderBy Columns []OrderByColumn
} }
// Name where clause name // Name where clause name
func (orderBy OrderByClause) Name() string { func (orderBy OrderBy) Name() string {
return "ORDER BY" return "ORDER BY"
} }
// Build build where clause // Build build where clause
func (orderBy OrderByClause) Build(builder Builder) { func (orderBy OrderBy) Build(builder Builder) {
for i := len(orderBy.Columns) - 1; i >= 0; i-- { for idx, column := range orderBy.Columns {
builder.WriteQuoted(orderBy.Columns[i].Column) if idx > 0 {
builder.WriteByte(',')
}
if orderBy.Columns[i].Desc { builder.WriteQuoted(column.Column)
if column.Desc {
builder.Write(" DESC") builder.Write(" DESC")
} }
}
}
// MergeClause merge order by clauses
func (orderBy OrderBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(OrderBy); ok {
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
if orderBy.Columns[i].Reorder { if orderBy.Columns[i].Reorder {
break orderBy.Columns = orderBy.Columns[i:]
} clause.Expression = orderBy
return
} }
} }
// MergeExpression merge order by clauses
func (orderBy OrderByClause) MergeExpression(expr Expression) {
if v, ok := expr.(OrderByClause); ok {
orderBy.Columns = append(v.Columns, orderBy.Columns...) orderBy.Columns = append(v.Columns, orderBy.Columns...)
} }
clause.Expression = orderBy
} }

49
clause/order_by_test.go Normal file
View File

@ -0,0 +1,49 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestOrderBy(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}},
"SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}},
},
},
"SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}, clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}},
},
},
"SELECT * FROM `users` ORDER BY `name`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -1,258 +0,0 @@
package clause
import "strings"
////////////////////////////////////////////////////////////////////////////////
// Query Expressions
////////////////////////////////////////////////////////////////////////////////
func Add(exprs ...Expression) AddConditions {
return AddConditions(exprs)
}
func Or(exprs ...Expression) OrConditions {
return OrConditions(exprs)
}
type AddConditions []Expression
func (cs AddConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" AND ")
}
c.Build(builder)
}
}
type OrConditions []Expression
func (cs OrConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" OR ")
}
c.Build(builder)
}
}
type NotConditions []Expression
func (cs NotConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" AND ")
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.Write(" NOT ")
c.Build(builder)
}
}
}
// String raw sql for where
type String struct {
SQL string
Values []interface{}
}
func (str String) Build(builder Builder) {
sql := str.SQL
for _, v := range str.Values {
sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1)
}
builder.Write(sql)
}
// IN Whether a value is within a set of values
type IN struct {
Column interface{}
Values []interface{}
}
func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.Write(" IN (NULL)")
case 1:
builder.Write(" = ", builder.AddVar(in.Values...))
default:
builder.Write(" IN (", builder.AddVar(in.Values...), ")")
}
}
func (in IN) NegationBuild(builder Builder) {
switch len(in.Values) {
case 0:
case 1:
builder.Write(" <> ", builder.AddVar(in.Values...))
default:
builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")")
}
}
// Eq equal to for where
type Eq struct {
Column interface{}
Value interface{}
}
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
if eq.Value == nil {
builder.Write(" IS NULL")
} else {
builder.Write(" = ", builder.AddVar(eq.Value))
}
}
func (eq Eq) NegationBuild(builder Builder) {
Neq{eq.Column, eq.Value}.Build(builder)
}
// Neq not equal to for where
type Neq struct {
Column interface{}
Value interface{}
}
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
if neq.Value == nil {
builder.Write(" IS NOT NULL")
} else {
builder.Write(" <> ", builder.AddVar(neq.Value))
}
}
func (neq Neq) NegationBuild(builder Builder) {
Eq{neq.Column, neq.Value}.Build(builder)
}
// Gt greater than for where
type Gt struct {
Column interface{}
Value interface{}
}
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
builder.Write(" > ", builder.AddVar(gt.Value))
}
func (gt Gt) NegationBuild(builder Builder) {
Lte{gt.Column, gt.Value}.Build(builder)
}
// Gte greater than or equal to for where
type Gte struct {
Column interface{}
Value interface{}
}
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
builder.Write(" >= ", builder.AddVar(gte.Value))
}
func (gte Gte) NegationBuild(builder Builder) {
Lt{gte.Column, gte.Value}.Build(builder)
}
// Lt less than for where
type Lt struct {
Column interface{}
Value interface{}
}
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
builder.Write(" < ", builder.AddVar(lt.Value))
}
func (lt Lt) NegationBuild(builder Builder) {
Gte{lt.Column, lt.Value}.Build(builder)
}
// Lte less than or equal to for where
type Lte struct {
Column interface{}
Value interface{}
}
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
builder.Write(" <= ", builder.AddVar(lte.Value))
}
func (lte Lte) NegationBuild(builder Builder) {
Gt{lte.Column, lte.Value}.Build(builder)
}
// Like whether string matches regular expression
type Like struct {
Column interface{}
Value interface{}
}
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
builder.Write(" NOT LIKE ", builder.AddVar(like.Value))
}
// Map
type Map map[interface{}]interface{}
func (m Map) Build(builder Builder) {
// TODO
}
func (m Map) NegationBuild(builder Builder) {
// TODO
}
// Attrs
type Attrs struct {
Value interface{}
Select []string
Omit []string
}
func (attrs Attrs) Build(builder Builder) {
// TODO
// builder.WriteQuoted(like.Column)
// builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (attrs Attrs) NegationBuild(builder Builder) {
// TODO
}
// ID
type ID struct {
Value []interface{}
}
func (id ID) Build(builder Builder) {
if len(id.Value) == 1 {
}
// TODO
// builder.WriteQuoted(like.Column)
// builder.Write(" LIKE ", builder.AddVar(like.Value))
}
func (id ID) NegationBuild(builder Builder) {
// TODO
}

30
clause/returning.go Normal file
View File

@ -0,0 +1,30 @@
package clause
type Returning struct {
Columns []Column
}
// Name where clause name
func (returning Returning) Name() string {
return "RETURNING"
}
// Build build where clause
func (returning Returning) Build(builder Builder) {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
}
// MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok {
returning.Columns = append(v.Columns, returning.Columns...)
}
clause.Expression = returning
}

36
clause/returning_test.go Normal file
View File

@ -0,0 +1,36 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestReturning(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}},
"SELECT * FROM `users` RETURNING `users`.`id`", nil,
}, {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -1,32 +1,18 @@
package clause package clause
// SelectInterface select clause interface
type SelectInterface interface {
Selects() []Column
Omits() []Column
}
// Select select attrs when querying, updating, creating // Select select attrs when querying, updating, creating
type Select struct { type Select struct {
SelectColumns []Column Columns []Column
OmitColumns []Column Omits []Column
} }
func (s Select) Name() string { func (s Select) Name() string {
return "SELECT" return "SELECT"
} }
func (s Select) Selects() []Column {
return s.SelectColumns
}
func (s Select) Omits() []Column {
return s.OmitColumns
}
func (s Select) Build(builder Builder) { func (s Select) Build(builder Builder) {
if len(s.SelectColumns) > 0 { if len(s.Columns) > 0 {
for idx, column := range s.SelectColumns { for idx, column := range s.Columns {
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')
} }
@ -37,13 +23,10 @@ func (s Select) Build(builder Builder) {
} }
} }
func (s Select) MergeExpression(expr Expression) { func (s Select) MergeClause(clause *Clause) {
if v, ok := expr.(SelectInterface); ok { if v, ok := clause.Expression.(Select); ok {
if len(s.SelectColumns) == 0 { s.Columns = append(v.Columns, s.Columns...)
s.SelectColumns = v.Selects() s.Omits = append(v.Omits, s.Omits...)
}
if len(s.OmitColumns) == 0 {
s.OmitColumns = v.Omits()
}
} }
clause.Expression = s
} }

41
clause/select_test.go Normal file
View File

@ -0,0 +1,41 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestSelect(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Columns: []clause.Column{clause.PrimaryColumn},
}, clause.From{}},
"SELECT `users`.`id` FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Columns: []clause.Column{clause.PrimaryColumn},
}, clause.Select{
Columns: []clause.Column{{Name: "name"}},
}, clause.From{}},
"SELECT `users`.`id`,`name` FROM `users`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

37
clause/set.go Normal file
View File

@ -0,0 +1,37 @@
package clause
type Set []Assignment
type Assignment struct {
Column Column
Value interface{}
}
func (set Set) Name() string {
return "SET"
}
func (set Set) Build(builder Builder) {
if len(set) > 0 {
for idx, assignment := range set {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(assignment.Column)
builder.WriteByte('=')
builder.Write(builder.AddVar(assignment.Value))
}
} else {
builder.WriteQuoted(PrimaryColumn)
builder.WriteByte('=')
builder.WriteQuoted(PrimaryColumn)
}
}
// MergeClause merge assignments clauses
func (set Set) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Set); ok {
set = append(v, set...)
}
clause.Expression = set
}

38
clause/set_test.go Normal file
View File

@ -0,0 +1,38 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestSet(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{
clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
},
"UPDATE `users` SET `users`.`id`=?", []interface{}{1},
},
{
[]clause.Interface{
clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}),
},
"UPDATE `users` SET `users`.`id`=?,`name`=?", []interface{}{1, "jinzhu"},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

38
clause/update.go Normal file
View File

@ -0,0 +1,38 @@
package clause
type Update struct {
Modifier string
Table Table
}
// Name update clause name
func (update Update) Name() string {
return "UPDATE"
}
// Build build update clause
func (update Update) Build(builder Builder) {
if update.Modifier != "" {
builder.Write(update.Modifier)
builder.WriteByte(' ')
}
if update.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(update.Table)
}
}
// MergeClause merge update clause
func (update Update) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Update); ok {
if update.Modifier == "" {
update.Modifier = v.Modifier
}
if update.Table.Name == "" {
update.Table = v.Table
}
}
clause.Expression = update
}

35
clause/update_test.go Normal file
View File

@ -0,0 +1,35 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestUpdate(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Update{}},
"UPDATE `users`", nil,
},
{
[]clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}},
"UPDATE LOW_PRIORITY `users`", nil,
},
{
[]clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}},
"UPDATE LOW_PRIORITY `products`", nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -25,11 +25,11 @@ func (values Values) Build(builder Builder) {
builder.Write(" VALUES ") builder.Write(" VALUES ")
for idx, value := range values.Values { for idx, value := range values.Values {
builder.WriteByte('(')
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')
} }
builder.WriteByte('(')
builder.Write(builder.AddVar(value...)) builder.Write(builder.AddVar(value...))
builder.WriteByte(')') builder.WriteByte(')')
} }
@ -37,3 +37,11 @@ func (values Values) Build(builder Builder) {
builder.Write("DEFAULT VALUES") builder.Write("DEFAULT VALUES")
} }
} }
// MergeClause merge values clauses
func (values Values) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Values); ok {
values.Values = append(v.Values, values.Values...)
}
clause.Expression = values
}

33
clause/values_test.go Normal file
View File

@ -0,0 +1,33 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestValues(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{
clause.Insert{},
clause.Values{
Columns: []clause.Column{{Name: "name"}, {Name: "age"}},
Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}},
},
},
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -2,9 +2,7 @@ package clause
// Where where clause // Where where clause
type Where struct { type Where struct {
AndConditions AddConditions Exprs []Expression
OrConditions []OrConditions
builders []Expression
} }
// Name where clause name // Name where clause name
@ -14,64 +12,122 @@ func (where Where) Name() string {
// Build build where clause // Build build where clause
func (where Where) Build(builder Builder) { func (where Where) Build(builder Builder) {
var withConditions bool // Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if len(where.AndConditions) > 0 { if v, ok := expr.(OrConditions); (!ok && expr != nil) || len(v.Exprs) > 1 {
withConditions = true if idx != 0 {
where.AndConditions.Build(builder) where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
} }
break
if len(where.builders) > 0 {
for _, b := range where.builders {
if withConditions {
builder.Write(" AND ")
}
withConditions = true
b.Build(builder)
} }
} }
var singleOrConditions []OrConditions for idx, expr := range where.Exprs {
for _, or := range where.OrConditions { if expr != nil {
if len(or) == 1 { if idx > 0 {
if withConditions { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.Write(" OR ") builder.Write(" OR ")
or.Build(builder)
} else { } else {
singleOrConditions = append(singleOrConditions, or)
}
} else {
withConditions = true
builder.Write(" AND (")
or.Build(builder)
builder.WriteByte(')')
}
}
for _, or := range singleOrConditions {
if withConditions {
builder.Write(" AND ") builder.Write(" AND ")
or.Build(builder)
} else {
withConditions = true
or.Build(builder)
} }
} }
if !withConditions { expr.Build(builder)
builder.Write(" FALSE") }
} }
return return
} }
// MergeExpression merge where clauses // MergeClause merge where clauses
func (where Where) MergeExpression(expr Expression) { func (where Where) MergeClause(clause *Clause) {
if w, ok := expr.(Where); ok { if w, ok := clause.Expression.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...) where.Exprs = append(w.Exprs, where.Exprs...)
where.OrConditions = append(where.OrConditions, w.OrConditions...) }
where.builders = append(where.builders, w.builders...)
clause.Expression = where
}
func And(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
return AndConditions{Exprs: exprs}
}
type AndConditions struct {
Exprs []Expression
}
func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 {
builder.Write("(")
}
for idx, c := range and.Exprs {
if idx > 0 {
builder.Write(" AND ")
}
c.Build(builder)
}
if len(and.Exprs) > 1 {
builder.Write(")")
}
}
func Or(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
return OrConditions{Exprs: exprs}
}
type OrConditions struct {
Exprs []Expression
}
func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 {
builder.Write("(")
}
for idx, c := range or.Exprs {
if idx > 0 {
builder.Write(" OR ")
}
c.Build(builder)
}
if len(or.Exprs) > 1 {
builder.Write(")")
}
}
func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
return NotConditions{Exprs: exprs}
}
type NotConditions struct {
Exprs []Expression
}
func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 {
builder.Write("(")
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.Write(" AND ")
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else { } else {
where.builders = append(where.builders, expr) builder.Write(" NOT ")
c.Build(builder)
}
}
if len(not.Exprs) > 1 {
builder.Write(")")
} }
} }

63
clause/where_test.go Normal file
View File

@ -0,0 +1,63 @@
package clause_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm/clause"
)
func TestWhere(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}

View File

@ -22,7 +22,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
// First find first record that match given conditions, order by primary key // First find first record that match given conditions, order by primary key
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance().Limit(1).Order(clause.OrderBy{ tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true, Desc: true,
}) })

View File

@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"log"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -26,7 +25,7 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
if len(clauses) > 0 { if len(clauses) > 0 {
instance.Statement.Build(clauses...) instance.Statement.Build(clauses...)
} }
return instance.Statement.SQL.String(), instance.Statement.Vars return strings.TrimSpace(instance.Statement.SQL.String()), instance.Statement.Vars
} }
// AddError add error to instance // AddError add error to instance
@ -85,10 +84,10 @@ func (stmt Statement) Quote(field interface{}) string {
switch v := field.(type) { switch v := field.(type) {
case clause.Table: case clause.Table:
if v.Table == clause.CurrentTable { if v.Name == clause.CurrentTable {
str.WriteString(stmt.Table) str.WriteString(stmt.Table)
} else { } else {
str.WriteString(v.Table) str.WriteString(v.Name)
} }
if v.Alias != "" { if v.Alias != "" {
@ -126,7 +125,7 @@ func (stmt Statement) Quote(field interface{}) string {
str.WriteByte(stmt.DB.quoteChars[1]) str.WriteByte(stmt.DB.quoteChars[1])
} }
default: default:
fmt.Sprint(field) str.WriteString(fmt.Sprint(field))
} }
str.WriteByte(stmt.DB.quoteChars[1]) str.WriteByte(stmt.DB.quoteChars[1])
@ -141,19 +140,28 @@ func (stmt *Statement) AddVar(vars ...interface{}) string {
placeholders.WriteByte(',') placeholders.WriteByte(',')
} }
if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { switch v := v.(type) {
stmt.NamedVars = append(stmt.NamedVars, namedArg) case sql.NamedArg:
if len(v.Name) > 0 {
stmt.NamedVars = append(stmt.NamedVars, v)
placeholders.WriteByte('@') placeholders.WriteByte('@')
placeholders.WriteString(namedArg.Name) placeholders.WriteString(v.Name)
} else if arrs, ok := v.([]interface{}); ok { } else {
stmt.Vars = append(stmt.Vars, v.Value)
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
}
case clause.Column:
placeholders.WriteString(stmt.Quote(v))
case []interface{}:
placeholders.WriteByte('(') placeholders.WriteByte('(')
if len(arrs) > 0 { if len(v) > 0 {
placeholders.WriteString(stmt.AddVar(arrs...)) placeholders.WriteString(stmt.AddVar(v...))
} else { } else {
placeholders.WriteString("NULL") placeholders.WriteString("NULL")
} }
placeholders.WriteByte(')') placeholders.WriteByte(')')
} else { default:
stmt.Vars = append(stmt.Vars, v)
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v)) placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v))
} }
} }
@ -166,42 +174,18 @@ func (stmt *Statement) AddClause(v clause.Interface) {
optimizer.OptimizeStatement(stmt) optimizer.OptimizeStatement(stmt)
} }
c, _ := stmt.Clauses[v.Name()] c, ok := stmt.Clauses[v.Name()]
if namer, ok := v.(clause.OverrideNameInterface); ok { if !ok {
c.Name = namer.OverrideName()
} else {
c.Name = v.Name() c.Name = v.Name()
} }
v.MergeClause(&c)
if c.Expression != nil {
v.MergeExpression(c.Expression)
}
c.Expression = v
stmt.Clauses[v.Name()] = c stmt.Clauses[v.Name()] = c
} }
// AddClauseIfNotExists add clause if not exists // AddClauseIfNotExists add clause if not exists
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
if optimizer, ok := v.(StatementOptimizer); ok { if _, ok := stmt.Clauses[v.Name()]; !ok {
optimizer.OptimizeStatement(stmt) stmt.AddClause(v)
}
log.Println(v.Name())
if c, ok := stmt.Clauses[v.Name()]; !ok {
if namer, ok := v.(clause.OverrideNameInterface); ok {
c.Name = namer.OverrideName()
} else {
c.Name = v.Name()
}
if c.Expression != nil {
v.MergeExpression(c.Expression)
}
c.Expression = v
stmt.Clauses[v.Name()] = c
log.Println(stmt.Clauses[v.Name()])
} }
} }
@ -211,7 +195,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
if i, err := strconv.Atoi(sql); err != nil { if i, err := strconv.Atoi(sql); err != nil {
query = i query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Expression{clause.String{SQL: sql, Values: args}} return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
} }
} }
@ -255,7 +239,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
} }
if len(conditions) == 0 { if len(conditions) == 0 {
conditions = append(conditions, clause.ID{Value: args}) conditions = append(conditions, clause.IN{Column: clause.PrimaryColumn, Values: args})
} }
return conditions return conditions