diff --git a/association.go b/association.go index 928dcf3e..eeb11efe 100644 --- a/association.go +++ b/association.go @@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) error { ) if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { - if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { + if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } } @@ -138,7 +138,7 @@ func (association *Association) Replace(values ...interface{}) error { } if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { - column, values := schema.ToQueryValues(foreignKeys, pvs) + column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) } case schema.Many2Many: @@ -164,14 +164,14 @@ func (association *Association) Replace(values ...interface{}) error { } _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 { + if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrorPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) - if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 { + if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } @@ -208,11 +208,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs) + pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) - relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs) + relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error @@ -220,11 +220,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs) + pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) - relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error @@ -250,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error { } _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs) + pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) - relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs) + relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error diff --git a/callbacks/delete.go b/callbacks/delete.go index dea8bb5e..ff0f601a 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -35,7 +35,7 @@ func Delete(db *gorm.DB) { if db.Statement.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) - column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + column, values := schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) @@ -43,7 +43,7 @@ func Delete(db *gorm.DB) { if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) - column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + column, values = schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) diff --git a/callbacks/preload.go b/callbacks/preload.go index a9907d68..cd09a6d6 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -49,7 +49,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } joinResults := rel.JoinTable.MakeSlice().Elem() - column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) + column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } reflectResults := rel.FieldSchema.MakeSlice().Elem() - column, values := schema.ToQueryValues(relForeignKeys, foreignValues) + column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { diff --git a/schema/relationship.go b/schema/relationship.go index 91c2ca8d..e3ff0307 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -462,10 +462,12 @@ func (rel *Relationship) ParseConstraint() *Constraint { } func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { + table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} if rel.JoinTable != nil { + table = rel.JoinTable.Table for _, ref := range rel.References { if ref.OwnPrimaryKey { foreignFields = append(foreignFields, ref.PrimaryKey) @@ -500,7 +502,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] } _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) - column, values := ToQueryValues(relForeignKeys, foreignValues) + column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) return diff --git a/schema/utils.go b/schema/utils.go index da236a18..defa83af 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -5,6 +5,7 @@ import ( "regexp" "strings" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) @@ -164,18 +165,23 @@ func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) } // ToQueryValues to query values -func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { +func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { queryValues := make([]interface{}, len(foreignValues)) if len(foreignKeys) == 1 { for idx, r := range foreignValues { queryValues[idx] = r[0] } - return foreignKeys[0], queryValues + return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues } else { + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + for idx, r := range foreignValues { queryValues[idx] = r } + return columns, queryValues } - return foreignKeys, queryValues } diff --git a/soft_delete.go b/soft_delete.go index 4ffceba6..e3e6e960 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -58,7 +58,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) - column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + column, values := schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) @@ -66,7 +66,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { if stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) - column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) + column, values = schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) diff --git a/statement.go b/statement.go index d6444fae..036b8297 100644 --- a/statement.go +++ b/statement.go @@ -107,6 +107,15 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteString(" AS ") stmt.DB.Dialector.QuoteTo(writer, v.Alias) } + case []clause.Column: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteString(",") + } + stmt.QuoteTo(writer, d) + } + writer.WriteByte(')') case string: stmt.DB.Dialector.QuoteTo(writer, v) case []string: diff --git a/tests/go.mod b/tests/go.mod index 3b17feac..3a5b4224 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.8 + gorm.io/driver/mysql v0.2.9 gorm.io/driver/postgres v0.2.5 gorm.io/driver/sqlite v1.0.8 gorm.io/driver/sqlserver v0.2.4 diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 617010c5..051e3ee2 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -13,7 +13,7 @@ type Blog struct { Locale string `gorm:"primary_key"` Subject string Body string - Tags []Tag `gorm:"many2many:blogs_tags;"` + Tags []Tag `gorm:"many2many:blog_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` } @@ -22,7 +22,7 @@ type Tag struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Value string - Blogs []*Blog `gorm:"many2many:blogs_tags"` + Blogs []*Blog `gorm:"many2many:blog_tags"` } func compareTags(tags []Tag, contents []string) bool {