diff --git a/README.md b/README.md index 9a671386..b0dcb7bd 100644 --- a/README.md +++ b/README.md @@ -570,7 +570,7 @@ db.Model(&user).Related(&card, "CreditCard") //// SELECT * FROM credit_cards WHERE user_id = 123; // 123 is user's primary key // CreditCard is user's field name, it means get user's CreditCard relations and fill it into variable card // If the field name is same as the variable's type name, like above example, it could be omitted, like: -db.Model(&user).Related(&creditCard, "CreditCard") +db.Model(&user).Related(&card) ``` ### Belongs To @@ -859,7 +859,7 @@ db.Where("name = ?", "jinzhu").Or("name = ?", "jinzhu 2").Find(&users).Count(&co //// SELECT * from USERS WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (users) //// SELECT count(*) FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (count) -db.Model(User{}).Where("name = ?", "jinzhu").Count(&count) +db.Model(&User{}).Where("name = ?", "jinzhu").Count(&count) //// SELECT count(*) FROM users WHERE name = 'jinzhu'; (count) db.Table("deleted_users").Count(&count) diff --git a/preload_test.go b/preload_test.go index 2fc441a9..7bc3f389 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1064,6 +1064,64 @@ func TestNestedManyToManyPreload3(t *testing.T) { } } +func TestNestedManyToManyPreload4(t *testing.T) { + type ( + Level4 struct { + ID uint + Value string + Level3ID uint + } + Level3 struct { + ID uint + Value string + Level4s []*Level4 + } + Level2 struct { + ID uint + Value string + Level3s []*Level3 `gorm:"many2many:level2_level3;"` + } + Level1 struct { + ID uint + Value string + Level2s []*Level2 `gorm:"many2many:level1_level2;"` + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level4{}) + DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists("level2_level3") + + dummy := Level1{ + Value: "Level1", + Level2s: []*Level2{&Level2{ + Value: "Level2", + Level3s: []*Level3{&Level3{ + Value: "Level3", + Level4s: []*Level4{&Level4{ + Value: "Level4", + }}, + }}, + }}, + } + + if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + if err := DB.Save(&dummy).Error; err != nil { + t.Error(err) + } + + var level1 Level1 + if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { + t.Error(err) + } +} + func TestManyToManyPreloadForPointer(t *testing.T) { type ( Level1 struct { diff --git a/query_test.go b/query_test.go index b762dee5..b376dc82 100644 --- a/query_test.go +++ b/query_test.go @@ -155,6 +155,14 @@ func TestSearchWithPlainSQL(t *testing.T) { t.Errorf("Should found 1 users, but got %v", len(users)) } + if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { + t.Error("no error should happen when query with empty slice, but got: ", err) + } + + if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { + t.Error("no error should happen when query with empty slice, but got: ", err) + } + if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { t.Errorf("Should not get RecordNotFound error when looking for none existing records") } diff --git a/scope_private.go b/scope_private.go index 138bd6fd..4fd48833 100644 --- a/scope_private.go +++ b/scope_private.go @@ -27,7 +27,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: return scope.primaryCondition(scope.AddToVars(value)) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) + str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string @@ -55,13 +55,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else { - values := reflect.ValueOf(arg) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { @@ -84,7 +85,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { + } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) notEqualSql = fmt.Sprintf("NOT (%v)", value) } else { @@ -123,12 +124,17 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string for _, arg := range args { switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value()