From 86c04795b754c96ec5bbeee05284a35e8caa4de1 Mon Sep 17 00:00:00 2001 From: Jinzhu <wosmvp@gmail.com> Date: Sun, 11 Feb 2018 15:52:52 +0800 Subject: [PATCH] Port PR #1655 to Not query builder --- main_test.go | 19 ++++++++++++++- scope.go | 68 +++++++++++++++++++++++++++++++--------------------- 2 files changed, 59 insertions(+), 28 deletions(-) diff --git a/main_test.go b/main_test.go index 48a8bd63..66c46af0 100644 --- a/main_test.go +++ b/main_test.go @@ -636,7 +636,7 @@ func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { DB.Save(&user) user = User{Name: "subquery_test_user2", Age: 11} DB.Save(&user) - user = User{Name: "subquery_test_user2", Age: 12} + user = User{Name: "subquery_test_user3", Age: 12} DB.Save(&user) var count int @@ -647,12 +647,29 @@ func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { Group("name"). QueryExpr(), ).Count(&count).Error + if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } if count != 2 { t.Errorf("Row count must be 2, instead got %d", count) } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_test%"). + Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}). + Group("name"). + QueryExpr(), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } } func TestQueryBuilderSubselectInHaving(t *testing.T) { diff --git a/scope.go b/scope.go index ba9bd37c..762904d7 100644 --- a/scope.go +++ b/scope.go @@ -460,7 +460,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { var ( columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") + comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") ) @@ -523,17 +523,17 @@ func (scope *Scope) primaryCondition(value interface{}) string { func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { switch value := clause["query"].(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: + return scope.primaryCondition(scope.AddToVars(value)) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: + str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) + clause["args"] = []interface{}{value} case string: if isNumberRegexp.MatchString(value) { return scope.primaryCondition(scope.AddToVars(value)) } else if value != "" { str = fmt.Sprintf("(%v)", value) } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string for key, value := range value { @@ -582,6 +582,9 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { + scope.Err(err) + } } buff := bytes.NewBuffer([]byte{}) @@ -593,9 +596,6 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } else { buff.WriteByte(str[pos]) } - if err != nil { - scope.Err(err) - } } str = buff.String() @@ -604,21 +604,9 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string var primaryKey = scope.PrimaryKey() switch value := clause["query"].(type) { - case string: - if isNumberRegexp.MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value)) - } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: @@ -628,6 +616,15 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string } else { return "" } + case string: + if isNumberRegexp.MatchString(value) { + id, _ := strconv.Atoi(value) + return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), id) + } else if comparisonRegexp.MatchString(value) { + str = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) + } case map[string]interface{}: var sqls []string for key, value := range value { @@ -642,13 +639,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string var sqls []string var newScope = scope.New(value) for _, field := range newScope.Fields() { - if !field.IsBlank { + if !field.IsIgnored && !field.IsBlank { sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") } + replacements := []string{} args := clause["args"].([]interface{}) for _, arg := range args { var err error @@ -656,28 +654,44 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string case reflect.Slice: // For where("id in (?)", []int64{1,2}) if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) } else if b, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(b), 1) + replacements = append(replacements, scope.AddToVars(b)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + replacements = append(replacements, scope.AddToVars(Expr("NULL"))) } default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, err = scanner.Value() } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + + replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { scope.Err(err) } } + + buff := bytes.NewBuffer([]byte{}) + i := 0 + + for pos := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteByte(str[pos]) + } + } + + str = buff.String() return }