Fix query IN with empty slice

This commit is contained in:
Jinzhu 2016-02-08 19:22:46 +08:00
parent 7aab3ae861
commit 0cf369dcff
2 changed files with 23 additions and 9 deletions

View File

@ -155,6 +155,14 @@ func TestSearchWithPlainSQL(t *testing.T) {
t.Errorf("Should found 1 users, but got %v", len(users)) t.Errorf("Should found 1 users, but got %v", len(users))
} }
if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil {
t.Error("no error should happen when query with empty slice, but got: ", err)
}
if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil {
t.Error("no error should happen when query with empty slice, but got: ", err)
}
if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() {
t.Errorf("Should not get RecordNotFound error when looking for none existing records") t.Errorf("Should not get RecordNotFound error when looking for none existing records")
} }

View File

@ -26,7 +26,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value)) return scope.primaryCondition(scope.AddToVars(value))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value} clause["args"] = []interface{}{value}
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
@ -54,13 +54,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
case reflect.Slice: // For where("id in (?)", []int64{1,2}) case reflect.Slice: // For where("id in (?)", []int64{1,2})
if bytes, ok := arg.([]byte); ok { if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
} else { } else if values := reflect.ValueOf(arg); values.Len() > 0 {
values := reflect.ValueOf(arg)
var tempMarks []string var tempMarks []string
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
} }
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
} }
default: default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok { if valuer, ok := interface{}(arg).(driver.Valuer); ok {
@ -83,7 +84,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value) id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value) str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSql = fmt.Sprintf("NOT (%v)", value) notEqualSql = fmt.Sprintf("NOT (%v)", value)
} else { } else {
@ -122,12 +123,17 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
for _, arg := range args { for _, arg := range args {
switch reflect.ValueOf(arg).Kind() { switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2}) case reflect.Slice: // For where("id in (?)", []int64{1,2})
values := reflect.ValueOf(arg) if bytes, ok := arg.([]byte); ok {
var tempMarks []string str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
for i := 0; i < values.Len(); i++ { } else if values := reflect.ValueOf(arg); values.Len() > 0 {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
} }
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default: default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok { if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value() arg, _ = scanner.Value()