diff --git a/migrator/migrator.go b/migrator/migrator.go index c6d0947a..91dd8e83 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -451,50 +451,80 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { + if stmt.Schema == nil { + return nil, nil, stmt.Table + } + + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return nil, &chk, stmt.Table + } + + getTable := func(rel *schema.Relationship) string { + switch rel.Type { + case schema.HasOne, schema.HasMany: + return rel.FieldSchema.Table + case schema.Many2Many: + return rel.JoinTable.Table + } + return stmt.Table + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + return constraint, nil, getTable(rel) + } + } + + if field := stmt.Schema.LookUpField(name); field != nil { + for _, cc := range checkConstraints { + if cc.Field == field { + return nil, &cc, stmt.Table + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { + return constraint, nil, getTable(rel) + } + } + } + return nil, nil, "" +} + func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - checkConstraints := stmt.Schema.ParseCheckConstraints() - if chk, ok := checkConstraints[name]; ok { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if chk != nil { return m.DB.Exec( "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, ).Error } - for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - sql, values := buildConstraint(constraint) - return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{m.CurrentTable(stmt)}, values...)...).Error + if constraint != nil { + var vars = []interface{}{clause.Table{Name: table}} + if stmt.TableExpr != nil { + vars[0] = stmt.TableExpr } + sql, values := buildConstraint(constraint) + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } - err := fmt.Errorf("failed to create constraint with name %v", name) - if field := stmt.Schema.LookUpField(name); field != nil { - for _, cc := range checkConstraints { - if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { - return err - } - } - - for _, rel := range stmt.Schema.Relationships.Relations { - if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { - if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { - return err - } - } - } - } - - return err + return nil }) } func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Exec( - "ALTER TABLE ? DROP CONSTRAINT ?", - m.CurrentTable(stmt), clause.Column{Name: name}, - ).Error + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) } @@ -502,9 +532,16 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", - currentDatabase, stmt.Table, name, + currentDatabase, table, name, ).Row().Scan(&count) }) diff --git a/schema/relationship.go b/schema/relationship.go index 41e0b9bd..9b7d803c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -519,7 +519,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { } for _, ref := range rel.References { - if ref.PrimaryKey != nil { + if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) { constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) constraint.References = append(constraint.References, ref.PrimaryKey) @@ -533,10 +533,6 @@ func (rel *Relationship) ParseConstraint() *Constraint { } } - if rel.JoinTable != nil { - return nil - } - return &constraint } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 275fe634..ca28dfbc 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -323,3 +323,33 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("Found deleted column") } } + +func TestMigrateConstraint(t *testing.T) { + if DB.Dialector.Name() == "sqlite" { + t.Skip() + } + + names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Manager", "fk_users_manager", "Team", "fk_users_team", "Languages", "fk_users_languages"} + + for _, name := range names { + if !DB.Migrator().HasConstraint(&User{}, name) { + DB.Migrator().CreateConstraint(&User{}, name) + } + + if err := DB.Migrator().DropConstraint(&User{}, name); err != nil { + t.Fatalf("failed to drop constraint %v, got error %v", name, err) + } + + if DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("constraint %v should been deleted", name) + } + + if err := DB.Migrator().CreateConstraint(&User{}, name); err != nil { + t.Fatalf("failed to create constraint %v, got error %v", name, err) + } + + if !DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("failed to found constraint %v", name) + } + } +}