Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions (#1939)

* Edit DB.clone(), DB.Dialect(), and Scope.Dialect() preserve transactions.

* Adds a test case for tables creations and autoMigrate in the same transaction.
This commit is contained in:
Kevin 2018-07-26 19:35:53 -04:00 committed by Jinzhu
parent dbb25e9487
commit ac3ec858a6
3 changed files with 52 additions and 2 deletions

View File

@ -119,7 +119,7 @@ func (s *DB) CommonDB() SQLCommon {
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.parent.dialect
return s.dialect
}
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
@ -484,6 +484,8 @@ func (s *DB) Begin() *DB {
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
@ -748,6 +750,7 @@ func (s *DB) clone() *DB {
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db),
}
for key, value := range s.values {

View File

@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) {
}
}
func TestCreateAndAutomigrateTransaction(t *testing.T) {
tx := DB.Begin()
func() {
type Bar struct {
ID uint
}
DB.DropTableIfExists(&Bar{})
if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
}()
func() {
type Bar struct {
Name string
}
err := tx.CreateTable(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
}
if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table")
}
}()
func() {
type Bar struct {
Stuff string
}
err := tx.AutoMigrate(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't")
}
}()
tx.Rollback()
}
type MultipleIndexes struct {
ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`

View File

@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
// Dialect get dialect
func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect
return scope.db.dialect
}
// Quote used to quote string to escape them for database