Add clause tests

This commit is contained in:
Jinzhu 2020-02-05 11:14:58 +08:00
parent 9d19be0826
commit 0160bab7dc
13 changed files with 92 additions and 21 deletions

View File

@ -80,7 +80,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{ tx.Statement.AddClause(clause.Where{
ORConditions: []clause.ORConditions{ OrConditions: []clause.OrConditions{
tx.Statement.BuildCondtion(query, args...), tx.Statement.BuildCondtion(query, args...),
}, },
}) })

View File

@ -12,17 +12,19 @@ import (
"github.com/jinzhu/gorm/tests" "github.com/jinzhu/gorm/tests"
) )
func TestClause(t *testing.T) { func TestClauses(t *testing.T) {
var ( var (
db, _ = gorm.Open(nil, nil) db, _ = gorm.Open(tests.DummyDialector{}, nil)
results = []struct { results = []struct {
Clauses []clause.Interface Clauses []clause.Interface
Result string Result string
Vars []interface{} Vars []interface{}
}{{ }{
[]clause.Interface{clause.Select{}, clause.From{}}, {
"SELECT * FROM users", []interface{}{}, []clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}},
}} "SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"},
},
}
) )
for idx, result := range results { for idx, result := range results {

View File

@ -5,6 +5,11 @@ const (
CurrentTable string = "@@@table@@@" CurrentTable string = "@@@table@@@"
) )
var PrimaryColumn = Column{
Table: CurrentTable,
Name: PrimaryKey,
}
// Expression expression interface // Expression expression interface
type Expression interface { type Expression interface {
Build(builder Builder) Build(builder Builder)

View File

@ -6,6 +6,14 @@ import "strings"
// Query Expressions // Query Expressions
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
func Add(exprs ...Expression) AddConditions {
return AddConditions(exprs)
}
func Or(exprs ...Expression) OrConditions {
return OrConditions(exprs)
}
type AddConditions []Expression type AddConditions []Expression
func (cs AddConditions) Build(builder Builder) { func (cs AddConditions) Build(builder Builder) {
@ -17,9 +25,9 @@ func (cs AddConditions) Build(builder Builder) {
} }
} }
type ORConditions []Expression type OrConditions []Expression
func (cs ORConditions) Build(builder Builder) { func (cs OrConditions) Build(builder Builder) {
for idx, c := range cs { for idx, c := range cs {
if idx > 0 { if idx > 0 {
builder.Write(" OR ") builder.Write(" OR ")

View File

@ -3,7 +3,7 @@ package clause
// Where where clause // Where where clause
type Where struct { type Where struct {
AndConditions AddConditions AndConditions AddConditions
ORConditions []ORConditions OrConditions []OrConditions
builders []Expression builders []Expression
} }
@ -31,8 +31,8 @@ func (where Where) Build(builder Builder) {
} }
} }
var singleOrConditions []ORConditions var singleOrConditions []OrConditions
for _, or := range where.ORConditions { for _, or := range where.OrConditions {
if len(or) == 1 { if len(or) == 1 {
if withConditions { if withConditions {
builder.Write(" OR ") builder.Write(" OR ")
@ -69,7 +69,7 @@ func (where Where) Build(builder Builder) {
func (where Where) MergeExpression(expr Expression) { func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok { if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...) where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...) where.OrConditions = append(where.OrConditions, w.OrConditions...)
where.builders = append(where.builders, w.builders...) where.builders = append(where.builders, w.builders...)
} else { } else {
where.builders = append(where.builders, expr) where.builders = append(where.builders, expr)

View File

@ -27,3 +27,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
return "?" return "?"
} }
func (Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}

View File

@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?" return "?"
} }
func (Dialector) QuoteChars() [2]byte {
return [2]byte{'"', '"'} // "name"
}

View File

@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?" return "?"
} }
func (Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}

5
go.mod
View File

@ -3,7 +3,8 @@ module github.com/jinzhu/gorm
go 1.13 go 1.13
require ( require (
github.com/go-sql-driver/mysql v1.5.0 // indirect
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/lib/pq v1.3.0 github.com/lib/pq v1.3.0 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
) )

17
gorm.go
View File

@ -23,16 +23,21 @@ type Config struct {
NowFunc func() time.Time NowFunc func() time.Time
} }
type shared struct {
callbacks *callbacks
cacheStore *sync.Map
quoteChars [2]byte
}
// DB GORM DB definition // DB GORM DB definition
type DB struct { type DB struct {
*Config *Config
Dialector Dialector
Instance Instance
DB CommonDB
ClauseBuilders map[string]clause.ClauseBuilder ClauseBuilders map[string]clause.ClauseBuilder
DB CommonDB
clone bool clone bool
callbacks *callbacks *shared
cacheStore *sync.Map
} }
// Session session config when create session with Session() method // Session session config when create session with Session() method
@ -65,13 +70,16 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
Dialector: dialector, Dialector: dialector,
ClauseBuilders: map[string]clause.ClauseBuilder{}, ClauseBuilders: map[string]clause.ClauseBuilder{},
clone: true, clone: true,
shared: &shared{
cacheStore: &sync.Map{}, cacheStore: &sync.Map{},
},
} }
db.callbacks = initializeCallbacks(db) db.callbacks = initializeCallbacks(db)
if dialector != nil { if dialector != nil {
err = dialector.Initialize(db) err = dialector.Initialize(db)
db.quoteChars = dialector.QuoteChars()
} }
return return
} }
@ -146,8 +154,7 @@ func (db *DB) getInstance() *DB {
Dialector: db.Dialector, Dialector: db.Dialector,
ClauseBuilders: db.ClauseBuilders, ClauseBuilders: db.ClauseBuilders,
DB: db.DB, DB: db.DB,
callbacks: db.callbacks, shared: db.shared,
cacheStore: db.cacheStore,
} }
} }

View File

@ -10,6 +10,7 @@ type Dialector interface {
Initialize(*DB) error Initialize(*DB) error
Migrator() Migrator Migrator() Migrator
BindVar(stmt *Statement, v interface{}) string BindVar(stmt *Statement, v interface{}) string
QuoteChars() [2]byte
} }
// CommonDB common db interface // CommonDB common db interface

View File

@ -81,6 +81,7 @@ func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
// Quote returns quoted value // Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string { func (stmt Statement) Quote(field interface{}) string {
var str strings.Builder var str strings.Builder
str.WriteByte(stmt.DB.quoteChars[0])
switch v := field.(type) { switch v := field.(type) {
case clause.Table: case clause.Table:
@ -91,8 +92,11 @@ func (stmt Statement) Quote(field interface{}) string {
} }
if v.Alias != "" { if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ") str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias) str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
} }
case clause.Column: case clause.Column:
if v.Table != "" { if v.Table != "" {
@ -101,7 +105,9 @@ func (stmt Statement) Quote(field interface{}) string {
} else { } else {
str.WriteString(v.Table) str.WriteString(v.Table)
} }
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteByte('.') str.WriteByte('.')
str.WriteByte(stmt.DB.quoteChars[0])
} }
if v.Name == clause.PrimaryKey { if v.Name == clause.PrimaryKey {
@ -111,14 +117,19 @@ func (stmt Statement) Quote(field interface{}) string {
} else { } else {
str.WriteString(v.Name) str.WriteString(v.Name)
} }
if v.Alias != "" { if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ") str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias) str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
} }
default: default:
fmt.Sprint(field) fmt.Sprint(field)
} }
str.WriteByte(stmt.DB.quoteChars[1])
return str.String() return str.String()
} }

24
tests/dummy_dialecter.go Normal file
View File

@ -0,0 +1,24 @@
package tests
import (
"github.com/jinzhu/gorm"
)
type DummyDialector struct {
}
func (DummyDialector) Initialize(*gorm.DB) error {
return nil
}
func (DummyDialector) Migrator() gorm.Migrator {
return nil
}
func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
func (DummyDialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}