diff --git a/chainable_api.go b/chainable_api.go index e2ba44cc..acceb58f 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -142,7 +142,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}}) + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) } return } diff --git a/statement.go b/statement.go index c03f6f88..d6444fae 100644 --- a/statement.go +++ b/statement.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "sort" "strconv" "strings" "sync" @@ -260,12 +261,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c conds = append(conds, clause.Eq{Column: i, Value: j}) } case map[string]string: - for i, j := range v { - conds = append(conds, clause.Eq{Column: i, Value: j}) + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } case map[string]interface{}: - for i, j := range v { - reflectValue := reflect.Indirect(reflect.ValueOf(j)) + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: values := make([]interface{}, reflectValue.Len()) @@ -273,9 +286,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c values[i] = reflectValue.Index(i).Interface() } - conds = append(conds, clause.IN{Column: i, Values: values}) + conds = append(conds, clause.IN{Column: key, Values: values}) default: - conds = append(conds, clause.Eq{Column: i, Value: j}) + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } } default: diff --git a/tests/query_test.go b/tests/query_test.go index c9eb5903..5a8bbef2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -218,6 +218,25 @@ func TestNot(t *testing.T) { } } +func TestOr(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*name.* AND .*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -269,6 +288,23 @@ func TestSelect(t *testing.T) { if user.Name != result.Name { t.Errorf("Should have user Name when selected it") } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Select("name", "age").Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select([]string{"name", "age"}).Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) + if !regexp.MustCompile("SELECT COALESCE\\(age,.*\\) FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", r.Statement.SQL.String()) + } + // SELECT COALESCE(age,'42') FROM users; } func TestPluckWithSelect(t *testing.T) {