Keep Refactoring Association Mode

This commit is contained in:
Jinzhu 2016-01-16 06:00:18 +08:00
parent 822e895d4d
commit 67874f9232
2 changed files with 64 additions and 76 deletions

View File

@ -57,16 +57,23 @@ func (association *Association) Replace(values ...interface{}) *Association {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
} }
// Relations except new created // Delete Relations except new created
if len(values) > 0 { if len(values) > 0 {
var associationForeignFieldNames []string var associationForeignFieldNames []string
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
associationForeignFieldNames = relationship.AssociationForeignFieldNames // if many to many relations, get association fields name from association foreign keys
associationFields := scope.New(reflect.New(field.Type()).Interface()).Fields()
for _, dbName := range relationship.AssociationForeignFieldNames {
associationForeignFieldNames = append(associationForeignFieldNames, associationFields[dbName].Name)
}
} else { } else {
associationForeignFieldNames = relationship.AssociationForeignDBNames // If other relations, use primary keys
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
} }
newPrimaryKeys := association.getPrimaryKeys(associationForeignFieldNames, field.Interface()) newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
if len(newPrimaryKeys) > 0 { if len(newPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
@ -75,12 +82,25 @@ func (association *Association) Replace(values ...interface{}) *Association {
} }
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
if sourcePrimaryKeys := association.getPrimaryKeys(relationship.ForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { // if many to many relations, delete related relations from join table
// get source fields name from source foreign keys
var (
sourceFields = scope.Fields()
sourceForeignFieldNames []string
)
for _, dbName := range relationship.ForeignFieldNames {
sourceForeignFieldNames = append(sourceForeignFieldNames, sourceFields[dbName].Name)
}
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
} }
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
var foreignKeyMap = map[string]interface{}{} var foreignKeyMap = map[string]interface{}{}
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil foreignKeyMap[foreignKey] = nil
@ -110,11 +130,9 @@ func (association *Association) Delete(values ...interface{}) *Association {
} }
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() { for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
if field.IsPrimaryKey { deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
}
} }
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
@ -127,8 +145,15 @@ func (association *Association) Delete(values ...interface{}) *Association {
} }
} }
// get association's foreign fields name
var associationFields = scope.New(reflect.New(field.Type()).Interface()).Fields()
var associationForeignFieldNames []string
for _, associationDBName := range relationship.AssociationForeignFieldNames {
associationForeignFieldNames = append(associationForeignFieldNames, associationFields[associationDBName].Name)
}
// association value's foreign keys // association value's foreign keys
deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
@ -147,7 +172,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,
) )
// set foreign key to be null // set foreign key to be null if there are some records affected
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 { if results.RowsAffected > 0 {
@ -176,28 +201,29 @@ func (association *Association) Delete(values ...interface{}) *Association {
} }
} }
// Remove deleted records from field // Remove deleted records from source's field
if association.Error == nil { if association.Error == nil {
if association.Field.Field.Kind() == reflect.Slice { if association.Field.Field.Kind() == reflect.Slice {
leftValues := reflect.Zero(association.Field.Field.Type()) leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ { for i := 0; i < association.Field.Field.Len(); i++ {
reflectValue := association.Field.Field.Index(i) reflectValue := association.Field.Field.Index(i)
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
var included = false var isDeleted = false
for _, pk := range deletingPrimaryKeys { for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) { if equalAsString(primaryKey, pk) {
included = true isDeleted = true
break
} }
} }
if !included { if !isDeleted {
leftValues = reflect.Append(leftValues, reflectValue) leftValues = reflect.Append(leftValues, reflectValue)
} }
} }
association.Field.Set(leftValues) association.Field.Set(leftValues)
} else if association.Field.Field.Kind() == reflect.Struct { } else if association.Field.Field.Kind() == reflect.Struct {
primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0] primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0]
for _, pk := range deletingPrimaryKeys { for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) { if equalAsString(primaryKey, pk) {
association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Field.Set(reflect.Zero(association.Field.Field.Type()))
@ -222,34 +248,32 @@ func (association *Association) Count() int {
relationship = association.Field.Relationship relationship = association.Field.Relationship
scope = association.Scope scope = association.Scope
fieldValue = association.Field.Field.Interface() fieldValue = association.Field.Field.Interface()
newScope = scope.New(fieldValue) query = scope.DB()
) )
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count) query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := scope.DB() primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
for idx, foreignKey := range relationship.ForeignDBNames { query = query.Where(
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), toQueryValues(primaryKeys)...,
field.Field.Interface()) )
}
}
if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
query.Model(fieldValue).Count(&count)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
query := scope.DB() primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
for idx, primaryKey := range relationship.AssociationForeignDBNames { query = query.Where(
if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), toQueryValues(primaryKeys)...,
field.Field.Interface()) )
}
}
query.Model(fieldValue).Count(&count)
} }
if relationship.PolymorphicType != "" {
query = query.Where(
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
scope.TableName(),
)
}
query.Model(fieldValue).Count(&count)
return count return count
} }

View File

@ -82,42 +82,6 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
return association return association
} }
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
scope := association.Scope
for _, value := range values {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
if reflectValue.Kind() == reflect.Slice {
for i := 0; i < reflectValue.Len(); i++ {
primaryKeys := []interface{}{}
newScope := scope.New(reflectValue.Index(i).Interface())
for _, column := range columns {
if field, ok := newScope.FieldByName(column); ok {
primaryKeys = append(primaryKeys, field.Field.Interface())
} else {
primaryKeys = append(primaryKeys, "")
}
}
results = append(results, primaryKeys)
}
} else if reflectValue.Kind() == reflect.Struct {
newScope := scope.New(value)
var primaryKeys []interface{}
for _, column := range columns {
if field, ok := newScope.FieldByName(column); ok {
primaryKeys = append(primaryKeys, field.Field.Interface())
} else {
primaryKeys = append(primaryKeys, "")
}
}
results = append(results, primaryKeys)
}
}
return
}
func toQueryMarks(primaryValues [][]interface{}) string { func toQueryMarks(primaryValues [][]interface{}) string {
var results []string var results []string