forked from mirror/gorm
Fix failed to guess relations for embedded types, close #3224
This commit is contained in:
parent
c11c939b95
commit
ff985b90cc
|
@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,7 @@ type Field struct {
|
||||||
TagSettings map[string]string
|
TagSettings map[string]string
|
||||||
Schema *Schema
|
Schema *Schema
|
||||||
EmbeddedSchema *Schema
|
EmbeddedSchema *Schema
|
||||||
|
OwnerSchema *Schema
|
||||||
ReflectValueOf func(reflect.Value) reflect.Value
|
ReflectValueOf func(reflect.Value) reflect.Value
|
||||||
ValueOf func(reflect.Value) (value interface{}, zero bool)
|
ValueOf func(reflect.Value) (value interface{}, zero bool)
|
||||||
Set func(reflect.Value, interface{}) error
|
Set func(reflect.Value, interface{}) error
|
||||||
|
@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
}
|
}
|
||||||
for _, ef := range field.EmbeddedSchema.Fields {
|
for _, ef := range field.EmbeddedSchema.Fields {
|
||||||
ef.Schema = schema
|
ef.Schema = schema
|
||||||
|
ef.OwnerSchema = field.EmbeddedSchema
|
||||||
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
||||||
// index is negative means is pointer
|
// index is negative means is pointer
|
||||||
if field.FieldType.Kind() == reflect.Struct {
|
if field.FieldType.Kind() == reflect.Struct {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/jinzhu/inflection"
|
"github.com/jinzhu/inflection"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
@ -66,10 +67,17 @@ func (schema *Schema) parseRelation(field *Field) {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if field.OwnerSchema != nil {
|
||||||
|
if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil {
|
||||||
|
schema.err = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||||
schema.buildPolymorphicRelation(relation, field, polymorphic)
|
schema.buildPolymorphicRelation(relation, field, polymorphic)
|
||||||
|
@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) {
|
||||||
} else {
|
} else {
|
||||||
switch field.IndirectFieldType.Kind() {
|
switch field.IndirectFieldType.Kind() {
|
||||||
case reflect.Struct, reflect.Slice:
|
case reflect.Struct, reflect.Slice:
|
||||||
schema.guessRelation(relation, field, true)
|
schema.guessRelation(relation, field, guessHas)
|
||||||
default:
|
default:
|
||||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
||||||
}
|
}
|
||||||
|
@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
|
type guessLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
guessHas guessLevel = iota
|
||||||
|
guessEmbeddedHas
|
||||||
|
guessBelongs
|
||||||
|
guessEmbeddedBelongs
|
||||||
|
)
|
||||||
|
|
||||||
|
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
|
||||||
var (
|
var (
|
||||||
primaryFields, foreignFields []*Field
|
primaryFields, foreignFields []*Field
|
||||||
primarySchema, foreignSchema = schema, relation.FieldSchema
|
primarySchema, foreignSchema = schema, relation.FieldSchema
|
||||||
)
|
)
|
||||||
|
|
||||||
if !guessHas {
|
reguessOrErr := func(err string, args ...interface{}) {
|
||||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
switch gl {
|
||||||
|
case guessHas:
|
||||||
|
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||||
|
case guessEmbeddedHas:
|
||||||
|
schema.guessRelation(relation, field, guessBelongs)
|
||||||
|
case guessBelongs:
|
||||||
|
schema.guessRelation(relation, field, guessEmbeddedBelongs)
|
||||||
|
default:
|
||||||
|
schema.err = fmt.Errorf(err, args...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reguessOrErr := func(err string, args ...interface{}) {
|
switch gl {
|
||||||
if guessHas {
|
case guessEmbeddedHas:
|
||||||
schema.guessRelation(relation, field, false)
|
if field.OwnerSchema != nil {
|
||||||
|
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||||
} else {
|
} else {
|
||||||
schema.err = fmt.Errorf(err, args...)
|
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case guessBelongs:
|
||||||
|
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||||
|
case guessEmbeddedBelongs:
|
||||||
|
if field.OwnerSchema != nil {
|
||||||
|
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||||
|
} else {
|
||||||
|
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, primaryField := range primarySchema.PrimaryFields {
|
for _, primaryField := range primarySchema.PrimaryFields {
|
||||||
lookUpName := schema.Name + primaryField.Name
|
lookUpName := primarySchema.Name + primaryField.Name
|
||||||
if !guessHas {
|
if gl == guessBelongs {
|
||||||
lookUpName = field.Name + primaryField.Name
|
lookUpName = field.Name + primaryField.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(foreignFields) == 0 {
|
if len(foreignFields) == 0 {
|
||||||
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas)
|
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||||
return
|
return
|
||||||
} else if len(relation.primaryKeys) > 0 {
|
} else if len(relation.primaryKeys) > 0 {
|
||||||
for idx, primaryKey := range relation.primaryKeys {
|
for idx, primaryKey := range relation.primaryKeys {
|
||||||
|
@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
||||||
relation.References = append(relation.References, &Reference{
|
relation.References = append(relation.References, &Reference{
|
||||||
PrimaryKey: primaryFields[idx],
|
PrimaryKey: primaryFields[idx],
|
||||||
ForeignKey: foreignField,
|
ForeignKey: foreignField,
|
||||||
OwnPrimaryKey: schema == primarySchema && guessHas,
|
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if guessHas {
|
if gl == guessHas || gl == guessEmbeddedHas {
|
||||||
relation.Type = "has"
|
relation.Type = "has"
|
||||||
} else {
|
} else {
|
||||||
relation.Type = BelongsTo
|
relation.Type = BelongsTo
|
||||||
|
|
|
@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) {
|
||||||
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
|
||||||
results: []string{"c5", "c1", "c2", "c4", "c3"},
|
results: []string{"c3", "c5", "c1", "c2", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
|
||||||
|
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEmbeddedStruct(t *testing.T) {
|
func TestEmbeddedStruct(t *testing.T) {
|
||||||
|
@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) {
|
||||||
t.Errorf("Failed to create got error %v", err)
|
t.Errorf("Failed to create got error %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedRelations(t *testing.T) {
|
||||||
|
type AdvancedUser struct {
|
||||||
|
User `gorm:"embedded"`
|
||||||
|
Advanced bool
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Debug().Migrator().DropTable(&AdvancedUser{})
|
||||||
|
|
||||||
|
if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
|
||||||
|
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue