diff --git a/callback_query.go b/callback_query.go index 2cb2152a..68f1da38 100644 --- a/callback_query.go +++ b/callback_query.go @@ -1,9 +1,7 @@ package gorm import ( - "fmt" "reflect" - "strings" "time" ) @@ -25,6 +23,9 @@ func Query(scope *Scope) { ) var dest = reflect.Indirect(reflect.ValueOf(scope.Value)) + if value, ok := scope.Get("gorm:query_destination"); ok { + dest = reflect.Indirect(reflect.ValueOf(value)) + } if dest.Kind() == reflect.Slice { isSlice = true @@ -33,11 +34,7 @@ func Query(scope *Scope) { scope.Search = scope.Search.clone().limit(1) } - if scope.Search.raw { - scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.SelectSql(), scope.TableName(), scope.CombinedConditionSql())) - } + scope.prepareQuerySql() if !scope.HasError() { rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...) diff --git a/main.go b/main.go index f2ad76e2..0c52b2ed 100644 --- a/main.go +++ b/main.go @@ -109,6 +109,14 @@ func (s *DB) Unscoped() *DB { return s.clone().search.unscoped().db } +func (s *DB) Attrs(attrs ...interface{}) *DB { + return s.clone().search.attrs(attrs...).db +} + +func (s *DB) Assign(attrs ...interface{}) *DB { + return s.clone().search.assign(attrs...).db +} + func (s *DB) First(out interface{}, where ...interface{}) *DB { scope := s.clone().NewScope(out) scope.Search = scope.Search.clone().order(scope.PrimaryKey()).limit(1) @@ -126,23 +134,17 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { } func (s *DB) Row() *sql.Row { - return s.do(s.Value).row() + return s.NewScope(s.Value).row() } func (s *DB) Rows() (*sql.Rows, error) { - return s.do(s.Value).rows() + return s.NewScope(s.Value).rows() } func (s *DB) Scan(dest interface{}) *DB { - return s.do(s.Value).query(dest).db -} - -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.attrs(attrs...).db -} - -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.assign(attrs...).db + scope := s.clone().NewScope(s.Value).Set("gorm:query_destination", dest) + Query(scope) + return scope.db } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { diff --git a/scope.go b/scope.go index 334229f6..066b96f5 100644 --- a/scope.go +++ b/scope.go @@ -1,6 +1,7 @@ package gorm import ( + "database/sql" "errors" "fmt" "github.com/jinzhu/gorm/dialect" @@ -378,10 +379,23 @@ func (scope *Scope) CommitOrRollback() *Scope { return scope } -func (scope *Scope) SelectSql() string { - if len(scope.Search.selectStr) == 0 { - return "*" +func (scope *Scope) prepareQuerySql() { + if scope.Search.raw { + scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) } else { - return scope.Search.selectStr + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.TableName(), scope.CombinedConditionSql())) } + return +} + +func (scope *Scope) row() *sql.Row { + defer scope.Trace(time.Now()) + scope.prepareQuerySql() + return scope.DB().QueryRow(scope.Sql, scope.SqlVars...) +} + +func (scope *Scope) rows() (*sql.Rows, error) { + defer scope.Trace(time.Now()) + scope.prepareQuerySql() + return scope.DB().Query(scope.Sql, scope.SqlVars...) }