From 26504f5caeb8c31dff62e8ddab68cee6b85a6580 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 17 Nov 2020 15:41:17 +0800 Subject: [PATCH] Use NewDB to replace WithConditions for Session --- association.go | 4 ++-- callbacks/associations.go | 14 +++++++------- callbacks/callmethod.go | 2 +- callbacks/delete.go | 4 ++-- callbacks/preload.go | 2 +- finisher_api.go | 8 ++++---- gorm.go | 9 ++++----- migrator.go | 2 +- migrator/migrator.go | 8 ++++---- statement.go | 2 +- tests/count_test.go | 2 +- tests/hooks_test.go | 7 +++++++ 12 files changed, 35 insertions(+), 29 deletions(-) diff --git a/association.go b/association.go index 140ae6ac..0f2102f7 100644 --- a/association.go +++ b/association.go @@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data @@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error } } diff --git a/callbacks/associations.go b/callbacks/associations.go index 1e6f62c5..1702f442 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -66,7 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -79,7 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -141,7 +141,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), ).Create(elems.Interface()).Error) } @@ -163,7 +163,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), ).Create(f.Interface()).Error) } @@ -224,7 +224,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), ).Create(elems.Interface()).Error) } @@ -291,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -299,7 +299,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) } } } diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index b81fc915..bcaa03f3 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -7,7 +7,7 @@ import ( ) func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { - tx := db.Session(&gorm.Session{}) + tx := db.Session(&gorm.Session{NewDB: true}) if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/callbacks/delete.go b/callbacks/delete.go index 0f4bcd6b..4a289e0c 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -34,7 +34,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { case schema.HasOne, schema.HasMany: queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{}).Model(modelValue) + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false if len(db.Statement.Selects) > 0 { @@ -71,7 +71,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { relForeignKeys []string modelValue = reflect.New(rel.JoinTable.ModelType).Interface() table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table) + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) ) for _, ref := range rel.References { diff --git a/callbacks/preload.go b/callbacks/preload.go index d60079e4..e1dfdace 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -13,7 +13,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( reflectValue = db.Statement.ReflectValue rel = rels[len(rels)-1] - tx = db.Session(&gorm.Session{}) + tx = db.Session(&gorm.Session{NewDB: true}) relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field diff --git a/finisher_api.go b/finisher_api.go index 211e2f8f..d1390a15 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -78,7 +78,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } @@ -144,7 +144,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat var ( tx = db.Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, - }).Session(&Session{WithConditions: true}) + }).Session(&Session{}) queryDB = tx rowsAffected int64 batch int @@ -480,7 +480,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(db.Session(&Session{WithConditions: true})) + err = fc(db.Session(&Session{})) } else { tx := db.Begin(opts...) @@ -506,7 +506,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.Session(&Session{WithConditions: true, Context: db.Statement.Context}) + tx = db.Session(&Session{Context: db.Statement.Context}) opt *sql.TxOptions err error ) diff --git a/gorm.go b/gorm.go index 3bf2479a..f7c18b08 100644 --- a/gorm.go +++ b/gorm.go @@ -63,7 +63,7 @@ type DB struct { type Session struct { DryRun bool PrepareStmt bool - WithConditions bool + NewDB bool SkipHooks bool SkipDefaultTransaction bool AllowGlobalUpdate bool @@ -196,7 +196,7 @@ func (db *DB) Session(config *Session) *DB { tx.Statement.UpdatingColumn = true } - if config.WithConditions { + if !config.NewDB { tx.clone = 2 } @@ -217,14 +217,13 @@ func (db *DB) Session(config *Session) *DB { // WithContext change current instance db's context to ctx func (db *DB) WithContext(ctx context.Context) *DB { - return db.Session(&Session{WithConditions: true, Context: ctx}) + return db.Session(&Session{Context: ctx}) } // Debug start debug mode func (db *DB) Debug() (tx *DB) { return db.Session(&Session{ - WithConditions: true, - Logger: db.Logger.LogMode(logger.Info), + Logger: db.Logger.LogMode(logger.Info), }) } diff --git a/migrator.go b/migrator.go index ac06a144..28ac35e7 100644 --- a/migrator.go +++ b/migrator.go @@ -7,7 +7,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) + return db.Dialector.Migrator(db.Session(&Session{})) } // AutoMigrate run auto migration for given models diff --git a/migrator/migrator.go b/migrator/migrator.go index 016ebfc7..5de820a8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -82,7 +82,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { return err @@ -154,7 +154,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" @@ -237,7 +237,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { @@ -404,7 +404,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + rows, err := m.DB.Session(&gorm.Session{NewDB: true}).Table(stmt.Table).Limit(1).Rows() if err == nil { defer rows.Close() rawColumnTypes, err := rows.ColumnTypes() diff --git a/statement.go b/statement.go index 7c0af59c..3f46ae0a 100644 --- a/statement.go +++ b/statement.go @@ -190,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance() + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.callbacks.Query().Execute(subdb) writer.WriteString(subdb.Statement.SQL.String()) diff --git a/tests/count_test.go b/tests/count_test.go index 41bad71d..55fb71e2 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -41,7 +41,7 @@ func TestCount(t *testing.T) { t.Errorf("multiple count in chain should works") } - tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{WithConditions: true}) + tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{}) tx.Count(&count1) tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 7e3ae4e4..fe3f7d08 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -380,6 +380,13 @@ func TestSetColumn(t *testing.T) { DB.First(&result2, product.ID) AssertEqual(t, result2, product) + + product2 := Product3{Name: "Product", Price: 0} + DB.Session(&gorm.Session{SkipHooks: true}).Create(&product2) + + if product2.Price != 0 { + t.Errorf("invalid price after create without hooks, got %+v", product2) + } } func TestHooksForSlice(t *testing.T) {