Setup clauses tests

This commit is contained in:
Jinzhu 2020-02-04 09:51:19 +08:00
parent 46b1c85f88
commit 9d19be0826
4 changed files with 71 additions and 8 deletions

View File

@ -9,9 +9,7 @@ import (
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{}) db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{ db.Statement.AddClauseIfNotExists(clause.From{})
Tables: []clause.Table{{Table: clause.CurrentTable}},
})
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)

54
clause/clause_test.go Normal file
View File

@ -0,0 +1,54 @@
package clause_test
import (
"fmt"
"reflect"
"sync"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/tests"
)
func TestClause(t *testing.T) {
var (
db, _ = gorm.Open(nil, nil)
results = []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM users", []interface{}{},
}}
)
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
var (
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{
DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{},
}
buildNames []string
)
for _, c := range result.Clauses {
buildNames = append(buildNames, c.Name())
stmt.AddClause(c)
}
stmt.Build(buildNames...)
if stmt.SQL.String() != result.Result {
t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String())
}
if reflect.DeepEqual(stmt.Vars, result.Vars) {
t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars)
}
})
}
}

View File

@ -10,8 +10,11 @@ func (From) Name() string {
return "FROM" return "FROM"
} }
var currentTable = Table{Table: CurrentTable}
// Build build from clause // Build build from clause
func (from From) Build(builder Builder) { func (from From) Build(builder Builder) {
if len(from.Tables) > 0 {
for idx, table := range from.Tables { for idx, table := range from.Tables {
if idx > 0 { if idx > 0 {
builder.WriteByte(',') builder.WriteByte(',')
@ -19,6 +22,9 @@ func (from From) Build(builder Builder) {
builder.WriteQuoted(table) builder.WriteQuoted(table)
} }
} else {
builder.WriteQuoted(currentTable)
}
} }
// MergeExpression merge order by clauses // MergeExpression merge order by clauses

View File

@ -84,6 +84,11 @@ func (stmt Statement) Quote(field interface{}) string {
switch v := field.(type) { switch v := field.(type) {
case clause.Table: case clause.Table:
if v.Table == clause.CurrentTable {
str.WriteString(stmt.Table)
} else {
str.WriteString(v.Table)
}
if v.Alias != "" { if v.Alias != "" {
str.WriteString(" AS ") str.WriteString(" AS ")