From 7a8c2bbff8d0327b20017b24299394263b94f69f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 11 Feb 2018 23:52:38 +0800 Subject: [PATCH] Refactor build SQL condition --- create_test.go | 6 +- main.go | 2 +- migration_test.go | 3 + query_test.go | 2 +- scope.go | 156 ++++++++++++++-------------------------------- 5 files changed, 57 insertions(+), 112 deletions(-) diff --git a/create_test.go b/create_test.go index 36472914..83b3a4ef 100644 --- a/create_test.go +++ b/create_test.go @@ -27,7 +27,9 @@ func TestCreate(t *testing.T) { } var newUser User - DB.First(&newUser, user.Id) + if err := DB.First(&newUser, user.Id).Error; err != nil { + t.Errorf("No error should happen, but got %v", err) + } if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { t.Errorf("User's PasswordHash should be saved ([]byte)") @@ -38,7 +40,7 @@ func TestCreate(t *testing.T) { } if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type)") + t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) } if newUser.Latitude != float { diff --git a/main.go b/main.go index fc4859ac..d342571d 100644 --- a/main.go +++ b/main.go @@ -430,7 +430,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { scope := s.NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) + generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) return scope.Exec().db diff --git a/migration_test.go b/migration_test.go index d58e1fb5..7c694485 100644 --- a/migration_test.go +++ b/migration_test.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "reflect" + "strconv" "testing" "time" @@ -168,6 +169,8 @@ type Num int64 func (i *Num) Scan(src interface{}) error { switch s := src.(type) { case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) case int64: *i = Num(s) default: diff --git a/query_test.go b/query_test.go index 77449f4f..3c3c74b5 100644 --- a/query_test.go +++ b/query_test.go @@ -99,7 +99,7 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { var address AddressByZipCode DB.First(&address, "00501") if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") + t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } diff --git a/scope.go b/scope.go index 762904d7..5ac147e4 100644 --- a/scope.go +++ b/scope.go @@ -8,7 +8,6 @@ import ( "fmt" "reflect" "regexp" - "strconv" "strings" "time" ) @@ -521,26 +520,58 @@ func (scope *Scope) primaryCondition(value interface{}) string { return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) } -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { +func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { + var ( + quotedTableName = scope.QuotedTableName() + quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) + equalSQL = "=" + inSQL = "IN" + ) + + // If building not conditions + if !include { + equalSQL = "<>" + inSQL = "NOT IN" + } + switch value := clause["query"].(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) + case sql.NullInt64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) + if !include && reflect.ValueOf(value).Len() == 0 { + return + } + str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) clause["args"] = []interface{}{value} case string: if isNumberRegexp.MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) + } + + if value != "" { + if !include { + if comparisonRegexp.MatchString(value) { + str = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) + } + } else { + str = fmt.Sprintf("(%v)", value) + } } case map[string]interface{}: var sqls []string for key, value := range value { if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key))) + if !include { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) + } } } return strings.Join(sqls, " AND ") @@ -549,7 +580,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri newScope := scope.New(value) for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -582,6 +613,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { scope.Err(err) } @@ -603,98 +635,6 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri return } -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } else { - return "" - } - case string: - if isNumberRegexp.MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), id) - } else if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - var newScope = scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - replacements := []string{} - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) - } - - if err != nil { - scope.Err(err) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - - for pos := range str { - if str[pos] == '?' { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteByte(str[pos]) - } - } - - str = buff.String() - return -} - func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { switch value := clause["query"].(type) { case string: @@ -758,19 +698,19 @@ func (scope *Scope) whereSQL() (sql string) { } for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { orConditions = append(orConditions, sql) } } for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, false); sql != "" { andConditions = append(andConditions, sql) } } @@ -844,7 +784,7 @@ func (scope *Scope) havingSQL() string { var andConditions []string for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } @@ -860,7 +800,7 @@ func (scope *Scope) havingSQL() string { func (scope *Scope) joinsSQL() string { var joinConditions []string for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) } }