diff --git a/preload.go b/preload.go index 541f7b95..0c8d70ad 100644 --- a/preload.go +++ b/preload.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strings" ) func getRealValue(value reflect.Value, field string) interface{} { @@ -20,90 +21,139 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { + preloadMap := map[string]bool{} if scope.Search.preload != nil { fields := scope.Fields() isSlice := scope.IndirectValue().Kind() == reflect.Slice - for key, conditions := range scope.Search.preload { - for _, field := range fields { - if field.Name == key && field.Relationship != nil { - results := makeSlice(field.Struct.Type) - relation := field.Relationship - primaryName := scope.PrimaryField().Name - associationPrimaryKey := scope.New(results).PrimaryField().Name - - switch relation.Kind { - case "has_one": - if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if isSlice { - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) - break - } - } - } else { - scope.SetColumn(field, result) - } - } - } - case "has_many": - if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - if isSlice { - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, result)) - break - } - } - } - } else { - scope.SetColumn(field, resultValues) - } - } - case "belongs_to": - if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { - scope.NewDB().Where(primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if isSlice { - value := getRealValue(result, associationPrimaryKey) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.SetColumn(field, result) - } - } - } - case "many_to_many": - scope.Err(errors.New("not supported relation")) - default: - scope.Err(errors.New("not supported relation")) + for _, preload := range scope.Search.preload { + schema, conditions := preload.schema, preload.conditions + keys := strings.Split(schema, ".") + currentScope := scope + currentFields := fields + currentIsSlice := isSlice + originalConditions := conditions + conditions = []interface{}{} + for i, key := range keys { + // log.Printf("--> %+v\n", key) + if !preloadMap[strings.Join(keys[:i+1], ".")] { + if i == len(keys)-1 { + // log.Printf("--> %+v\n", originalConditions) + conditions = originalConditions } - break + + var found bool + for _, field := range currentFields { + if field.Name == key && field.Relationship != nil { + found = true + // log.Printf("--> %+v\n", field.Name) + results := makeSlice(field.Struct.Type) + relation := field.Relationship + primaryName := currentScope.PrimaryField().Name + associationPrimaryKey := currentScope.New(results).PrimaryField().Name + + switch relation.Kind { + case "has_one": + if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { + condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) + currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if currentIsSlice { + value := getRealValue(result, relation.ForeignFieldName) + objects := currentScope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) + break + } + } + } else { + // log.Printf("--> %+v\n", result.Interface()) + err := currentScope.SetColumn(field, result) + if err != nil { + scope.Err(err) + return + } + // printutils.PrettyPrint(currentScope.Value) + } + } + // printutils.PrettyPrint(currentScope.Value) + } + case "has_many": + // log.Printf("--> %+v\n", key) + if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { + condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) + currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + if currentIsSlice { + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + value := getRealValue(result, relation.ForeignFieldName) + objects := currentScope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, primaryName), value) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + break + } + } + } + // printutils.PrettyPrint(currentScope.IndirectValue().Interface()) + } else { + currentScope.SetColumn(field, resultValues) + } + } + case "belongs_to": + if primaryKeys := currentScope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { + currentScope.NewDB().Where(primaryKeys).Find(results, conditions...) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if currentIsSlice { + value := getRealValue(result, associationPrimaryKey) + objects := currentScope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + currentScope.SetColumn(field, result) + } + } + } + case "many_to_many": + // currentScope.Err(errors.New("not supported relation")) + fallthrough + default: + currentScope.Err(errors.New("not supported relation")) + } + break + } + } + + if !found { + value := reflect.ValueOf(currentScope.Value) + if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { + value = value.Index(0).Elem() + } + scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type())) + return + } + + preloadMap[strings.Join(keys[:i+1], ".")] = true + } + + if i < len(keys)-1 { + // TODO: update current scope + currentScope = currentScope.getColumnsAsScope(key) + currentFields = currentScope.Fields() + currentIsSlice = currentScope.IndirectValue().Kind() == reflect.Slice } } } @@ -120,19 +170,44 @@ func makeSlice(typ reflect.Type) interface{} { return slice.Interface() } -func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) { +func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: - primaryKeyMap := map[interface{}]bool{} for i := 0; i < values.Len(); i++ { - primaryKeyMap[reflect.Indirect(values.Index(i)).FieldByName(column).Interface()] = true - } - for key := range primaryKeyMap { - primaryKeys = append(primaryKeys, key) + columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) } case reflect.Struct: return []interface{}{values.FieldByName(column).Interface()} } return } + +func (scope *Scope) getColumnsAsScope(column string) *Scope { + values := scope.IndirectValue() + // log.Println(values.Type(), column) + switch values.Kind() { + case reflect.Slice: + fieldType, _ := values.Type().Elem().FieldByName(column) + var columns reflect.Value + if fieldType.Type.Kind() == reflect.Slice { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem() + } else { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type))).Elem() + } + for i := 0; i < values.Len(); i++ { + column := reflect.Indirect(values.Index(i)).FieldByName(column) + if column.Kind() == reflect.Slice { + for i := 0; i < column.Len(); i++ { + columns = reflect.Append(columns, column.Index(i).Addr()) + } + } else { + columns = reflect.Append(columns, column.Addr()) + } + } + return scope.New(columns.Interface()) + case reflect.Struct: + return scope.New(values.FieldByName(column).Addr().Interface()) + } + return nil +} diff --git a/preload_test.go b/preload_test.go index 2547933b..c5d395d4 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1,6 +1,11 @@ package gorm_test -import "testing" +import ( + "encoding/json" + "log" + "reflect" + "testing" +) func getPreloadUser(name string) *User { return getPreparedUser(name, "Preload") @@ -85,3 +90,417 @@ func TestPreload(t *testing.T) { } } } + +func TestNestedPreload(t *testing.T) { + log.SetFlags(log.Lshortfile) + // Struct: Level3 + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + + // Slice: []Level3 + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } +} + +func toJSONString(v interface{}) []byte { + r, _ := json.MarshalIndent(v, "", " ") + return r +} diff --git a/search.go b/search.go index a180fb92..9411af43 100644 --- a/search.go +++ b/search.go @@ -14,7 +14,7 @@ type search struct { omits []string orders []string joins string - preload map[string][]interface{} + preload []searchPreload offset string limit string group string @@ -23,6 +23,11 @@ type search struct { Unscoped bool } +type searchPreload struct { + schema string + conditions []interface{} +} + func (s *search) clone() *search { clone := *s return &clone @@ -97,11 +102,15 @@ func (s *search) Joins(query string) *search { return s } -func (s *search) Preload(column string, values ...interface{}) *search { - if s.preload == nil { - s.preload = map[string][]interface{}{} +func (s *search) Preload(schema string, values ...interface{}) *search { + var preloads []searchPreload + for _, preload := range s.preload { + if preload.schema != schema { + preloads = append(preloads, preload) + } } - s.preload[column] = values + preloads = append(preloads, searchPreload{schema, values}) + s.preload = preloads return s }