From 72104c6bf0e45cd56535797cc8def5d521303fd9 Mon Sep 17 00:00:00 2001 From: James Kong Date: Fri, 15 Jan 2016 16:38:56 -0800 Subject: [PATCH] Fixes querying with inline map when a value is nil This changes the inline map query build condition to use `IS NULL` instead of the equality operator when the provided value is `nil`. --- query_test.go | 45 ++++++++++++++++++++++++++++++++++++--------- scope_private.go | 12 ++++++++++-- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/query_test.go b/query_test.go index 274e8e9b..a60c1f2c 100644 --- a/query_test.go +++ b/query_test.go @@ -66,8 +66,8 @@ func TestUIntPrimaryKey(t *testing.T) { func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string + ZipCode string `gorm:"primary_key"` + Address string } DB.AutoMigrate(&AddressByZipCode{}) @@ -207,10 +207,12 @@ func TestSearchWithStruct(t *testing.T) { } func TestSearchWithMap(t *testing.T) { + productID := 1 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) + user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), ProductID: &productID} + DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) var user User DB.First(&user, map[string]interface{}{"name": user1.Name}) @@ -234,6 +236,21 @@ func TestSearchWithMap(t *testing.T) { if len(users) != 1 { t.Errorf("Search all records with inline map") } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": nil}) + if len(users) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&users, map[string]interface{}{"name": user1.Name, "product_id": nil}) + if len(users) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": productID}) + if len(users) != 1 { + t.Errorf("Search all records with inline multiple value map") + } } func TestSearchWithEmptyChain(t *testing.T) { @@ -377,10 +394,15 @@ func TestNot(t *testing.T) { DB.Create(getPreparedUser("user1", "not")) DB.Create(getPreparedUser("user2", "not")) DB.Create(getPreparedUser("user3", "not")) - DB.Create(getPreparedUser("user4", "not")) + + user4 := getPreparedUser("user4", "not") + productID := 1 + user4.ProductID = &productID + DB.Create(user4) + DB := DB.Where("role = ?", "not") - var users1, users2, users3, users4, users5, users6, users7, users8 []User + var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User if DB.Find(&users1).RowsAffected != 4 { t.Errorf("should find 4 not users") } @@ -423,15 +445,20 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - DB.Not("name", []string{"user3"}).Find(&users7) - if len(users1)-len(users7) != int(name3Count) { + DB.Not(map[string]interface{}{"name": "user3", "product_id": nil}).Find(&users7) + if len(users7) != 1 { + t.Errorf("Should find all user's name not equal to 3 who do not have product ids") + } + + DB.Not("name", []string{"user3"}).Find(&users8) + if len(users1)-len(users8) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users8) - if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) { + DB.Not("name", []string{"user3", "user2"}).Find(&users9) + if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { t.Errorf("Should find all users's name not equal 3") } } diff --git a/scope_private.go b/scope_private.go index a154c426..fa5d5f44 100644 --- a/scope_private.go +++ b/scope_private.go @@ -31,7 +31,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) + } } return strings.Join(sqls, " AND ") case interface{}: @@ -97,7 +101,11 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) + } } return strings.Join(sqls, " AND ") case interface{}: