diff --git a/association_test.go b/association_test.go index 485a68e4..994649a5 100644 --- a/association_test.go +++ b/association_test.go @@ -88,6 +88,12 @@ func TestRelated(t *testing.T) { t.Errorf("Should have two emails") } + var emails3 []*Email + DB.Model(&user).Related(&emails3) + if len(emails3) != 2 { + t.Errorf("Should have two emails") + } + var user1 User DB.Model(&user).Related(&user1.Emails) if len(user1.Emails) != 2 { diff --git a/scope_private.go b/scope_private.go index 4fd7149d..eddcfcc3 100644 --- a/scope_private.go +++ b/scope_private.go @@ -407,18 +407,20 @@ func (scope *Scope) count(value interface{}) *Scope { } func (scope *Scope) typeName() string { - value := scope.IndirectValue() - if value.Kind() == reflect.Slice { - return value.Type().Elem().Name() + typ := scope.IndirectValue().Type() + + for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { + typ = typ.Elem() } - return value.Type().Name() + return typ.Name() } func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) fromFields := scope.Fields() toFields := toScope.Fields() + for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { var fromField, toField *Field if field, ok := scope.FieldByName(foreignKey); ok {