From 1d803dfdd9fa106f329ff6247433e893d44cb152 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 01:02:07 +0800 Subject: [PATCH] Add migrator tests for mssql --- dialects/mssql/migrator.go | 11 +++++++++++ dialects/mssql/mssql.go | 18 ++++++++++++------ dialects/mssql/mssql_test.go | 29 +++++++++++++++++++++++++++++ migrator/migrator.go | 12 +++++++----- 4 files changed, 59 insertions(+), 11 deletions(-) create mode 100644 dialects/mssql/mssql_test.go diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go index 43eaf573..412d86c6 100644 --- a/dialects/mssql/migrator.go +++ b/dialects/mssql/migrator.go @@ -9,6 +9,17 @@ type Migrator struct { migrator.Migrator } +func (m Migrator) HasTable(value interface{}) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", + stmt.Table, m.CurrentDatabase(), + ).Row().Scan(&count) + }) + return count > 0 +} + func (m Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 78c048b4..ded49aae 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -3,6 +3,7 @@ package mssql import ( "database/sql" "fmt" + "strconv" _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" @@ -29,17 +30,18 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, }}} } func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { - return "?" + return "@p" + strconv.Itoa(len(stmt.Vars)) } func (dialector Dialector) QuoteChars() [2]byte { - return [2]byte{'[', ']'} // `name` + return [2]byte{'"', '"'} // `name` } func (dialector Dialector) DataTypeOf(field *schema.Field) string { @@ -64,8 +66,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Float: return "decimal" case schema.String: - if field.Size > 0 && field.Size <= 4000 { - return fmt.Sprintf("nvarchar(%d)", field.Size) + size := field.Size + if field.PrimaryKey { + size = 256 + } + if size > 0 && size <= 4000 { + return fmt.Sprintf("nvarchar(%d)", size) } return "ntext" case schema.Time: diff --git a/dialects/mssql/mssql_test.go b/dialects/mssql/mssql_test.go new file mode 100644 index 00000000..b56e7369 --- /dev/null +++ b/dialects/mssql/mssql_test.go @@ -0,0 +1,29 @@ +package mssql_test + +import ( + "fmt" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/mssql" + "github.com/jinzhu/gorm/tests" +) + +var ( + DB *gorm.DB + err error +) + +func init() { + if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestCURD(t *testing.T) { + tests.RunTestsSuit(t, DB) +} + +func TestMigrate(t *testing.T) { + tests.TestMigrate(t, DB) +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 318c2fb8..4b52193f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -189,11 +189,13 @@ func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { value := values[i] - tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error - }); err != nil { - return err + if m.DB.Migrator().HasTable(value) { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } } } return nil