From 0cf369dcff417ed51c8e439d8047f1b7503d87ce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Feb 2016 19:22:46 +0800 Subject: [PATCH] Fix query IN with empty slice --- query_test.go | 8 ++++++++ scope_private.go | 24 +++++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/query_test.go b/query_test.go index a7d5bc0e..ed7a518d 100644 --- a/query_test.go +++ b/query_test.go @@ -155,6 +155,14 @@ func TestSearchWithPlainSQL(t *testing.T) { t.Errorf("Should found 1 users, but got %v", len(users)) } + if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { + t.Error("no error should happen when query with empty slice, but got: ", err) + } + + if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { + t.Error("no error should happen when query with empty slice, but got: ", err) + } + if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { t.Errorf("Should not get RecordNotFound error when looking for none existing records") } diff --git a/scope_private.go b/scope_private.go index fa5d5f44..2b6c3332 100644 --- a/scope_private.go +++ b/scope_private.go @@ -26,7 +26,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: return scope.primaryCondition(scope.AddToVars(value)) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) + str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string @@ -54,13 +54,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else { - values := reflect.ValueOf(arg) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { @@ -83,7 +84,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { + } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) notEqualSql = fmt.Sprintf("NOT (%v)", value) } else { @@ -122,12 +123,17 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string for _, arg := range args { switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + if bytes, ok := arg.([]byte); ok { + str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + } else { + str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value()