diff --git a/model_struct.go b/model_struct.go index 6e1ff055..7e4b683c 100644 --- a/model_struct.go +++ b/model_struct.go @@ -205,7 +205,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var foreignKeys []string if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { - foreignKeys := append(foreignKeys, gormSettings["FOREIGNKEY"]) + foreignKeys = append(foreignKeys, foreignKey) } switch indirectType.Kind() { case reflect.Slice: diff --git a/preload.go b/preload.go index 03910c44..0a302ab2 100644 --- a/preload.go +++ b/preload.go @@ -8,12 +8,15 @@ import ( "strings" ) -func getRealValue(value reflect.Value, field string) interface{} { - result := reflect.Indirect(value).FieldByName(field).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() +func getRealValue(value reflect.Value, columns []string) (results []interface{}) { + for _, column := range columns { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) } - return result + return } func equalAsString(a interface{}, b interface{}) bool { @@ -97,26 +100,23 @@ func makeSlice(typ reflect.Type) interface{} { } func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) break } @@ -131,27 +131,24 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) if scope.IndirectValue().Kind() == reflect.Slice { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { + if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, result)) break @@ -165,25 +162,23 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { relation := field.Relationship - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName) + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - associationPrimaryKey := scope.New(results).PrimaryField().Name - - scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, associationPrimaryKey) + value := getRealValue(result, relation.AssociationForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } } @@ -193,15 +188,23 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } -func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { +func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: for i := 0; i < values.Len(); i++ { - columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) + var result []interface{} + for _, column := range columns { + result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) + } + results = append(results, result) } case reflect.Struct: - return []interface{}{values.FieldByName(column).Interface()} + var result []interface{} + for _, column := range columns { + result = append(result, values.FieldByName(column).Interface()) + } + return [][]interface{}{result} } return }