From c422d75f4b474d36f60a9559273d08d080bc0c28 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 May 2020 18:50:20 +0800 Subject: [PATCH] Add Scopes tests --- callbacks/delete.go | 2 -- clause/expression.go | 30 +++++++++++++++++++++++++-- tests/scopes_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++ tests/utils.go | 2 +- 4 files changed, 77 insertions(+), 5 deletions(-) create mode 100644 tests/scopes_test.go diff --git a/callbacks/delete.go b/callbacks/delete.go index 1c59afbe..b3278c83 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -1,7 +1,6 @@ package callbacks import ( - "fmt" "reflect" "github.com/jinzhu/gorm" @@ -38,7 +37,6 @@ func Delete(db *gorm.DB) { if db.Statement.Schema != nil && !db.Statement.Unscoped { for _, c := range db.Statement.Schema.DeleteClauses { db.Statement.AddClause(c) - fmt.Println(db.Statement.SQL.String()) } } diff --git a/clause/expression.go b/clause/expression.go index 067774d4..e54da1af 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,5 +1,7 @@ package clause +import "reflect" + // Expression expression interface type Expression interface { Build(builder Builder) @@ -18,12 +20,36 @@ type Expr struct { // Build build raw expression func (expr Expr) Build(builder Builder) { - var idx int + var ( + afterParenthesis bool + idx int + ) + for _, v := range []byte(expr.SQL) { if v == '?' { - builder.AddVar(builder, expr.Vars[idx]) + if afterParenthesis { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + idx++ } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } builder.WriteByte(v) } } diff --git a/tests/scopes_test.go b/tests/scopes_test.go new file mode 100644 index 00000000..c0530da5 --- /dev/null +++ b/tests/scopes_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func NameIn1And2(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) +} + +func NameIn2And3(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) +} + +func NameIn(names []string) func(d *gorm.DB) *gorm.DB { + return func(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", names) + } +} + +func TestScopes(t *testing.T) { + var users = []*User{ + GetUser("ScopeUser1", Config{}), + GetUser("ScopeUser2", Config{}), + GetUser("ScopeUser3", Config{}), + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Scopes(NameIn1And2).Find(&users1) + if len(users1) != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", len(users1)) + } + + DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) + if len(users2) != 1 { + t.Errorf("Should found one user's name is 2, but got %v", len(users2)) + } + + DB.Scopes(NameIn([]string{users[0].Name, users[2].Name})).Find(&users3) + if len(users3) != 2 { + t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) + } +} diff --git a/tests/utils.go b/tests/utils.go index 92163d5c..041dc9b1 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -87,7 +87,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { format := "2006-01-02T15:04:05Z07:00" if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { - t.Errorf("%v: expect: %v, got %v", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) } } else if got != expect { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)