diff --git a/association.go b/association.go index e34a10bd..f62e712b 100644 --- a/association.go +++ b/association.go @@ -4,14 +4,14 @@ import ( "errors" "fmt" "reflect" + "strings" ) type Association struct { - Scope *Scope - PrimaryKey interface{} - Column string - Error error - Field *Field + Scope *Scope + Column string + Error error + Field *Field } func (association *Association) setErr(err error) *Association { @@ -45,60 +45,43 @@ func (association *Association) Append(values ...interface{}) *Association { return association.setErr(scope.db.Error) } -func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} { - primaryKeys := []interface{}{} +func (association *Association) Delete(values ...interface{}) *Association { scope := association.Scope + relationship := association.Field.Relationship - for _, value := range values { - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - if reflectValue.Kind() == reflect.Slice { - for i := 0; i < reflectValue.Len(); i++ { - if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank { - primaryKeys = append(primaryKeys, primaryField.Field.Interface()) - } - } - } else if reflectValue.Kind() == reflect.Struct { - if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank { - primaryKeys = append(primaryKeys, primaryField.Field.Interface()) + // many to many + if relationship.Kind == "many_to_many" { + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) } } - } - return primaryKeys -} -func (association *Association) Delete(values ...interface{}) *Association { - primaryKeys := association.getPrimaryKeys(values...) + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) + query = query.Where(sql, toQueryValues(primaryKeys)...) - if len(primaryKeys) == 0 { - association.setErr(errors.New("no primary key found")) - } else { - scope := association.Scope - relationship := association.Field.Relationship - // many to many - if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { - leftValues := reflect.Zero(association.Field.Field.Type()) - for i := 0; i < association.Field.Field.Len(); i++ { - value := association.Field.Field.Index(i) - if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { - var included = false - for _, primaryKey := range primaryKeys { - if equalAsString(primaryKey, primaryField.Field.Interface()) { - included = true - } - } - if !included { - leftValues = reflect.Append(leftValues, value) + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { + leftValues := reflect.Zero(association.Field.Field.Type()) + for i := 0; i < association.Field.Field.Len(); i++ { + value := association.Field.Field.Index(i) + if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { + var included = false + for _, primaryKey := range primaryKeys { + if equalAsString(primaryKey, primaryField.Field.Interface()) { + included = true } } + if !included { + leftValues = reflect.Append(leftValues, value) + } } - association.Field.Set(leftValues) } - } else { - association.setErr(errors.New("delete only support many to many")) + association.Field.Set(leftValues) } + } else { + association.setErr(errors.New("delete only support many to many")) } return association } @@ -109,12 +92,12 @@ func (association *Association) Replace(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { field := association.Field.Field - oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) + oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Append(values...) - newPrimaryKeys := association.getPrimaryKeys(field.Interface()) + newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) - var addedPrimaryKeys = []interface{}{} + var addedPrimaryKeys = [][]interface{}{} for _, newKey := range newPrimaryKeys { hasEqual := false for _, oldKey := range oldPrimaryKeys { @@ -127,13 +110,21 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, newKey) } } - for _, primaryKey := range association.getPrimaryKeys(values...) { + + for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } if len(addedPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) + query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) } } else { @@ -146,8 +137,13 @@ func (association *Association) Clear() *Association { relationship := association.Field.Relationship scope := association.Scope if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey) + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { @@ -168,18 +164,103 @@ func (association *Association) Count() int { if relationship.Kind == "many_to_many" { relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) + query := scope.DB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + field.Field.Interface()) + } + } + if relationship.PolymorphicType != "" { - countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - countScope.Count(&count) + query.Count(&count) } else if relationship.Kind == "belongs_to" { - if v, ok := scope.FieldByName(association.Column); ok { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count) + query := scope.DB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + field.Field.Interface()) + } } + query.Table(newScope.TableName()).Count(&count) } return count } + +func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} { + results := [][]interface{}{} + scope := association.Scope + + for _, value := range values { + primaryKeys := []interface{}{} + + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + if reflectValue.Kind() == reflect.Slice { + for i := 0; i < reflectValue.Len(); i++ { + newScope := scope.New(reflectValue.Index(i).Interface()) + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + } + } else if reflectValue.Kind() == reflect.Struct { + newScope := scope.New(value) + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + } + + results = append(results, primaryKeys) + } + return results +} + +func toQueryMarks(primaryValues [][]interface{}) string { + var results []string + + for _, primaryValue := range primaryValues { + var marks []string + for range primaryValue { + marks = append(marks, "?") + } + + if len(marks) > 1 { + results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) + } else { + results = append(results, strings.Join(marks, "")) + } + } + return strings.Join(results, ",") +} + +func toQueryCondition(scope *Scope, columns []string) string { + var newColumns []string + for _, column := range columns { + newColumns = append(newColumns, scope.Quote(column)) + } + + if len(columns) > 1 { + return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) + } else { + return strings.Join(columns, ",") + } +} + +func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { + for _, primaryValue := range primaryValues { + for _, value := range primaryValue { + values = append(values, value) + } + } + return values +} diff --git a/main.go b/main.go index 7c4c4df4..e7f93a02 100644 --- a/main.go +++ b/main.go @@ -448,7 +448,7 @@ func (s *DB) Association(column string) *Association { if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) } else { - return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field} + return &Association{Scope: scope, Column: column, Field: field} } } else { err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)