From f73f7b251f3d44a6e974c2ba0301060d7ab32502 Mon Sep 17 00:00:00 2001 From: Richard Knop Date: Sat, 13 Feb 2016 20:28:42 +0800 Subject: [PATCH 1/2] HasTable now works with table name passed as a string. Before, only HasTable(&Foo) would work but HasTable("foos") would always return false. This PR fixes that. --- main_test.go | 12 ++++++++++++ scope.go | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/main_test.go b/main_test.go index 65467d73..8722c7c0 100644 --- a/main_test.go +++ b/main_test.go @@ -165,12 +165,24 @@ func TestHasTable(t *testing.T) { Stuff string } DB.DropTable(&Foo{}) + + // Table should not exist at this point, HasTable should return false + if ok := DB.HasTable("foos"); ok { + t.Errorf("Table should not exist, but does") + } if ok := DB.HasTable(&Foo{}); ok { t.Errorf("Table should not exist, but does") } + + // We create the table if err := DB.CreateTable(&Foo{}).Error; err != nil { t.Errorf("Table should be created") } + + // And now it should exits, and HasTable should return true + if ok := DB.HasTable("foos"); !ok { + t.Errorf("Table should exist, but HasTable informs it does not") + } if ok := DB.HasTable(&Foo{}); !ok { t.Errorf("Table should exist, but HasTable informs it does not") } diff --git a/scope.go b/scope.go index a11d4ec4..2fa4acdb 100644 --- a/scope.go +++ b/scope.go @@ -267,6 +267,10 @@ type dbTabler interface { // TableName get table name func (scope *Scope) TableName() string { + if strTableName, ok := scope.Value.(string); ok { + return strTableName + } + if scope.Search != nil && len(scope.Search.tableName) > 0 { return scope.Search.tableName } From d37d1844017325a40fe80155bf32c1c99cc2622c Mon Sep 17 00:00:00 2001 From: Richard Knop Date: Wed, 10 Feb 2016 23:44:41 +0800 Subject: [PATCH 2/2] Fixed a nested preload panic bug. --- preload.go | 3 ++ preload_test.go | 104 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 81 insertions(+), 26 deletions(-) diff --git a/preload.go b/preload.go index d9f4e9c5..a855f794 100644 --- a/preload.go +++ b/preload.go @@ -294,6 +294,9 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) + if object.Kind() == reflect.Ptr { + object = object.Elem() + } source := getRealValue(object, foreignFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { diff --git a/preload_test.go b/preload_test.go index 7f4b8fdb..c5a3a136 100644 --- a/preload_test.go +++ b/preload_test.go @@ -611,62 +611,114 @@ func TestNestedPreload9(t *testing.T) { } } -type Level1A struct { +type LevelA1 struct { ID uint Value string } -type Level1B struct { - ID uint - Value string - Level2s []*Level2 +type LevelA2 struct { + ID uint + Value string + LevelA3s []*LevelA3 } -type Level2 struct { +type LevelA3 struct { ID uint Value string - Level1AID sql.NullInt64 - Level1A *Level1A - Level1BID sql.NullInt64 - Level1B *Level1B + LevelA1ID sql.NullInt64 + LevelA1 *LevelA1 + LevelA2ID sql.NullInt64 + LevelA2 *LevelA2 } func TestNestedPreload10(t *testing.T) { - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1B{}) - DB.DropTableIfExists(&Level1A{}) + DB.DropTableIfExists(&LevelA3{}) + DB.DropTableIfExists(&LevelA2{}) + DB.DropTableIfExists(&LevelA1{}) - if err := DB.AutoMigrate(&Level1A{}, &Level1B{}, &Level2{}).Error; err != nil { + if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { t.Error(err) } - level1A := &Level1A{Value: "foo"} - if err := DB.Save(&level1A).Error; err != nil { + levelA1 := &LevelA1{Value: "foo"} + if err := DB.Save(levelA1).Error; err != nil { t.Error(err) } - want := []*Level1B{ - &Level1B{ + want := []*LevelA2{ + &LevelA2{ Value: "bar", - Level2s: []*Level2{ - &Level2{ + LevelA3s: []*LevelA3{ + &LevelA3{ Value: "qux", - Level1A: level1A, + LevelA1: levelA1, }, }, }, - &Level1B{ + &LevelA2{ Value: "bar 2", }, } - for _, level1B := range want { - if err := DB.Save(level1B).Error; err != nil { + for _, levelA2 := range want { + if err := DB.Save(levelA2).Error; err != nil { t.Error(err) } } - var got []*Level1B - if err := DB.Preload("Level2s.Level1A").Find(&got).Error; err != nil { + var got []*LevelA2 + if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +type LevelB1 struct { + ID uint + Value string + LevelB3s []*LevelB3 +} + +type LevelB2 struct { + ID uint + Value string +} + +type LevelB3 struct { + ID uint + Value string + LevelB1ID sql.NullInt64 + LevelB1 *LevelB1 + LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` +} + +func TestNestedPreload11(t *testing.T) { + DB.DropTableIfExists(&LevelB2{}) + DB.DropTableIfExists(&LevelB3{}) + DB.DropTableIfExists(&LevelB1{}) + if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { + t.Error(err) + } + + levelB1 := &LevelB1{Value: "foo"} + if err := DB.Create(levelB1).Error; err != nil { + t.Error(err) + } + + levelB3 := &LevelB3{ + Value: "bar", + LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + } + if err := DB.Create(levelB3).Error; err != nil { + t.Error(err) + } + levelB1.LevelB3s = []*LevelB3{levelB3} + + want := []*LevelB1{levelB1} + var got []*LevelB1 + if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { t.Error(err) }