Fix self-referential belongs to, close #3319

This commit is contained in:
Jinzhu 2020-08-28 11:31:13 +08:00
parent dacbaa5f02
commit c19a3abefb
4 changed files with 36 additions and 18 deletions

View File

@ -54,7 +54,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro
for _, queryClause := range association.Relationship.JoinTable.QueryClauses { for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause) joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE", "LIMIT") joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
} }
@ -112,7 +112,7 @@ func (association *Association) Replace(values ...interface{}) error {
updateMap[ref.ForeignKey.DBName] = nil updateMap[ref.ForeignKey.DBName] = nil
} }
association.DB.UpdateColumns(updateMap) association.Error = association.DB.UpdateColumns(updateMap).Error
} }
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
var ( var (

View File

@ -82,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) {
schema.buildMany2ManyRelation(relation, field, many2many) schema.buildMany2ManyRelation(relation, field, many2many)
} else { } else {
switch field.IndirectFieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice: case reflect.Struct:
schema.guessRelation(relation, field, guessBelongs)
case reflect.Slice:
schema.guessRelation(relation, field, guessHas) schema.guessRelation(relation, field, guessHas)
default: default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
@ -324,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
type guessLevel int type guessLevel int
const ( const (
guessHas guessLevel = iota guessBelongs guessLevel = iota
guessEmbeddedHas
guessBelongs
guessEmbeddedBelongs guessEmbeddedBelongs
guessHas
guessEmbeddedHas
) )
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
@ -338,25 +340,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
reguessOrErr := func() { reguessOrErr := func() {
switch gl { switch gl {
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
case guessEmbeddedHas:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs: case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs) schema.guessRelation(relation, field, guessEmbeddedBelongs)
case guessEmbeddedBelongs:
schema.guessRelation(relation, field, guessHas)
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default: default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
} }
} }
switch gl { switch gl {
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr()
return
}
case guessBelongs: case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs: case guessEmbeddedBelongs:
@ -366,6 +362,14 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue
reguessOrErr() reguessOrErr()
return return
} }
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr()
return
}
} }
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {

View File

@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) {
}) })
} }
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatedBy *int32
Creator *User `gorm:"foreignKey:CreatedBy;references:ID"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) { func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model

View File

@ -171,7 +171,7 @@ func TestNestedModel(t *testing.T) {
fields := []schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true},
{Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64},
{Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64},
} }