forked from mirror/gorm
Test Many2Many Association
This commit is contained in:
parent
cc064f26ee
commit
dea48a8c59
|
@ -128,49 +128,40 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||||
}
|
}
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
var primaryFields, relPrimaryFields []*schema.Field
|
var primaryFields, relPrimaryFields []*schema.Field
|
||||||
var foreignKeys, relForeignKeys []string
|
var joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||||
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
var conds []clause.Expression
|
||||||
conds := []clause.Expression{}
|
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.PrimaryValue == "" {
|
||||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
if ref.OwnPrimaryKey {
|
||||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||||
} else if ref.PrimaryValue != "" {
|
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||||
conds = append(conds, clause.Eq{
|
} else {
|
||||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||||
Value: ref.PrimaryValue,
|
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||||
})
|
}
|
||||||
} else {
|
} else {
|
||||||
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
generateConds := func(rv reflect.Value) {
|
var (
|
||||||
_, values := schema.GetIdentityFieldValuesMap(rv, primaryFields)
|
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
column, queryValues := schema.ToQueryValues(foreignKeys, values)
|
_, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||||
|
_, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
||||||
|
)
|
||||||
|
|
||||||
relValue := rel.Field.ReflectValueOf(rv)
|
if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 {
|
||||||
_, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields)
|
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||||
relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues)
|
} else {
|
||||||
|
return ErrorPrimaryKeyRequired
|
||||||
conds = append(conds, clause.And(
|
|
||||||
clause.IN{Column: column, Values: queryValues},
|
|
||||||
clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues); len(relValues) > 0 {
|
||||||
case reflect.Struct:
|
conds = append(conds, clause.Not(clause.IN{Column: relColumn, Values: relValues}))
|
||||||
generateConds(reflectValue)
|
|
||||||
case reflect.Slice, reflect.Array:
|
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
|
||||||
generateConds(reflectValue.Index(i))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
association.DB.Where(conds).Delete(modelValue)
|
association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return association.Error
|
return association.Error
|
||||||
|
@ -227,9 +218,39 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||||
|
|
||||||
tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs)
|
tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs)
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
modelValue := reflect.New(rel.JoinTable.ModelType).Interface()
|
var primaryFields, relPrimaryFields []*schema.Field
|
||||||
conds := rel.ToQueryConditions(reflectValue)
|
var joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||||
tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue)
|
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if ref.PrimaryValue == "" {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
|
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||||
|
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||||
|
} else {
|
||||||
|
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||||
|
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||||
|
_, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||||
|
_, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
||||||
|
)
|
||||||
|
|
||||||
|
if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 {
|
||||||
|
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||||
|
} else {
|
||||||
|
return ErrorPrimaryKeyRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues)
|
||||||
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
|
tx.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BeforeDelete(db *gorm.DB) {
|
func BeforeDelete(db *gorm.DB) {
|
||||||
|
@ -37,13 +38,22 @@ func Delete(db *gorm.DB) {
|
||||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||||
|
|
||||||
values := []reflect.Value{db.Statement.ReflectValue}
|
values := []reflect.Value{db.Statement.ReflectValue}
|
||||||
if db.Statement.Dest != db.Statement.Model {
|
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||||
values = append(values, reflect.ValueOf(db.Statement.Model))
|
values = append(values, reflect.ValueOf(db.Statement.Model))
|
||||||
}
|
}
|
||||||
for _, field := range db.Statement.Schema.PrimaryFields {
|
|
||||||
for _, value := range values {
|
if db.Statement.Schema != nil {
|
||||||
if value, isZero := field.ValueOf(value); !isZero {
|
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
|
if len(values) > 0 {
|
||||||
|
db.Where(clause.IN{Column: column, Values: values})
|
||||||
|
} else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||||
|
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||||
|
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
|
if len(values) > 0 {
|
||||||
|
db.Where(clause.IN{Column: column, Values: values})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,4 +21,6 @@ var (
|
||||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||||
// ErrPtrStructSupported only ptr of struct supported
|
// ErrPtrStructSupported only ptr of struct supported
|
||||||
ErrPtrStructSupported = errors.New("only ptr of struct supported")
|
ErrPtrStructSupported = errors.New("only ptr of struct supported")
|
||||||
|
// ErrorPrimaryKeyRequired primary keys required
|
||||||
|
ErrorPrimaryKeyRequired = errors.New("primary key required")
|
||||||
)
|
)
|
||||||
|
|
|
@ -53,7 +53,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
||||||
} else {
|
} else {
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
|
|
||||||
if !rv.IsValid() || rv.IsNil() {
|
if !rv.IsValid() {
|
||||||
|
vars[idx] = "NULL"
|
||||||
|
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||||
vars[idx] = "NULL"
|
vars[idx] = "NULL"
|
||||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||||
|
|
|
@ -641,7 +641,6 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
|
||||||
if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil {
|
if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil {
|
||||||
t.Fatalf("Error happened when append account, got %v", err)
|
t.Fatalf("Error happened when append account, got %v", err)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
|
||||||
if toy.ID == 0 {
|
if toy.ID == 0 {
|
||||||
t.Fatalf("Toy's ID should be created")
|
t.Fatalf("Toy's ID should be created")
|
||||||
|
|
Loading…
Reference in New Issue