Test Many2Many Association

This commit is contained in:
Jinzhu 2020-05-26 00:16:41 +08:00
parent cc064f26ee
commit dea48a8c59
5 changed files with 76 additions and 42 deletions

View File

@ -128,49 +128,40 @@ func (association *Association) Replace(values ...interface{}) error {
}
case schema.Many2Many:
var primaryFields, relPrimaryFields []*schema.Field
var foreignKeys, relForeignKeys []string
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
conds := []clause.Expression{}
var joinPrimaryKeys, joinRelPrimaryKeys []string
var conds []clause.Expression
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
generateConds := func(rv reflect.Value) {
_, values := schema.GetIdentityFieldValuesMap(rv, primaryFields)
column, queryValues := schema.ToQueryValues(foreignKeys, values)
var (
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
_, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
_, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
)
relValue := rel.Field.ReflectValueOf(rv)
_, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields)
relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues)
conds = append(conds, clause.And(
clause.IN{Column: column, Values: queryValues},
clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}),
))
if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 {
conds = append(conds, clause.IN{Column: column, Values: values})
} else {
return ErrorPrimaryKeyRequired
}
switch reflectValue.Kind() {
case reflect.Struct:
generateConds(reflectValue)
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
generateConds(reflectValue.Index(i))
}
if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues); len(relValues) > 0 {
conds = append(conds, clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
association.DB.Where(conds).Delete(modelValue)
association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue)
}
}
return association.Error
@ -227,9 +218,39 @@ func (association *Association) Delete(values ...interface{}) error {
tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs)
case schema.Many2Many:
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
conds := rel.ToQueryConditions(reflectValue)
tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue)
var primaryFields, relPrimaryFields []*schema.Field
var joinPrimaryKeys, joinRelPrimaryKeys []string
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
var (
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
_, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
_, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
)
if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 {
conds = append(conds, clause.IN{Column: column, Values: values})
} else {
return ErrorPrimaryKeyRequired
}
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
tx.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue)
}
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)

View File

@ -5,6 +5,7 @@ import (
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
)
func BeforeDelete(db *gorm.DB) {
@ -37,13 +38,22 @@ func Delete(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Delete{})
values := []reflect.Value{db.Statement.ReflectValue}
if db.Statement.Dest != db.Statement.Model {
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
values = append(values, reflect.ValueOf(db.Statement.Model))
}
for _, field := range db.Statement.Schema.PrimaryFields {
for _, value := range values {
if value, isZero := field.ValueOf(value); !isZero {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Where(clause.IN{Column: column, Values: values})
} else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Where(clause.IN{Column: column, Values: values})
}
}
}

View File

@ -21,4 +21,6 @@ var (
ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPtrStructSupported only ptr of struct supported
ErrPtrStructSupported = errors.New("only ptr of struct supported")
// ErrorPrimaryKeyRequired primary keys required
ErrorPrimaryKeyRequired = errors.New("primary key required")
)

View File

@ -53,7 +53,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
} else {
rv := reflect.ValueOf(v)
if !rv.IsValid() || rv.IsNil() {
if !rv.IsValid() {
vars[idx] = "NULL"
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = "NULL"
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)

View File

@ -641,7 +641,6 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil {
t.Fatalf("Error happened when append account, got %v", err)
}
return
if toy.ID == 0 {
t.Fatalf("Toy's ID should be created")