From 86c04795b754c96ec5bbeee05284a35e8caa4de1 Mon Sep 17 00:00:00 2001
From: Jinzhu <wosmvp@gmail.com>
Date: Sun, 11 Feb 2018 15:52:52 +0800
Subject: [PATCH] Port PR #1655 to Not query builder

---
 main_test.go | 19 ++++++++++++++-
 scope.go     | 68 +++++++++++++++++++++++++++++++---------------------
 2 files changed, 59 insertions(+), 28 deletions(-)

diff --git a/main_test.go b/main_test.go
index 48a8bd63..66c46af0 100644
--- a/main_test.go
+++ b/main_test.go
@@ -636,7 +636,7 @@ func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
 	DB.Save(&user)
 	user = User{Name: "subquery_test_user2", Age: 11}
 	DB.Save(&user)
-	user = User{Name: "subquery_test_user2", Age: 12}
+	user = User{Name: "subquery_test_user3", Age: 12}
 	DB.Save(&user)
 
 	var count int
@@ -647,12 +647,29 @@ func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
 			Group("name").
 			QueryExpr(),
 	).Count(&count).Error
+
 	if err != nil {
 		t.Errorf("Expected to get no errors, but got %v", err)
 	}
 	if count != 2 {
 		t.Errorf("Row count must be 2, instead got %d", count)
 	}
+
+	err = DB.Raw("select count(*) from (?) tmp",
+		DB.Table("users").
+			Select("name").
+			Where("name LIKE ?", "subquery_test%").
+			Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}).
+			Group("name").
+			QueryExpr(),
+	).Count(&count).Error
+
+	if err != nil {
+		t.Errorf("Expected to get no errors, but got %v", err)
+	}
+	if count != 1 {
+		t.Errorf("Row count must be 1, instead got %d", count)
+	}
 }
 
 func TestQueryBuilderSubselectInHaving(t *testing.T) {
diff --git a/scope.go b/scope.go
index ba9bd37c..762904d7 100644
--- a/scope.go
+++ b/scope.go
@@ -460,7 +460,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
 var (
 	columnRegexp        = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name`
 	isNumberRegexp      = regexp.MustCompile("^\\s*\\d+\\s*$")                   // match if string is number
-	comparisonRegexp    = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ")
+	comparisonRegexp    = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
 	countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
 )
 
@@ -523,17 +523,17 @@ func (scope *Scope) primaryCondition(value interface{}) string {
 
 func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
 	switch value := clause["query"].(type) {
+	case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
+		return scope.primaryCondition(scope.AddToVars(value))
+	case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
+		str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()))
+		clause["args"] = []interface{}{value}
 	case string:
 		if isNumberRegexp.MatchString(value) {
 			return scope.primaryCondition(scope.AddToVars(value))
 		} else if value != "" {
 			str = fmt.Sprintf("(%v)", value)
 		}
-	case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
-		return scope.primaryCondition(scope.AddToVars(value))
-	case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
-		str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()))
-		clause["args"] = []interface{}{value}
 	case map[string]interface{}:
 		var sqls []string
 		for key, value := range value {
@@ -582,6 +582,9 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
 
 			replacements = append(replacements, scope.AddToVars(arg))
 		}
+		if err != nil {
+			scope.Err(err)
+		}
 	}
 
 	buff := bytes.NewBuffer([]byte{})
@@ -593,9 +596,6 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
 		} else {
 			buff.WriteByte(str[pos])
 		}
-		if err != nil {
-			scope.Err(err)
-		}
 	}
 
 	str = buff.String()
@@ -604,21 +604,9 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
 }
 
 func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
-	var notEqualSQL string
 	var primaryKey = scope.PrimaryKey()
 
 	switch value := clause["query"].(type) {
-	case string:
-		if isNumberRegexp.MatchString(value) {
-			id, _ := strconv.Atoi(value)
-			return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
-		} else if comparisonRegexp.MatchString(value) {
-			str = fmt.Sprintf(" NOT (%v) ", value)
-			notEqualSQL = fmt.Sprintf("NOT (%v)", value)
-		} else {
-			str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value))
-			notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value))
-		}
 	case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
 		return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value)
 	case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
@@ -628,6 +616,15 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
 		} else {
 			return ""
 		}
+	case string:
+		if isNumberRegexp.MatchString(value) {
+			id, _ := strconv.Atoi(value)
+			return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), id)
+		} else if comparisonRegexp.MatchString(value) {
+			str = fmt.Sprintf("NOT (%v)", value)
+		} else {
+			str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value))
+		}
 	case map[string]interface{}:
 		var sqls []string
 		for key, value := range value {
@@ -642,13 +639,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
 		var sqls []string
 		var newScope = scope.New(value)
 		for _, field := range newScope.Fields() {
-			if !field.IsBlank {
+			if !field.IsIgnored && !field.IsBlank {
 				sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
 			}
 		}
 		return strings.Join(sqls, " AND ")
 	}
 
+	replacements := []string{}
 	args := clause["args"].([]interface{})
 	for _, arg := range args {
 		var err error
@@ -656,28 +654,44 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
 		case reflect.Slice: // For where("id in (?)", []int64{1,2})
 			if scanner, ok := interface{}(arg).(driver.Valuer); ok {
 				arg, err = scanner.Value()
-				str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
+				replacements = append(replacements, scope.AddToVars(arg))
 			} else if b, ok := arg.([]byte); ok {
-				str = strings.Replace(str, "?", scope.AddToVars(b), 1)
+				replacements = append(replacements, scope.AddToVars(b))
 			} else if values := reflect.ValueOf(arg); values.Len() > 0 {
 				var tempMarks []string
 				for i := 0; i < values.Len(); i++ {
 					tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
 				}
-				str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
+				replacements = append(replacements, strings.Join(tempMarks, ","))
 			} else {
-				str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
+				replacements = append(replacements, scope.AddToVars(Expr("NULL")))
 			}
 		default:
 			if scanner, ok := interface{}(arg).(driver.Valuer); ok {
 				arg, err = scanner.Value()
 			}
-			str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
+
+			replacements = append(replacements, scope.AddToVars(arg))
 		}
+
 		if err != nil {
 			scope.Err(err)
 		}
 	}
+
+	buff := bytes.NewBuffer([]byte{})
+	i := 0
+
+	for pos := range str {
+		if str[pos] == '?' {
+			buff.WriteString(replacements[i])
+			i++
+		} else {
+			buff.WriteByte(str[pos])
+		}
+	}
+
+	str = buff.String()
 	return
 }