Add Table method for JoinTableHandler

This commit is contained in:
Jinzhu 2015-03-04 12:16:16 +08:00
parent 3740a3beaa
commit 80576bbbbc
4 changed files with 38 additions and 25 deletions

View File

@ -75,10 +75,8 @@ func (association *Association) Delete(values ...interface{}) *Association {
relationship := association.Field.Relationship relationship := association.Field.Relationship
// many to many // many to many
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)", sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, primaryKeys)
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
leftValues := reflect.Zero(association.Field.Field.Type()) leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ { for i := 0; i < association.Field.Field.Len(); i++ {
@ -132,8 +130,8 @@ func (association *Association) Replace(values ...interface{}) *Association {
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
} }
sql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName)) sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, addedPrimaryKeys) query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship)) association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
} else { } else {
association.setErr(errors.New("replace only support many to many")) association.setErr(errors.New("replace only support many to many"))
@ -145,8 +143,8 @@ func (association *Association) Clear() *Association {
relationship := association.Field.Relationship relationship := association.Field.Relationship
scope := association.Scope scope := association.Scope
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName)) sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey) query := scope.NewDB().Where(sql, association.PrimaryKey)
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Field.Set(reflect.Zero(association.Field.Field.Type()))
} else { } else {
@ -165,8 +163,7 @@ func (association *Association) Count() int {
newScope := scope.New(association.Field.Field.Interface()) newScope := scope.New(association.Field.Field.Interface())
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
query := scope.DB().Table(relationship.JoinTable). query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count) scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {

View File

@ -6,6 +6,7 @@ import (
) )
type JoinTableHandler interface { type JoinTableHandler interface {
Table(*DB, *Relationship) string
Add(*DB, *Relationship, interface{}, interface{}) error Add(*DB, *Relationship, interface{}, interface{}) error
Delete(*DB, *Relationship) error Delete(*DB, *Relationship) error
Scope(*DB, *Relationship) *DB Scope(*DB, *Relationship) *DB
@ -13,17 +14,22 @@ type JoinTableHandler interface {
type defaultJoinTableHandler struct{} type defaultJoinTableHandler struct{}
func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error { func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string {
return relationship.JoinTable
}
func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
scope := db.NewScope("") scope := db.NewScope("")
quotedForeignDBName := scope.Quote(relationship.ForeignDBName) quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName) quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
table := s.Table(db, relationship)
sql := fmt.Sprintf( sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);", "INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
scope.Quote(relationship.JoinTable), scope.Quote(table),
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","), strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
scope.Dialect().SelectFromDummyTable(), scope.Dialect().SelectFromDummyTable(),
scope.Quote(relationship.JoinTable), scope.Quote(table),
quotedForeignDBName, quotedForeignDBName,
quotedAssociationDBName, quotedAssociationDBName,
) )
@ -31,12 +37,12 @@ func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignV
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
} }
func (*defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error { func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
return db.Delete("").Error return db.Table(s.Table(db, relationship)).Delete("").Error
} }
func (*defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB { func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
return db return db.Table(s.Table(db, relationship))
} }
var DefaultJoinTableHandler = &defaultJoinTableHandler{} var DefaultJoinTableHandler = &defaultJoinTableHandler{}

View File

@ -21,6 +21,10 @@ type PersonAddress struct {
CreatedAt time.Time CreatedAt time.Time
} }
func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string {
return relationship.JoinTable
}
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{ return db.Where(map[string]interface{}{
relationship.ForeignDBName: foreignValue, relationship.ForeignDBName: foreignValue,
@ -36,8 +40,9 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error
return db.Delete(&PersonAddress{}).Error return db.Delete(&PersonAddress{}).Error
} }
func (*PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
return db.Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", relationship.JoinTable, relationship.JoinTable)) table := pa.Table(db, relationship)
return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
} }
func TestJoinTable(t *testing.T) { func TestJoinTable(t *testing.T) {

View File

@ -402,14 +402,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if fromField != nil { if fromField != nil {
if relationship := fromField.Relationship; relationship != nil { if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship))
joinSql := fmt.Sprintf( joinSql := fmt.Sprintf(
"INNER JOIN %v ON %v.%v = %v.%v", "INNER JOIN %v ON %v.%v = %v.%v",
scope.Quote(relationship.JoinTable), quotedJoinTable,
scope.Quote(relationship.JoinTable), quotedJoinTable,
scope.Quote(relationship.AssociationForeignDBName), scope.Quote(relationship.AssociationForeignDBName),
toScope.QuotedTableName(), toScope.QuotedTableName(),
scope.Quote(toScope.PrimaryKey())) scope.Quote(toScope.PrimaryKey()))
whereSql := fmt.Sprintf("%v.%v = ?", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName)) whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName))
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error) scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
@ -441,16 +444,18 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
func (scope *Scope) createJoinTable(field *StructField) { func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" { if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" {
if !scope.Dialect().HasTable(scope, relationship.JoinTable) { joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
joinTable := joinTableHandler.Table(scope.db, relationship)
if !scope.Dialect().HasTable(scope, joinTable) {
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255) primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
scope.Quote(relationship.JoinTable), scope.Quote(joinTable),
strings.Join([]string{ strings.Join([]string{
scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType, scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType,
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
).Error) ).Error)
} }
scope.NewDB().Table(relationship.JoinTable).AutoMigrate(scope.db.GetJoinTableHandler(relationship.JoinTable)) scope.NewDB().Table(joinTable).AutoMigrate()
} }
} }