diff --git a/callback_query.go b/callback_query.go index 68f1da38..e671899f 100644 --- a/callback_query.go +++ b/callback_query.go @@ -8,14 +8,6 @@ import ( func Query(scope *Scope) { defer scope.Trace(time.Now()) - inlineCondition, ok := scope.Get("gorm:inline_condition") - if ok { - inlineConditions := inlineCondition.([]interface{}) - if len(inlineConditions) > 0 { - scope.Search = scope.Search.clone().where(inlineConditions[0], inlineConditions[1:]...) - } - } - var ( isSlice bool anyRecordFound bool diff --git a/main.go b/main.go index 0c52b2ed..ec77330b 100644 --- a/main.go +++ b/main.go @@ -120,17 +120,17 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1) - return scope.Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db + return scope.inlineCondition(where).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) scope.Search = scope.Search.clone().order(scope.PrimaryKey() + " DESC").limit(1) - return scope.Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db + return scope.inlineCondition(where).callCallbacks(s.parent.callback.queries).db } func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).Set("gorm:inline_condition", where).callCallbacks(s.parent.callback.queries).db + return s.clone().NewScope(out).inlineCondition(where).callCallbacks(s.parent.callback.queries).db } func (s *DB) Row() *sql.Row { @@ -150,7 +150,7 @@ func (s *DB) Scan(dest interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { c := s.clone() if c.First(out, where...).Error == RecordNotFound { - return c.do(out).where(where).initialize().db + return c.NewScope(out).inlineCondition(where).initialize().db } else if len(s.search.assignAttrs) > 0 { return c.do(out).updateAttrs(s.search.assignAttrs).db } diff --git a/scope.go b/scope.go index 066b96f5..cdbcbf8a 100644 --- a/scope.go +++ b/scope.go @@ -388,6 +388,13 @@ func (scope *Scope) prepareQuerySql() { return } +func (scope *Scope) inlineCondition(values []interface{}) *Scope { + if len(values) > 0 { + scope.Search = scope.Search.clone().where(values[0], values[1:]...) + } + return scope +} + func (scope *Scope) row() *sql.Row { defer scope.Trace(time.Now()) scope.prepareQuerySql() @@ -399,3 +406,12 @@ func (scope *Scope) rows() (*sql.Rows, error) { scope.prepareQuerySql() return scope.DB().Query(scope.Sql, scope.SqlVars...) } + +func (scope *Scope) initialize() *Scope { + for _, clause := range scope.Search.whereClause { + scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) + } + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) + return scope +}