diff --git a/association.go b/association.go index dbc928e8..e34a10bd 100644 --- a/association.go +++ b/association.go @@ -78,7 +78,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) @@ -134,7 +134,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if len(addedPrimaryKeys) > 0 { sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) - association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) } } else { association.setErr(errors.New("replace only support many to many")) @@ -148,7 +148,7 @@ func (association *Association) Clear() *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey) - if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -166,7 +166,7 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) + relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/callback_shared.go b/callback_shared.go index 88158cfc..c1b9bd00 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -55,7 +55,7 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newDB.Save(elem).Error) if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value)) + scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) } } default: diff --git a/join_table_handler.go b/join_table_handler.go index 27051cbd..dceb4277 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -10,9 +10,9 @@ import ( type JoinTableHandlerInterface interface { Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) Table(db *DB) string - Add(db *DB, source interface{}, destination interface{}) error - Delete(db *DB, sources ...interface{}) error - JoinWith(db *DB, source interface{}) *DB + Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error + Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error + JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB } type JoinTableForeignKey struct { @@ -74,8 +74,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } } -func (s JoinTableHandler) Table(*DB) string { - return s.TableName +func (s JoinTableHandler) Table(db *DB) string { + if draftMode, ok := db.Get("publish:draft_mode"); ok && draftMode.(bool) { + return s.TableName + "_draft" + } else { + return s.TableName + } } func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { @@ -98,7 +102,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin return values } -func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { +func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") searchMap := s.GetSearchMap(db, source1, source2) @@ -115,7 +119,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) values = append(values, value) } - quotedTable := s.Table(db) + quotedTable := handler.Table(db) sql := fmt.Sprintf( "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", quotedTable, @@ -129,7 +133,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) return db.Exec(sql, values...).Error } -func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { +func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var conditions []string var values []interface{} @@ -138,11 +142,11 @@ func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { values = append(values, value) } - return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error + return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } -func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { - quotedTable := s.Table(db) +func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { + quotedTable := handler.Table(db) scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType diff --git a/join_table_test.go b/join_table_test.go index f8b097b6..3353aee2 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -22,7 +22,7 @@ type PersonAddress struct { CreatedAt time.Time } -func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { +func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), "address_id": db.NewScope(associationValue).PrimaryKeyValue(), @@ -33,11 +33,11 @@ func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValu }).FirstOrCreate(&PersonAddress{}).Error } -func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error { +func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { return db.Delete(&PersonAddress{}).Error } -func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB { +func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } diff --git a/scope_private.go b/scope_private.go index 63fcea46..5faebe2e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -413,7 +413,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error) + scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()