Refactor JoinTableHandler

This commit is contained in:
Jinzhu 2015-03-18 18:14:28 +08:00
parent c13e2f18f8
commit 6ba0c1661f
3 changed files with 80 additions and 37 deletions

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect"
"strings" "strings"
) )
@ -13,70 +14,114 @@ type JoinTableHandlerInterface interface {
} }
type JoinTableSource struct { type JoinTableSource struct {
ForeignKey string ModelType reflect.Type
ForeignKeyPrefix string ForeignKeys []struct {
ModelStruct DBName string
AssociationDBName string
}
} }
type JoinTableHandler struct { type JoinTableHandler struct {
TableName string TableName string `sql:"-"`
Source1 JoinTableSource Source JoinTableSource `sql:"-"`
Source2 JoinTableSource Destination JoinTableSource `sql:"-"`
} }
func (jt JoinTableHandler) Table(*DB) string { func (s JoinTableHandler) Table(*DB) string {
return jt.TableName return s.TableName
} }
func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} { func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{} values := map[string]interface{}{}
for _, source := range sources { for _, source := range sources {
scope := db.NewScope(source) scope := db.NewScope(source)
for _, primaryField := range scope.GetModelStruct().PrimaryFields { modelType := scope.GetModelStruct().ModelType
if field, ok := scope.Fields()[primaryField.DBName]; ok {
values[primaryField.DBName] = field.Field.Interface() if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
} }
} }
} }
return values return values
} }
func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
scope := db.NewScope("") scope := db.NewScope("")
valueMap := jt.GetValueMap(db, source1, source2) searchMap := s.GetSearchMap(db, source1, source2)
var setColumns, setBinVars, queryConditions []string var assignColumns, binVars, conditions []string
var values []interface{} var values []interface{}
for key, value := range valueMap { for key, value := range searchMap {
setColumns = append(setColumns, key) assignColumns = append(assignColumns, key)
setBinVars = append(setBinVars, `?`) binVars = append(binVars, `?`)
queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key))) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value) values = append(values, value)
} }
for _, value := range valueMap { for _, value := range searchMap {
values = append(values, value) values = append(values, value)
} }
quotedTable := jt.Table(db) quotedTable := s.Table(db)
sql := fmt.Sprintf( sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);", "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
quotedTable, quotedTable,
strings.Join(setColumns, ","), strings.Join(assignColumns, ","),
strings.Join(setBinVars, ","), strings.Join(binVars, ","),
scope.Dialect().SelectFromDummyTable(), scope.Dialect().SelectFromDummyTable(),
quotedTable, quotedTable,
strings.Join(queryConditions, " AND "), strings.Join(conditions, " AND "),
) )
return db.Exec(sql, values...).Error return db.Exec(sql, values...).Error
} }
func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error { func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
// return db.Table(jt.Table(db)).Delete("").Error var conditions []string
return nil var values []interface{}
for key, value := range s.GetSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
values = append(values, value)
}
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
} }
func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB { func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
return db quotedTable := s.Table(db)
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
var joinConditions []string
var queryConditions []string
var values []interface{}
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName)))
}
for _, foreignKey := range s.Source.ForeignKeys {
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName)))
}
for _, foreignKey := range s.Destination.ForeignKeys {
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
}
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", strings.Join(joinConditions, " AND "))).
Where(strings.Join(queryConditions, " AND "), values...)
} }

View File

@ -15,16 +15,13 @@ type Person struct {
} }
type PersonAddress struct { type PersonAddress struct {
gorm.JoinTableHandler
PersonID int PersonID int
AddressID int AddressID int
DeletedAt time.Time DeletedAt time.Time
CreatedAt time.Time CreatedAt time.Time
} }
func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string {
return relationship.JoinTable
}
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{ return db.Where(map[string]interface{}{
relationship.ForeignDBName: foreignValue, relationship.ForeignDBName: foreignValue,
@ -41,14 +38,14 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error
} }
func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
table := pa.Table(db, relationship) table := pa.Table(db)
return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
} }
func TestJoinTable(t *testing.T) { func TestJoinTable(t *testing.T) {
DB.Exec("drop table person_addresses;") DB.Exec("drop table person_addresses;")
DB.AutoMigrate(&Person{}) DB.AutoMigrate(&Person{})
DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") // DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses")
address1 := &Address{Address1: "address 1"} address1 := &Address{Address1: "address 1"}
address2 := &Address{Address1: "address 2"} address2 := &Address{Address1: "address 2"}

View File

@ -437,7 +437,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
func (scope *Scope) createJoinTable(field *StructField) { func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTable := relationship.JoinTableHandler.Table(scope.db) joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(scope, joinTable) { if !scope.Dialect().HasTable(scope, joinTable) {
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
@ -447,7 +448,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
).Error) ).Error)
} }
scope.NewDB().Table(joinTable).AutoMigrate() scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
} }
} }