From 0ce635cc679baf86d3e611b12a39c13065d64f44 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Dec 2015 15:13:40 +0800 Subject: [PATCH] Add Append support for other associations --- association.go | 53 +++++++++++++++++++++++---------------- association_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 21 deletions(-) diff --git a/association.go b/association.go index 1773f61f..9838f440 100644 --- a/association.go +++ b/association.go @@ -29,42 +29,53 @@ func (association *Association) Find(value interface{}) *Association { func (association *Association) Append(values ...interface{}) *Association { scope := association.Scope field := association.Field + relationship := association.Field.Relationship - createJoinTable := func(reflectValue reflect.Value) { - var value = reflectValue.Interface() + saveAssociation := func(reflectValue reflect.Value) { + // value has to been pointer if reflectValue.Kind() != reflect.Ptr { reflectPtr := reflect.New(reflectValue.Type()) reflectPtr.Elem().Set(reflectValue) - value = reflectPtr.Interface() + reflectValue = reflectPtr } - if scope.New(value).PrimaryKeyZero() { - scope.NewDB().Save(value) + // value has to been saved + if scope.New(reflectValue.Interface()).PrimaryKeyZero() { + scope.NewDB().Save(reflectValue.Interface()) } - relationship := association.Field.Relationship - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, value)) + // Assign Fields + fieldType := field.Field.Type() + if reflectValue.Type().AssignableTo(fieldType) { + field.Set(reflectValue) + } else if reflectValue.Type().Elem().AssignableTo(fieldType) { + field.Set(reflectValue.Elem()) + } else if fieldType.Kind() == reflect.Slice { + if reflectValue.Type().AssignableTo(fieldType.Elem()) { + field.Set(reflect.Append(field.Field, reflectValue)) + } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { + field.Set(reflect.Append(field.Field, reflectValue.Elem())) + } + } - result := reflect.ValueOf(value) - fieldElemType := field.Field.Type().Elem() - if result.Type().AssignableTo(fieldElemType) { - field.Set(reflect.Append(field.Field, result)) - } else if result.Type().Elem().AssignableTo(fieldElemType) { - field.Set(reflect.Append(field.Field, result.Elem())) + if relationship.Kind == "many_to_many" { + association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) + } else { + association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) } } for _, value := range values { - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - - if reflectValue.Kind() == reflect.Struct { - createJoinTable(reflectValue) - } else if reflectValue.Kind() == reflect.Slice { - for i := 0; i < reflectValue.Len(); i++ { - createJoinTable(reflectValue.Index(i)) + reflectValue := reflect.ValueOf(value) + indirectReflectValue := reflect.Indirect(reflectValue) + if indirectReflectValue.Kind() == reflect.Struct { + saveAssociation(reflectValue) + } else if indirectReflectValue.Kind() == reflect.Slice { + for i := 0; i < indirectReflectValue.Len(); i++ { + saveAssociation(indirectReflectValue.Index(i)) } } else { - association.setErr(errors.New("invalid association type")) + association.setErr(errors.New("invalid value type")) } } return association diff --git a/association_test.go b/association_test.go index 994649a5..a5cc1860 100644 --- a/association_test.go +++ b/association_test.go @@ -5,6 +5,66 @@ import ( "testing" ) +func TestHasOne(t *testing.T) { + DB.DropTable(Category{}, Post{}) + DB.CreateTable(Category{}, Post{}) + + post := Post{ + Title: "post 1", + Body: "body 1", + Category: Category{Name: "Category 1"}, + MainCategory: Category{Name: "Main Category 1"}, + } + + if err := DB.Save(&post).Error; err != nil { + t.Errorf("Got errors when save post", err.Error()) + } + + // Query + var category Category + DB.Model(&post).Association("Category").Find(&category) + if category.Name != "Category 1" { + t.Errorf("Query has one relations with Association") + } + + var category1 Category + DB.Model(&post).Related(&category1) + if category1.Name != "Category 1" { + t.Errorf("Query has one relations with Related") + } + + // Append + var category2 = Category{ + Name: "Category 2", + } + DB.Model(&post).Association("Category").Append(&category2) + + if category2.Id == 0 { + t.Errorf("Category should has ID when created with Append") + } + + var category21 Category + DB.Model(&post).Related(&category21) + + if category21.Name != "Category 2" { + t.Errorf("Category should be updated with Append") + } + + // Replace + // DB.Model(&post).Association("Category").Replace(&Category{ + // Name: "Category 3", + // }) + + // var category3 Category + // DB.Model(&post).Related(&category3) + // if category3.Name != "Category 3" { + // t.Errorf("Category should be updated with Replace") + // } + + // Delete + // Clear +} + func TestHasOneAndHasManyAssociation(t *testing.T) { DB.DropTable(Category{}, Post{}, Comment{}) DB.CreateTable(Category{}, Post{}, Comment{})