Add migrator tests for mssql

This commit is contained in:
Jinzhu 2020-02-23 01:02:07 +08:00
parent ce84e82c9e
commit 1d803dfdd9
4 changed files with 59 additions and 11 deletions

View File

@ -9,6 +9,17 @@ type Migrator struct {
migrator.Migrator migrator.Migrator
} }
func (m Migrator) HasTable(value interface{}) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?",
stmt.Table, m.CurrentDatabase(),
).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) HasIndex(value interface{}, name string) bool { func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {

View File

@ -3,6 +3,7 @@ package mssql
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"strconv"
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -31,15 +32,16 @@ func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{ return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db, DB: db,
Dialector: dialector, Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}} }}}
} }
func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?" return "@p" + strconv.Itoa(len(stmt.Vars))
} }
func (dialector Dialector) QuoteChars() [2]byte { func (dialector Dialector) QuoteChars() [2]byte {
return [2]byte{'[', ']'} // `name` return [2]byte{'"', '"'} // `name`
} }
func (dialector Dialector) DataTypeOf(field *schema.Field) string { func (dialector Dialector) DataTypeOf(field *schema.Field) string {
@ -64,8 +66,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
case schema.Float: case schema.Float:
return "decimal" return "decimal"
case schema.String: case schema.String:
if field.Size > 0 && field.Size <= 4000 { size := field.Size
return fmt.Sprintf("nvarchar(%d)", field.Size) if field.PrimaryKey {
size = 256
}
if size > 0 && size <= 4000 {
return fmt.Sprintf("nvarchar(%d)", size)
} }
return "ntext" return "ntext"
case schema.Time: case schema.Time:

View File

@ -0,0 +1,29 @@
package mssql_test
import (
"fmt"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/dialects/mssql"
"github.com/jinzhu/gorm/tests"
)
var (
DB *gorm.DB
err error
)
func init() {
if DB, err = gorm.Open(mssql.Open("sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"), &gorm.Config{}); err != nil {
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
}
}
func TestCURD(t *testing.T) {
tests.RunTestsSuit(t, DB)
}
func TestMigrate(t *testing.T) {
tests.TestMigrate(t, DB)
}

View File

@ -189,6 +189,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false) values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- { for i := len(values) - 1; i >= 0; i-- {
value := values[i] value := values[i]
if m.DB.Migrator().HasTable(value) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error
@ -196,6 +197,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
return err return err
} }
} }
}
return nil return nil
} }