diff --git a/association.go b/association.go index e59b8938..25e1fe8d 100644 --- a/association.go +++ b/association.go @@ -54,7 +54,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } - joinStmt.Build("WHERE", "LIMIT") + joinStmt.Build("WHERE") tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } @@ -112,7 +112,7 @@ func (association *Association) Replace(values ...interface{}) error { updateMap[ref.ForeignKey.DBName] = nil } - association.DB.UpdateColumns(updateMap) + association.Error = association.DB.UpdateColumns(updateMap).Error } case schema.HasOne, schema.HasMany: var ( diff --git a/schema/relationship.go b/schema/relationship.go index dad2e629..5132ff74 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -82,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) { schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.IndirectFieldType.Kind() { - case reflect.Struct, reflect.Slice: + case reflect.Struct: + schema.guessRelation(relation, field, guessBelongs) + case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) @@ -324,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel type guessLevel int const ( - guessHas guessLevel = iota - guessEmbeddedHas - guessBelongs + guessBelongs guessLevel = iota guessEmbeddedBelongs + guessHas + guessEmbeddedHas ) func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { @@ -338,25 +340,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr := func() { switch gl { - case guessHas: - schema.guessRelation(relation, field, guessEmbeddedHas) - case guessEmbeddedHas: - schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: default: schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } switch gl { - case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { - reguessOrErr() - return - } case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: @@ -366,6 +362,14 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr() return } + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } else { + reguessOrErr() + return + } } if len(relation.foreignKeys) > 0 { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2c09f528..2e85c538 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) { }) } +func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy *int32 + Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, + }) +} + func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/schema/schema_test.go b/schema/schema_test.go index 8bd1e5ca..4d13ebd2 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -171,7 +171,7 @@ func TestNestedModel(t *testing.T) { fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, - {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, }