mirror of https://github.com/go-gorm/gorm.git
Add Table method for JoinTableHandler
This commit is contained in:
parent
3740a3beaa
commit
80576bbbbc
|
@ -75,10 +75,8 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
relationship := association.Field.Relationship
|
relationship := association.Field.Relationship
|
||||||
// many to many
|
// many to many
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
sql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)",
|
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName),
|
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
|
||||||
scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
|
|
||||||
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, primaryKeys)
|
|
||||||
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
||||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||||
|
@ -132,8 +130,8 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
|
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
sql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
|
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
||||||
association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
|
association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
|
||||||
} else {
|
} else {
|
||||||
association.setErr(errors.New("replace only support many to many"))
|
association.setErr(errors.New("replace only support many to many"))
|
||||||
|
@ -145,8 +143,8 @@ func (association *Association) Clear() *Association {
|
||||||
relationship := association.Field.Relationship
|
relationship := association.Field.Relationship
|
||||||
scope := association.Scope
|
scope := association.Scope
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
sql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName))
|
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
||||||
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey)
|
query := scope.NewDB().Where(sql, association.PrimaryKey)
|
||||||
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||||
} else {
|
} else {
|
||||||
|
@ -165,8 +163,7 @@ func (association *Association) Count() int {
|
||||||
newScope := scope.New(association.Field.Field.Interface())
|
newScope := scope.New(association.Field.Field.Interface())
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
query := scope.DB().Table(relationship.JoinTable).
|
query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
|
||||||
Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
|
|
||||||
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
|
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
|
||||||
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
|
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type JoinTableHandler interface {
|
type JoinTableHandler interface {
|
||||||
|
Table(*DB, *Relationship) string
|
||||||
Add(*DB, *Relationship, interface{}, interface{}) error
|
Add(*DB, *Relationship, interface{}, interface{}) error
|
||||||
Delete(*DB, *Relationship) error
|
Delete(*DB, *Relationship) error
|
||||||
Scope(*DB, *Relationship) *DB
|
Scope(*DB, *Relationship) *DB
|
||||||
|
@ -13,17 +14,22 @@ type JoinTableHandler interface {
|
||||||
|
|
||||||
type defaultJoinTableHandler struct{}
|
type defaultJoinTableHandler struct{}
|
||||||
|
|
||||||
func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
|
func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string {
|
||||||
|
return relationship.JoinTable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
|
||||||
scope := db.NewScope("")
|
scope := db.NewScope("")
|
||||||
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
|
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
|
||||||
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
|
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
|
||||||
|
table := s.Table(db, relationship)
|
||||||
|
|
||||||
sql := fmt.Sprintf(
|
sql := fmt.Sprintf(
|
||||||
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
|
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
|
||||||
scope.Quote(relationship.JoinTable),
|
scope.Quote(table),
|
||||||
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
|
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
|
||||||
scope.Dialect().SelectFromDummyTable(),
|
scope.Dialect().SelectFromDummyTable(),
|
||||||
scope.Quote(relationship.JoinTable),
|
scope.Quote(table),
|
||||||
quotedForeignDBName,
|
quotedForeignDBName,
|
||||||
quotedAssociationDBName,
|
quotedAssociationDBName,
|
||||||
)
|
)
|
||||||
|
@ -31,12 +37,12 @@ func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignV
|
||||||
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
|
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
|
func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
|
||||||
return db.Delete("").Error
|
return db.Table(s.Table(db, relationship)).Delete("").Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
|
func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
|
||||||
return db
|
return db.Table(s.Table(db, relationship))
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultJoinTableHandler = &defaultJoinTableHandler{}
|
var DefaultJoinTableHandler = &defaultJoinTableHandler{}
|
||||||
|
|
|
@ -21,6 +21,10 @@ type PersonAddress struct {
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string {
|
||||||
|
return relationship.JoinTable
|
||||||
|
}
|
||||||
|
|
||||||
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
|
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
|
||||||
return db.Where(map[string]interface{}{
|
return db.Where(map[string]interface{}{
|
||||||
relationship.ForeignDBName: foreignValue,
|
relationship.ForeignDBName: foreignValue,
|
||||||
|
@ -36,8 +40,9 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error
|
||||||
return db.Delete(&PersonAddress{}).Error
|
return db.Delete(&PersonAddress{}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
|
func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
|
||||||
return db.Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", relationship.JoinTable, relationship.JoinTable))
|
table := pa.Table(db, relationship)
|
||||||
|
return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJoinTable(t *testing.T) {
|
func TestJoinTable(t *testing.T) {
|
||||||
|
|
|
@ -402,14 +402,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
if fromField != nil {
|
if fromField != nil {
|
||||||
if relationship := fromField.Relationship; relationship != nil {
|
if relationship := fromField.Relationship; relationship != nil {
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
|
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
||||||
|
quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship))
|
||||||
|
|
||||||
joinSql := fmt.Sprintf(
|
joinSql := fmt.Sprintf(
|
||||||
"INNER JOIN %v ON %v.%v = %v.%v",
|
"INNER JOIN %v ON %v.%v = %v.%v",
|
||||||
scope.Quote(relationship.JoinTable),
|
quotedJoinTable,
|
||||||
scope.Quote(relationship.JoinTable),
|
quotedJoinTable,
|
||||||
scope.Quote(relationship.AssociationForeignDBName),
|
scope.Quote(relationship.AssociationForeignDBName),
|
||||||
toScope.QuotedTableName(),
|
toScope.QuotedTableName(),
|
||||||
scope.Quote(toScope.PrimaryKey()))
|
scope.Quote(toScope.PrimaryKey()))
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName))
|
whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName))
|
||||||
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
|
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
|
||||||
} else if relationship.Kind == "belongs_to" {
|
} else if relationship.Kind == "belongs_to" {
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||||
|
@ -441,16 +444,18 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
|
|
||||||
func (scope *Scope) createJoinTable(field *StructField) {
|
func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" {
|
||||||
if !scope.Dialect().HasTable(scope, relationship.JoinTable) {
|
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
||||||
|
joinTable := joinTableHandler.Table(scope.db, relationship)
|
||||||
|
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
|
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||||
scope.Quote(relationship.JoinTable),
|
scope.Quote(joinTable),
|
||||||
strings.Join([]string{
|
strings.Join([]string{
|
||||||
scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType,
|
scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType,
|
||||||
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
|
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
|
||||||
).Error)
|
).Error)
|
||||||
}
|
}
|
||||||
scope.NewDB().Table(relationship.JoinTable).AutoMigrate(scope.db.GetJoinTableHandler(relationship.JoinTable))
|
scope.NewDB().Table(joinTable).AutoMigrate()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue