diff --git a/association.go b/association.go index 2f26b74a..62d638b6 100644 --- a/association.go +++ b/association.go @@ -30,27 +30,12 @@ func (association *Association) Find(value interface{}) *Association { func (association *Association) Append(values ...interface{}) *Association { scope := association.Scope field := association.Field - fieldType := field.Field.Type() for _, value := range values { - reflectvalue := reflect.ValueOf(value) - if reflectvalue.Kind() == reflect.Ptr { - if reflectvalue.Elem().Kind() == reflect.Struct { - if fieldType.Elem().Kind() == reflect.Ptr { - field.Set(reflect.Append(field.Field, reflectvalue)) - } else if fieldType.Elem().Kind() == reflect.Struct { - field.Set(reflect.Append(field.Field, reflectvalue.Elem())) - } - } else if reflectvalue.Elem().Kind() == reflect.Slice { - if fieldType.Elem().Kind() == reflect.Ptr { - field.Set(reflect.AppendSlice(field.Field, reflectvalue)) - } else if fieldType.Elem().Kind() == reflect.Struct { - field.Set(reflect.AppendSlice(field.Field, reflectvalue.Elem())) - } - } - } else if reflectvalue.Kind() == reflect.Struct && fieldType.Elem().Kind() == reflect.Struct { + reflectvalue := reflect.Indirect(reflect.ValueOf(value)) + if reflectvalue.Kind() == reflect.Struct { field.Set(reflect.Append(field.Field, reflectvalue)) - } else if reflectvalue.Kind() == reflect.Slice && fieldType.Elem() == reflectvalue.Type().Elem() { + } else if reflectvalue.Kind() == reflect.Slice { field.Set(reflect.AppendSlice(field.Field, reflectvalue)) } else { association.setErr(errors.New("invalid association type")) @@ -65,23 +50,18 @@ func (association *Association) getPrimaryKeys(values ...interface{}) []interfac scope := association.Scope for _, value := range values { - reflectValue := reflect.ValueOf(value) - if reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } + reflectValue := reflect.Indirect(reflect.ValueOf(value)) if reflectValue.Kind() == reflect.Slice { for i := 0; i < reflectValue.Len(); i++ { - newScope := scope.New(reflectValue.Index(i).Interface()) - primaryKey := newScope.PrimaryKeyValue() - if !reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { - primaryKeys = append(primaryKeys, primaryKey) + primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryKeyField() + if !primaryField.IsBlank { + primaryKeys = append(primaryKeys, primaryField.Field.Interface()) } } } else if reflectValue.Kind() == reflect.Struct { - newScope := scope.New(value) - primaryKey := newScope.PrimaryKeyValue() - if !reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { - primaryKeys = append(primaryKeys, primaryKey) + primaryField := scope.New(value).PrimaryKeyField() + if !primaryField.IsBlank { + primaryKeys = append(primaryKeys, primaryField.Field.Interface()) } } } @@ -121,16 +101,16 @@ func (association *Association) Replace(values ...interface{}) *Association { newPrimaryKeys := association.getPrimaryKeys(field.Interface()) var addedPrimaryKeys = []interface{}{} - for _, new := range newPrimaryKeys { + for _, newKey := range newPrimaryKeys { hasEqual := false - for _, old := range oldPrimaryKeys { - if reflect.DeepEqual(new, old) { + for _, oldKey := range oldPrimaryKeys { + if reflect.DeepEqual(newKey, oldKey) { hasEqual = true break } } if !hasEqual { - addedPrimaryKeys = append(addedPrimaryKeys, new) + addedPrimaryKeys = append(addedPrimaryKeys, newKey) } } for _, primaryKey := range association.getPrimaryKeys(values...) {