From f8e2f0456223494021cecb97cb387940309d063a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2015 23:09:07 +0800 Subject: [PATCH] Some Tweaks for Preload Many2Many, Add tests with inline conditions --- join_table_handler.go | 18 ++++-------------- preload.go | 27 +++++++++++++++++++-------- preload_test.go | 22 +++++++++++++++++++--- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/join_table_handler.go b/join_table_handler.go index 1fb25e5d..162af4a8 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -13,10 +13,9 @@ type JoinTableHandlerInterface interface { Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB + PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB SourceForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey - DestinationType() reflect.Type } type JoinTableForeignKey struct { @@ -139,8 +138,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so var queryConditions []string var values []interface{} if s.Source.ModelType == modelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() for _, foreignKey := range s.Destination.ForeignKeys { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } @@ -156,7 +155,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so } } -func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}, conditions ...interface{}) *DB { +func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { quotedTable := handler.Table(db) scope := db.NewScope(source) @@ -165,8 +164,8 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db var queryConditions []string var values []interface{} if s.Source.ModelType == modelType { + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() for _, foreignKey := range s.Destination.ForeignKeys { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).inlineCondition(conditions...).QuotedTableName() joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } @@ -179,11 +178,6 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db queryConditions = append(queryConditions, condString) } - if len(conditions) > 0 { - queryConditions = append(queryConditions, toString(conditions[0])) - values = append(values, conditions[1:]...) - } - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). Where(strings.Join(queryConditions, " AND "), values...) } else { @@ -191,7 +185,3 @@ func (s JoinTableHandler) PreloadWithJoin(handler JoinTableHandlerInterface, db return db } } - -func (s JoinTableHandler) DestinationType() reflect.Type { - return s.Destination.ModelType -} diff --git a/preload.go b/preload.go index c7810b63..c0869a2a 100644 --- a/preload.go +++ b/preload.go @@ -195,14 +195,16 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf relation := field.Relationship joinTableHandler := relation.JoinTableHandler - destType := joinTableHandler.DestinationType() - - db := scope.NewDB().Table(scope.db.NewScope(reflect.New(destType).Elem().Interface()).TableName()) + destType := field.StructField.Struct.Type.Elem() + var isPtr bool + if destType.Kind() == reflect.Ptr { + isPtr = true + destType = destType.Elem() + } var destKeys []string var sourceKeys []string - - linkHash := make(map[string][]string) + var linkHash = make(map[string][]string) for _, key := range joinTableHandler.DestinationForeignKeys() { destKeys = append(destKeys, key.DBName) @@ -213,7 +215,13 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf } results := reflect.New(field.Struct.Type).Elem() - rows, err := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value, conditions...).Rows() + + db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()) + preloadJoinDB := joinTableHandler.PreloadWithJoin(joinTableHandler, db, scope.Value) + if len(conditions) > 0 { + preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) + } + rows, err := preloadJoinDB.Rows() if scope.Err(err) != nil { return @@ -264,8 +272,11 @@ func (scope *Scope) handleHasManyToManyPreload(field *Field, conditions []interf linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], toString(destKey)) } - results = reflect.Append(results, elem) - + if isPtr { + results = reflect.Append(results, elem.Addr()) + } else { + results = reflect.Append(results, elem) + } } if scope.IndirectValue().Kind() == reflect.Slice { diff --git a/preload_test.go b/preload_test.go index db37f37d..043e24d6 100644 --- a/preload_test.go +++ b/preload_test.go @@ -612,7 +612,7 @@ func TestManyToManyPreload(t *testing.T) { Level2 struct { ID uint `gorm:"primary_key;"` Value string - Level1s []Level1 `gorm:"many2many:levels;"` + Level1s []*Level1 `gorm:"many2many:levels;"` } ) @@ -623,7 +623,7 @@ func TestManyToManyPreload(t *testing.T) { panic(err) } - want := Level2{Value: "Bob", Level1s: []Level1{ + want := Level2{Value: "Bob", Level1s: []*Level1{ {Value: "ru"}, {Value: "en"}, }} @@ -631,7 +631,7 @@ func TestManyToManyPreload(t *testing.T) { panic(err) } - want2 := Level2{Value: "Tom", Level1s: []Level1{ + want2 := Level2{Value: "Tom", Level1s: []*Level1{ {Value: "zh"}, {Value: "de"}, }} @@ -665,6 +665,22 @@ func TestManyToManyPreload(t *testing.T) { if !reflect.DeepEqual(got3, []Level2{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } } func toJSONString(v interface{}) []byte {