diff --git a/association.go b/association.go index 027f327e..a889157b 100644 --- a/association.go +++ b/association.go @@ -1,6 +1,7 @@ package gorm import ( + "errors" "fmt" "reflect" @@ -34,16 +35,119 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { + var ( + tx = association.DB + queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + ) + + if association.Relationship.JoinTable != nil { + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + tx.Clauses(queryClause) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + association.Error = tx.Find(out, conds...).Error } return association.Error } func (association *Association) Append(values ...interface{}) error { + if association.Error == nil { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + if len(values) > 0 { + association.Error = association.Replace(values...) + } + default: + association.saveAssociation(false, values...) + } + } + return association.Error } func (association *Association) Replace(values ...interface{}) error { + if association.Error == nil { + association.saveAssociation(true, values...) + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + switch rel.Type { + case schema.HasOne, schema.HasMany: + var ( + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + } else { + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateMap[ref.ForeignKey.DBName] = nil + } + } + + _, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + column, queryValues := schema.ToQueryValues(foreignKeys, values) + association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) + case schema.Many2Many: + var primaryFields, relPrimaryFields []*schema.Field + var foreignKeys, relForeignKeys []string + modelValue := reflect.New(rel.JoinTable.ModelType).Interface() + conds := []clause.Expression{} + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } + } + + generateConds := func(rv reflect.Value) { + _, values := schema.GetIdentityFieldValuesMap(rv, primaryFields) + column, queryValues := schema.ToQueryValues(foreignKeys, values) + + relValue := rel.Field.ReflectValueOf(rv) + _, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields) + relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues) + + conds = append(conds, clause.And( + clause.IN{Column: column, Values: queryValues}, + clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}), + )) + } + + switch reflectValue.Kind() { + case reflect.Struct: + generateConds(reflectValue) + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + generateConds(reflectValue.Index(i)) + } + } + + association.DB.Where(conds).Delete(modelValue) + } + } return association.Error } @@ -78,7 +182,7 @@ func (association *Association) Delete(values ...interface{}) error { column, values := schema.ToQueryValues(foreignKeys, relQueryValues) tx.Where(clause.IN{Column: column, Values: values}) - switch association.Relationship.Type { + switch rel.Type { case schema.HasOne, schema.HasMany: modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) @@ -164,3 +268,95 @@ func (association *Association) Count() (count int) { return } + +func (association *Association) saveAssociation(clear bool, values ...interface{}) { + reflectValue := association.DB.Statement.ReflectValue + + appendToRelations := func(source, rv reflect.Value, clear bool) { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() > 0 { + association.Error = association.Relationship.Field.Set(source, rv.Index(0)) + } + case reflect.Struct: + association.Error = association.Relationship.Field.Set(source, rv) + } + case schema.HasMany, schema.Many2Many: + elemType := association.Relationship.Field.IndirectFieldType.Elem() + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue)) + if clear { + fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType) + } + + appendToFieldValues := func(ev reflect.Value) { + if ev.Type().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev) + } else if ev.Type().Elem().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev.Elem()) + } else { + association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) + } + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + appendToFieldValues(reflect.Indirect(rv.Index(i))) + } + case reflect.Struct: + appendToFieldValues(rv) + } + + if association.Error == nil { + association.Error = association.Relationship.Field.Set(source, fieldValue) + } + } + } + + selectedColumns := []string{association.Relationship.Name} + hasZero := false + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + selectedColumns = append(selectedColumns, ref.ForeignKey.Name) + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(values) != reflectValue.Len() { + if clear && len(values) == 0 { + for i := 0; i < reflectValue.Len(); i++ { + association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) + } + break + } + association.Error = errors.New("invalid association values, length doesn't match") + } + + for i := 0; i < reflectValue.Len(); i++ { + appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + + if !hasZero { + _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i)) + } + } + case reflect.Struct: + if clear && len(values) == 0 { + association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) + } + + for idx, value := range values { + appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) + } + + _, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) + } + + if hasZero { + association.DB.Save(reflectValue.Interface()) + } else { + association.DB.Select(selectedColumns).Save(reflectValue.Interface()) + } +} diff --git a/callbacks/associations.go b/callbacks/associations.go index df19d5f5..a0c296e3 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -28,7 +28,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: var ( objs []reflect.Value fieldType = rel.Field.FieldType @@ -92,7 +92,7 @@ func SaveAfterAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: var ( fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr @@ -193,7 +193,7 @@ func SaveAfterAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { appendToElems(db.Statement.ReflectValue.Index(i)) } @@ -260,7 +260,7 @@ func SaveAfterAssociations(db *gorm.DB) { } switch db.Statement.ReflectValue.Kind() { - case reflect.Slice: + case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { appendToElems(db.Statement.ReflectValue.Index(i)) }