Fix ambiguous column when using same column name in join table, close #3120

This commit is contained in:
Jinzhu 2020-07-09 09:03:48 +08:00
parent e1084e78d0
commit 2ae0653af2
9 changed files with 40 additions and 23 deletions

View File

@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) error {
) )
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values}) tx.Not(clause.IN{Column: column, Values: values})
} }
} }
@ -138,7 +138,7 @@ func (association *Association) Replace(values ...interface{}) error {
} }
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(foreignKeys, pvs) column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
} }
case schema.Many2Many: case schema.Many2Many:
@ -164,14 +164,14 @@ func (association *Association) Replace(values ...interface{}) error {
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values}) tx.Where(clause.IN{Column: column, Values: values})
} else { } else {
return ErrorPrimaryKeyRequired return ErrorPrimaryKeyRequired
} }
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 { if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
} }
@ -208,11 +208,11 @@ func (association *Association) Delete(values ...interface{}) error {
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs) pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
@ -220,11 +220,11 @@ func (association *Association) Delete(values ...interface{}) error {
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs) pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
@ -250,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error {
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs) pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error

View File

@ -35,7 +35,7 @@ func Delete(db *gorm.DB) {
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
@ -43,7 +43,7 @@ func Delete(db *gorm.DB) {
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})

View File

@ -49,7 +49,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
joinResults := rel.JoinTable.MakeSlice().Elem() joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
// convert join identity map to relation identity map // convert join identity map to relation identity map
@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
reflectResults := rel.FieldSchema.MakeSlice().Elem() reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(relForeignKeys, foreignValues) column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues)
for _, cond := range conds { for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {

View File

@ -462,10 +462,12 @@ func (rel *Relationship) ParseConstraint() *Constraint {
} }
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{} foreignFields := []*Field{}
relForeignKeys := []string{} relForeignKeys := []string{}
if rel.JoinTable != nil { if rel.JoinTable != nil {
table = rel.JoinTable.Table
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey) foreignFields = append(foreignFields, ref.PrimaryKey)
@ -500,7 +502,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
} }
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
column, values := ToQueryValues(relForeignKeys, foreignValues) column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})
return return

View File

@ -5,6 +5,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
@ -164,18 +165,23 @@ func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field)
} }
// ToQueryValues to query values // ToQueryValues to query values
func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
queryValues := make([]interface{}, len(foreignValues)) queryValues := make([]interface{}, len(foreignValues))
if len(foreignKeys) == 1 { if len(foreignKeys) == 1 {
for idx, r := range foreignValues { for idx, r := range foreignValues {
queryValues[idx] = r[0] queryValues[idx] = r[0]
} }
return foreignKeys[0], queryValues return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
} else { } else {
columns := make([]clause.Column, len(foreignKeys))
for idx, key := range foreignKeys {
columns[idx] = clause.Column{Table: table, Name: key}
}
for idx, r := range foreignValues { for idx, r := range foreignValues {
queryValues[idx] = r queryValues[idx] = r
} }
return columns, queryValues
} }
return foreignKeys, queryValues
} }

View File

@ -58,7 +58,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.Schema != nil { if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
@ -66,7 +66,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.Dest != stmt.Model && stmt.Model != nil { if stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})

View File

@ -107,6 +107,15 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteString(" AS ") writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias) stmt.DB.Dialector.QuoteTo(writer, v.Alias)
} }
case []clause.Column:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteString(",")
}
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case string: case string:
stmt.DB.Dialector.QuoteTo(writer, v) stmt.DB.Dialector.QuoteTo(writer, v)
case []string: case []string:

View File

@ -6,7 +6,7 @@ require (
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
gorm.io/driver/mysql v0.2.8 gorm.io/driver/mysql v0.2.9
gorm.io/driver/postgres v0.2.5 gorm.io/driver/postgres v0.2.5
gorm.io/driver/sqlite v1.0.8 gorm.io/driver/sqlite v1.0.8
gorm.io/driver/sqlserver v0.2.4 gorm.io/driver/sqlserver v0.2.4

View File

@ -13,7 +13,7 @@ type Blog struct {
Locale string `gorm:"primary_key"` Locale string `gorm:"primary_key"`
Subject string Subject string
Body string Body string
Tags []Tag `gorm:"many2many:blogs_tags;"` Tags []Tag `gorm:"many2many:blog_tags;"`
SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"`
LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"`
} }
@ -22,7 +22,7 @@ type Tag struct {
ID uint `gorm:"primary_key"` ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"` Locale string `gorm:"primary_key"`
Value string Value string
Blogs []*Blog `gorm:"many2many:blogs_tags"` Blogs []*Blog `gorm:"many2many:blog_tags"`
} }
func compareTags(tags []Tag, contents []string) bool { func compareTags(tags []Tag, contents []string) bool {