This commit is contained in:
Jinzhu 2020-07-05 12:23:45 +08:00
parent 1a2fabb34d
commit 4e066c9590
3 changed files with 56 additions and 7 deletions

View File

@ -142,7 +142,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}}) tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
} }
return return
} }

View File

@ -6,6 +6,7 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -260,12 +261,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
conds = append(conds, clause.Eq{Column: i, Value: j}) conds = append(conds, clause.Eq{Column: i, Value: j})
} }
case map[string]string: case map[string]string:
for i, j := range v { var keys = make([]string, 0, len(v))
conds = append(conds, clause.Eq{Column: i, Value: j}) for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} }
case map[string]interface{}: case map[string]interface{}:
for i, j := range v { var keys = make([]string, 0, len(v))
reflectValue := reflect.Indirect(reflect.ValueOf(j)) for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
values := make([]interface{}, reflectValue.Len()) values := make([]interface{}, reflectValue.Len())
@ -273,9 +286,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c
values[i] = reflectValue.Index(i).Interface() values[i] = reflectValue.Index(i).Interface()
} }
conds = append(conds, clause.IN{Column: i, Values: values}) conds = append(conds, clause.IN{Column: key, Values: values})
default: default:
conds = append(conds, clause.Eq{Column: i, Value: j}) conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} }
} }
default: default:

View File

@ -218,6 +218,25 @@ func TestNot(t *testing.T) {
} }
} }
func TestOr(t *testing.T) {
dryDB := DB.Session(&gorm.Session{DryRun: true})
result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Where("name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*name.* AND .*age.*\\)").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Where("name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
}
func TestPluck(t *testing.T) { func TestPluck(t *testing.T) {
users := []*User{ users := []*User{
GetUser("pluck-user1", Config{}), GetUser("pluck-user1", Config{}),
@ -269,6 +288,23 @@ func TestSelect(t *testing.T) {
if user.Name != result.Name { if user.Name != result.Name {
t.Errorf("Should have user Name when selected it") t.Errorf("Should have user Name when selected it")
} }
dryDB := DB.Session(&gorm.Session{DryRun: true})
r := dryDB.Select("name", "age").Find(&User{})
if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String())
}
r = dryDB.Select([]string{"name", "age"}).Find(&User{})
if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String())
}
r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{})
if !regexp.MustCompile("SELECT COALESCE\\(age,.*\\) FROM .*users.*").MatchString(r.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String())
}
// SELECT COALESCE(age,'42') FROM users;
} }
func TestPluckWithSelect(t *testing.T) { func TestPluckWithSelect(t *testing.T) {