fix: preload panic when model and dest different close #5130

commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 13:37:22 2022 +0800

    Refactor #5130

commit 40cbba49f3
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:36:56 2022 +0800

    test: fix test fail

commit 66d3f07829
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:29:09 2022 +0800

    test: drop table and auto migrate

commit 7cbf019a93
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 15:27:45 2022 +0800

    fix: preload panic when model and dest different
This commit is contained in:
chenrui 2022-03-18 13:38:46 +08:00 committed by Jinzhu
parent c2e36ebe62
commit 5431da8caf
5 changed files with 63 additions and 33 deletions

View File

@ -10,10 +10,9 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var ( var (
reflectValue = db.Statement.ReflectValue reflectValue = tx.Statement.ReflectValue
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
relForeignKeys []string relForeignKeys []string
relForeignFields []*schema.Field relForeignFields []*schema.Field
foreignFields []*schema.Field foreignFields []*schema.Field
@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
inlineConds []interface{} inlineConds []interface{}
) )
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if rel.JoinTable != nil { if rel.JoinTable != nil {
var ( var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References)) joinForeignFields = make([]*schema.Field, 0, len(rel.References))
@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 { if len(joinForeignValues) == 0 {
return return nil
} }
joinResults := rel.JoinTable.MakeSlice().Elem() joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
// convert join identity map to relation identity map // convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields)) fieldValues := make([]interface{}, len(joinForeignFields))
@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < joinResults.Len(); i++ { for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i) joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields { for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
} }
for idx, field := range joinRelForeignFields { for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
} }
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
_, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
} else { } else {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 { if len(foreignValues) == 0 {
return return nil
} }
} }
@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
return err
}
} }
fieldValues := make([]interface{}, len(relForeignFields)) fieldValues := make([]interface{}, len(relForeignFields))
@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
case reflect.Struct: case reflect.Struct:
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default: default:
rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default: default:
rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
} }
} }
} }
@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i) elem := reflectResults.Index(i)
for idx, field := range relForeignFields { for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
} }
datas, ok := identityMap[utils.ToStringKey(fieldValues...)] datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok { if !ok {
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
elem.Interface()))
continue
} }
for _, data := range datas { for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
} }
@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectFieldValue = reflect.Indirect(reflectFieldValue) reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() { switch reflectFieldValue.Kind() {
case reflect.Struct: case reflect.Struct:
rel.Field.Set(db.Statement.Context, data, elem.Interface()) rel.Field.Set(tx.Statement.Context, data, elem.Interface())
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())
} else { } else {
rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
} }
} }
} }
} }
return tx.Error
} }

View File

@ -237,9 +237,20 @@ func Preload(db *gorm.DB) {
} }
sort.Strings(preloadNames) sort.Strings(preloadNames)
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
preloadDB.Statement.Settings.Store(k, v)
return true
})
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
return
}
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
for _, name := range preloadNames { for _, name := range preloadNames {
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else { } else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
} }

View File

@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
} else if tables := strings.Split(name, "."); len(tables) == 2 { } else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1] tx.Statement.Table = tables[1]
} else { } else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
} }
return return
} }

View File

@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) {
} }
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
} }
if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {

View File

@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) {
} }
wg.Wait() wg.Wait()
} }
func TestPreloadWithDiffModel(t *testing.T) {
user := *GetUser("preload_with_diff_model", Config{Account: true})
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
var result struct {
Something string
User
}
DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select(
"users.*, 'yo' as something").First(&result, "name = ?", user.Name)
CheckUser(t, user, result.User)
}