diff --git a/association.go b/association.go index e3aee8f2..49fd4558 100644 --- a/association.go +++ b/association.go @@ -128,49 +128,40 @@ func (association *Association) Replace(values ...interface{}) error { } case schema.Many2Many: var primaryFields, relPrimaryFields []*schema.Field - var foreignKeys, relForeignKeys []string - modelValue := reflect.New(rel.JoinTable.ModelType).Interface() - conds := []clause.Expression{} + var joinPrimaryKeys, joinRelPrimaryKeys []string + var conds []clause.Expression for _, ref := range rel.References { - if ref.OwnPrimaryKey { - primaryFields = append(primaryFields, ref.PrimaryKey) - foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) - } else if ref.PrimaryValue != "" { - conds = append(conds, clause.Eq{ - Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } } else { - relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } - generateConds := func(rv reflect.Value) { - _, values := schema.GetIdentityFieldValuesMap(rv, primaryFields) - column, queryValues := schema.ToQueryValues(foreignKeys, values) + var ( + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + ) - relValue := rel.Field.ReflectValueOf(rv) - _, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields) - relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues) - - conds = append(conds, clause.And( - clause.IN{Column: column, Values: queryValues}, - clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}), - )) + if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 { + conds = append(conds, clause.IN{Column: column, Values: values}) + } else { + return ErrorPrimaryKeyRequired } - switch reflectValue.Kind() { - case reflect.Struct: - generateConds(reflectValue) - case reflect.Slice, reflect.Array: - for i := 0; i < reflectValue.Len(); i++ { - generateConds(reflectValue.Index(i)) - } + if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues); len(relValues) > 0 { + conds = append(conds, clause.Not(clause.IN{Column: relColumn, Values: relValues})) } - association.DB.Where(conds).Delete(modelValue) + association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue) } } return association.Error @@ -227,9 +218,39 @@ func (association *Association) Delete(values ...interface{}) error { tx.Session(&Session{}).Model(modelValue).Clauses(conds...).UpdateColumns(updateAttrs) case schema.Many2Many: - modelValue := reflect.New(rel.JoinTable.ModelType).Interface() - conds := rel.ToQueryConditions(reflectValue) - tx.Clauses(clause.Where{Exprs: conds}).Delete(modelValue) + var primaryFields, relPrimaryFields []*schema.Field + var joinPrimaryKeys, joinRelPrimaryKeys []string + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + var ( + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + _, queryValues = schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, relQueryValues = schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + ) + + if column, values := schema.ToQueryValues(joinPrimaryKeys, queryValues); len(values) > 0 { + conds = append(conds, clause.IN{Column: column, Values: values}) + } else { + return ErrorPrimaryKeyRequired + } + + relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, relQueryValues) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + tx.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue) } relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) diff --git a/callbacks/delete.go b/callbacks/delete.go index 50b2880a..a88edcf8 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -5,6 +5,7 @@ import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) func BeforeDelete(db *gorm.DB) { @@ -37,13 +38,22 @@ func Delete(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Delete{}) values := []reflect.Value{db.Statement.ReflectValue} - if db.Statement.Dest != db.Statement.Model { + if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { values = append(values, reflect.ValueOf(db.Statement.Model)) } - for _, field := range db.Statement.Schema.PrimaryFields { - for _, value := range values { - if value, isZero := field.ValueOf(value); !isZero { - db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Where(clause.IN{Column: column, Values: values}) + } else if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Where(clause.IN{Column: column, Values: values}) } } } diff --git a/errors.go b/errors.go index 4f2bd4fa..140a5186 100644 --- a/errors.go +++ b/errors.go @@ -21,4 +21,6 @@ var ( ErrUnsupportedRelation = errors.New("unsupported relations") // ErrPtrStructSupported only ptr of struct supported ErrPtrStructSupported = errors.New("only ptr of struct supported") + // ErrorPrimaryKeyRequired primary keys required + ErrorPrimaryKeyRequired = errors.New("primary key required") ) diff --git a/logger/sql.go b/logger/sql.go index 219ae301..bb4e3e06 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -53,7 +53,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } else { rv := reflect.ValueOf(v) - if !rv.IsValid() || rv.IsNil() { + if !rv.IsValid() { + vars[idx] = "NULL" + } else if rv.Kind() == reflect.Ptr && rv.IsNil() { vars[idx] = "NULL" } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) diff --git a/tests/associations_test.go b/tests/associations_test.go index b6ddbd29..3ab69b42 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -641,7 +641,6 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { t.Fatalf("Error happened when append account, got %v", err) } - return if toy.ID == 0 { t.Fatalf("Toy's ID should be created")