diff --git a/create_test.go b/create_test.go index 36472914..83b3a4ef 100644 --- a/create_test.go +++ b/create_test.go @@ -27,7 +27,9 @@ func TestCreate(t *testing.T) { } var newUser User - DB.First(&newUser, user.Id) + if err := DB.First(&newUser, user.Id).Error; err != nil { + t.Errorf("No error should happen, but got %v", err) + } if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { t.Errorf("User's PasswordHash should be saved ([]byte)") @@ -38,7 +40,7 @@ func TestCreate(t *testing.T) { } if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type)") + t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) } if newUser.Latitude != float { diff --git a/main.go b/main.go index fc4859ac..d342571d 100644 --- a/main.go +++ b/main.go @@ -430,7 +430,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { scope := s.NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) + generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) return scope.Exec().db diff --git a/main_test.go b/main_test.go index 83e6f7aa..66c46af0 100644 --- a/main_test.go +++ b/main_test.go @@ -631,6 +631,47 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { } } +func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { + user := User{Name: "subquery_test_user1", Age: 10} + DB.Save(&user) + user = User{Name: "subquery_test_user2", Age: 11} + DB.Save(&user) + user = User{Name: "subquery_test_user3", Age: 12} + DB.Save(&user) + + var count int + err := DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}). + Group("name"). + QueryExpr(), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_test%"). + Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}). + Group("name"). + QueryExpr(), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } +} + func TestQueryBuilderSubselectInHaving(t *testing.T) { user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64} DB.Save(&user) diff --git a/migration_test.go b/migration_test.go index d58e1fb5..7c694485 100644 --- a/migration_test.go +++ b/migration_test.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "reflect" + "strconv" "testing" "time" @@ -168,6 +169,8 @@ type Num int64 func (i *Num) Scan(src interface{}) error { switch s := src.(type) { case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) case int64: *i = Num(s) default: diff --git a/query_test.go b/query_test.go index 77449f4f..3c3c74b5 100644 --- a/query_test.go +++ b/query_test.go @@ -99,7 +99,7 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { var address AddressByZipCode DB.First(&address, "00501") if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") + t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } diff --git a/scope.go b/scope.go index 29508d8d..5ac147e4 100644 --- a/scope.go +++ b/scope.go @@ -1,16 +1,15 @@ package gorm import ( + "bytes" "database/sql" "database/sql/driver" "errors" "fmt" + "reflect" "regexp" - "strconv" "strings" "time" - - "reflect" ) // Scope contain current operation's information when you perform any operation on the database @@ -460,7 +459,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { var ( columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") + comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") ) @@ -521,26 +520,58 @@ func (scope *Scope) primaryCondition(value interface{}) string { return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) } -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { +func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { + var ( + quotedTableName = scope.QuotedTableName() + quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) + equalSQL = "=" + inSQL = "IN" + ) + + // If building not conditions + if !include { + equalSQL = "<>" + inSQL = "NOT IN" + } + switch value := clause["query"].(type) { + case sql.NullInt64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: + if !include && reflect.ValueOf(value).Len() == 0 { + return + } + str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) + clause["args"] = []interface{}{value} case string: if isNumberRegexp.MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) + } + + if value != "" { + if !include { + if comparisonRegexp.MatchString(value) { + str = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) + } + } else { + str = fmt.Sprintf("(%v)", value) + } } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string for key, value := range value { if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key))) + if !include { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) + } } } return strings.Join(sqls, " AND ") @@ -549,12 +580,13 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri newScope := scope.New(value) for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") } + replacements := []string{} args := clause["args"].([]interface{}) for _, arg := range args { var err error @@ -562,107 +594,44 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case reflect.Slice: // For where("id in (?)", []int64{1,2}) if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } else if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + replacements = append(replacements, scope.AddToVars(arg)) + } else if b, ok := arg.([]byte); ok { + replacements = append(replacements, scope.AddToVars(b)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + replacements = append(replacements, scope.AddToVars(Expr("NULL"))) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, err = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { scope.Err(err) } } - return -} -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case string: - if isNumberRegexp.MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) + buff := bytes.NewBuffer([]byte{}) + i := 0 + for pos := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value)) + buff.WriteByte(str[pos]) } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } else { - return "" - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - var newScope = scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") } - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } else if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) - } - if err != nil { - scope.Err(err) - } - } + str = buff.String() + return } @@ -675,6 +644,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) } args := clause["args"].([]interface{}) + replacements := []string{} for _, arg := range args { switch reflect.ValueOf(arg).Kind() { case reflect.Slice: @@ -683,14 +653,28 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) } } + + buff := bytes.NewBuffer([]byte{}) + i := 0 + for pos := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteByte(str[pos]) + } + } + + str = buff.String() + return } @@ -714,19 +698,19 @@ func (scope *Scope) whereSQL() (sql string) { } for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { orConditions = append(orConditions, sql) } } for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, false); sql != "" { andConditions = append(andConditions, sql) } } @@ -800,7 +784,7 @@ func (scope *Scope) havingSQL() string { var andConditions []string for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } @@ -816,7 +800,7 @@ func (scope *Scope) havingSQL() string { func (scope *Scope) joinsSQL() string { var joinConditions []string for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) } }