package migrator import ( "context" "database/sql" "fmt" "reflect" "regexp" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) // Migrator m struct type Migrator struct { Config } // Config schema config type Config struct { CreateIndexAfterCreateTable bool DB *gorm.DB gorm.Dialector } type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string } func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { stmt.Table = m.DB.Statement.Table } if table, ok := value.(string); ok { stmt.Table = table } else if err := stmt.Parse(value); err != nil { return err } return fc(stmt) } func (m Migrator) DataTypeOf(field *schema.Field) string { fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { return dataType } } return m.Dialector.DataTypeOf(field) } func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) if field.NotNull { expr.SQL += " NOT NULL" } if field.Unique { expr.SQL += " UNIQUE" } if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) } else { expr.SQL += " DEFAULT " + field.DefaultValue } } return } // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { columnTypes, _ := m.DB.Migrator().ColumnTypes(value) for _, field := range stmt.Schema.FieldsByDBName { var foundColumn *sql.ColumnType for _, columnType := range columnTypes { if columnType.Name() == field.DBName { foundColumn = columnType break } } if foundColumn == nil { // not found, add column if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { // found, smart migrate return err } } for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { if !tx.Migrator().HasConstraint(value, constraint.Name) { if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } } } for _, chk := range stmt.Schema.ParseCheckConstraints() { if !tx.Migrator().HasConstraint(value, chk.Name) { if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } } } return nil }); err != nil { return err } } } return nil } func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{clause.Table{Name: stmt.Table}} hasPrimaryKeyInDataType bool ) for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += "? ?" hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { createTableSQL += "PRIMARY KEY ?," primaryKeys := []interface{}{} for _, field := range stmt.Schema.PrimaryFields { primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) } values = append(values, primaryKeys) } for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { defer func(value interface{}, name string) { errr = tx.Migrator().CreateIndex(value, name) }(value, idx.Name) } else { if idx.Class != "" { createTableSQL += idx.Class + " " } createTableSQL += "INDEX ? ?," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.DisableForeignKeyConstraintWhenMigrating { if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { sql, vars := buildConstraint(constraint) createTableSQL += sql + "," values = append(values, vars...) } } } } for _, chk := range stmt.Schema.ParseCheckConstraints() { createTableSQL += "CONSTRAINT ? CHECK (?)," values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) } createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL += ")" if tableOption, ok := m.DB.Get("gorm:table_options"); ok { createTableSQL += fmt.Sprint(tableOption) } errr = tx.Exec(createTableSQL, values...).Error return errr }); err != nil { return err } } return nil } func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } } return nil } func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) }) return count > 0 } func (m Migrator) RenameTable(oldName, newName interface{}) error { var oldTable, newTable string if v, ok := oldName.(string); ok { oldTable = v } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(oldName); err == nil { oldTable = stmt.Table } else { return err } } if v, ok := newName.(string); ok { newTable = v } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(newName); err == nil { newTable = stmt.Table } else { return err } } return m.DB.Exec("ALTER TABLE ? RENAME TO ?", clause.Table{Name: oldTable}, clause.Table{Name: newTable}).Error } func (m Migrator) AddColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } return m.DB.Exec( "ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, ).Error }) } func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { fileType := clause.Expr{SQL: m.DataTypeOf(field)} return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType, ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() name := field if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName } return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(oldName); field != nil { oldName = field.DBName } if field := stmt.Schema.LookUpField(newName); field != nil { newName = field.DBName } return m.DB.Exec( "ALTER TABLE ? RENAME COLUMN ? TO ?", clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { // found, smart migrate fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) realDataType := strings.ToLower(columnType.DatabaseTypeName()) alterColumn := false // check size if length, _ := columnType.Length(); length != int64(field.Size) { if length > 0 && field.Size > 0 { alterColumn = true } else { // has size in data type and not equal matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1) matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } } } // check precision if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { alterColumn = true } } // check nullable if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { // not primary key & database is nullable if !field.PrimaryKey && nullable { alterColumn = true } } if alterColumn { return m.DB.Migrator().AlterColumn(value, field.Name) } return nil } func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ? limit 1", clause.Table{Name: stmt.Table}).Rows() if err == nil { columnTypes, err = rows.ColumnTypes() _ = rows.Close() } return err }) return } func (m Migrator) CreateView(name string, option gorm.ViewOption) error { return gorm.ErrNotImplemented } func (m Migrator) DropView(name string) error { return gorm.ErrNotImplemented } func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" if constraint.OnDelete != "" { sql += " ON DELETE " + constraint.OnDelete } if constraint.OnUpdate != "" { sql += " ON UPDATE " + constraint.OnUpdate } var foreignKeys, references []interface{} for _, field := range constraint.ForeignKeys { foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) } for _, field := range constraint.References { references = append(references, clause.Column{Name: field.DBName}) } results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) return } func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { sql, values := buildConstraint(constraint) return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error } } err := fmt.Errorf("failed to create constraint with name %v", name) if field := stmt.Schema.LookUpField(name); field != nil { for _, cc := range checkConstraints { if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { return err } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { return err } } } } return err }) } func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, ).Error }) } func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", currentDatabase, stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { str = opt.Expression } else if opt.Length > 0 { str += fmt.Sprintf("(%d)", opt.Length) } if opt.Collate != "" { str += " COLLATE " + opt.Collate } if opt.Sort != "" { str += " " + opt.Sort } results = append(results, clause.Expr{SQL: str}) } return } type BuildIndexOptionsInterface interface { BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} } func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} createIndexSQL := "CREATE " if idx.Class != "" { createIndexSQL += idx.Class + " " } createIndexSQL += "INDEX ? ON ??" if idx.Type != "" { createIndexSQL += " USING " + idx.Type } return m.DB.Exec(createIndexSQL, values...).Error } return fmt.Errorf("failed to create index with name %v", name) }) } func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error }) } func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } return m.DB.Raw( "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER TABLE ? RENAME INDEX ? TO ?", clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return } // ReorderModels reorder models according to constraint dependencies func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { type Dependency struct { *gorm.Statement Depends []*schema.Schema } var ( modelNames, orderedModelNames []string orderedModelNamesMap = map[string]bool{} parsedSchemas = map[*schema.Schema]bool{} valuesMap = map[string]Dependency{} insertIntoOrderedList func(name string) parseDependence func(value interface{}, addToList bool) ) parseDependence = func(value interface{}, addToList bool) { dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } beDependedOn := map[*schema.Schema]bool{} if err := dep.Parse(value); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } if _, ok := parsedSchemas[dep.Statement.Schema]; ok { return } parsedSchemas[dep.Statement.Schema] = true for _, rel := range dep.Schema.Relationships.Relations { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } if rel.Type == schema.HasOne || rel.Type == schema.HasMany { beDependedOn[rel.FieldSchema] = true } if rel.JoinTable != nil { // append join value defer func(rel *schema.Relationship, joinValue interface{}) { if !beDependedOn[rel.FieldSchema] { dep.Depends = append(dep.Depends, rel.FieldSchema) } else { fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() parseDependence(fieldValue, autoAdd) } parseDependence(joinValue, autoAdd) }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) } } valuesMap[dep.Schema.Table] = dep if addToList { modelNames = append(modelNames, dep.Schema.Table) } } insertIntoOrderedList = func(name string) { if _, ok := orderedModelNamesMap[name]; ok { return // avoid loop } orderedModelNamesMap[name] = true dep := valuesMap[name] for _, d := range dep.Depends { if _, ok := valuesMap[d.Table]; ok { insertIntoOrderedList(d.Table) } else if autoAdd { parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) insertIntoOrderedList(d.Table) } } orderedModelNames = append(orderedModelNames, name) } for _, value := range values { if v, ok := value.(string); ok { results = append(results, v) } else { parseDependence(value, true) } } for _, name := range modelNames { insertIntoOrderedList(name) } for _, name := range orderedModelNames { results = append(results, valuesMap[name].Statement.Dest) } return }