Test Many2Many Association

This commit is contained in:
Jinzhu 2020-05-26 00:16:41 +08:00
parent cc064f26ee
commit dea48a8c59
5 changed files with 76 additions and 42 deletions

View File

@ -128,49 +128,40 @@ func (association *Association) Replace(values ...interface{}) error {
} }
case schema.Many2Many: case schema.Many2Many:
var primaryFields, relPrimaryFields []*schema.Field var primaryFields, relPrimaryFields []*schema.Field
var foreignKeys, relForeignKeys []string var joinPrimaryKeys, joinRelPrimaryKeys []string
modelValue := reflect.New(rel.JoinTable.ModelType).Interface() var conds []clause.Expression
conds := []clause.Expression{}
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.PrimaryValue == "" {
primaryFields = append(primaryFields, ref.PrimaryKey) if ref.OwnPrimaryKey {
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) primaryFields = append(primaryFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" { joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
conds = append(conds, clause.Eq{ } else {
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
Value: ref.PrimaryValue, joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}) }
} else { } else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} }
} }
generateConds := func(rv reflect.Value) { var (
_, values := schema.GetIdentityFieldValuesMap(rv, primaryFields) modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
column, queryValues := schema.ToQueryValues(foreignKeys, values) _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
_, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
)
relValue := rel.Field.ReflectValueOf(rv) if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 {
_, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields) conds = append(conds, clause.IN{Column: column, Values: values})
relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues) } else {
return ErrorPrimaryKeyRequired
conds = append(conds, clause.And(
clause.IN{Column: column, Values: queryValues},
clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}),
))
} }
switch reflectValue.Kind() { if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues); len(relValues) > 0 {
case reflect.Struct: conds = append(conds, clause.Not(clause.IN{Column: relColumn, Values: relValues}))
generateConds(reflectValue)
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
generateConds(reflectValue.Index(i))
}
} }
association.DB.Where(conds).Delete(modelValue) association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue)
} }
} }
return association.Error return association.Error
@ -227,9 +218,39 @@ func (association *Association) Delete(values ...interface{}) error {
tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs) tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs)
case schema.Many2Many: case schema.Many2Many:
modelValue := reflect.New(rel.JoinTable.ModelType).Interface() var primaryFields, relPrimaryFields []*schema.Field
conds := rel.ToQueryConditions(reflectValue) var joinPrimaryKeys, joinRelPrimaryKeys []string
tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
var (
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
_, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
_, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
)
if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 {
conds = append(conds, clause.IN{Column: column, Values: values})
} else {
return ErrorPrimaryKeyRequired
}
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
tx.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue)
} }
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)

View File

@ -5,6 +5,7 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
) )
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
@ -37,13 +38,22 @@ func Delete(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Delete{}) db.Statement.AddClauseIfNotExists(clause.Delete{})
values := []reflect.Value{db.Statement.ReflectValue} values := []reflect.Value{db.Statement.ReflectValue}
if db.Statement.Dest != db.Statement.Model { if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
values = append(values, reflect.ValueOf(db.Statement.Model)) values = append(values, reflect.ValueOf(db.Statement.Model))
} }
for _, field := range db.Statement.Schema.PrimaryFields {
for _, value := range values { if db.Statement.Schema != nil {
if value, isZero := field.ValueOf(value); !isZero { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Where(clause.IN{Column: column, Values: values})
} else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Where(clause.IN{Column: column, Values: values})
} }
} }
} }

View File

@ -21,4 +21,6 @@ var (
ErrUnsupportedRelation = errors.New("unsupported relations") ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPtrStructSupported only ptr of struct supported // ErrPtrStructSupported only ptr of struct supported
ErrPtrStructSupported = errors.New("only ptr of struct supported") ErrPtrStructSupported = errors.New("only ptr of struct supported")
// ErrorPrimaryKeyRequired primary keys required
ErrorPrimaryKeyRequired = errors.New("primary key required")
) )

View File

@ -53,7 +53,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
} else { } else {
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
if !rv.IsValid() || rv.IsNil() { if !rv.IsValid() {
vars[idx] = "NULL"
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = "NULL" vars[idx] = "NULL"
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx) convertParams(reflect.Indirect(rv).Interface(), idx)

View File

@ -641,7 +641,6 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil {
t.Fatalf("Error happened when append account, got %v", err) t.Fatalf("Error happened when append account, got %v", err)
} }
return
if toy.ID == 0 { if toy.ID == 0 {
t.Fatalf("Toy's ID should be created") t.Fatalf("Toy's ID should be created")