Refact where method

This commit is contained in:
Jinzhu 2013-10-27 12:00:39 +08:00
parent 7948705181
commit 44c53f6dd5
4 changed files with 73 additions and 32 deletions

View File

@ -3,7 +3,6 @@
Yet Another ORM library for Go, aims for developer friendly Yet Another ORM library for Go, aims for developer friendly
## TODO ## TODO
* Offset
* Or query * Or query
* Not query * Not query
* Better First method (First(&user, primary_key, where conditions)) * Better First method (First(&user, primary_key, where conditions))

4
orm.go
View File

@ -19,6 +19,8 @@ type Orm struct {
db *sql.DB db *sql.DB
driver string driver string
whereClause []map[string]interface{} whereClause []map[string]interface{}
orClause []map[string]interface{}
notClause []map[string]interface{}
selectStr string selectStr string
orderStrs []string orderStrs []string
offsetStr string offsetStr string
@ -145,10 +147,12 @@ func (s *Orm) Pluck(column string, value interface{}) (orm *Orm) {
} }
func (s *Orm) Or(querystring interface{}, args ...interface{}) *Orm { func (s *Orm) Or(querystring interface{}, args ...interface{}) *Orm {
s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args})
return s return s
} }
func (s *Orm) Not(querystring interface{}, args ...interface{}) *Orm { func (s *Orm) Not(querystring interface{}, args ...interface{}) *Orm {
s.notClause = append(s.notClause, map[string]interface{}{"query": querystring, "args": args})
return s return s
} }

View File

@ -273,3 +273,11 @@ func TestOffset(t *testing.T) {
t.Errorf("Offset should works perfectly") t.Errorf("Offset should works perfectly")
} }
} }
func TestWhereOr(t *testing.T) {
// var users []User
// db.Where("name = ?", "1").Or("name = ?", "3").Find(&users)
// if len(users) != 3 {
// t.Errorf("Should find three users with name 1 and 3")
// }
}

92
sql.go
View File

@ -50,10 +50,11 @@ func (s *Orm) query(out interface{}) {
is_slice = true is_slice = true
dest_type = dest_out.Type().Elem() dest_type = dest_out.Type().Elem()
} }
debug(s.Sql)
debug(s.SqlVars)
rows, err := s.db.Query(s.Sql, s.SqlVars...) rows, err := s.db.Query(s.Sql, s.SqlVars...)
defer rows.Close() defer rows.Close()
s.Error = err s.Error = err
if rows.Err() != nil { if rows.Err() != nil {
s.Error = rows.Err() s.Error = rows.Err()
@ -167,43 +168,72 @@ func (s *Orm) deleteSql(value interface{}) {
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql()) s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
return return
} }
func (s *Orm) buildWhereCondition(clause map[string]interface{}) string {
str := "( " + clause["query"].(string) + " )"
func (s *Orm) whereSql() (sql string) { args := clause["args"].([]interface{})
var conditions []string for _, arg := range args {
if !s.model.PrimaryKeyIsEmpty() { switch reflect.TypeOf(arg).Kind() {
conditions = append(conditions, fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue()))) case reflect.Slice: // For where("id in (?)", []int64{1,2})
} v := reflect.ValueOf(arg)
if len(s.whereClause) > 0 { var temp_marks []string
for _, clause := range s.whereClause { for i := 0; i < v.Len(); i++ {
str := "( " + clause["query"].(string) + " )" temp_marks = append(temp_marks, "?")
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
v := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < v.Len(); i++ {
temp_marks = append(temp_marks, "?")
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
for i := 0; i < v.Len(); i++ {
str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1)
}
default:
str = strings.Replace(str, "?", s.addToVars(arg), 1)
}
} }
conditions = append(conditions, str)
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
for i := 0; i < v.Len(); i++ {
str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1)
}
default:
str = strings.Replace(str, "?", s.addToVars(arg), 1)
} }
} }
return str
}
if len(conditions) > 0 { func (s *Orm) whereSql() (sql string) {
sql = "WHERE " + strings.Join(conditions, " AND ") var primary_condiation string
var and_conditions, or_conditions, not_conditions []string
if !s.model.PrimaryKeyIsEmpty() {
primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue()))
} }
for _, clause := range s.whereClause {
and_conditions = append(and_conditions, s.buildWhereCondition(clause))
}
for _, clause := range s.notClause {
and_conditions = append(and_conditions, "!"+s.buildWhereCondition(clause))
}
for _, clause := range s.orClause {
or_conditions = append(or_conditions, s.buildWhereCondition(clause))
}
and_sql := strings.Join(and_conditions, " AND ")
or_sql := strings.Join(not_conditions, " OR ")
combined_conditions := and_sql
if len(combined_conditions) > 0 {
if len(or_sql) > 0 {
combined_conditions = combined_conditions + " OR " + or_sql
}
} else {
combined_conditions = or_sql
}
if len(primary_condiation) > 0 {
sql = "WHERE " + primary_condiation
if len(combined_conditions) > 0 {
sql = sql + " AND ( " + combined_conditions + " )"
}
} else if len(combined_conditions) > 0 {
sql = "WHERE " + combined_conditions
}
debug(sql)
return return
} }