From 72104c6bf0e45cd56535797cc8def5d521303fd9 Mon Sep 17 00:00:00 2001 From: James Kong Date: Fri, 15 Jan 2016 16:38:56 -0800 Subject: [PATCH 1/2] Fixes querying with inline map when a value is nil This changes the inline map query build condition to use `IS NULL` instead of the equality operator when the provided value is `nil`. --- query_test.go | 45 ++++++++++++++++++++++++++++++++++++--------- scope_private.go | 12 ++++++++++-- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/query_test.go b/query_test.go index 274e8e9b..a60c1f2c 100644 --- a/query_test.go +++ b/query_test.go @@ -66,8 +66,8 @@ func TestUIntPrimaryKey(t *testing.T) { func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string + ZipCode string `gorm:"primary_key"` + Address string } DB.AutoMigrate(&AddressByZipCode{}) @@ -207,10 +207,12 @@ func TestSearchWithStruct(t *testing.T) { } func TestSearchWithMap(t *testing.T) { + productID := 1 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) + user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), ProductID: &productID} + DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) var user User DB.First(&user, map[string]interface{}{"name": user1.Name}) @@ -234,6 +236,21 @@ func TestSearchWithMap(t *testing.T) { if len(users) != 1 { t.Errorf("Search all records with inline map") } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": nil}) + if len(users) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&users, map[string]interface{}{"name": user1.Name, "product_id": nil}) + if len(users) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": productID}) + if len(users) != 1 { + t.Errorf("Search all records with inline multiple value map") + } } func TestSearchWithEmptyChain(t *testing.T) { @@ -377,10 +394,15 @@ func TestNot(t *testing.T) { DB.Create(getPreparedUser("user1", "not")) DB.Create(getPreparedUser("user2", "not")) DB.Create(getPreparedUser("user3", "not")) - DB.Create(getPreparedUser("user4", "not")) + + user4 := getPreparedUser("user4", "not") + productID := 1 + user4.ProductID = &productID + DB.Create(user4) + DB := DB.Where("role = ?", "not") - var users1, users2, users3, users4, users5, users6, users7, users8 []User + var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User if DB.Find(&users1).RowsAffected != 4 { t.Errorf("should find 4 not users") } @@ -423,15 +445,20 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - DB.Not("name", []string{"user3"}).Find(&users7) - if len(users1)-len(users7) != int(name3Count) { + DB.Not(map[string]interface{}{"name": "user3", "product_id": nil}).Find(&users7) + if len(users7) != 1 { + t.Errorf("Should find all user's name not equal to 3 who do not have product ids") + } + + DB.Not("name", []string{"user3"}).Find(&users8) + if len(users1)-len(users8) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users8) - if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) { + DB.Not("name", []string{"user3", "user2"}).Find(&users9) + if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { t.Errorf("Should find all users's name not equal 3") } } diff --git a/scope_private.go b/scope_private.go index a154c426..fa5d5f44 100644 --- a/scope_private.go +++ b/scope_private.go @@ -31,7 +31,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) + } } return strings.Join(sqls, " AND ") case interface{}: @@ -97,7 +101,11 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) + } } return strings.Join(sqls, " AND ") case interface{}: From 211ccf4ea61de36d67ff443aa94714f2a0cf0688 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 16 Jan 2016 11:37:16 +0800 Subject: [PATCH 2/2] Fix using pointer value as foreign key --- association_test.go | 4 +++- field.go | 26 ++++++++++++++++++-------- preload.go | 2 +- query_test.go | 19 +++++++++---------- structs_test.go | 2 +- utils_private.go | 5 +++-- 6 files changed, 35 insertions(+), 23 deletions(-) diff --git a/association_test.go b/association_test.go index ab3abd91..f02d4620 100644 --- a/association_test.go +++ b/association_test.go @@ -600,7 +600,9 @@ func TestRelated(t *testing.T) { Company: Company{Name: "company1"}, } - DB.Save(&user) + if err := DB.Save(&user).Error; err != nil { + t.Errorf("No error should happen when saving user") + } if user.CreditCard.ID == 0 { t.Errorf("After user save, credit card should have id") diff --git a/field.go b/field.go index 79e2c0ec..2ed4e732 100644 --- a/field.go +++ b/field.go @@ -13,7 +13,7 @@ type Field struct { Field reflect.Value } -func (field *Field) Set(value interface{}) error { +func (field *Field) Set(value interface{}) (err error) { if !field.Field.IsValid() { return errors.New("field value not valid") } @@ -27,15 +27,25 @@ func (field *Field) Set(value interface{}) error { reflectValue = reflect.ValueOf(value) } + fieldValue := field.Field if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(field.Field.Type()) { - field.Field.Set(reflectValue.Convert(field.Field.Type())) - } else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { - if err := scanner.Scan(reflectValue.Interface()); err != nil { - return err - } + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else { - return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type()) + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.Struct.Type.Elem())) + } + fieldValue = fieldValue.Elem() + } + + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) + } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + err = scanner.Scan(reflectValue.Interface()) + } else { + err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) + } } } else { field.Field.Set(reflect.Zero(field.Field.Type())) diff --git a/preload.go b/preload.go index d12995f3..f3f2df12 100644 --- a/preload.go +++ b/preload.go @@ -27,7 +27,7 @@ func getRealValue(value reflect.Value, columns []string) (results []interface{}) } func equalAsString(a interface{}, b interface{}) bool { - return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) + return toString(a) == toString(b) } func Preload(scope *Scope) { diff --git a/query_test.go b/query_test.go index a60c1f2c..a7d5bc0e 100644 --- a/query_test.go +++ b/query_test.go @@ -207,11 +207,11 @@ func TestSearchWithStruct(t *testing.T) { } func TestSearchWithMap(t *testing.T) { - productID := 1 + companyID := 1 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), ProductID: &productID} + user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID} DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) var user User @@ -237,17 +237,17 @@ func TestSearchWithMap(t *testing.T) { t.Errorf("Search all records with inline map") } - DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": nil}) + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) if len(users) != 0 { t.Errorf("Search all records with inline map containing null value finding 0 records") } - DB.Find(&users, map[string]interface{}{"name": user1.Name, "product_id": nil}) + DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) if len(users) != 1 { t.Errorf("Search all records with inline map containing null value finding 1 record") } - DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": productID}) + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) if len(users) != 1 { t.Errorf("Search all records with inline multiple value map") } @@ -396,8 +396,7 @@ func TestNot(t *testing.T) { DB.Create(getPreparedUser("user3", "not")) user4 := getPreparedUser("user4", "not") - productID := 1 - user4.ProductID = &productID + user4.Company = Company{} DB.Create(user4) DB := DB.Where("role = ?", "not") @@ -445,9 +444,9 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - DB.Not(map[string]interface{}{"name": "user3", "product_id": nil}).Find(&users7) - if len(users7) != 1 { - t.Errorf("Should find all user's name not equal to 3 who do not have product ids") + DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) + if len(users1)-len(users7) != 2 { // not user3 or user4 + t.Errorf("Should find all user's name not equal to 3 who do not have company id") } DB.Not("name", []string{"user3"}).Find(&users8) diff --git a/structs_test.go b/structs_test.go index 20666740..cb9c9260 100644 --- a/structs_test.go +++ b/structs_test.go @@ -28,7 +28,7 @@ type User struct { CreditCard CreditCard Latitude float64 Languages []Language `gorm:"many2many:user_languages;"` - CompanyID int64 + CompanyID *int Company Company Role PasswordHash []byte diff --git a/utils_private.go b/utils_private.go index 50549857..f297857b 100644 --- a/utils_private.go +++ b/utils_private.go @@ -82,9 +82,10 @@ func toString(str interface{}) string { return strings.Join(results, "_") } else if bytes, ok := str.([]byte); ok { return string(bytes) - } else { - return fmt.Sprintf("%v", str) + } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { + return fmt.Sprintf("%v", reflectValue.Interface()) } + return "" } func strInSlice(a string, list []string) bool {