diff --git a/callbacks/update.go b/callbacks/update.go index 03d5c1e9..1ea77552 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -137,6 +137,32 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { updatingValue = updatingValue.Elem() } + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var priamryKeyExprs []clause.Expression + for i := 0; i < stmt.ReflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + switch value := updatingValue.Interface().(type) { case map[string]interface{}: set = make([]clause.Assignment, 0, len(value)) @@ -218,31 +244,5 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { - switch stmt.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var priamryKeyExprs []clause.Expression - for i := 0; i < stmt.ReflectValue.Len(); i++ { - var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) - var notZero bool - for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) - exprs[idx] = clause.Eq{Column: field.DBName, Value: value} - notZero = notZero || !isZero - } - if notZero { - priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) - } - } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) - case reflect.Struct: - for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) - } - } - } - } - return } diff --git a/migrator/migrator.go b/migrator/migrator.go index 8f872ee4..a4cc99a6 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -103,9 +103,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - if !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err + if constraint.Schema == stmt.Schema { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } } } } @@ -177,9 +179,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil { - sql, vars := buildConstraint(constraint) - createTableSQL += sql + "," - values = append(values, vars...) + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } } // create join table @@ -360,7 +364,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter } if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate + sql += " ON UPDATE " + constraint.OnUpdate } var foreignKeys, references []interface{} @@ -550,7 +554,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep.Parse(value) for _, rel := range dep.Schema.Relationships.Relations { - if c := rel.ParseConstraint(); c != nil && c.Schema != c.ReferenceSchema { + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } } diff --git a/schema/relationship.go b/schema/relationship.go index efa44554..afa083ed 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -85,6 +85,10 @@ func (schema *Schema) parseRelation(field *Field) { } if relation.Type == "has" { + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil { + relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation + } + switch field.IndirectFieldType.Kind() { case reflect.Struct: relation.Type = HasOne @@ -384,18 +388,24 @@ func (rel *Relationship) ParseConstraint() *Constraint { Field: rel.Field, OnUpdate: settings["ONUPDATE"], OnDelete: settings["ONDELETE"], - Schema: rel.Schema, } for _, ref := range rel.References { - if ref.PrimaryKey != nil && !ref.OwnPrimaryKey { + if ref.PrimaryKey != nil { constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) constraint.References = append(constraint.References, ref.PrimaryKey) - constraint.ReferenceSchema = ref.PrimaryKey.Schema + + if ref.OwnPrimaryKey { + constraint.Schema = ref.ForeignKey.Schema + constraint.ReferenceSchema = rel.Schema + } else { + constraint.Schema = rel.Schema + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } } } - if rel.JoinTable != nil || constraint.ReferenceSchema == nil { + if rel.JoinTable != nil { return nil } diff --git a/tests/associations_test.go b/tests/associations_test.go index 44262109..9b4dd105 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -31,3 +31,112 @@ func TestInvalidAssociation(t *testing.T) { t.Fatalf("should return errors for invalid association, but got nil") } } + +func TestForeignKeyConstraints(t *testing.T) { + type Profile struct { + ID uint + Name string + MemberID uint + } + + type Member struct { + ID uint + Refer uint `gorm:"unique_index"` + Name string + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Refer: 1, Name: "foreign_key_constraints", Profile: Profile{Name: "my_profile"}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.MemberID != member.ID { + t.Fatalf("member id is not equal: expects: %v, got: %v", member.ID, profile.MemberID) + } + + member.Profile = Profile{} + DB.Model(&member).Update("Refer", 100) + + var profile2 Profile + if err := DB.First(&profile2, "id = ?", profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile2.MemberID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, profile2.MemberID) + } + + if r := DB.Delete(&member); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile2, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} + +func TestForeignKeyConstraintsBelongsTo(t *testing.T) { + type Profile struct { + ID uint + Name string + Refer uint `gorm:"unique_index"` + } + + type Member struct { + ID uint + Name string + ProfileID uint + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:ProfileID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Name: "foreign_key_constraints_belongs_to", Profile: Profile{Name: "my_profile_belongs_to", Refer: 1}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.Refer != member.ProfileID { + t.Fatalf("member id is not equal: expects: %v, got: %v", profile.Refer, member.ProfileID) + } + + DB.Model(&profile).Update("Refer", 100) + + var member2 Member + if err := DB.First(&member2, "id = ?", member.ID).Error; err != nil { + t.Fatalf("failed to find member, got error: %v", err) + } else if member2.ProfileID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, member2.ProfileID) + } + + if r := DB.Delete(&profile); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 8f678b21..4a25a69b 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -433,8 +433,8 @@ func TestNestedPreload9(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint - Level2_1ID uint + Level2ID *uint + Level2_1ID *uint Level0s []Level0 `json:",omitempty"` } Level2 struct { diff --git a/tests/tests_test.go b/tests/tests_test.go index c80fb849..9e135b4e 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -66,6 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + db.Exec("PRAGMA foreign_keys = ON") } if debug := os.Getenv("DEBUG"); debug == "true" { diff --git a/utils/tests/models.go b/utils/tests/models.go index 878129e8..021b0229 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -37,7 +37,7 @@ type Account struct { type Pet struct { gorm.Model - UserID uint + UserID *uint Name string Toy Toy `gorm:"polymorphic:Owner;"` }