Add clause tests

This commit is contained in:
Jinzhu 2020-02-05 11:14:58 +08:00
parent 9d19be0826
commit 0160bab7dc
13 changed files with 92 additions and 21 deletions

View File

@ -80,7 +80,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
ORConditions: []clause.ORConditions{
OrConditions: []clause.OrConditions{
tx.Statement.BuildCondtion(query, args...),
},
})

View File

@ -12,17 +12,19 @@ import (
"github.com/jinzhu/gorm/tests"
)
func TestClause(t *testing.T) {
func TestClauses(t *testing.T) {
var (
db, _ = gorm.Open(nil, nil)
db, _ = gorm.Open(tests.DummyDialector{}, nil)
results = []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM users", []interface{}{},
}}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}},
"SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"},
},
}
)
for idx, result := range results {

View File

@ -5,6 +5,11 @@ const (
CurrentTable string = "@@@table@@@"
)
var PrimaryColumn = Column{
Table: CurrentTable,
Name: PrimaryKey,
}
// Expression expression interface
type Expression interface {
Build(builder Builder)

View File

@ -6,6 +6,14 @@ import "strings"
// Query Expressions
////////////////////////////////////////////////////////////////////////////////
func Add(exprs ...Expression) AddConditions {
return AddConditions(exprs)
}
func Or(exprs ...Expression) OrConditions {
return OrConditions(exprs)
}
type AddConditions []Expression
func (cs AddConditions) Build(builder Builder) {
@ -17,9 +25,9 @@ func (cs AddConditions) Build(builder Builder) {
}
}
type ORConditions []Expression
type OrConditions []Expression
func (cs ORConditions) Build(builder Builder) {
func (cs OrConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" OR ")

View File

@ -3,7 +3,7 @@ package clause
// Where where clause
type Where struct {
AndConditions AddConditions
ORConditions []ORConditions
OrConditions []OrConditions
builders []Expression
}
@ -31,8 +31,8 @@ func (where Where) Build(builder Builder) {
}
}
var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
var singleOrConditions []OrConditions
for _, or := range where.OrConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
@ -69,7 +69,7 @@ func (where Where) Build(builder Builder) {
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.OrConditions = append(where.OrConditions, w.OrConditions...)
where.builders = append(where.builders, w.builders...)
} else {
where.builders = append(where.builders, expr)

View File

@ -27,3 +27,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
return "?"
}
func (Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}

View File

@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
func (Dialector) QuoteChars() [2]byte {
return [2]byte{'"', '"'} // "name"
}

View File

@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
func (Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}

5
go.mod
View File

@ -3,7 +3,8 @@ module github.com/jinzhu/gorm
go 1.13
require (
github.com/go-sql-driver/mysql v1.5.0 // indirect
github.com/jinzhu/inflection v1.0.0
github.com/lib/pq v1.3.0
github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/lib/pq v1.3.0 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
)

19
gorm.go
View File

@ -23,16 +23,21 @@ type Config struct {
NowFunc func() time.Time
}
type shared struct {
callbacks *callbacks
cacheStore *sync.Map
quoteChars [2]byte
}
// DB GORM DB definition
type DB struct {
*Config
Dialector
Instance
DB CommonDB
ClauseBuilders map[string]clause.ClauseBuilder
DB CommonDB
clone bool
callbacks *callbacks
cacheStore *sync.Map
*shared
}
// Session session config when create session with Session() method
@ -65,13 +70,16 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
Dialector: dialector,
ClauseBuilders: map[string]clause.ClauseBuilder{},
clone: true,
cacheStore: &sync.Map{},
shared: &shared{
cacheStore: &sync.Map{},
},
}
db.callbacks = initializeCallbacks(db)
if dialector != nil {
err = dialector.Initialize(db)
db.quoteChars = dialector.QuoteChars()
}
return
}
@ -146,8 +154,7 @@ func (db *DB) getInstance() *DB {
Dialector: db.Dialector,
ClauseBuilders: db.ClauseBuilders,
DB: db.DB,
callbacks: db.callbacks,
cacheStore: db.cacheStore,
shared: db.shared,
}
}

View File

@ -10,6 +10,7 @@ type Dialector interface {
Initialize(*DB) error
Migrator() Migrator
BindVar(stmt *Statement, v interface{}) string
QuoteChars() [2]byte
}
// CommonDB common db interface

View File

@ -81,6 +81,7 @@ func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string {
var str strings.Builder
str.WriteByte(stmt.DB.quoteChars[0])
switch v := field.(type) {
case clause.Table:
@ -91,8 +92,11 @@ func (stmt Statement) Quote(field interface{}) string {
}
if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
}
case clause.Column:
if v.Table != "" {
@ -101,7 +105,9 @@ func (stmt Statement) Quote(field interface{}) string {
} else {
str.WriteString(v.Table)
}
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteByte('.')
str.WriteByte(stmt.DB.quoteChars[0])
}
if v.Name == clause.PrimaryKey {
@ -111,14 +117,19 @@ func (stmt Statement) Quote(field interface{}) string {
} else {
str.WriteString(v.Name)
}
if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
}
default:
fmt.Sprint(field)
}
str.WriteByte(stmt.DB.quoteChars[1])
return str.String()
}

24
tests/dummy_dialecter.go Normal file
View File

@ -0,0 +1,24 @@
package tests
import (
"github.com/jinzhu/gorm"
)
type DummyDialector struct {
}
func (DummyDialector) Initialize(*gorm.DB) error {
return nil
}
func (DummyDialector) Migrator() gorm.Migrator {
return nil
}
func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
func (DummyDialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}