Inline which condition for first/find

This commit is contained in:
Jinzhu 2013-10-27 20:54:23 +08:00
parent ec49f57394
commit f74c9015a1
5 changed files with 67 additions and 16 deletions

View File

@ -3,7 +3,6 @@
Yet Another ORM library for Go, aims for developer friendly Yet Another ORM library for Go, aims for developer friendly
## TODO ## TODO
* Better First method (First(&user, primary_key, where conditions))
* Update, Updates * Update, Updates
* Soft Delete * Soft Delete
* Even more complex where query (with map or struct) * Even more complex where query (with map or struct)

View File

@ -145,13 +145,13 @@ func (s *Chain) Exec(sql string) *Chain {
return s return s
} }
func (s *Chain) First(out interface{}) *Chain { func (s *Chain) First(out interface{}, where ...interface{}) *Chain {
s.do(out).query() s.do(out).query(where...)
return s return s
} }
func (s *Chain) Find(out interface{}) *Chain { func (s *Chain) Find(out interface{}, where ...interface{}) *Chain {
s.do(out).query() s.do(out).query(where...)
return s return s
} }

37
do.go
View File

@ -5,6 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strconv"
"strings" "strings"
) )
@ -160,7 +162,11 @@ func (s *Do) prepareQuerySql() *Do {
return s return s
} }
func (s *Do) query() { func (s *Do) query(where ...interface{}) {
if len(where) > 0 {
s.where(where[0], where[1:len(where)]...)
}
var ( var (
is_slice bool is_slice bool
dest_type reflect.Type dest_type reflect.Type
@ -176,7 +182,6 @@ func (s *Do) query() {
rows, err := s.db.Query(s.Sql, s.SqlVars...) rows, err := s.db.Query(s.Sql, s.SqlVars...)
defer rows.Close() defer rows.Close()
s.err(err) s.err(err)
if rows.Err() != nil { if rows.Err() != nil {
s.err(rows.Err()) s.err(rows.Err())
} }
@ -246,8 +251,28 @@ func (s *Do) pluck(value interface{}) *Do {
return s return s
} }
func (s *Do) buildWhereCondition(clause map[string]interface{}) string { func (s *Do) where(querystring interface{}, args ...interface{}) *Do {
str := "( " + clause["query"].(string) + " )" s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
return s
}
func (s *Do) primaryCondiation(value interface{}) string {
return fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), value)
}
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
switch clause["query"].(type) {
case string:
value := clause["query"].(string)
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return s.primaryCondiation(s.addToVars(id))
} else {
str = "( " + value + " )"
}
case int, int64, int32:
return s.primaryCondiation(s.addToVars(clause["query"]))
}
args := clause["args"].([]interface{}) args := clause["args"].([]interface{})
for _, arg := range args { for _, arg := range args {
@ -269,7 +294,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) string {
str = strings.Replace(str, "?", s.addToVars(arg), 1) str = strings.Replace(str, "?", s.addToVars(arg), 1)
} }
} }
return str return
} }
func (s *Do) whereSql() (sql string) { func (s *Do) whereSql() (sql string) {
@ -277,7 +302,7 @@ func (s *Do) whereSql() (sql string) {
var and_conditions, or_conditions []string var and_conditions, or_conditions []string
if !s.model.PrimaryKeyZero() { if !s.model.PrimaryKeyZero() {
primary_condiation = fmt.Sprintf("(%v = %v)", s.quote(s.model.PrimaryKeyDb()), s.addToVars(s.model.PrimaryKeyValue())) primary_condiation = s.primaryCondiation(s.addToVars(s.model.PrimaryKeyValue()))
} }
for _, clause := range s.whereClause { for _, clause := range s.whereClause {

View File

@ -25,12 +25,12 @@ func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
return s.buildORM().Where(querystring, args...) return s.buildORM().Where(querystring, args...)
} }
func (s *DB) First(out interface{}) *Chain { func (s *DB) First(out interface{}, where ...interface{}) *Chain {
return s.buildORM().First(out) return s.buildORM().First(out, where...)
} }
func (s *DB) Find(out interface{}) *Chain { func (s *DB) Find(out interface{}, where ...interface{}) *Chain {
return s.buildORM().Find(out) return s.buildORM().Find(out, where...)
} }
func (s *DB) Limit(value interface{}) *Chain { func (s *DB) Limit(value interface{}) *Chain {

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"errors" "errors"
"reflect" "reflect"
"strconv"
"testing" "testing"
"time" "time"
) )
@ -63,8 +64,34 @@ func init() {
} }
func TestFirst(t *testing.T) { func TestFirst(t *testing.T) {
var u1, u2 User var u1, u2, u3, u4, u5, u6, u7 User
db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2) db.Where("name = ?", "3").Order("age desc").First(&u1).First(&u2)
db.Where("name = ?", "3").First(&u3, "age = 22").First(&u4, "age = ?", 24).First(&u5, "age = ?", 26)
if !((u5.Id == 0) && (u3.Age == 22 && u3.Name == "3") && (u4.Age == 24 && u4.Name == "3")) {
t.Errorf("Inline where condition for first when search")
}
var us1, us2, us3, us4 []User
db.Find(&us1, "age = 22").Find(&us2, "name = ?", "3").Find(&us3, "age > ?", 20)
if !(len(us1) == 1 && len(us2) == 2 && len(us3) == 3) {
t.Errorf("Inline where condition for find when search")
}
db.Find(&us4, "name = ? and age > ?", "3", "22")
if len(us4) != 1 {
t.Errorf("More complex inline where condition for find, %v", us4)
}
db.First(&u6, u1.Id)
if !(u6.Id == u1.Id && u6.Id != 0) {
t.Errorf("Should find out user with int id")
}
db.First(&u7, strconv.Itoa(int(u1.Id)))
if !(u6.Id == u1.Id && u6.Id != 0) {
t.Errorf("Should find out user with string id")
}
} }
func TestSaveAndFind(t *testing.T) { func TestSaveAndFind(t *testing.T) {