mirror of https://github.com/go-gorm/gorm.git
Fixes querying with inline map when a value is nil
This changes the inline map query build condition to use `IS NULL` instead of the equality operator when the provided value is `nil`.
This commit is contained in:
parent
341d047aa7
commit
72104c6bf0
|
@ -66,8 +66,8 @@ func TestUIntPrimaryKey(t *testing.T) {
|
||||||
|
|
||||||
func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
|
func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
|
||||||
type AddressByZipCode struct {
|
type AddressByZipCode struct {
|
||||||
ZipCode string `gorm:"primary_key"`
|
ZipCode string `gorm:"primary_key"`
|
||||||
Address string
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.AutoMigrate(&AddressByZipCode{})
|
DB.AutoMigrate(&AddressByZipCode{})
|
||||||
|
@ -207,10 +207,12 @@ func TestSearchWithStruct(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSearchWithMap(t *testing.T) {
|
func TestSearchWithMap(t *testing.T) {
|
||||||
|
productID := 1
|
||||||
user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
|
user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
|
||||||
user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
|
user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
|
||||||
user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
|
user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
|
||||||
DB.Save(&user1).Save(&user2).Save(&user3)
|
user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), ProductID: &productID}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4)
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
DB.First(&user, map[string]interface{}{"name": user1.Name})
|
DB.First(&user, map[string]interface{}{"name": user1.Name})
|
||||||
|
@ -234,6 +236,21 @@ func TestSearchWithMap(t *testing.T) {
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
t.Errorf("Search all records with inline map")
|
t.Errorf("Search all records with inline map")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": nil})
|
||||||
|
if len(users) != 0 {
|
||||||
|
t.Errorf("Search all records with inline map containing null value finding 0 records")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, map[string]interface{}{"name": user1.Name, "product_id": nil})
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Search all records with inline map containing null value finding 1 record")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": productID})
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Search all records with inline multiple value map")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSearchWithEmptyChain(t *testing.T) {
|
func TestSearchWithEmptyChain(t *testing.T) {
|
||||||
|
@ -377,10 +394,15 @@ func TestNot(t *testing.T) {
|
||||||
DB.Create(getPreparedUser("user1", "not"))
|
DB.Create(getPreparedUser("user1", "not"))
|
||||||
DB.Create(getPreparedUser("user2", "not"))
|
DB.Create(getPreparedUser("user2", "not"))
|
||||||
DB.Create(getPreparedUser("user3", "not"))
|
DB.Create(getPreparedUser("user3", "not"))
|
||||||
DB.Create(getPreparedUser("user4", "not"))
|
|
||||||
|
user4 := getPreparedUser("user4", "not")
|
||||||
|
productID := 1
|
||||||
|
user4.ProductID = &productID
|
||||||
|
DB.Create(user4)
|
||||||
|
|
||||||
DB := DB.Where("role = ?", "not")
|
DB := DB.Where("role = ?", "not")
|
||||||
|
|
||||||
var users1, users2, users3, users4, users5, users6, users7, users8 []User
|
var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User
|
||||||
if DB.Find(&users1).RowsAffected != 4 {
|
if DB.Find(&users1).RowsAffected != 4 {
|
||||||
t.Errorf("should find 4 not users")
|
t.Errorf("should find 4 not users")
|
||||||
}
|
}
|
||||||
|
@ -423,15 +445,20 @@ func TestNot(t *testing.T) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not("name", []string{"user3"}).Find(&users7)
|
DB.Not(map[string]interface{}{"name": "user3", "product_id": nil}).Find(&users7)
|
||||||
if len(users1)-len(users7) != int(name3Count) {
|
if len(users7) != 1 {
|
||||||
|
t.Errorf("Should find all user's name not equal to 3 who do not have product ids")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not("name", []string{"user3"}).Find(&users8)
|
||||||
|
if len(users1)-len(users8) != int(name3Count) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
var name2Count int64
|
var name2Count int64
|
||||||
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
||||||
DB.Not("name", []string{"user3", "user2"}).Find(&users8)
|
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
|
||||||
if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) {
|
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range value {
|
for key, value := range value {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
|
if value != nil {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||||
|
} else {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
|
@ -97,7 +101,11 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range value {
|
for key, value := range value {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
|
if value != nil {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
|
||||||
|
} else {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
|
|
Loading…
Reference in New Issue