forked from mirror/gorm
Fix using pointer value as foreign key
This commit is contained in:
parent
72104c6bf0
commit
211ccf4ea6
|
@ -600,7 +600,9 @@ func TestRelated(t *testing.T) {
|
||||||
Company: Company{Name: "company1"},
|
Company: Company{Name: "company1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Save(&user)
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when saving user")
|
||||||
|
}
|
||||||
|
|
||||||
if user.CreditCard.ID == 0 {
|
if user.CreditCard.ID == 0 {
|
||||||
t.Errorf("After user save, credit card should have id")
|
t.Errorf("After user save, credit card should have id")
|
||||||
|
|
26
field.go
26
field.go
|
@ -13,7 +13,7 @@ type Field struct {
|
||||||
Field reflect.Value
|
Field reflect.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (field *Field) Set(value interface{}) error {
|
func (field *Field) Set(value interface{}) (err error) {
|
||||||
if !field.Field.IsValid() {
|
if !field.Field.IsValid() {
|
||||||
return errors.New("field value not valid")
|
return errors.New("field value not valid")
|
||||||
}
|
}
|
||||||
|
@ -27,15 +27,25 @@ func (field *Field) Set(value interface{}) error {
|
||||||
reflectValue = reflect.ValueOf(value)
|
reflectValue = reflect.ValueOf(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fieldValue := field.Field
|
||||||
if reflectValue.IsValid() {
|
if reflectValue.IsValid() {
|
||||||
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
|
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||||
field.Field.Set(reflectValue.Convert(field.Field.Type()))
|
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||||
} else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
|
||||||
if err := scanner.Scan(reflectValue.Interface()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type())
|
if fieldValue.Kind() == reflect.Ptr {
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
|
||||||
|
}
|
||||||
|
fieldValue = fieldValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||||
|
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||||
|
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||||
|
err = scanner.Scan(reflectValue.Interface())
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
field.Field.Set(reflect.Zero(field.Field.Type()))
|
field.Field.Set(reflect.Zero(field.Field.Type()))
|
||||||
|
|
|
@ -27,7 +27,7 @@ func getRealValue(value reflect.Value, columns []string) (results []interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func equalAsString(a interface{}, b interface{}) bool {
|
func equalAsString(a interface{}, b interface{}) bool {
|
||||||
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
return toString(a) == toString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Preload(scope *Scope) {
|
func Preload(scope *Scope) {
|
||||||
|
|
|
@ -207,11 +207,11 @@ func TestSearchWithStruct(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSearchWithMap(t *testing.T) {
|
func TestSearchWithMap(t *testing.T) {
|
||||||
productID := 1
|
companyID := 1
|
||||||
user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
|
user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
|
||||||
user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
|
user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
|
||||||
user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
|
user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
|
||||||
user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), ProductID: &productID}
|
user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID}
|
||||||
DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4)
|
DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4)
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
|
@ -237,17 +237,17 @@ func TestSearchWithMap(t *testing.T) {
|
||||||
t.Errorf("Search all records with inline map")
|
t.Errorf("Search all records with inline map")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": nil})
|
DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil})
|
||||||
if len(users) != 0 {
|
if len(users) != 0 {
|
||||||
t.Errorf("Search all records with inline map containing null value finding 0 records")
|
t.Errorf("Search all records with inline map containing null value finding 0 records")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Find(&users, map[string]interface{}{"name": user1.Name, "product_id": nil})
|
DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil})
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
t.Errorf("Search all records with inline map containing null value finding 1 record")
|
t.Errorf("Search all records with inline map containing null value finding 1 record")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Find(&users, map[string]interface{}{"name": user4.Name, "product_id": productID})
|
DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID})
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
t.Errorf("Search all records with inline multiple value map")
|
t.Errorf("Search all records with inline multiple value map")
|
||||||
}
|
}
|
||||||
|
@ -396,8 +396,7 @@ func TestNot(t *testing.T) {
|
||||||
DB.Create(getPreparedUser("user3", "not"))
|
DB.Create(getPreparedUser("user3", "not"))
|
||||||
|
|
||||||
user4 := getPreparedUser("user4", "not")
|
user4 := getPreparedUser("user4", "not")
|
||||||
productID := 1
|
user4.Company = Company{}
|
||||||
user4.ProductID = &productID
|
|
||||||
DB.Create(user4)
|
DB.Create(user4)
|
||||||
|
|
||||||
DB := DB.Where("role = ?", "not")
|
DB := DB.Where("role = ?", "not")
|
||||||
|
@ -445,9 +444,9 @@ func TestNot(t *testing.T) {
|
||||||
t.Errorf("Should find all users's name not equal 3")
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not(map[string]interface{}{"name": "user3", "product_id": nil}).Find(&users7)
|
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
|
||||||
if len(users7) != 1 {
|
if len(users1)-len(users7) != 2 { // not user3 or user4
|
||||||
t.Errorf("Should find all user's name not equal to 3 who do not have product ids")
|
t.Errorf("Should find all user's name not equal to 3 who do not have company id")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Not("name", []string{"user3"}).Find(&users8)
|
DB.Not("name", []string{"user3"}).Find(&users8)
|
||||||
|
|
|
@ -28,7 +28,7 @@ type User struct {
|
||||||
CreditCard CreditCard
|
CreditCard CreditCard
|
||||||
Latitude float64
|
Latitude float64
|
||||||
Languages []Language `gorm:"many2many:user_languages;"`
|
Languages []Language `gorm:"many2many:user_languages;"`
|
||||||
CompanyID int64
|
CompanyID *int
|
||||||
Company Company
|
Company Company
|
||||||
Role
|
Role
|
||||||
PasswordHash []byte
|
PasswordHash []byte
|
||||||
|
|
|
@ -82,9 +82,10 @@ func toString(str interface{}) string {
|
||||||
return strings.Join(results, "_")
|
return strings.Join(results, "_")
|
||||||
} else if bytes, ok := str.([]byte); ok {
|
} else if bytes, ok := str.([]byte); ok {
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
} else {
|
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
|
||||||
return fmt.Sprintf("%v", str)
|
return fmt.Sprintf("%v", reflectValue.Interface())
|
||||||
}
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func strInSlice(a string, list []string) bool {
|
func strInSlice(a string, list []string) bool {
|
||||||
|
|
Loading…
Reference in New Issue