Enhance migrator Columntype interface (#5088)

* Update Migrator ColumnType interface

* Update MigrateColumn Test

* Upgrade test drivers

* Fix typo
This commit is contained in:
Jinzhu 2022-02-19 16:59:22 +08:00
parent 39d84cba5f
commit 0af95f509a
5 changed files with 185 additions and 14 deletions

View File

@ -1,6 +1,8 @@
package gorm
import (
"reflect"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
@ -33,14 +35,23 @@ type ViewOption struct {
Query *DB
}
// ColumnType column type interface
type ColumnType interface {
Name() string
DatabaseTypeName() string
DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
}
// Migrator migrator interface
type Migrator interface {
// AutoMigrate
AutoMigrate(dst ...interface{}) error

107
migrator/column_type.go Normal file
View File

@ -0,0 +1,107 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimayKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. lke `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

View File

@ -30,10 +30,12 @@ type Config struct {
gorm.Dialector
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string
}
// RunWithValue run migration with statement value
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil {
@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
return fc(stmt)
}
// DataTypeOf return field's db data type
func (m Migrator) DataTypeOf(field *schema.Field) string {
fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
return m.Dialector.DataTypeOf(field)
}
// FullDataTypeOf returns field's db full data type
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field)
@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return
}
// AutoMigrate
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{})
@ -156,12 +160,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return nil
}
// GetTables returns tables
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
// CreateTable create table in database for values
func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
@ -252,6 +258,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
return nil
}
// DropTable drop table for values
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- {
@ -265,6 +272,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
return nil
}
// HasTable returns table exists or not for value, value could be a struct or string
func (m Migrator) HasTable(value interface{}) bool {
var count int64
@ -276,6 +284,7 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0
}
// RenameTable rename table from oldName to newName
func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable interface{}
if v, ok := oldName.(string); ok {
@ -303,12 +312,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
}
func (m Migrator) AddColumn(value interface{}, field string) error {
// AddColumn create `name` column for value
func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
f := stmt.Schema.LookUpField(field)
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", field)
return fmt.Errorf("failed to look up field with name: %s", name)
}
if !f.IgnoreMigration {
@ -322,6 +332,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
})
}
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
@ -334,6 +345,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
})
}
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
@ -348,6 +360,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
})
}
// HasColumn check has column `field` for value or not
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -366,6 +379,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
return count > 0
}
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil {
@ -383,6 +397,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
})
}
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
// found, smart migrate
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
@ -448,7 +463,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
}
for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, c)
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
}
return
@ -457,10 +472,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr
}
// CreateView create view
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
return gorm.ErrNotImplemented
}
// DropView drop view
func (m Migrator) DropView(name string) error {
return gorm.ErrNotImplemented
}
@ -487,6 +504,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
return
}
// GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
if stmt.Schema == nil {
return nil, nil, stmt.Table
@ -531,6 +549,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
return nil, nil, stmt.Schema.Table
}
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@ -554,6 +573,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
})
}
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@ -566,6 +586,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
})
}
// HasConstraint check has constraint or not
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -586,6 +607,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0
}
// BuildIndexOptions build index options
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
@ -607,10 +629,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
return
}
// BuildIndexOptionsInterface build index options interface
type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
}
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@ -642,6 +666,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
})
}
// DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@ -652,6 +677,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
})
}
// HasIndex check has index `name` or not
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -669,6 +695,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
return count > 0
}
// RenameIndex rename index from oldName to newName
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
@ -678,6 +705,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
})
}
// CurrentDatabase returns current database name
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
return
@ -781,6 +809,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
return
}
// CurrentTable returns current statement's table expression
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
if stmt.TableExpr != nil {
return *stmt.TableExpr

View File

@ -3,17 +3,16 @@ module gorm.io/gorm/tests
go 1.14
require (
github.com/denisenkom/go-mssqldb v0.12.0 // indirect
github.com/google/uuid v1.3.0
github.com/jackc/pgx/v4 v4.15.0 // indirect
github.com/jinzhu/now v1.1.4
github.com/lib/pq v1.10.4
github.com/mattn/go-sqlite3 v1.14.11 // indirect
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
gorm.io/driver/mysql v1.2.3
gorm.io/driver/postgres v1.2.3
gorm.io/driver/sqlite v1.2.6
gorm.io/driver/sqlserver v1.2.1
gorm.io/driver/mysql v1.3.0
gorm.io/driver/postgres v1.3.0
gorm.io/driver/sqlite v1.3.0
gorm.io/driver/sqlserver v1.3.0
gorm.io/gorm v1.22.5
)

View File

@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
}
func TestSmartMigrateColumn(t *testing.T) {
fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()]
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
type UserMigrateColumn struct {
ID uint
@ -313,9 +313,15 @@ func TestMigrateIndexes(t *testing.T) {
}
func TestMigrateColumns(t *testing.T) {
fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()]
sqlite := DB.Dialector.Name() == "sqlite"
sqlserver := DB.Dialector.Name() == "sqlserver"
type ColumnStruct struct {
gorm.Model
Name string
Age int `gorm:"default:18;comment:my age"`
Code string `gorm:"unique"`
}
DB.Migrator().DropTable(&ColumnStruct{})
@ -340,10 +346,29 @@ func TestMigrateColumns(t *testing.T) {
stmt.Parse(&ColumnStruct2{})
for _, columnType := range columnTypes {
if columnType.Name() == "name" {
switch columnType.Name() {
case "id":
if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "name":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType)
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
}
if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 {
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
}
case "age":
if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" {
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" {
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code":
if v, ok := columnType.Unique(); (fullSupported || ok) && !v {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
}
}