forked from mirror/gorm
Setup clauses tests
This commit is contained in:
parent
46b1c85f88
commit
9d19be0826
|
@ -9,9 +9,7 @@ import (
|
|||
|
||||
func Query(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Select{})
|
||||
db.Statement.AddClauseIfNotExists(clause.From{
|
||||
Tables: []clause.Table{{Table: clause.CurrentTable}},
|
||||
})
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
|
||||
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
|
||||
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestClause(t *testing.T) {
|
||||
var (
|
||||
db, _ = gorm.Open(nil, nil)
|
||||
results = []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}},
|
||||
"SELECT * FROM users", []interface{}{},
|
||||
}}
|
||||
)
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
var (
|
||||
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt = gorm.Statement{
|
||||
DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
buildNames []string
|
||||
)
|
||||
|
||||
for _, c := range result.Clauses {
|
||||
buildNames = append(buildNames, c.Name())
|
||||
stmt.AddClause(c)
|
||||
}
|
||||
|
||||
stmt.Build(buildNames...)
|
||||
|
||||
if stmt.SQL.String() != result.Result {
|
||||
t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String())
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(stmt.Vars, result.Vars) {
|
||||
t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -10,14 +10,20 @@ func (From) Name() string {
|
|||
return "FROM"
|
||||
}
|
||||
|
||||
var currentTable = Table{Table: CurrentTable}
|
||||
|
||||
// Build build from clause
|
||||
func (from From) Build(builder Builder) {
|
||||
for idx, table := range from.Tables {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
if len(from.Tables) > 0 {
|
||||
for idx, table := range from.Tables {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(table)
|
||||
builder.WriteQuoted(table)
|
||||
}
|
||||
} else {
|
||||
builder.WriteQuoted(currentTable)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -84,6 +84,11 @@ func (stmt Statement) Quote(field interface{}) string {
|
|||
|
||||
switch v := field.(type) {
|
||||
case clause.Table:
|
||||
if v.Table == clause.CurrentTable {
|
||||
str.WriteString(stmt.Table)
|
||||
} else {
|
||||
str.WriteString(v.Table)
|
||||
}
|
||||
|
||||
if v.Alias != "" {
|
||||
str.WriteString(" AS ")
|
||||
|
|
Loading…
Reference in New Issue