From d3c63a03cbed09c07d4c5a19189d25768f3204ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 00:18:12 +0800 Subject: [PATCH] Handle constraint dependencies smartly --- migrator/migrator.go | 77 ++++++++++++++++++++++++++++++++++++++++++-- tests/migrate.go | 12 +++---- 2 files changed, 80 insertions(+), 9 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index a5ec1a62..318c2fb8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -48,7 +48,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { // TODO smart migrate data type - for _, value := range values { + for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { @@ -100,7 +100,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } func (m Migrator) CreateTable(values ...interface{}) error { - for _, value := range values { + for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( @@ -186,7 +186,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { } func (m Migrator) DropTable(values ...interface{}) error { - for _, value := range values { + values = m.ReorderModels(values, false) + for i := len(values) - 1; i >= 0; i-- { + value := values[i] tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error @@ -475,3 +477,72 @@ func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return } + +// ReorderModels reorder models according to constraint dependencies +func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { + type Dependency struct { + Table string + Depends []*schema.Schema + } + + var ( + modelNames, orderedModelNames []string + orderedModelNamesMap = map[string]bool{} + valuesMap = map[string]*gorm.Statement{} + dependencies = map[string]Dependency{} + insertIntoOrderedMap func(name string) + ) + + parseDependence := func(value interface{}, addToMap bool) { + stmt := &gorm.Statement{DB: m.DB, Dest: value} + stmt.Parse(value) + dep := Dependency{Table: stmt.Schema.Table} + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil { + dep.Depends = append(dep.Depends, constraint.ReferenceSchema) + } + } + dependencies[stmt.Schema.Table] = dep + + if addToMap { + modelNames = append(modelNames, stmt.Schema.Table) + valuesMap[stmt.Schema.Table] = stmt + } + } + + for _, value := range values { + parseDependence(value, true) + } + + insertIntoOrderedMap = func(name string) { + // avoid loop + if _, ok := orderedModelNamesMap[name]; ok { + return + } + + dep := dependencies[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + if _, ok := orderedModelNamesMap[d.Table]; !ok && name != d.Table { + insertIntoOrderedMap(d.Table) + } + } else if autoAdd { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedMap(d.Table) + } + } + + orderedModelNames = append(orderedModelNames, name) + orderedModelNamesMap[name] = true + } + + for _, name := range modelNames { + insertIntoOrderedMap(name) + } + + for _, name := range orderedModelNames { + results = append(results, valuesMap[name].Dest) + } + return +} diff --git a/tests/migrate.go b/tests/migrate.go index 477f0ad6..fa8a89e8 100644 --- a/tests/migrate.go +++ b/tests/migrate.go @@ -1,20 +1,20 @@ package tests import ( + "math/rand" "testing" + "time" "github.com/jinzhu/gorm" ) func TestMigrate(t *testing.T, db *gorm.DB) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - for _, m := range allModels { - if db.Migrator().HasTable(m) { - if err := db.Migrator().DropTable(m); err != nil { - t.Errorf("Failed to drop table, got error %v", err) - } - } + if err := db.Migrator().DropTable(allModels...); err != nil { + t.Errorf("Failed to drop table, got error %v", err) } if err := db.AutoMigrate(allModels...); err != nil {