diff --git a/query_test.go b/query_test.go index 80ebd473..fac7d4d8 100644 --- a/query_test.go +++ b/query_test.go @@ -222,6 +222,36 @@ func TestSearchWithPlainSQL(t *testing.T) { } } +func TestSearchWithTwoDimensionalArray(t *testing.T) { + var users []User + user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} + user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} + user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} + DB.Create(&user1) + DB.Create(&user2) + DB.Create(&user3) + + if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" { + if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { + t.Errorf("No error should happen when query with 2D array, but got %v", err) + + if len(users) != 2 { + t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) + } + } + } + + if dialect := DB.Dialect().GetName(); dialect == "mssql" { + if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { + t.Errorf("No error should happen when query with 2D array, but got %v", err) + + if len(users) != 2 { + t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) + } + } + } +} + func TestSearchWithStruct(t *testing.T) { user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} diff --git a/scope.go b/scope.go index 3fe4675d..cdb772ca 100644 --- a/scope.go +++ b/scope.go @@ -606,6 +606,22 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) replacements = append(replacements, scope.AddToVars(arg)) } else if b, ok := arg.([]byte); ok { replacements = append(replacements, scope.AddToVars(b)) + } else if as, ok := arg.([][]interface{}); ok { + var tempMarks []string + for _, a := range as { + var arrayMarks []string + for _, v := range a { + arrayMarks = append(arrayMarks, scope.AddToVars(v)) + } + + if len(arrayMarks) > 0 { + tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) + } + } + + if len(tempMarks) > 0 { + replacements = append(replacements, strings.Join(tempMarks, ",")) + } } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ {