diff --git a/callback.go b/callback.go index be9a7f12..603e5111 100644 --- a/callback.go +++ b/callback.go @@ -9,6 +9,7 @@ type callback struct { updates []*func(scope *Scope) deletes []*func(scope *Scope) queries []*func(scope *Scope) + rowQueries []*func(scope *Scope) processors []*callbackProcessor } @@ -55,6 +56,10 @@ func (c *callback) Query() *callbackProcessor { return c.addProcessor("query") } +func (c *callback) RowQuery() *callbackProcessor { + return c.addProcessor("row_query") +} + func (cp *callbackProcessor) Before(name string) *callbackProcessor { cp.before = name return cp @@ -168,7 +173,7 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) { } func (c *callback) sort() { - creates, updates, deletes, queries := []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{} + var creates, updates, deletes, queries, rowQueries []*callbackProcessor for _, processor := range c.processors { switch processor.typ { @@ -180,6 +185,8 @@ func (c *callback) sort() { deletes = append(deletes, processor) case "query": queries = append(queries, processor) + case "row_query": + rowQueries = append(rowQueries, processor) } } @@ -187,6 +194,7 @@ func (c *callback) sort() { c.updates = sortProcessors(updates) c.deletes = sortProcessors(deletes) c.queries = sortProcessors(queries) + c.rowQueries = sortProcessors(rowQueries) } var DefaultCallback = &callback{processors: []*callbackProcessor{}} diff --git a/callback_query.go b/callback_query.go index 5daa5fec..0eea6f89 100644 --- a/callback_query.go +++ b/callback_query.go @@ -16,17 +16,17 @@ func Query(scope *Scope) { destType reflect.Type ) - var dest = scope.IndirectValue() - if value, ok := scope.InstanceGet("gorm:query_destination"); ok { - dest = reflect.Indirect(reflect.ValueOf(value)) - } - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { if primaryKey := scope.PrimaryKey(); primaryKey != "" { scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) } } + var dest = scope.IndirectValue() + if value, ok := scope.InstanceGet("gorm:query_destination"); ok { + dest = reflect.Indirect(reflect.ValueOf(value)) + } + if kind := dest.Kind(); kind == reflect.Slice { isSlice = true destType = dest.Type().Elem() diff --git a/main.go b/main.go index f39a373f..04f59bcf 100644 --- a/main.go +++ b/main.go @@ -211,6 +211,10 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } +func (s *DB) Scan(dest interface{}) *DB { + return s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db +} + func (s *DB) Row() *sql.Row { return s.NewScope(s.Value).row() } @@ -219,8 +223,16 @@ func (s *DB) Rows() (*sql.Rows, error) { return s.NewScope(s.Value).rows() } -func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db +func (s *DB) Pluck(column string, value interface{}) *DB { + return s.NewScope(s.Value).pluck(column, value).db +} + +func (s *DB) Count(value interface{}) *DB { + return s.NewScope(s.Value).count(value).db +} + +func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { + return s.clone().NewScope(s.Value).related(value, foreignKeys...).db } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { @@ -307,18 +319,6 @@ func (s *DB) Model(value interface{}) *DB { return c } -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db -} - -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - func (s *DB) Table(name string) *DB { clone := s.clone() clone.search.Table(name) diff --git a/scope_private.go b/scope_private.go index 03a47acc..b9476455 100644 --- a/scope_private.go +++ b/scope_private.go @@ -336,12 +336,14 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore func (scope *Scope) row() *sql.Row { defer scope.Trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.prepareQuerySql() return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) } func (scope *Scope) rows() (*sql.Rows, error) { defer scope.Trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.prepareQuerySql() return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) }