From 8e67a08774bb60a6380b9b2e761d440e361b3d9e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Jun 2021 15:38:20 +0800 Subject: [PATCH] Fix Scopes with Row, close #4465 --- callbacks/associations.go | 2 +- callbacks/create.go | 19 +++++++++---------- callbacks/row.go | 3 ++- finisher_api.go | 6 +++--- tests/count_test.go | 1 - tests/scopes_test.go | 9 +++++++++ 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 6d74f20d..78f976c3 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -373,7 +373,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, }) if tx.Statement.FullSaveAssociations { - tx = tx.InstanceSet("gorm:update_track_time", true) + tx = tx.Set("gorm:update_track_time", true) } if len(selects) > 0 { diff --git a/callbacks/create.go b/callbacks/create.go index e46d3d05..04ee6b30 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -243,9 +243,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + _, updateTrackTime = stmt.Get("gorm:update_track_time") curTime = stmt.DB.NowFunc() isZero bool ) + stmt.Settings.Delete("gorm:update_track_time") + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} for _, db := range stmt.Schema.DBNames { @@ -284,11 +287,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { field.Set(rv, curTime) values.Values[i][idx], _ = field.ValueOf(rv) } - } else if field.AutoUpdateTime > 0 { - if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) - } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + field.Set(rv, curTime) + values.Values[i][idx], _ = field.ValueOf(rv) } } @@ -326,11 +327,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { field.Set(stmt.ReflectValue, curTime) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } - } else if field.AutoUpdateTime > 0 { - if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) - } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } } diff --git a/callbacks/row.go b/callbacks/row.go index 10e880e1..407c32d7 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -9,7 +9,8 @@ func RowQuery(db *gorm.DB) { BuildQuerySQL(db) if !db.DryRun { - if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { + if isRows, ok := db.Get("rows"); ok && isRows.(bool) { + db.Statement.Settings.Delete("rows") db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index 771fa153..0f6440a3 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -79,7 +79,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } - tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) + tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { @@ -426,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance().InstanceSet("rows", false) + tx := db.getInstance().Set("rows", false) tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -436,7 +436,7 @@ func (db *DB) Row() *sql.Row { } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.getInstance().InstanceSet("rows", true) + tx := db.getInstance().Set("rows", true) tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { diff --git a/tests/count_test.go b/tests/count_test.go index 0fef82f7..dd25f8b6 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -124,7 +124,6 @@ func TestCount(t *testing.T) { var count9 int64 if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { - fmt.Println("kdkdkdkdk") return tx.Table("users") }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 0ec4783b..94fff308 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "testing" "gorm.io/gorm" @@ -62,4 +63,12 @@ func TestScopes(t *testing.T) { if result.RowsAffected != 2 { t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) } + + var maxId int64 + userTable := func(db *gorm.DB) *gorm.DB { + return db.WithContext(context.Background()).Table("users") + } + if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { + t.Errorf("select max(id)") + } }