Add method inlineCondition

This commit is contained in:
Jinzhu 2014-01-28 10:01:53 +08:00
parent 6e848fc987
commit 275de4f133
3 changed files with 20 additions and 12 deletions

View File

@ -8,14 +8,6 @@ import (
func Query(scope *Scope) {
defer scope.Trace(time.Now())
inlineCondition, ok := scope.Get("gorm:inline_condition")
if ok {
inlineConditions := inlineCondition.([]interface{})
if len(inlineConditions) > 0 {
scope.Search = scope.Search.clone().where(inlineConditions[0], inlineConditions[1:]...)
}
}
var (
isSlice bool
anyRecordFound bool

View File

@ -120,17 +120,17 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
func (s *DB) First(out interface{}, where ...interface{}) *DB {
scope := s.clone().NewScope(out)
scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1)
return scope.Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db
return scope.inlineCondition(where).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
scope := s.clone().NewScope(out)
scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1)
return scope.Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db
return scope.inlineCondition(where).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().NewScope(out).Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db
return s.clone().NewScope(out).inlineCondition(where).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Row() *sql.Row {
@ -150,7 +150,7 @@ func (s *DB) Scan(dest interface{}) *DB {
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
c := s.clone()
if c.First(out, where...).Error == RecordNotFound {
return c.do(out).where(where).initialize().db
return c.NewScope(out).inlineCondition(where).initialize().db
} else if len(s.search.assignAttrs) > 0 {
return c.do(out).updateAttrs(s.search.assignAttrs).db
}

View File

@ -388,6 +388,13 @@ func (scope *Scope) prepareQuerySql() {
return
}
func (scope *Scope) inlineCondition(values []interface{}) *Scope {
if len(values) > 0 {
scope.Search = scope.Search.clone().where(values[0], values[1:]...)
}
return scope
}
func (scope *Scope) row() *sql.Row {
defer scope.Trace(time.Now())
scope.prepareQuerySql()
@ -399,3 +406,12 @@ func (scope *Scope) rows() (*sql.Rows, error) {
scope.prepareQuerySql()
return scope.DB().Query(scope.Sql, scope.SqlVars...)
}
func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereClause {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
return scope
}