diff --git a/association_test.go b/association_test.go index c2f55d0e..8ba198e5 100644 --- a/association_test.go +++ b/association_test.go @@ -600,7 +600,9 @@ func TestRelated(t *testing.T) { Company: Company{Name: "company1"}, } - DB.Save(&user) + if err := DB.Save(&user).Error; err != nil { + t.Errorf("No error should happen when saving user") + } if user.CreditCard.ID == 0 { t.Errorf("After user save, credit card should have id") diff --git a/field.go b/field.go index 79e2c0ec..2ed4e732 100644 --- a/field.go +++ b/field.go @@ -13,7 +13,7 @@ type Field struct { Field reflect.Value } -func (field *Field) Set(value interface{}) error { +func (field *Field) Set(value interface{}) (err error) { if !field.Field.IsValid() { return errors.New("field value not valid") } @@ -27,15 +27,25 @@ func (field *Field) Set(value interface{}) error { reflectValue = reflect.ValueOf(value) } + fieldValue := field.Field if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(field.Field.Type()) { - field.Field.Set(reflectValue.Convert(field.Field.Type())) - } else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { - if err := scanner.Scan(reflectValue.Interface()); err != nil { - return err - } + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else { - return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type()) + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.Struct.Type.Elem())) + } + fieldValue = fieldValue.Elem() + } + + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) + } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + err = scanner.Scan(reflectValue.Interface()) + } else { + err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) + } } } else { field.Field.Set(reflect.Zero(field.Field.Type())) diff --git a/preload.go b/preload.go index 20c9aacf..692280ef 100644 --- a/preload.go +++ b/preload.go @@ -7,7 +7,6 @@ import ( "strings" ) -// Preload preload relations callback func Preload(scope *Scope) { if scope.Search.preload == nil || scope.HasError() { return diff --git a/query_test.go b/query_test.go index 71ced650..a7d5bc0e 100644 --- a/query_test.go +++ b/query_test.go @@ -207,10 +207,12 @@ func TestSearchWithStruct(t *testing.T) { } func TestSearchWithMap(t *testing.T) { + companyID := 1 user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) + user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID} + DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) var user User DB.First(&user, map[string]interface{}{"name": user1.Name}) @@ -234,6 +236,21 @@ func TestSearchWithMap(t *testing.T) { if len(users) != 1 { t.Errorf("Search all records with inline map") } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) + if len(users) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) + if len(users) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) + if len(users) != 1 { + t.Errorf("Search all records with inline multiple value map") + } } func TestSearchWithEmptyChain(t *testing.T) { @@ -377,10 +394,14 @@ func TestNot(t *testing.T) { DB.Create(getPreparedUser("user1", "not")) DB.Create(getPreparedUser("user2", "not")) DB.Create(getPreparedUser("user3", "not")) - DB.Create(getPreparedUser("user4", "not")) + + user4 := getPreparedUser("user4", "not") + user4.Company = Company{} + DB.Create(user4) + DB := DB.Where("role = ?", "not") - var users1, users2, users3, users4, users5, users6, users7, users8 []User + var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User if DB.Find(&users1).RowsAffected != 4 { t.Errorf("should find 4 not users") } @@ -423,15 +444,20 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - DB.Not("name", []string{"user3"}).Find(&users7) - if len(users1)-len(users7) != int(name3Count) { + DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) + if len(users1)-len(users7) != 2 { // not user3 or user4 + t.Errorf("Should find all user's name not equal to 3 who do not have company id") + } + + DB.Not("name", []string{"user3"}).Find(&users8) + if len(users1)-len(users8) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users8) - if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) { + DB.Not("name", []string{"user3", "user2"}).Find(&users9) + if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { t.Errorf("Should find all users's name not equal 3") } } diff --git a/scope_private.go b/scope_private.go index a9288bb2..9d0283d7 100644 --- a/scope_private.go +++ b/scope_private.go @@ -32,7 +32,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) + } } return strings.Join(sqls, " AND ") case interface{}: @@ -98,7 +102,11 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + if value != nil { + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) + } } return strings.Join(sqls, " AND ") case interface{}: diff --git a/structs_test.go b/structs_test.go index e595df58..42eb6bc3 100644 --- a/structs_test.go +++ b/structs_test.go @@ -28,7 +28,7 @@ type User struct { CreditCard CreditCard Latitude float64 Languages []Language `gorm:"many2many:user_languages;"` - CompanyID int64 + CompanyID *int Company Company Role PasswordHash []byte diff --git a/utils_private.go b/utils_private.go index f8f918fb..c4f0e963 100644 --- a/utils_private.go +++ b/utils_private.go @@ -87,9 +87,10 @@ func toString(str interface{}) string { return strings.Join(results, "_") } else if bytes, ok := str.([]byte); ok { return string(bytes) - } else { - return fmt.Sprintf("%v", str) + } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { + return fmt.Sprintf("%v", reflectValue.Interface()) } + return "" } func makeSlice(elemType reflect.Type) interface{} {