mirror of https://github.com/go-gorm/gorm.git
Add JoinTableHandler
This commit is contained in:
parent
aa0a4012be
commit
6d64e6837b
|
@ -71,15 +71,15 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
if len(primaryKeys) == 0 {
|
if len(primaryKeys) == 0 {
|
||||||
association.setErr(errors.New("no primary key found"))
|
association.setErr(errors.New("no primary key found"))
|
||||||
} else {
|
} else {
|
||||||
|
scope := association.Scope
|
||||||
relationship := association.Field.Relationship
|
relationship := association.Field.Relationship
|
||||||
// many to many
|
// many to many
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)",
|
sql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)",
|
||||||
relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName),
|
scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName),
|
||||||
relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName))
|
scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
|
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, primaryKeys)
|
||||||
if err := association.Scope.DB().Table(relationship.JoinTable).
|
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
||||||
Where(whereSql, association.PrimaryKey, primaryKeys).Delete("").Error; err == nil {
|
|
||||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||||
value := association.Field.Field.Index(i)
|
value := association.Field.Field.Index(i)
|
||||||
|
@ -132,11 +132,9 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
|
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)",
|
sql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName),
|
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
||||||
relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName))
|
association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
|
||||||
|
|
||||||
scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey, addedPrimaryKeys).Delete("")
|
|
||||||
} else {
|
} else {
|
||||||
association.setErr(errors.New("replace only support many to many"))
|
association.setErr(errors.New("replace only support many to many"))
|
||||||
}
|
}
|
||||||
|
@ -147,8 +145,9 @@ func (association *Association) Clear() *Association {
|
||||||
relationship := association.Field.Relationship
|
relationship := association.Field.Relationship
|
||||||
scope := association.Scope
|
scope := association.Scope
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName))
|
sql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName))
|
||||||
if err := scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey).Delete("").Error; err == nil {
|
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey)
|
||||||
|
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||||
} else {
|
} else {
|
||||||
association.setErr(err)
|
association.setErr(err)
|
||||||
|
@ -166,9 +165,10 @@ func (association *Association) Count() int {
|
||||||
newScope := scope.New(association.Field.Field.Interface())
|
newScope := scope.New(association.Field.Field.Interface())
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
scope.DB().Table(relationship.JoinTable).
|
query := scope.DB().Table(relationship.JoinTable).
|
||||||
Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
|
Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
|
||||||
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey).Row().Scan(&count)
|
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
|
||||||
|
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||||
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import "reflect"
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func BeginTransaction(scope *Scope) {
|
func BeginTransaction(scope *Scope) {
|
||||||
scope.Begin()
|
scope.Begin()
|
||||||
|
@ -53,24 +49,8 @@ func SaveAfterAssociations(scope *Scope) {
|
||||||
scope.Err(newDB.Save(elem).Error)
|
scope.Err(newDB.Save(elem).Error)
|
||||||
|
|
||||||
if joinTable := relationship.JoinTable; joinTable != "" {
|
if joinTable := relationship.JoinTable; joinTable != "" {
|
||||||
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
|
scope.Err(scope.db.GetJoinTableHandler(joinTable).
|
||||||
foreignValue := scope.PrimaryKeyValue()
|
Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue()))
|
||||||
quoteAssociationForeignDBName := scope.Quote(relationship.AssociationForeignDBName)
|
|
||||||
associationForeignValue := newScope.PrimaryKeyValue()
|
|
||||||
|
|
||||||
newScope.Raw(fmt.Sprintf(
|
|
||||||
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);",
|
|
||||||
joinTable,
|
|
||||||
strings.Join([]string{quotedForeignDBName, quoteAssociationForeignDBName}, ","),
|
|
||||||
strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","),
|
|
||||||
scope.Dialect().SelectFromDummyTable(),
|
|
||||||
joinTable,
|
|
||||||
quotedForeignDBName,
|
|
||||||
newScope.AddToVars(foreignValue),
|
|
||||||
quoteAssociationForeignDBName,
|
|
||||||
newScope.AddToVars(associationForeignValue),
|
|
||||||
))
|
|
||||||
scope.Err(scope.NewDB().Exec(newScope.Sql, newScope.SqlVars...).Error)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JoinTableHandler interface {
|
||||||
|
Add(*DB, *Relationship, interface{}, interface{}) error
|
||||||
|
Delete(*DB, *Relationship) error
|
||||||
|
Scope(*DB, *Relationship) *DB
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultJoinTableHandler struct{}
|
||||||
|
|
||||||
|
func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
|
||||||
|
scope := db.NewScope("")
|
||||||
|
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
|
||||||
|
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
|
||||||
|
|
||||||
|
sql := fmt.Sprintf(
|
||||||
|
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
|
||||||
|
scope.Quote(relationship.JoinTable),
|
||||||
|
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
|
||||||
|
scope.Dialect().SelectFromDummyTable(),
|
||||||
|
scope.Quote(relationship.JoinTable),
|
||||||
|
quotedForeignDBName,
|
||||||
|
quotedAssociationDBName,
|
||||||
|
)
|
||||||
|
|
||||||
|
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
|
||||||
|
return db.Delete("").Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultJoinTableHandler = &defaultJoinTableHandler{}
|
69
main.go
69
main.go
|
@ -20,20 +20,21 @@ var NowFunc = func() time.Time {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Error error
|
Error error
|
||||||
RowsAffected int64
|
RowsAffected int64
|
||||||
ModelStructs map[reflect.Type]*ModelStruct
|
ModelStructs map[reflect.Type]*ModelStruct
|
||||||
callback *callback
|
callback *callback
|
||||||
db sqlCommon
|
db sqlCommon
|
||||||
parent *DB
|
parent *DB
|
||||||
search *search
|
search *search
|
||||||
logMode int
|
logMode int
|
||||||
logger logger
|
logger logger
|
||||||
dialect Dialect
|
dialect Dialect
|
||||||
singularTable bool
|
singularTable bool
|
||||||
source string
|
source string
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
|
joinTableHandlers map[string]JoinTableHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func Open(dialect string, args ...interface{}) (DB, error) {
|
func Open(dialect string, args ...interface{}) (DB, error) {
|
||||||
|
@ -91,20 +92,6 @@ func (db *DB) NewScope(value interface{}) *Scope {
|
||||||
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
|
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) FreshDB() *DB {
|
|
||||||
newDB := &DB{
|
|
||||||
dialect: s.dialect,
|
|
||||||
logger: s.logger,
|
|
||||||
callback: s.parent.callback.clone(),
|
|
||||||
source: s.source,
|
|
||||||
values: map[string]interface{}{},
|
|
||||||
db: s.db,
|
|
||||||
ModelStructs: map[reflect.Type]*ModelStruct{},
|
|
||||||
}
|
|
||||||
newDB.parent = newDB
|
|
||||||
return newDB
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
// CommonDB Return the underlying sql.DB or sql.Tx instance.
|
||||||
// Use of this method is discouraged. It's mainly intended to allow
|
// Use of this method is discouraged. It's mainly intended to allow
|
||||||
// coexistence with legacy non-GORM code.
|
// coexistence with legacy non-GORM code.
|
||||||
|
@ -473,3 +460,29 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||||
value, ok = s.values[name]
|
value, ok = s.values[name]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DB) GetJoinTableHandler(table string) JoinTableHandler {
|
||||||
|
if s.parent.joinTableHandlers != nil {
|
||||||
|
if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok {
|
||||||
|
return joinTableHandler
|
||||||
|
}
|
||||||
|
if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok {
|
||||||
|
return joinTableHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return DefaultJoinTableHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) {
|
||||||
|
if s.parent.joinTableHandlers == nil {
|
||||||
|
s.parent.joinTableHandlers = map[string]JoinTableHandler{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tables) > 0 {
|
||||||
|
for _, table := range tables {
|
||||||
|
s.parent.joinTableHandlers[table] = joinTableHandler
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.parent.joinTableHandlers["*"] = joinTableHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue