diff --git a/association.go b/association.go index 9ce64043..5181279c 100644 --- a/association.go +++ b/association.go @@ -71,15 +71,15 @@ func (association *Association) Delete(values ...interface{}) *Association { if len(primaryKeys) == 0 { association.setErr(errors.New("no primary key found")) } else { + scope := association.Scope relationship := association.Field.Relationship // many to many if relationship.Kind == "many_to_many" { - whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)", - relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName), - relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName)) - - if err := association.Scope.DB().Table(relationship.JoinTable). - Where(whereSql, association.PrimaryKey, primaryKeys).Delete("").Error; err == nil { + sql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)", + scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), + scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName)) + query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, primaryKeys) + if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) @@ -132,11 +132,9 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } - whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", - relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName), - relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName)) - - scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey, addedPrimaryKeys).Delete("") + sql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName)) + query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, addedPrimaryKeys) + association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship)) } else { association.setErr(errors.New("replace only support many to many")) } @@ -147,8 +145,9 @@ func (association *Association) Clear() *Association { relationship := association.Field.Relationship scope := association.Scope if relationship.Kind == "many_to_many" { - whereSql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName)) - if err := scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey).Delete("").Error; err == nil { + sql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName)) + query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey) + if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -166,9 +165,10 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - scope.DB().Table(relationship.JoinTable). + query := scope.DB().Table(relationship.JoinTable). Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). - Where(relationship.ForeignDBName+" = ?", association.PrimaryKey).Row().Scan(&count) + Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) + scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/callback_shared.go b/callback_shared.go index ae7b99da..48f0b937 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -1,10 +1,6 @@ package gorm -import ( - "fmt" - "reflect" - "strings" -) +import "reflect" func BeginTransaction(scope *Scope) { scope.Begin() @@ -53,24 +49,8 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newDB.Save(elem).Error) if joinTable := relationship.JoinTable; joinTable != "" { - quotedForeignDBName := scope.Quote(relationship.ForeignDBName) - foreignValue := scope.PrimaryKeyValue() - quoteAssociationForeignDBName := scope.Quote(relationship.AssociationForeignDBName) - associationForeignValue := newScope.PrimaryKeyValue() - - newScope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);", - joinTable, - strings.Join([]string{quotedForeignDBName, quoteAssociationForeignDBName}, ","), - strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","), - scope.Dialect().SelectFromDummyTable(), - joinTable, - quotedForeignDBName, - newScope.AddToVars(foreignValue), - quoteAssociationForeignDBName, - newScope.AddToVars(associationForeignValue), - )) - scope.Err(scope.NewDB().Exec(newScope.Sql, newScope.SqlVars...).Error) + scope.Err(scope.db.GetJoinTableHandler(joinTable). + Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue())) } } default: diff --git a/join_table.go b/join_table.go new file mode 100644 index 00000000..7d9c419e --- /dev/null +++ b/join_table.go @@ -0,0 +1,42 @@ +package gorm + +import ( + "fmt" + "strings" +) + +type JoinTableHandler interface { + Add(*DB, *Relationship, interface{}, interface{}) error + Delete(*DB, *Relationship) error + Scope(*DB, *Relationship) *DB +} + +type defaultJoinTableHandler struct{} + +func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error { + scope := db.NewScope("") + quotedForeignDBName := scope.Quote(relationship.ForeignDBName) + quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName) + + sql := fmt.Sprintf( + "INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);", + scope.Quote(relationship.JoinTable), + strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","), + scope.Dialect().SelectFromDummyTable(), + scope.Quote(relationship.JoinTable), + quotedForeignDBName, + quotedAssociationDBName, + ) + + return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error +} + +func (*defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error { + return db.Delete("").Error +} + +func (*defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB { + return db +} + +var DefaultJoinTableHandler = &defaultJoinTableHandler{} diff --git a/main.go b/main.go index cf416799..5f90a146 100644 --- a/main.go +++ b/main.go @@ -20,20 +20,21 @@ var NowFunc = func() time.Time { } type DB struct { - Value interface{} - Error error - RowsAffected int64 - ModelStructs map[reflect.Type]*ModelStruct - callback *callback - db sqlCommon - parent *DB - search *search - logMode int - logger logger - dialect Dialect - singularTable bool - source string - values map[string]interface{} + Value interface{} + Error error + RowsAffected int64 + ModelStructs map[reflect.Type]*ModelStruct + callback *callback + db sqlCommon + parent *DB + search *search + logMode int + logger logger + dialect Dialect + singularTable bool + source string + values map[string]interface{} + joinTableHandlers map[string]JoinTableHandler } func Open(dialect string, args ...interface{}) (DB, error) { @@ -91,20 +92,6 @@ func (db *DB) NewScope(value interface{}) *Scope { return &Scope{db: dbClone, Search: dbClone.search, Value: value} } -func (s *DB) FreshDB() *DB { - newDB := &DB{ - dialect: s.dialect, - logger: s.logger, - callback: s.parent.callback.clone(), - source: s.source, - values: map[string]interface{}{}, - db: s.db, - ModelStructs: map[reflect.Type]*ModelStruct{}, - } - newDB.parent = newDB - return newDB -} - // CommonDB Return the underlying sql.DB or sql.Tx instance. // Use of this method is discouraged. It's mainly intended to allow // coexistence with legacy non-GORM code. @@ -473,3 +460,29 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { value, ok = s.values[name] return } + +func (s *DB) GetJoinTableHandler(table string) JoinTableHandler { + if s.parent.joinTableHandlers != nil { + if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok { + return joinTableHandler + } + if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok { + return joinTableHandler + } + } + return DefaultJoinTableHandler +} + +func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) { + if s.parent.joinTableHandlers == nil { + s.parent.joinTableHandlers = map[string]JoinTableHandler{} + } + + if len(tables) > 0 { + for _, table := range tables { + s.parent.joinTableHandlers[table] = joinTableHandler + } + } else { + s.parent.joinTableHandlers["*"] = joinTableHandler + } +}