mirror of https://github.com/go-gorm/gorm.git
Fix failed to guess relations for embedded types, close #3224
This commit is contained in:
parent
c11c939b95
commit
ff985b90cc
|
@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,6 +62,7 @@ type Field struct {
|
|||
TagSettings map[string]string
|
||||
Schema *Schema
|
||||
EmbeddedSchema *Schema
|
||||
OwnerSchema *Schema
|
||||
ReflectValueOf func(reflect.Value) reflect.Value
|
||||
ValueOf func(reflect.Value) (value interface{}, zero bool)
|
||||
Set func(reflect.Value, interface{}) error
|
||||
|
@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
}
|
||||
for _, ef := range field.EmbeddedSchema.Fields {
|
||||
ef.Schema = schema
|
||||
ef.OwnerSchema = field.EmbeddedSchema
|
||||
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
||||
// index is negative means is pointer
|
||||
if field.FieldType.Kind() == reflect.Struct {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"gorm.io/gorm/clause"
|
||||
|
@ -66,9 +67,16 @@ func (schema *Schema) parseRelation(field *Field) {
|
|||
}
|
||||
)
|
||||
|
||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
if field.OwnerSchema != nil {
|
||||
if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
|
@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) {
|
|||
} else {
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
case reflect.Struct, reflect.Slice:
|
||||
schema.guessRelation(relation, field, true)
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
|
||||
}
|
||||
|
@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
}
|
||||
}
|
||||
|
||||
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
|
||||
type guessLevel int
|
||||
|
||||
const (
|
||||
guessHas guessLevel = iota
|
||||
guessEmbeddedHas
|
||||
guessBelongs
|
||||
guessEmbeddedBelongs
|
||||
)
|
||||
|
||||
func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
|
||||
var (
|
||||
primaryFields, foreignFields []*Field
|
||||
primarySchema, foreignSchema = schema, relation.FieldSchema
|
||||
)
|
||||
|
||||
if !guessHas {
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
reguessOrErr := func(err string, args ...interface{}) {
|
||||
switch gl {
|
||||
case guessHas:
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
case guessEmbeddedHas:
|
||||
schema.guessRelation(relation, field, guessBelongs)
|
||||
case guessBelongs:
|
||||
schema.guessRelation(relation, field, guessEmbeddedBelongs)
|
||||
default:
|
||||
schema.err = fmt.Errorf(err, args...)
|
||||
}
|
||||
}
|
||||
|
||||
reguessOrErr := func(err string, args ...interface{}) {
|
||||
if guessHas {
|
||||
schema.guessRelation(relation, field, false)
|
||||
switch gl {
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
} else {
|
||||
schema.err = fmt.Errorf(err, args...)
|
||||
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||
return
|
||||
}
|
||||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
} else {
|
||||
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||
}
|
||||
} else {
|
||||
for _, primaryField := range primarySchema.PrimaryFields {
|
||||
lookUpName := schema.Name + primaryField.Name
|
||||
if !guessHas {
|
||||
lookUpName := primarySchema.Name + primaryField.Name
|
||||
if gl == guessBelongs {
|
||||
lookUpName = field.Name + primaryField.Name
|
||||
}
|
||||
|
||||
|
@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||
}
|
||||
|
||||
if len(foreignFields) == 0 {
|
||||
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas)
|
||||
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
|
||||
return
|
||||
} else if len(relation.primaryKeys) > 0 {
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
|
@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
|
|||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryFields[idx],
|
||||
ForeignKey: foreignField,
|
||||
OwnPrimaryKey: schema == primarySchema && guessHas,
|
||||
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
|
||||
})
|
||||
}
|
||||
|
||||
if guessHas {
|
||||
if gl == guessHas || gl == guessEmbeddedHas {
|
||||
relation.Type = "has"
|
||||
} else {
|
||||
relation.Type = BelongsTo
|
||||
|
|
|
@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) {
|
|||
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
|
||||
results: []string{"c5", "c1", "c2", "c4", "c3"},
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
|
||||
results: []string{"c3", "c5", "c1", "c2", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
|
||||
results: []string{"c5", "c1", "c2", "c3", "c4"},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestEmbeddedStruct(t *testing.T) {
|
||||
|
@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) {
|
|||
t.Errorf("Failed to create got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedRelations(t *testing.T) {
|
||||
type AdvancedUser struct {
|
||||
User `gorm:"embedded"`
|
||||
Advanced bool
|
||||
}
|
||||
|
||||
DB.Debug().Migrator().DropTable(&AdvancedUser{})
|
||||
|
||||
if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
|
||||
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue