forked from mirror/gorm
fix: preload panic when model and dest different close #5130
commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e Author: Jinzhu <wosmvp@gmail.com> Date: Fri Mar 18 13:37:22 2022 +0800 Refactor #5130 commit40cbba49f3
Author: chenrui <chenrui@jingdaka.com> Date: Sat Mar 5 17:36:56 2022 +0800 test: fix test fail commit66d3f07829
Author: chenrui <chenrui@jingdaka.com> Date: Sat Mar 5 17:29:09 2022 +0800 test: drop table and auto migrate commit7cbf019a93
Author: chenrui <chenrui@jingdaka.com> Date: Sat Mar 5 15:27:45 2022 +0800 fix: preload panic when model and dest different
This commit is contained in:
parent
c2e36ebe62
commit
5431da8caf
|
@ -10,10 +10,9 @@ import (
|
|||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
relForeignKeys []string
|
||||
relForeignFields []*schema.Field
|
||||
foreignFields []*schema.Field
|
||||
|
@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
inlineConds []interface{}
|
||||
)
|
||||
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
var (
|
||||
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
|
@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
}
|
||||
}
|
||||
|
||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields)
|
||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(joinForeignValues) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
|
||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert join identity map to relation identity map
|
||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||
|
@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
for i := 0; i < joinResults.Len(); i++ {
|
||||
joinIndexValue := joinResults.Index(i)
|
||||
for idx, field := range joinForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue)
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
for idx, field := range joinRelForeignFields {
|
||||
joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue)
|
||||
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||
|
@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
}
|
||||
}
|
||||
|
||||
_, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields)
|
||||
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
|
@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
}
|
||||
}
|
||||
|
||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields)
|
||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(foreignValues) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
}
|
||||
}
|
||||
|
||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fieldValues := make([]interface{}, len(relForeignFields))
|
||||
|
@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
case reflect.Struct:
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
||||
rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
||||
default:
|
||||
rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())
|
||||
rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
||||
rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
||||
default:
|
||||
rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
|
||||
rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
for i := 0; i < reflectResults.Len(); i++ {
|
||||
elem := reflectResults.Index(i)
|
||||
for idx, field := range relForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem)
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
|
||||
}
|
||||
|
||||
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
|
||||
if !ok {
|
||||
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists",
|
||||
elem.Interface()))
|
||||
continue
|
||||
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
|
||||
}
|
||||
|
||||
for _, data := range datas {
|
||||
reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data)
|
||||
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
|
||||
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
|
||||
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
|
||||
}
|
||||
|
@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||
reflectFieldValue = reflect.Indirect(reflectFieldValue)
|
||||
switch reflectFieldValue.Kind() {
|
||||
case reflect.Struct:
|
||||
rel.Field.Set(db.Statement.Context, data, elem.Interface())
|
||||
rel.Field.Set(tx.Statement.Context, data, elem.Interface())
|
||||
case reflect.Slice, reflect.Array:
|
||||
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
|
||||
rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())
|
||||
rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())
|
||||
} else {
|
||||
rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
|
||||
rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Error
|
||||
}
|
||||
|
|
|
@ -237,9 +237,20 @@ func Preload(db *gorm.DB) {
|
|||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
preloadDB.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
||||
return
|
||||
}
|
||||
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])
|
||||
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||
}
|
||||
|
|
|
@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
|||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
tx.Statement.Table = tables[1]
|
||||
} else {
|
||||
} else if name != "" {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
tx.Statement.Table = name
|
||||
} else {
|
||||
tx.Statement.TableExpr = nil
|
||||
tx.Statement.Table = ""
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) {
|
|||
}
|
||||
|
||||
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
|
||||
t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
|
||||
t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {
|
||||
|
|
|
@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) {
|
|||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPreloadWithDiffModel(t *testing.T) {
|
||||
user := *GetUser("preload_with_diff_model", Config{Account: true})
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Something string
|
||||
User
|
||||
}
|
||||
|
||||
DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select(
|
||||
"users.*, 'yo' as something").First(&result, "name = ?", user.Name)
|
||||
|
||||
CheckUser(t, user, result.User)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue