Add JoinTableHandler

This commit is contained in:
Jinzhu 2015-02-28 11:48:18 +08:00
parent aa0a4012be
commit 6d64e6837b
4 changed files with 101 additions and 66 deletions

View File

@ -71,15 +71,15 @@ func (association *Association) Delete(values ...interface{}) *Association {
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
association.setErr(errors.New("no primary key found")) association.setErr(errors.New("no primary key found"))
} else { } else {
scope := association.Scope
relationship := association.Field.Relationship relationship := association.Field.Relationship
// many to many // many to many
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)", sql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)",
relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName),
relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName)) scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, primaryKeys)
if err := association.Scope.DB().Table(relationship.JoinTable). if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
Where(whereSql, association.PrimaryKey, primaryKeys).Delete("").Error; err == nil {
leftValues := reflect.Zero(association.Field.Field.Type()) leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ { for i := 0; i < association.Field.Field.Len(); i++ {
value := association.Field.Field.Index(i) value := association.Field.Field.Index(i)
@ -132,11 +132,9 @@ func (association *Association) Replace(values ...interface{}) *Association {
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
} }
whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", sql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName), query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, addedPrimaryKeys)
relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName)) association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey, addedPrimaryKeys).Delete("")
} else { } else {
association.setErr(errors.New("replace only support many to many")) association.setErr(errors.New("replace only support many to many"))
} }
@ -147,8 +145,9 @@ func (association *Association) Clear() *Association {
relationship := association.Field.Relationship relationship := association.Field.Relationship
scope := association.Scope scope := association.Scope
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
whereSql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName)) sql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName))
if err := scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey).Delete("").Error; err == nil { query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey)
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Field.Set(reflect.Zero(association.Field.Field.Type()))
} else { } else {
association.setErr(err) association.setErr(err)
@ -166,9 +165,10 @@ func (association *Association) Count() int {
newScope := scope.New(association.Field.Field.Interface()) newScope := scope.New(association.Field.Field.Interface())
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
scope.DB().Table(relationship.JoinTable). query := scope.DB().Table(relationship.JoinTable).
Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey).Row().Scan(&count) Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)

View File

@ -1,10 +1,6 @@
package gorm package gorm
import ( import "reflect"
"fmt"
"reflect"
"strings"
)
func BeginTransaction(scope *Scope) { func BeginTransaction(scope *Scope) {
scope.Begin() scope.Begin()
@ -53,24 +49,8 @@ func SaveAfterAssociations(scope *Scope) {
scope.Err(newDB.Save(elem).Error) scope.Err(newDB.Save(elem).Error)
if joinTable := relationship.JoinTable; joinTable != "" { if joinTable := relationship.JoinTable; joinTable != "" {
quotedForeignDBName := scope.Quote(relationship.ForeignDBName) scope.Err(scope.db.GetJoinTableHandler(joinTable).
foreignValue := scope.PrimaryKeyValue() Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue()))
quoteAssociationForeignDBName := scope.Quote(relationship.AssociationForeignDBName)
associationForeignValue := newScope.PrimaryKeyValue()
newScope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);",
joinTable,
strings.Join([]string{quotedForeignDBName, quoteAssociationForeignDBName}, ","),
strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","),
scope.Dialect().SelectFromDummyTable(),
joinTable,
quotedForeignDBName,
newScope.AddToVars(foreignValue),
quoteAssociationForeignDBName,
newScope.AddToVars(associationForeignValue),
))
scope.Err(scope.NewDB().Exec(newScope.Sql, newScope.SqlVars...).Error)
} }
} }
default: default:

42
join_table.go Normal file
View File

@ -0,0 +1,42 @@
package gorm
import (
"fmt"
"strings"
)
type JoinTableHandler interface {
Add(*DB, *Relationship, interface{}, interface{}) error
Delete(*DB, *Relationship) error
Scope(*DB, *Relationship) *DB
}
type defaultJoinTableHandler struct{}
func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
scope := db.NewScope("")
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
scope.Quote(relationship.JoinTable),
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
scope.Dialect().SelectFromDummyTable(),
scope.Quote(relationship.JoinTable),
quotedForeignDBName,
quotedAssociationDBName,
)
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
}
func (*defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
return db.Delete("").Error
}
func (*defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
return db
}
var DefaultJoinTableHandler = &defaultJoinTableHandler{}

69
main.go
View File

@ -20,20 +20,21 @@ var NowFunc = func() time.Time {
} }
type DB struct { type DB struct {
Value interface{} Value interface{}
Error error Error error
RowsAffected int64 RowsAffected int64
ModelStructs map[reflect.Type]*ModelStruct ModelStructs map[reflect.Type]*ModelStruct
callback *callback callback *callback
db sqlCommon db sqlCommon
parent *DB parent *DB
search *search search *search
logMode int logMode int
logger logger logger logger
dialect Dialect dialect Dialect
singularTable bool singularTable bool
source string source string
values map[string]interface{} values map[string]interface{}
joinTableHandlers map[string]JoinTableHandler
} }
func Open(dialect string, args ...interface{}) (DB, error) { func Open(dialect string, args ...interface{}) (DB, error) {
@ -91,20 +92,6 @@ func (db *DB) NewScope(value interface{}) *Scope {
return &Scope{db: dbClone, Search: dbClone.search, Value: value} return &Scope{db: dbClone, Search: dbClone.search, Value: value}
} }
func (s *DB) FreshDB() *DB {
newDB := &DB{
dialect: s.dialect,
logger: s.logger,
callback: s.parent.callback.clone(),
source: s.source,
values: map[string]interface{}{},
db: s.db,
ModelStructs: map[reflect.Type]*ModelStruct{},
}
newDB.parent = newDB
return newDB
}
// CommonDB Return the underlying sql.DB or sql.Tx instance. // CommonDB Return the underlying sql.DB or sql.Tx instance.
// Use of this method is discouraged. It's mainly intended to allow // Use of this method is discouraged. It's mainly intended to allow
// coexistence with legacy non-GORM code. // coexistence with legacy non-GORM code.
@ -473,3 +460,29 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
value, ok = s.values[name] value, ok = s.values[name]
return return
} }
func (s *DB) GetJoinTableHandler(table string) JoinTableHandler {
if s.parent.joinTableHandlers != nil {
if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok {
return joinTableHandler
}
if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok {
return joinTableHandler
}
}
return DefaultJoinTableHandler
}
func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) {
if s.parent.joinTableHandlers == nil {
s.parent.joinTableHandlers = map[string]JoinTableHandler{}
}
if len(tables) > 0 {
for _, table := range tables {
s.parent.joinTableHandlers[table] = joinTableHandler
}
} else {
s.parent.joinTableHandlers["*"] = joinTableHandler
}
}