From 211ccf4ea61de36d67ff443aa94714f2a0cf0688 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 16 Jan 2016 11:37:16 +0800 Subject: [PATCH] Fix using pointer value as foreign key --- association_test.go | 4 +++- field.go | 26 ++++++++++++++++++-------- preload.go | 2 +- query_test.go | 19 +++++++++---------- structs_test.go | 2 +- utils_private.go | 5 +++-- 6 files changed, 35 insertions(+), 23 deletions(-) diff --git a/association_test.go b/association_test.go index ab3abd91..f02d4620 100644 --- a/association_test.go +++ b/association_test.go @@ -600,7 +600,9 @@ func TestRelated(t *testing.T) { Company: Company{Name: "company1"}, } - DB.Save(&user) + if err := DB.Save(&user).Error; err != nil { + t.Errorf("No error should happen when saving user") + } if user.CreditCard.ID == 0 { t.Errorf("After user save, credit card should have id") diff --git a/field.go b/field.go index 79e2c0ec..2ed4e732 100644 --- a/field.go +++ b/field.go @@ -13,7 +13,7 @@ type Field struct { Field reflect.Value } -func (field *Field) Set(value interface{}) error { +func (field *Field) Set(value interface{}) (err error) { if !field.Field.IsValid() { return errors.New("field value not valid") } @@ -27,15 +27,25 @@ func (field *Field) Set(value interface{}) error { reflectValue = reflect.ValueOf(value) } + fieldValue := field.Field if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(field.Field.Type()) { - field.Field.Set(reflectValue.Convert(field.Field.Type())) - } else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { - if err := scanner.Scan(reflectValue.Interface()); err != nil { - return err - } + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else { - return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type()) + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.Struct.Type.Elem())) + } + fieldValue = fieldValue.Elem() + } + + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) + } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + err = scanner.Scan(reflectValue.Interface()) + } else { + err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) + } } } else { field.Field.Set(reflect.Zero(field.Field.Type())) diff --git a/preload.go b/preload.go index d12995f3..f3f2df12 100644 --- a/preload.go +++ b/preload.go @@ -27,7 +27,7 @@ func getRealValue(value reflect.Value, columns []string) (results []interface{}) } func equalAsString(a interface{}, b interface{}) bool { - return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) + return toString(a) == toString(b) } func Preload(scope *Scope) { diff --git a/query_test.go b/query_test.go index a60c1f2c..a7d5bc0e 100644 --- a/query_test.go +++ b/query_test.go @@ -207,11 +207,11 @@ func TestSearchWithStruct(t *testing.T) { } func TestSearchWithMap(t *testing.T) { - productID := 1 + companyID := 1 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), ProductID: &productID} + user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID} DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) var user User @@ -237,17 +237,17 @@ func TestSearchWithMap(t *testing.T) { t.Errorf("Search all records with inline map") } - DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": nil}) + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) if len(users) != 0 { t.Errorf("Search all records with inline map containing null value finding 0 records") } - DB.Find(&users, map[string]interface{}{"name": user1.Name, "product_id": nil}) + DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) if len(users) != 1 { t.Errorf("Search all records with inline map containing null value finding 1 record") } - DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": productID}) + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) if len(users) != 1 { t.Errorf("Search all records with inline multiple value map") } @@ -396,8 +396,7 @@ func TestNot(t *testing.T) { DB.Create(getPreparedUser("user3", "not")) user4 := getPreparedUser("user4", "not") - productID := 1 - user4.ProductID = &productID + user4.Company = Company{} DB.Create(user4) DB := DB.Where("role = ?", "not") @@ -445,9 +444,9 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - DB.Not(map[string]interface{}{"name": "user3", "product_id": nil}).Find(&users7) - if len(users7) != 1 { - t.Errorf("Should find all user's name not equal to 3 who do not have product ids") + DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) + if len(users1)-len(users7) != 2 { // not user3 or user4 + t.Errorf("Should find all user's name not equal to 3 who do not have company id") } DB.Not("name", []string{"user3"}).Find(&users8) diff --git a/structs_test.go b/structs_test.go index 20666740..cb9c9260 100644 --- a/structs_test.go +++ b/structs_test.go @@ -28,7 +28,7 @@ type User struct { CreditCard CreditCard Latitude float64 Languages []Language `gorm:"many2many:user_languages;"` - CompanyID int64 + CompanyID *int Company Company Role PasswordHash []byte diff --git a/utils_private.go b/utils_private.go index 50549857..f297857b 100644 --- a/utils_private.go +++ b/utils_private.go @@ -82,9 +82,10 @@ func toString(str interface{}) string { return strings.Join(results, "_") } else if bytes, ok := str.([]byte); ok { return string(bytes) - } else { - return fmt.Sprintf("%v", str) + } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { + return fmt.Sprintf("%v", reflectValue.Interface()) } + return "" } func strInSlice(a string, list []string) bool {