From 1895d281bf7a183e5d679c1962737eb74ab19546 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 23:08:20 +0800 Subject: [PATCH] Add migrator tests for mysql --- dialects/mysql/mysql.go | 11 ++++++---- dialects/mysql/mysql_test.go | 21 ++++++++++++++++++ dialects/sqlite/sqlite.go | 1 - finisher_api.go | 2 +- migrator/migrator.go | 41 +++++++++++++++++++----------------- tests/migrate.go | 2 +- tests/model.go | 4 ++-- 7 files changed, 54 insertions(+), 28 deletions(-) diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 3b456891..5fcc2d69 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -23,9 +23,8 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("sqlite3", dialector.DSN) - - return nil + db.DB, err = sql.Open("mysql", dialector.DSN) + return } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { @@ -75,9 +74,13 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return "double" case schema.String: size := field.Size + if field.PrimaryKey { + size = 256 + } + if size >= 65536 && size <= int(math.Pow(2, 24)) { return "mediumtext" - } else if size > int(math.Pow(2, 24)) || size < 0 { + } else if size > int(math.Pow(2, 24)) || size <= 0 { return "longtext" } return fmt.Sprintf("varchar(%d)", size) diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index 49c26915..7fd5e373 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -1,12 +1,33 @@ package mysql_test import ( + "fmt" "testing" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/dialects/mysql" + "github.com/jinzhu/gorm/tests" ) func TestOpen(t *testing.T) { gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), nil) } + +var ( + DB *gorm.DB + err error +) + +func init() { + if DB, err = gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestCURD(t *testing.T) { + tests.RunTestsSuit(t, DB) +} + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 38cd760b..54fa7de0 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -21,7 +21,6 @@ func Open(dsn string) gorm.Dialector { func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - db.DB, err = sql.Open("sqlite3", dialector.DSN) return } diff --git a/finisher_api.go b/finisher_api.go index c9b58861..2c5d4f65 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -140,7 +140,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(tx) + err = fc(tx.Session(&Session{})) if err == nil { err = tx.Commit().Error diff --git a/migrator/migrator.go b/migrator/migrator.go index e3097abd..a5ec1a62 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -18,8 +18,9 @@ type Migrator struct { // Config schema config type Config struct { - CreateIndexAfterCreateTable bool - DB *gorm.DB + CreateIndexAfterCreateTable bool + AllowDeferredConstraintsWhenAutoMigrate bool + DB *gorm.DB gorm.Dialector } @@ -47,17 +48,17 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { // TODO smart migrate data type - for _, value := range values { - if !m.DB.Migrator().HasTable(value) { - if err := m.DB.Migrator().CreateTable(value); err != nil { + tx := m.DB.Session(&gorm.Session{}) + if !tx.Migrator().HasTable(value) { + if err := tx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, field := range stmt.Schema.FieldsByDBName { - if !m.DB.Migrator().HasColumn(value, field.DBName) { - if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil { + if !tx.Migrator().HasColumn(value, field.DBName) { + if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } } @@ -65,16 +66,16 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - if !m.DB.Migrator().HasConstraint(value, constraint.Name) { - if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !m.DB.Migrator().HasConstraint(value, chk.Name) { - if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } @@ -83,8 +84,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(joinValue) { + defer tx.Migrator().CreateTable(joinValue) } } } @@ -100,6 +101,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range values { + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( createTableSQL = "CREATE TABLE ? (" @@ -144,10 +146,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { - m.DB.Migrator().CreateIndex(value, idx.Name) + tx.Migrator().CreateIndex(value, idx.Name) } else { createTableSQL += "INDEX ? ?," - values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } @@ -161,8 +163,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { // create join table if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if !tx.Migrator().HasTable(joinValue) { + defer tx.Migrator().CreateTable(joinValue) } } } @@ -175,7 +177,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL += ")" - return m.DB.Exec(createTableSQL, values...).Error + return tx.Exec(createTableSQL, values...).Error }); err != nil { return err } @@ -185,8 +187,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { for _, value := range values { + tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error + return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } diff --git a/tests/migrate.go b/tests/migrate.go index 9f7e2d67..477f0ad6 100644 --- a/tests/migrate.go +++ b/tests/migrate.go @@ -7,7 +7,7 @@ import ( ) func TestMigrate(t *testing.T, db *gorm.DB) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} for _, m := range allModels { if db.Migrator().HasTable(m) { diff --git a/tests/model.go b/tests/model.go index ac2156c7..b2d5efe1 100644 --- a/tests/model.go +++ b/tests/model.go @@ -21,7 +21,7 @@ type User struct { Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company - ManagerID int + ManagerID uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak"` @@ -49,7 +49,7 @@ type Toy struct { } type Company struct { - ID uint + ID int Name string }