Fix ambiguous column when using same column name in join table, close #3120

This commit is contained in:
Jinzhu 2020-07-09 09:03:48 +08:00
parent e1084e78d0
commit 2ae0653af2
9 changed files with 40 additions and 23 deletions

View File

@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) error {
)
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
}
@ -138,7 +138,7 @@ func (association *Association) Replace(values ...interface{}) error {
}
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(foreignKeys, pvs)
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
}
case schema.Many2Many:
@ -164,14 +164,14 @@ func (association *Association) Replace(values ...interface{}) error {
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 {
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrorPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 {
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
@ -208,11 +208,11 @@ func (association *Association) Delete(values ...interface{}) error {
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
@ -220,11 +220,11 @@ func (association *Association) Delete(values ...interface{}) error {
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs)
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
@ -250,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error {
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs)
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error

View File

@ -35,7 +35,7 @@ func Delete(db *gorm.DB) {
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
column, values := schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
@ -43,7 +43,7 @@ func Delete(db *gorm.DB) {
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
column, values = schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})

View File

@ -49,7 +49,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
}
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues)
column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
// convert join identity map to relation identity map
@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(relForeignKeys, foreignValues)
column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues)
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {

View File

@ -462,10 +462,12 @@ func (rel *Relationship) ParseConstraint() *Constraint {
}
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{}
relForeignKeys := []string{}
if rel.JoinTable != nil {
table = rel.JoinTable.Table
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
@ -500,7 +502,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
}
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
column, values := ToQueryValues(relForeignKeys, foreignValues)
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values})
return

View File

@ -5,6 +5,7 @@ import (
"regexp"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)
@ -164,18 +165,23 @@ func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field)
}
// ToQueryValues to query values
func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
queryValues := make([]interface{}, len(foreignValues))
if len(foreignKeys) == 1 {
for idx, r := range foreignValues {
queryValues[idx] = r[0]
}
return foreignKeys[0], queryValues
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
} else {
columns := make([]clause.Column, len(foreignKeys))
for idx, key := range foreignKeys {
columns[idx] = clause.Column{Table: table, Name: key}
}
for idx, r := range foreignValues {
queryValues[idx] = r
}
return columns, queryValues
}
return foreignKeys, queryValues
}

View File

@ -58,7 +58,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
column, values := schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
@ -66,7 +66,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
column, values = schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})

View File

@ -107,6 +107,15 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
}
case []clause.Column:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteString(",")
}
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
case []string:

View File

@ -6,7 +6,7 @@ require (
github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
gorm.io/driver/mysql v0.2.8
gorm.io/driver/mysql v0.2.9
gorm.io/driver/postgres v0.2.5
gorm.io/driver/sqlite v1.0.8
gorm.io/driver/sqlserver v0.2.4

View File

@ -13,7 +13,7 @@ type Blog struct {
Locale string `gorm:"primary_key"`
Subject string
Body string
Tags []Tag `gorm:"many2many:blogs_tags;"`
Tags []Tag `gorm:"many2many:blog_tags;"`
SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"`
LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"`
}
@ -22,7 +22,7 @@ type Tag struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Value string
Blogs []*Blog `gorm:"many2many:blogs_tags"`
Blogs []*Blog `gorm:"many2many:blog_tags"`
}
func compareTags(tags []Tag, contents []string) bool {