diff --git a/association.go b/association.go index e0978f2a..f828aea4 100644 --- a/association.go +++ b/association.go @@ -1,12 +1,11 @@ package gorm import ( - "errors" "fmt" "reflect" - "strings" ) +// Association Association Mode contains some helper methods to handle relationship things easily. type Association struct { Scope *Scope Column string @@ -14,86 +13,13 @@ type Association struct { Field *Field } -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} - +// Find find out all related associations func (association *Association) Find(value interface{}) *Association { association.Scope.related(value, association.Column) return association.setErr(association.Scope.db.Error) } -func (association *Association) saveAssociations(values ...interface{}) *Association { - scope := association.Scope - field := association.Field - relationship := association.Field.Relationship - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - +// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to func (association *Association) Append(values ...interface{}) *Association { if relationship := association.Field.Relationship; relationship.Kind == "has_one" { return association.Replace(values...) @@ -101,6 +27,7 @@ func (association *Association) Append(values ...interface{}) *Association { return association.saveAssociations(values...) } +// Replace replace current associations with new one func (association *Association) Replace(values ...interface{}) *Association { var ( relationship = association.Field.Relationship @@ -115,7 +42,7 @@ func (association *Association) Replace(values ...interface{}) *Association { // Belongs To if relationship.Kind == "belongs_to" { - // Set foreign key to be null only when clearing value + // Set foreign key to be null when clearing value (length equals 0) if len(values) == 0 { // Set foreign key to be nil var foreignKeyMap = map[string]interface{}{} @@ -125,29 +52,21 @@ func (association *Association) Replace(values ...interface{}) *Association { association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) } } else { - // Relations + // Polymorphic Relations if relationship.PolymorphicDBName != "" { newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } // Relations except new created if len(values) > 0 { - var newPrimaryKeys [][]interface{} var associationForeignFieldNames []string - if relationship.Kind == "many_to_many" { - // If many to many relations, get it from foreign key associationForeignFieldNames = relationship.AssociationForeignFieldNames } else { - // If other relations, get real primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() { - if field.IsPrimaryKey { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } + associationForeignFieldNames = relationship.AssociationForeignDBNames } - newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface()) + newPrimaryKeys := association.getPrimaryKeys(associationForeignFieldNames, field.Interface()) if len(newPrimaryKeys) > 0 { sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) @@ -156,13 +75,11 @@ func (association *Association) Replace(values ...interface{}) *Association { } if relationship.Kind == "many_to_many" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } + if sourcePrimaryKeys := association.getPrimaryKeys(relationship.ForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { + newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { var foreignKeyMap = map[string]interface{}{} for idx, foreignKey := range relationship.ForeignDBNames { @@ -179,6 +96,7 @@ func (association *Association) Replace(values ...interface{}) *Association { return association } +// Delete remove relationship between source & passed arguments, but won't delete those arguments func (association *Association) Delete(values ...interface{}) *Association { var ( relationship = association.Field.Relationship @@ -292,10 +210,12 @@ func (association *Association) Delete(values ...interface{}) *Association { return association } +// Clear remove relationship between source & current associations, won't delete those associations func (association *Association) Clear() *Association { return association.Replace() } +// Count return the count of current associations func (association *Association) Count() int { var ( count = 0 @@ -333,78 +253,3 @@ func (association *Association) Count() int { return count } - -func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) { - scope := association.Scope - - for _, value := range values { - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - if reflectValue.Kind() == reflect.Slice { - for i := 0; i < reflectValue.Len(); i++ { - primaryKeys := []interface{}{} - newScope := scope.New(reflectValue.Index(i).Interface()) - for _, column := range columns { - if field, ok := newScope.FieldByName(column); ok { - primaryKeys = append(primaryKeys, field.Field.Interface()) - } else { - primaryKeys = append(primaryKeys, "") - } - } - results = append(results, primaryKeys) - } - } else if reflectValue.Kind() == reflect.Struct { - newScope := scope.New(value) - var primaryKeys []interface{} - for _, column := range columns { - if field, ok := newScope.FieldByName(column); ok { - primaryKeys = append(primaryKeys, field.Field.Interface()) - } else { - primaryKeys = append(primaryKeys, "") - } - } - - results = append(results, primaryKeys) - } - } - - return -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for _ = range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { - for _, primaryValue := range primaryValues { - for _, value := range primaryValue { - values = append(values, value) - } - } - return values -} diff --git a/association_utils.go b/association_utils.go new file mode 100644 index 00000000..7ec2ab7f --- /dev/null +++ b/association_utils.go @@ -0,0 +1,158 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +func (association *Association) setErr(err error) *Association { + if err != nil { + association.Error = err + } + return association +} + +func (association *Association) saveAssociations(values ...interface{}) *Association { + scope := association.Scope + field := association.Field + relationship := association.Field.Relationship + + saveAssociation := func(reflectValue reflect.Value) { + // value has to been pointer + if reflectValue.Kind() != reflect.Ptr { + reflectPtr := reflect.New(reflectValue.Type()) + reflectPtr.Elem().Set(reflectValue) + reflectValue = reflectPtr + } + + // value has to been saved for many2many + if relationship.Kind == "many_to_many" { + if scope.New(reflectValue.Interface()).PrimaryKeyZero() { + association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) + } + } + + // Assign Fields + var fieldType = field.Field.Type() + var setFieldBackToValue, setSliceFieldBackToValue bool + if reflectValue.Type().AssignableTo(fieldType) { + field.Set(reflectValue) + } else if reflectValue.Type().Elem().AssignableTo(fieldType) { + // if field's type is struct, then need to set value back to argument after save + setFieldBackToValue = true + field.Set(reflectValue.Elem()) + } else if fieldType.Kind() == reflect.Slice { + if reflectValue.Type().AssignableTo(fieldType.Elem()) { + field.Set(reflect.Append(field.Field, reflectValue)) + } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { + // if field's type is slice of struct, then need to set value back to argument after save + setSliceFieldBackToValue = true + field.Set(reflect.Append(field.Field, reflectValue.Elem())) + } + } + + if relationship.Kind == "many_to_many" { + association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) + } else { + association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) + + if setFieldBackToValue { + reflectValue.Elem().Set(field.Field) + } else if setSliceFieldBackToValue { + reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) + } + } + } + + for _, value := range values { + reflectValue := reflect.ValueOf(value) + indirectReflectValue := reflect.Indirect(reflectValue) + if indirectReflectValue.Kind() == reflect.Struct { + saveAssociation(reflectValue) + } else if indirectReflectValue.Kind() == reflect.Slice { + for i := 0; i < indirectReflectValue.Len(); i++ { + saveAssociation(indirectReflectValue.Index(i)) + } + } else { + association.setErr(errors.New("invalid value type")) + } + } + return association +} + +func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) { + scope := association.Scope + + for _, value := range values { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + if reflectValue.Kind() == reflect.Slice { + for i := 0; i < reflectValue.Len(); i++ { + primaryKeys := []interface{}{} + newScope := scope.New(reflectValue.Index(i).Interface()) + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + results = append(results, primaryKeys) + } + } else if reflectValue.Kind() == reflect.Struct { + newScope := scope.New(value) + var primaryKeys []interface{} + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + + results = append(results, primaryKeys) + } + } + + return +} + +func toQueryMarks(primaryValues [][]interface{}) string { + var results []string + + for _, primaryValue := range primaryValues { + var marks []string + for _ = range primaryValue { + marks = append(marks, "?") + } + + if len(marks) > 1 { + results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) + } else { + results = append(results, strings.Join(marks, "")) + } + } + return strings.Join(results, ",") +} + +func toQueryCondition(scope *Scope, columns []string) string { + var newColumns []string + for _, column := range columns { + newColumns = append(newColumns, scope.Quote(column)) + } + + if len(columns) > 1 { + return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) + } + return strings.Join(newColumns, ",") +} + +func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { + for _, primaryValue := range primaryValues { + for _, value := range primaryValue { + values = append(values, value) + } + } + return values +}