Fix failed to guess relations for embedded types, close #3224

This commit is contained in:
Jinzhu 2020-08-04 12:10:19 +08:00
parent c11c939b95
commit ff985b90cc
5 changed files with 76 additions and 18 deletions

View File

@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
return nil
}); err != nil {
fmt.Println(err)
return err
}
}

View File

@ -62,6 +62,7 @@ type Field struct {
TagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
OwnerSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error
@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
for _, ef := range field.EmbeddedSchema.Fields {
ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
// index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct {

View File

@ -5,6 +5,7 @@ import (
"reflect"
"regexp"
"strings"
"sync"
"github.com/jinzhu/inflection"
"gorm.io/gorm/clause"
@ -66,9 +67,16 @@ func (schema *Schema) parseRelation(field *Field) {
}
)
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
if field.OwnerSchema != nil {
if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil {
schema.err = err
return
}
} else {
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
}
}
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) {
} else {
switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice:
schema.guessRelation(relation, field, true)
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
}
@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
type guessLevel int
const (
guessHas guessLevel = iota
guessEmbeddedHas
guessBelongs
guessEmbeddedBelongs
)
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
var (
primaryFields, foreignFields []*Field
primarySchema, foreignSchema = schema, relation.FieldSchema
)
if !guessHas {
primarySchema, foreignSchema = relation.FieldSchema, schema
reguessOrErr := func(err string, args ...interface{}) {
switch gl {
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
case guessEmbeddedHas:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
default:
schema.err = fmt.Errorf(err, args...)
}
}
reguessOrErr := func(err string, args ...interface{}) {
if guessHas {
schema.guessRelation(relation, field, false)
switch gl {
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
schema.err = fmt.Errorf(err, args...)
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
}
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
}
}
@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
}
} else {
for _, primaryField := range primarySchema.PrimaryFields {
lookUpName := schema.Name + primaryField.Name
if !guessHas {
lookUpName := primarySchema.Name + primaryField.Name
if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name
}
@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
}
if len(foreignFields) == 0 {
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas)
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
} else if len(relation.primaryKeys) > 0 {
for idx, primaryKey := range relation.primaryKeys {
@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx],
ForeignKey: foreignField,
OwnPrimaryKey: schema == primarySchema && guessHas,
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
})
}
if guessHas {
if gl == guessHas || gl == guessEmbeddedHas {
relation.Type = "has"
} else {
relation.Type = BelongsTo

View File

@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) {
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c4", "c3"},
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c3", "c5", "c1", "c2", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
}

View File

@ -7,6 +7,7 @@ import (
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
func TestEmbeddedStruct(t *testing.T) {
@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) {
t.Errorf("Failed to create got error %v", err)
}
}
func TestEmbeddedRelations(t *testing.T) {
type AdvancedUser struct {
User `gorm:"embedded"`
Advanced bool
}
DB.Debug().Migrator().DropTable(&AdvancedUser{})
if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
}
}