From 5431da8caf09ad19256170df17e2e75eb541f4a5 Mon Sep 17 00:00:00 2001 From: chenrui <631807682@qq.com> Date: Fri, 18 Mar 2022 13:38:46 +0800 Subject: [PATCH] fix: preload panic when model and dest different close #5130 commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e Author: Jinzhu Date: Fri Mar 18 13:37:22 2022 +0800 Refactor #5130 commit 40cbba49f374c9bae54f80daee16697ae45e905b Author: chenrui Date: Sat Mar 5 17:36:56 2022 +0800 test: fix test fail commit 66d3f078291102a30532b6a9d97c757228a9b543 Author: chenrui Date: Sat Mar 5 17:29:09 2022 +0800 test: drop table and auto migrate commit 7cbf019a930019476a97ac7ac0f5fc186e8d5b42 Author: chenrui Date: Sat Mar 5 15:27:45 2022 +0800 fix: preload panic when model and dest different --- callbacks/preload.go | 56 ++++++++++++++++++------------------- callbacks/query.go | 15 ++++++++-- chainable_api.go | 5 +++- tests/preload_suits_test.go | 2 +- tests/preload_test.go | 18 ++++++++++++ 5 files changed, 63 insertions(+), 33 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 2363a8ca..888f832d 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -10,10 +10,9 @@ import ( "gorm.io/gorm/utils" ) -func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { +func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( - reflectValue = db.Statement.ReflectValue - tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + reflectValue = tx.Statement.ReflectValue relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field @@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload inlineConds []interface{} ) - db.Statement.Settings.Range(func(k, v interface{}) bool { - tx.Statement.Settings.Store(k, v) - return true - }) - if rel.JoinTable != nil { var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) @@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { - return + return nil } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { + return err + } // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) + joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { - return + return nil } } @@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + return err + } } fieldValues := make([]interface{}, len(relForeignFields)) @@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } } } @@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] if !ok { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", - elem.Interface())) - continue + return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) } for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) + reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(db.Statement.Context, data, elem.Interface()) + rel.Field.Set(tx.Statement.Context, data, elem.Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) + rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } } + + return tx.Error } diff --git a/callbacks/query.go b/callbacks/query.go index c4c80406..6ba3dd38 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -237,9 +237,20 @@ func Preload(db *gorm.DB) { } sort.Strings(preloadNames) + preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + preloadDB.Statement.Settings.Store(k, v) + return true + }) + + if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + return + } + preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + for _, name := range preloadNames { - if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { - preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) + if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/chainable_api.go b/chainable_api.go index 173479d3..38ad5cde 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = tables[1] - } else { + } else if name != "" { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = name + } else { + tx.Statement.TableExpr = nil + tx.Statement.Table = "" } return } diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 0ef8890b..b5b6a70f 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) { } if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) } if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { diff --git a/tests/preload_test.go b/tests/preload_test.go index adb54ee1..cb4343ec 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) { } wg.Wait() } + +func TestPreloadWithDiffModel(t *testing.T) { + user := *GetUser("preload_with_diff_model", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var result struct { + Something string + User + } + + DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select( + "users.*, 'yo' as something").First(&result, "name = ?", user.Name) + + CheckUser(t, user, result.User) +}