From fab7d96da5d0308a77684acb9b39eb558b6ea58e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 22 Feb 2020 17:53:57 +0800 Subject: [PATCH] Add DataTypeOf for dialector --- dialects/mssql/migrator.go | 37 ++++++ dialects/mssql/mssql.go | 75 ++++++++++++ dialects/mysql/migrator.go | 43 +++++++ dialects/mysql/mysql.go | 83 ++++++++++++- dialects/postgres/migrator.go | 89 ++++++++++++++ dialects/postgres/postgres.go | 51 +++++++- dialects/sqlite/migrator.go | 122 +++++++++++++++++++ dialects/sqlite/sqlite.go | 32 ++++- interfaces.go | 5 +- migrator.go | 4 +- migrator/migrator.go | 223 +++++++++++++++++----------------- schema/field.go | 5 +- 12 files changed, 640 insertions(+), 129 deletions(-) create mode 100644 dialects/mssql/migrator.go create mode 100644 dialects/mssql/mssql.go create mode 100644 dialects/mysql/migrator.go create mode 100644 dialects/postgres/migrator.go create mode 100644 dialects/sqlite/migrator.go diff --git a/dialects/mssql/migrator.go b/dialects/mssql/migrator.go new file mode 100644 index 00000000..43eaf573 --- /dev/null +++ b/dialects/mssql/migrator.go @@ -0,0 +1,37 @@ +package mssql + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/migrator" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", + name, stmt.Table, + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`, + name, stmt.Table, m.CurrentDatabase(), + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) + return +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go new file mode 100644 index 00000000..bdca667d --- /dev/null +++ b/dialects/mssql/mssql.go @@ -0,0 +1,75 @@ +package mssql + +import ( + "database/sql" + "fmt" + + _ "github.com/denisenkom/go-mssqldb" + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" +) + +type Dialector struct { + DSN string +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{DSN: dsn} +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + db.DB, err = sql.Open("sqlserver", dialector.DSN) + return +} + +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} +} + +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { + return "?" +} + +func (dialector Dialector) QuoteChars() [2]byte { + return [2]byte{'[', ']'} // `name` +} + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "bit" + case schema.Int, schema.Uint: + var sqlType string + switch { + case field.Size < 16: + sqlType = "smallint" + case field.Size < 31: + sqlType = "int" + default: + sqlType = "bigint" + } + + if field.AutoIncrement { + return sqlType + " IDENTITY(1,1)" + } + return sqlType + case schema.Float: + return "decimal" + case schema.String: + if field.Size > 0 && field.Size <= 4000 { + return fmt.Sprintf("nvarchar(%d)", field.Size) + } + return "ntext" + case schema.Time: + return "datetimeoffset" + case schema.Bytes: + return "binary" + } + + return "" +} diff --git a/dialects/mysql/migrator.go b/dialects/mysql/migrator.go new file mode 100644 index 00000000..2c11af94 --- /dev/null +++ b/dialects/mysql/migrator.go @@ -0,0 +1,43 @@ +package mysql + +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/migrator" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return m.DB.Exec( + "ALTER TABLE ? MODIFY COLUMN ? TYPE ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + ).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if chk.Name == name { + return m.DB.Exec( + "ALTER TABLE ? DROP CHECK ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error + } + } + + return m.DB.Exec( + "ALTER TABLE ? DROP FOREIGN KEY ?", + clause.Table{Name: stmt.Table}, clause.Column{Name: name}, + ).Error + }) +} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index b402ef95..e2fea53c 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -1,33 +1,104 @@ package mysql import ( + "database/sql" + "fmt" + "math" + _ "github.com/go-sql-driver/mysql" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" ) type Dialector struct { + DSN string } func Open(dsn string) gorm.Dialector { - return &Dialector{} + return &Dialector{DSN: dsn} } -func (Dialector) Initialize(db *gorm.DB) error { +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) + db.DB, err = sql.Open("sqlite3", dialector.DSN) return nil } -func (Dialector) Migrator() gorm.Migrator { - return nil +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} } -func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (Dialector) QuoteChars() [2]byte { +func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "boolean" + case schema.Int, schema.Uint: + sqlType := "int" + switch { + case field.Size <= 8: + sqlType = "tinyint" + case field.Size <= 16: + sqlType = "smallint" + case field.Size <= 32: + sqlType = "int" + default: + sqlType = "bigint" + } + + if field.DataType == schema.Uint { + sqlType += " unsigned" + } + + if field.AutoIncrement { + sqlType += " AUTO_INCREMENT" + } + return sqlType + case schema.Float: + if field.Size <= 32 { + return "float" + } + return "double" + case schema.String: + size := field.Size + if size >= 65536 && size <= int(math.Pow(2, 24)) { + return "mediumtext" + } else if size > int(math.Pow(2, 24)) || size < 0 { + return "longtext" + } + return fmt.Sprintf("varchar(%d)", size) + case schema.Time: + precision := "" + if field.Precision > 0 { + precision = fmt.Sprintf("(%d)", field.Precision) + } + + if field.NotNull || field.PrimaryKey { + return "datetime" + precision + } + return "datetime" + precision + " NULL" + case schema.Bytes: + if field.Size > 0 && field.Size < 65536 { + return fmt.Sprintf("varbinary(%d)", field.Size) + } + + if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { + return "mediumblob" + } + + return "longblob" + } + + return "" +} diff --git a/dialects/postgres/migrator.go b/dialects/postgres/migrator.go new file mode 100644 index 00000000..35101bf3 --- /dev/null +++ b/dialects/postgres/migrator.go @@ -0,0 +1,89 @@ +package postgres + +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) HasIndex(value interface{}, indexName string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + err := fmt.Errorf("failed to create index with name %v", name) + indexes := stmt.Schema.ParseIndexes() + + if idx, ok := indexes[name]; ok { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } else if field := stmt.Schema.LookUpField(name); field != nil { + for _, idx := range indexes { + for _, idxOpt := range idx.Fields { + if idxOpt.Field == field { + if err = m.CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + } + } + return err + }) +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 9ea0048a..a3eeefb9 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -2,9 +2,12 @@ package postgres import ( "database/sql" + "fmt" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" _ "github.com/lib/pq" ) @@ -24,14 +27,54 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { return } -func (Dialector) Migrator() gorm.Migrator { - return nil +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} } -func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (Dialector) QuoteChars() [2]byte { +func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'"', '"'} // "name" } + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "boolean" + case schema.Int, schema.Uint: + if field.AutoIncrement { + switch { + case field.Size < 16: + return "smallserial" + case field.Size < 31: + return "serial" + default: + return "bigserial" + } + } else { + switch { + case field.Size < 16: + return "smallint" + case field.Size < 31: + return "integer" + default: + return "bigint" + } + } + case schema.Float: + return "decimal" + case schema.String: + if field.Size > 0 { + return fmt.Sprintf("varchar(%d)", field.Size) + } + return "text" + case schema.Time: + return "timestamp with time zone" + case schema.Bytes: + return "bytea" + } + + return "" +} diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go new file mode 100644 index 00000000..07e189ad --- /dev/null +++ b/dialects/sqlite/migrator.go @@ -0,0 +1,122 @@ +package sqlite + +import ( + "fmt" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)", + stmt.Table, `%"`+name+`" %`, `%`+name+` %`, + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE ?", + stmt.Table, "%INDEX "+name+" ON%", + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) CreateConstraint(interface{}, string) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) DropConstraint(interface{}, string) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) CurrentDatabase() (name string) { + var null interface{} + m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + err := fmt.Errorf("failed to create index with name %v", name) + indexes := stmt.Schema.ParseIndexes() + + if idx, ok := indexes[name]; ok { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } else if field := stmt.Schema.LookUpField(name); field != nil { + for _, idx := range indexes { + for _, idxOpt := range idx.Fields { + if idxOpt.Field == field { + if err = m.CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + } + } + return err + }) +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 80a18cfb..b77226db 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -5,6 +5,8 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" + "github.com/jinzhu/gorm/migrator" + "github.com/jinzhu/gorm/schema" _ "github.com/mattn/go-sqlite3" ) @@ -24,14 +26,36 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { return } -func (Dialector) Migrator() gorm.Migrator { - return nil +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} } -func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { +func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } -func (Dialector) QuoteChars() [2]byte { +func (dialector Dialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "NUMERIC" + case schema.Int, schema.Uint: + if field.AutoIncrement { + // https://www.sqlite.org/autoinc.html + return "INTEGER PRIMARY KEY AUTOINCREMENT" + } else { + return "INTEGER" + } + case schema.Float: + return "REAL" + case schema.String, schema.Time: + return "TEXT" + case schema.Bytes: + return "BLOB" + } + + return "" +} diff --git a/interfaces.go b/interfaces.go index 71522455..8f0f3085 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + + "github.com/jinzhu/gorm/schema" ) // Dialector GORM database dialector type Dialector interface { Initialize(*DB) error - Migrator() Migrator + Migrator(db *DB) Migrator + DataTypeOf(*schema.Field) string BindVar(stmt *Statement, v interface{}) string QuoteChars() [2]byte } diff --git a/migrator.go b/migrator.go index a5ea4d8f..d90c362f 100644 --- a/migrator.go +++ b/migrator.go @@ -6,7 +6,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator() + return db.Dialector.Migrator(db) } // ViewOption view option @@ -26,7 +26,7 @@ type Migrator interface { // Tables CreateTable(dst ...interface{}) error DropTable(dst ...interface{}) error - HasTable(dst ...interface{}) bool + HasTable(dst interface{}) bool RenameTable(oldName, newName string) error // Columns diff --git a/migrator/migrator.go b/migrator/migrator.go index 7e749037..9e94cc68 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -11,21 +11,21 @@ import ( "github.com/jinzhu/gorm/schema" ) -// Migrator migrator struct +// Migrator m struct type Migrator struct { - *Config + Config } // Config schema config type Config struct { - CheckExistsBeforeDropping bool - DB *gorm.DB + DB *gorm.DB + gorm.Dialector } -func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { - stmt := migrator.DB.Statement +func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := m.DB.Statement if stmt == nil { - stmt = &gorm.Statement{DB: migrator.DB} + stmt = &gorm.Statement{DB: m.DB} } if err := stmt.Parse(value); err != nil { @@ -35,20 +35,28 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement return fc(stmt) } +func (m Migrator) DataTypeOf(field *schema.Field) string { + if field.DBDataType != "" { + return field.DBDataType + } + + return m.Dialector.DataTypeOf(field) +} + // AutoMigrate -func (migrator Migrator) AutoMigrate(values ...interface{}) error { +func (m Migrator) AutoMigrate(values ...interface{}) error { // TODO smart migrate data type for _, value := range values { - if !migrator.DB.Migrator().HasTable(value) { - if err := migrator.DB.Migrator().CreateTable(value); err != nil { + if !m.DB.Migrator().HasTable(value) { + if err := m.DB.Migrator().CreateTable(value); err != nil { return err } } else { - if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, field := range stmt.Schema.FieldsByDBName { - if !migrator.DB.Migrator().HasColumn(value, field.DBName) { - if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil { + if !m.DB.Migrator().HasColumn(value, field.DBName) { + if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil { return err } } @@ -56,16 +64,16 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) { - if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { + if !m.DB.Migrator().HasConstraint(value, constraint.Name) { + if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !migrator.DB.Migrator().HasConstraint(value, chk.Name) { - if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !m.DB.Migrator().HasConstraint(value, chk.Name) { + if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } @@ -73,8 +81,8 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { // create join table joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !migrator.DB.Migrator().HasTable(joinValue) { - defer migrator.DB.Migrator().CreateTable(joinValue) + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) } } return nil @@ -87,9 +95,9 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { return nil } -func (migrator Migrator) CreateTable(values ...interface{}) error { +func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range values { - if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{clause.Table{Name: stmt.Table}} @@ -100,7 +108,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType}) + values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)}) if field.AutoIncrement { createTableSQL += " AUTO_INCREMENT" @@ -133,7 +141,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { createTableSQL += "INDEX ? ?," - values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt)) + values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } for _, rel := range stmt.Schema.Relationships.Relations { @@ -145,8 +153,8 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { // create join table joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !migrator.DB.Migrator().HasTable(joinValue) { - defer migrator.DB.Migrator().CreateTable(joinValue) + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) } } @@ -158,7 +166,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL += ")" - return migrator.DB.Exec(createTableSQL, values...).Error + return m.DB.Exec(createTableSQL, values...).Error }); err != nil { return err } @@ -166,10 +174,10 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { return nil } -func (migrator Migrator) DropTable(values ...interface{}) error { +func (m Migrator) DropTable(values ...interface{}) error { for _, value := range values { - if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } @@ -177,42 +185,36 @@ func (migrator Migrator) DropTable(values ...interface{}) error { return nil } -func (migrator Migrator) HasTable(values ...interface{}) bool { +func (m Migrator) HasTable(value interface{}) bool { var count int64 - for _, value := range values { - err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error - }) + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) + }) - if err != nil || count == 0 { - return false - } - } - - return true + return count > 0 } -func (migrator Migrator) RenameTable(oldName, newName string) error { - return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error +func (m Migrator) RenameTable(oldName, newName string) error { + return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error } -func (migrator Migrator) AddColumn(value interface{}, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) AddColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } -func (migrator Migrator) DropColumn(value interface{}, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) DropColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, ).Error } @@ -220,44 +222,41 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error { }) } -func (migrator Migrator) AlterColumn(value interface{}, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } -func (migrator Migrator) HasColumn(value interface{}, field string) bool { +func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 - migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() name := field if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName } - return migrator.DB.Raw( + return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, stmt.Table, name, - ).Scan(&count).Error + ).Row().Scan(&count) }) - if count != 0 { - return true - } - return false + return count > 0 } -func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) - return migrator.DB.Exec( + oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName) + return m.DB.Exec( "ALTER TABLE ? RENAME COLUMN ? TO ?", clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, ).Error @@ -266,15 +265,15 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) }) } -func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { +func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { return nil, gorm.ErrNotImplemented } -func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error { +func (m Migrator) CreateView(name string, option gorm.ViewOption) error { return gorm.ErrNotImplemented } -func (migrator Migrator) DropView(name string) error { +func (m Migrator) DropView(name string) error { return gorm.ErrNotImplemented } @@ -300,11 +299,11 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } -func (migrator Migrator) CreateConstraint(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { - return migrator.DB.Exec( + return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK ?", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error @@ -313,21 +312,21 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { sql, values := buildConstraint(constraint) - return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error } } err := fmt.Errorf("failed to create constraint with name %v", name) if field := stmt.Schema.LookUpField(name); field != nil { for _, cc := range checkConstraints { - if err = migrator.CreateIndex(value, cc.Name); err != nil { + if err = m.CreateIndex(value, cc.Name); err != nil { return err } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { - if err = migrator.CreateIndex(value, constraint.Name); err != nil { + if err = m.CreateIndex(value, constraint.Name); err != nil { return err } } @@ -338,32 +337,29 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error }) } -func (migrator Migrator) DropConstraint(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec( +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( "ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, ).Error }) } -func (migrator Migrator) HasConstraint(value interface{}, name string) bool { +func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 - migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw( + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", currentDatabase, stmt.Table, name, - ).Scan(&count).Error + ).Row().Scan(&count) }) - if count != 0 { - return true - } - return false + return count > 0 } -func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { @@ -372,6 +368,10 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results str += fmt.Sprintf("(%d)", opt.Length) } + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + if opt.Sort != "" { str += " " + opt.Sort } @@ -380,13 +380,17 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results return } -func (migrator Migrator) CreateIndex(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { +type BuildIndexOptionsInterface interface { + BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { err := fmt.Errorf("failed to create index with name %v", name) indexes := stmt.Schema.ParseIndexes() if idx, ok := indexes[name]; ok { - opts := buildIndexOptions(idx.Fields, stmt) + opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} createIndexSQL := "CREATE " @@ -404,12 +408,12 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " USING " + idx.Type } - return migrator.DB.Raw(createIndexSQL, values...).Error + return m.DB.Exec(createIndexSQL, values...).Error } else if field := stmt.Schema.LookUpField(name); field != nil { for _, idx := range indexes { for _, idxOpt := range idx.Fields { if idxOpt.Field == field { - if err = migrator.CreateIndex(value, idx.Name); err != nil { + if err = m.CreateIndex(value, idx.Name); err != nil { return err } } @@ -420,38 +424,35 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error { }) } -func (migrator Migrator) DropIndex(value interface{}, name string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error }) } -func (migrator Migrator) HasIndex(value interface{}, name string) bool { +func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 - migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - currentDatabase := migrator.DB.Migrator().CurrentDatabase() - return migrator.DB.Raw( + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw( "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name, - ).Scan(&count).Error + ).Row().Scan(&count) }) - if count != 0 { - return true - } - return false + return count > 0 } -func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { - return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { - return migrator.DB.Exec( +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( "ALTER TABLE ? RENAME INDEX ? TO ?", clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } -func (migrator Migrator) CurrentDatabase() (name string) { - migrator.DB.Raw("SELECT DATABASE()").Scan(&name) +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return } diff --git a/schema/field.go b/schema/field.go index 60cfc2ab..ea4e6a40 100644 --- a/schema/field.go +++ b/schema/field.go @@ -138,7 +138,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if num, ok := field.TagSettings["SIZE"]; ok { - field.Size, _ = strconv.Atoi(num) + var err error + if field.Size, err = strconv.Atoi(num); err != nil { + field.Size = -1 + } } if p, ok := field.TagSettings["PRECISION"]; ok {