Fix failed to guess relations for embedded types, close #3224

This commit is contained in:
Jinzhu 2020-08-04 12:10:19 +08:00
parent c11c939b95
commit ff985b90cc
5 changed files with 76 additions and 18 deletions

View File

@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
return nil return nil
}); err != nil { }); err != nil {
fmt.Println(err)
return err return err
} }
} }

View File

@ -62,6 +62,7 @@ type Field struct {
TagSettings map[string]string TagSettings map[string]string
Schema *Schema Schema *Schema
EmbeddedSchema *Schema EmbeddedSchema *Schema
OwnerSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool) ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error Set func(reflect.Value, interface{}) error
@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
for _, ef := range field.EmbeddedSchema.Fields { for _, ef := range field.EmbeddedSchema.Fields {
ef.Schema = schema ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
// index is negative means is pointer // index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct { if field.FieldType.Kind() == reflect.Struct {

View File

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
"sync"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -66,10 +67,17 @@ func (schema *Schema) parseRelation(field *Field) {
} }
) )
if field.OwnerSchema != nil {
if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil {
schema.err = err
return
}
} else {
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err schema.err = err
return return
} }
}
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
schema.buildPolymorphicRelation(relation, field, polymorphic) schema.buildPolymorphicRelation(relation, field, polymorphic)
@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) {
} else { } else {
switch field.IndirectFieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice: case reflect.Struct, reflect.Slice:
schema.guessRelation(relation, field, true) schema.guessRelation(relation, field, guessHas)
default: default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
} }
@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
} }
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { type guessLevel int
const (
guessHas guessLevel = iota
guessEmbeddedHas
guessBelongs
guessEmbeddedBelongs
)
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
var ( var (
primaryFields, foreignFields []*Field primaryFields, foreignFields []*Field
primarySchema, foreignSchema = schema, relation.FieldSchema primarySchema, foreignSchema = schema, relation.FieldSchema
) )
if !guessHas { reguessOrErr := func(err string, args ...interface{}) {
primarySchema, foreignSchema = relation.FieldSchema, schema switch gl {
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
case guessEmbeddedHas:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
default:
schema.err = fmt.Errorf(err, args...)
}
} }
reguessOrErr := func(err string, args ...interface{}) { switch gl {
if guessHas { case guessEmbeddedHas:
schema.guessRelation(relation, field, false) if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else { } else {
schema.err = fmt.Errorf(err, args...) reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
}
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
} }
} }
@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
} }
} else { } else {
for _, primaryField := range primarySchema.PrimaryFields { for _, primaryField := range primarySchema.PrimaryFields {
lookUpName := schema.Name + primaryField.Name lookUpName := primarySchema.Name + primaryField.Name
if !guessHas { if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name lookUpName = field.Name + primaryField.Name
} }
@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
} }
if len(foreignFields) == 0 { if len(foreignFields) == 0 {
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return return
} else if len(relation.primaryKeys) > 0 { } else if len(relation.primaryKeys) > 0 {
for idx, primaryKey := range relation.primaryKeys { for idx, primaryKey := range relation.primaryKeys {
@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
relation.References = append(relation.References, &Reference{ relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx], PrimaryKey: primaryFields[idx],
ForeignKey: foreignField, ForeignKey: foreignField,
OwnPrimaryKey: schema == primarySchema && guessHas, OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
}) })
} }
if guessHas { if gl == guessHas || gl == guessEmbeddedHas {
relation.Type = "has" relation.Type = "has"
} else { } else {
relation.Type = BelongsTo relation.Type = BelongsTo

View File

@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) {
results: []string{"c5", "c1", "c2", "c3", "c4"}, results: []string{"c5", "c1", "c2", "c3", "c4"},
}, },
{ {
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c4", "c3"}, results: []string{"c3", "c5", "c1", "c2", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
}, },
} }

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
) )
func TestEmbeddedStruct(t *testing.T) { func TestEmbeddedStruct(t *testing.T) {
@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) {
t.Errorf("Failed to create got error %v", err) t.Errorf("Failed to create got error %v", err)
} }
} }
func TestEmbeddedRelations(t *testing.T) {
type AdvancedUser struct {
User `gorm:"embedded"`
Advanced bool
}
DB.Debug().Migrator().DropTable(&AdvancedUser{})
if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
}
}