Merge branch 'eatigo-master'

This commit is contained in:
Jinzhu 2018-02-12 00:06:50 +08:00
commit 0cc4d47ce5
6 changed files with 136 additions and 106 deletions

View File

@ -27,7 +27,9 @@ func TestCreate(t *testing.T) {
}
var newUser User
DB.First(&newUser, user.Id)
if err := DB.First(&newUser, user.Id).Error; err != nil {
t.Errorf("No error should happen, but got %v", err)
}
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
t.Errorf("User's PasswordHash should be saved ([]byte)")
@ -38,7 +40,7 @@ func TestCreate(t *testing.T) {
}
if newUser.UserNum != Num(111) {
t.Errorf("User's UserNum should be saved (custom type)")
t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum)
}
if newUser.Latitude != float {

View File

@ -430,7 +430,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB {
// Exec execute raw sql
func (s *DB) Exec(sql string, values ...interface{}) *DB {
scope := s.NewScope(nil)
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
scope.Raw(generatedSQL)
return scope.Exec().db

View File

@ -631,6 +631,47 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) {
}
}
func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
user := User{Name: "subquery_test_user1", Age: 10}
DB.Save(&user)
user = User{Name: "subquery_test_user2", Age: 11}
DB.Save(&user)
user = User{Name: "subquery_test_user3", Age: 12}
DB.Save(&user)
var count int
err := DB.Raw("select count(*) from (?) tmp",
DB.Table("users").
Select("name").
Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}).
Group("name").
QueryExpr(),
).Count(&count).Error
if err != nil {
t.Errorf("Expected to get no errors, but got %v", err)
}
if count != 2 {
t.Errorf("Row count must be 2, instead got %d", count)
}
err = DB.Raw("select count(*) from (?) tmp",
DB.Table("users").
Select("name").
Where("name LIKE ?", "subquery_test%").
Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}).
Group("name").
QueryExpr(),
).Count(&count).Error
if err != nil {
t.Errorf("Expected to get no errors, but got %v", err)
}
if count != 1 {
t.Errorf("Row count must be 1, instead got %d", count)
}
}
func TestQueryBuilderSubselectInHaving(t *testing.T) {
user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64}
DB.Save(&user)

View File

@ -7,6 +7,7 @@ import (
"fmt"
"os"
"reflect"
"strconv"
"testing"
"time"
@ -168,6 +169,8 @@ type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
n, _ := strconv.Atoi(string(s))
*i = Num(n)
case int64:
*i = Num(s)
default:

View File

@ -99,7 +99,7 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
var address AddressByZipCode
DB.First(&address, "00501")
if address.ZipCode != "00501" {
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed")
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
}
}

186
scope.go
View File

@ -1,16 +1,15 @@
package gorm
import (
"bytes"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"reflect"
)
// Scope contain current operation's information when you perform any operation on the database
@ -460,7 +459,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
var (
columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name`
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ")
comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
)
@ -521,26 +520,58 @@ func (scope *Scope) primaryCondition(value interface{}) string {
return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value)
}
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) {
var (
quotedTableName = scope.QuotedTableName()
quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
equalSQL = "="
inSQL = "IN"
)
// If building not conditions
if !include {
equalSQL = "<>"
inSQL = "NOT IN"
}
switch value := clause["query"].(type) {
case sql.NullInt64:
return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
if !include && reflect.ValueOf(value).Len() == 0 {
return
}
str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL)
clause["args"] = []interface{}{value}
case string:
if isNumberRegexp.MatchString(value) {
return scope.primaryCondition(scope.AddToVars(value))
} else if value != "" {
return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value))
}
if value != "" {
if !include {
if comparisonRegexp.MatchString(value) {
str = fmt.Sprintf("NOT (%v)", value)
} else {
str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value))
}
} else {
str = fmt.Sprintf("(%v)", value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value}
}
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value)))
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key)))
if !include {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key)))
}
}
}
return strings.Join(sqls, " AND ")
@ -549,12 +580,13 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
newScope := scope.New(value)
for _, field := range newScope.Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
replacements := []string{}
args := clause["args"].([]interface{})
for _, arg := range args {
var err error
@ -562,107 +594,44 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
case reflect.Slice: // For where("id in (?)", []int64{1,2})
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = scanner.Value()
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
} else if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
replacements = append(replacements, scope.AddToVars(arg))
} else if b, ok := arg.([]byte); ok {
replacements = append(replacements, scope.AddToVars(b))
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
replacements = append(replacements, strings.Join(tempMarks, ","))
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
replacements = append(replacements, scope.AddToVars(Expr("NULL")))
}
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
replacements = append(replacements, scope.AddToVars(arg))
}
if err != nil {
scope.Err(err)
}
}
return
}
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSQL string
var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) {
case string:
if isNumberRegexp.MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if comparisonRegexp.MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSQL = fmt.Sprintf("NOT (%v)", value)
buff := bytes.NewBuffer([]byte{})
i := 0
for pos := range str {
if str[pos] == '?' {
buff.WriteString(replacements[i])
i++
} else {
str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value))
notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value))
buff.WriteByte(str[pos])
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey))
clause["args"] = []interface{}{value}
} else {
return ""
}
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key)))
}
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
var newScope = scope.New(value)
for _, field := range newScope.Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
var err error
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = scanner.Value()
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
} else if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
} else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
}
default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = scanner.Value()
}
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
}
if err != nil {
scope.Err(err)
}
}
str = buff.String()
return
}
@ -675,6 +644,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
}
args := clause["args"].([]interface{})
replacements := []string{}
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice:
@ -683,14 +653,28 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
replacements = append(replacements, strings.Join(tempMarks, ","))
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
replacements = append(replacements, scope.AddToVars(arg))
}
}
buff := bytes.NewBuffer([]byte{})
i := 0
for pos := range str {
if str[pos] == '?' {
buff.WriteString(replacements[i])
i++
} else {
buff.WriteByte(str[pos])
}
}
str = buff.String()
return
}
@ -714,19 +698,19 @@ func (scope *Scope) whereSQL() (sql string) {
}
for _, clause := range scope.Search.whereConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
andConditions = append(andConditions, sql)
}
}
for _, clause := range scope.Search.orConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
orConditions = append(orConditions, sql)
}
}
for _, clause := range scope.Search.notConditions {
if sql := scope.buildNotCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, false); sql != "" {
andConditions = append(andConditions, sql)
}
}
@ -800,7 +784,7 @@ func (scope *Scope) havingSQL() string {
var andConditions []string
for _, clause := range scope.Search.havingConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
andConditions = append(andConditions, sql)
}
}
@ -816,7 +800,7 @@ func (scope *Scope) havingSQL() string {
func (scope *Scope) joinsSQL() string {
var joinConditions []string
for _, clause := range scope.Search.joinConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
}
}