diff --git a/callback_query_preload.go b/callback_query_preload.go index 4d2678bd..13a109dd 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -28,6 +28,10 @@ func preloadCallback(scope *Scope) { for idx, preloadField := range preloadFields { var currentPreloadConditions []interface{} + if currentScope == nil { + continue + } + // if not preloaded if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { @@ -67,7 +71,9 @@ func preloadCallback(scope *Scope) { // preload next level if idx < len(preloadFields)-1 { currentScope = currentScope.getColumnAsScope(preloadField) - currentFields = currentScope.Fields() + if currentScope != nil { + currentFields = currentScope.Fields() + } } } } diff --git a/preload_test.go b/preload_test.go index 5c49ecc2..e2fb35ff 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1288,17 +1288,23 @@ func TestNilPointerSlice(t *testing.T) { t.Error(err) } - want := Level1{Value: "Bob", Level2: &Level2{ - Value: "en", - Level3: &Level3{ - Value: "native", + want := Level1{ + Value: "Bob", + Level2: &Level2{ + Value: "en", + Level3: &Level3{ + Value: "native", + }, }, - }} + } if err := DB.Save(&want).Error; err != nil { t.Error(err) } - want2 := Level1{Value: "Tom", Level2: nil} + want2 := Level1{ + Value: "Tom", + Level2: nil, + } if err := DB.Save(&want2).Error; err != nil { t.Error(err) } @@ -1321,6 +1327,52 @@ func TestNilPointerSlice(t *testing.T) { } } +func TestNilPointerSlice2(t *testing.T) { + type ( + Level4 struct { + ID uint + } + Level3 struct { + ID uint + Level4ID sql.NullInt64 `sql:"index"` + Level4 *Level4 + } + Level2 struct { + ID uint + Level3s []*Level3 `gorm:"many2many:level2_level3s"` + } + Level1 struct { + ID uint + Level2ID sql.NullInt64 `sql:"index"` + Level2 *Level2 + } + ) + + DB.DropTableIfExists(new(Level4)) + DB.DropTableIfExists(new(Level3)) + DB.DropTableIfExists(new(Level2)) + DB.DropTableIfExists(new(Level1)) + + if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { + t.Error(err) + } + + want := new(Level1) + if err := DB.Save(want).Error; err != nil { + t.Error(err) + } + + got := new(Level1) + err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r