diff --git a/association.go b/association.go index 862276d5..d1984229 100644 --- a/association.go +++ b/association.go @@ -1,27 +1,28 @@ package gorm import ( + "errors" "fmt" "reflect" ) -// Association Association Mode contains some helper methods to handle relationship things easily. +// Association Mode contains some helper methods to handle relationship things easily. type Association struct { - Scope *Scope - Column string Error error - Field *Field + scope *Scope + column string + field *Field } // Find find out all related associations func (association *Association) Find(value interface{}) *Association { - association.Scope.related(value, association.Column) - return association.setErr(association.Scope.db.Error) + association.scope.related(value, association.column) + return association.setErr(association.scope.db.Error) } -// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to +// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to func (association *Association) Append(values ...interface{}) *Association { - if relationship := association.Field.Relationship; relationship.Kind == "has_one" { + if relationship := association.field.Relationship; relationship.Kind == "has_one" { return association.Replace(values...) } return association.saveAssociations(values...) @@ -30,14 +31,14 @@ func (association *Association) Append(values ...interface{}) *Association { // Replace replace current associations with new one func (association *Association) Replace(values ...interface{}) *Association { var ( - relationship = association.Field.Relationship - scope = association.Scope - field = association.Field.Field + relationship = association.field.Relationship + scope = association.scope + field = association.field.Field newDB = scope.NewDB() ) // Append new values - association.Field.Set(reflect.Zero(association.Field.Field.Type())) + association.field.Set(reflect.Zero(association.field.Field.Type())) association.saveAssociations(values...) // Belongs To @@ -109,7 +110,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - fieldValue := reflect.New(association.Field.Field.Type()).Interface() + fieldValue := reflect.New(association.field.Field.Type()).Interface() association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) } } @@ -119,9 +120,9 @@ func (association *Association) Replace(values ...interface{}) *Association { // Delete remove relationship between source & passed arguments, but won't delete those arguments func (association *Association) Delete(values ...interface{}) *Association { var ( - relationship = association.Field.Relationship - scope = association.Scope - field = association.Field.Field + relationship = association.field.Relationship + scope = association.scope + field = association.field.Field newDB = scope.NewDB() ) @@ -196,18 +197,18 @@ func (association *Association) Delete(values ...interface{}) *Association { ) // set matched relation's foreign key to be null - fieldValue := reflect.New(association.Field.Field.Type()).Interface() + fieldValue := reflect.New(association.field.Field.Type()).Interface() association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) } } // Remove deleted records from source's field if association.Error == nil { - if association.Field.Field.Kind() == reflect.Slice { - leftValues := reflect.Zero(association.Field.Field.Type()) + if field.Kind() == reflect.Slice { + leftValues := reflect.Zero(field.Type()) - for i := 0; i < association.Field.Field.Len(); i++ { - reflectValue := association.Field.Field.Index(i) + for i := 0; i < field.Len(); i++ { + reflectValue := field.Index(i) primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] var isDeleted = false for _, pk := range deletingPrimaryKeys { @@ -221,12 +222,12 @@ func (association *Association) Delete(values ...interface{}) *Association { } } - association.Field.Set(leftValues) - } else if association.Field.Field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0] + association.field.Set(leftValues) + } else if field.Kind() == reflect.Struct { + primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] for _, pk := range deletingPrimaryKeys { if equalAsString(primaryKey, pk) { - association.Field.Set(reflect.Zero(association.Field.Field.Type())) + association.field.Set(reflect.Zero(field.Type())) break } } @@ -245,14 +246,14 @@ func (association *Association) Clear() *Association { func (association *Association) Count() int { var ( count = 0 - relationship = association.Field.Relationship - scope = association.Scope - fieldValue = association.Field.Field.Interface() + relationship = association.field.Relationship + scope = association.scope + fieldValue = association.field.Field.Interface() query = scope.DB() ) if relationship.Kind == "many_to_many" { - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value) + query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( @@ -277,3 +278,81 @@ func (association *Association) Count() int { query.Model(fieldValue).Count(&count) return count } + +// saveAssociations save passed values as associations +func (association *Association) saveAssociations(values ...interface{}) *Association { + var ( + scope = association.scope + field = association.field + relationship = field.Relationship + ) + + saveAssociation := func(reflectValue reflect.Value) { + // value has to been pointer + if reflectValue.Kind() != reflect.Ptr { + reflectPtr := reflect.New(reflectValue.Type()) + reflectPtr.Elem().Set(reflectValue) + reflectValue = reflectPtr + } + + // value has to been saved for many2many + if relationship.Kind == "many_to_many" { + if scope.New(reflectValue.Interface()).PrimaryKeyZero() { + association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) + } + } + + // Assign Fields + var fieldType = field.Field.Type() + var setFieldBackToValue, setSliceFieldBackToValue bool + if reflectValue.Type().AssignableTo(fieldType) { + field.Set(reflectValue) + } else if reflectValue.Type().Elem().AssignableTo(fieldType) { + // if field's type is struct, then need to set value back to argument after save + setFieldBackToValue = true + field.Set(reflectValue.Elem()) + } else if fieldType.Kind() == reflect.Slice { + if reflectValue.Type().AssignableTo(fieldType.Elem()) { + field.Set(reflect.Append(field.Field, reflectValue)) + } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { + // if field's type is slice of struct, then need to set value back to argument after save + setSliceFieldBackToValue = true + field.Set(reflect.Append(field.Field, reflectValue.Elem())) + } + } + + if relationship.Kind == "many_to_many" { + association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) + } else { + association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) + + if setFieldBackToValue { + reflectValue.Elem().Set(field.Field) + } else if setSliceFieldBackToValue { + reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) + } + } + } + + for _, value := range values { + reflectValue := reflect.ValueOf(value) + indirectReflectValue := reflect.Indirect(reflectValue) + if indirectReflectValue.Kind() == reflect.Struct { + saveAssociation(reflectValue) + } else if indirectReflectValue.Kind() == reflect.Slice { + for i := 0; i < indirectReflectValue.Len(); i++ { + saveAssociation(indirectReflectValue.Index(i)) + } + } else { + association.setErr(errors.New("invalid value type")) + } + } + return association +} + +func (association *Association) setErr(err error) *Association { + if err != nil { + association.Error = err + } + return association +} diff --git a/association_utils.go b/association_utils.go deleted file mode 100644 index 912c9ca4..00000000 --- a/association_utils.go +++ /dev/null @@ -1,122 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} - -func (association *Association) saveAssociations(values ...interface{}) *Association { - scope := association.Scope - field := association.Field - relationship := association.Field.Relationship - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for _ = range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { - for _, primaryValue := range primaryValues { - for _, value := range primaryValue { - values = append(values, value) - } - } - return values -} diff --git a/main.go b/main.go index 376a967f..ce3f7fb1 100644 --- a/main.go +++ b/main.go @@ -480,7 +480,7 @@ func (s *DB) Association(column string) *Association { if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) } else { - return &Association{Scope: scope, Column: column, Field: field} + return &Association{scope: scope, column: column, field: field} } } else { err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) diff --git a/utils.go b/utils.go index a4cf0b8c..58b14ac4 100644 --- a/utils.go +++ b/utils.go @@ -2,6 +2,7 @@ package gorm import ( "bytes" + "fmt" "strings" "sync" ) @@ -100,3 +101,42 @@ type expr struct { func Expr(expression string, args ...interface{}) *expr { return &expr{expr: expression, args: args} } + +func toQueryMarks(primaryValues [][]interface{}) string { + var results []string + + for _, primaryValue := range primaryValues { + var marks []string + for _ = range primaryValue { + marks = append(marks, "?") + } + + if len(marks) > 1 { + results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) + } else { + results = append(results, strings.Join(marks, "")) + } + } + return strings.Join(results, ",") +} + +func toQueryCondition(scope *Scope, columns []string) string { + var newColumns []string + for _, column := range columns { + newColumns = append(newColumns, scope.Quote(column)) + } + + if len(columns) > 1 { + return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) + } + return strings.Join(newColumns, ",") +} + +func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { + for _, primaryValue := range primaryValues { + for _, value := range primaryValue { + values = append(values, value) + } + } + return values +}