forked from mirror/gorm
Refactor JoinTableHandler
This commit is contained in:
parent
c13e2f18f8
commit
6ba0c1661f
103
join_table.go
103
join_table.go
|
@ -2,6 +2,7 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,70 +14,114 @@ type JoinTableHandlerInterface interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type JoinTableSource struct {
|
type JoinTableSource struct {
|
||||||
ForeignKey string
|
ModelType reflect.Type
|
||||||
ForeignKeyPrefix string
|
ForeignKeys []struct {
|
||||||
ModelStruct
|
DBName string
|
||||||
|
AssociationDBName string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type JoinTableHandler struct {
|
type JoinTableHandler struct {
|
||||||
TableName string
|
TableName string `sql:"-"`
|
||||||
Source1 JoinTableSource
|
Source JoinTableSource `sql:"-"`
|
||||||
Source2 JoinTableSource
|
Destination JoinTableSource `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinTableHandler) Table(*DB) string {
|
func (s JoinTableHandler) Table(*DB) string {
|
||||||
return jt.TableName
|
return s.TableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} {
|
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||||
values := map[string]interface{}{}
|
values := map[string]interface{}{}
|
||||||
|
|
||||||
for _, source := range sources {
|
for _, source := range sources {
|
||||||
scope := db.NewScope(source)
|
scope := db.NewScope(source)
|
||||||
for _, primaryField := range scope.GetModelStruct().PrimaryFields {
|
modelType := scope.GetModelStruct().ModelType
|
||||||
if field, ok := scope.Fields()[primaryField.DBName]; ok {
|
|
||||||
values[primaryField.DBName] = field.Field.Interface()
|
if s.Source.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||||
|
}
|
||||||
|
} else if s.Destination.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
|
func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
|
||||||
scope := db.NewScope("")
|
scope := db.NewScope("")
|
||||||
valueMap := jt.GetValueMap(db, source1, source2)
|
searchMap := s.GetSearchMap(db, source1, source2)
|
||||||
|
|
||||||
var setColumns, setBinVars, queryConditions []string
|
var assignColumns, binVars, conditions []string
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
for key, value := range valueMap {
|
for key, value := range searchMap {
|
||||||
setColumns = append(setColumns, key)
|
assignColumns = append(assignColumns, key)
|
||||||
setBinVars = append(setBinVars, `?`)
|
binVars = append(binVars, `?`)
|
||||||
queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range valueMap {
|
for _, value := range searchMap {
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
quotedTable := jt.Table(db)
|
quotedTable := s.Table(db)
|
||||||
sql := fmt.Sprintf(
|
sql := fmt.Sprintf(
|
||||||
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
|
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
|
||||||
quotedTable,
|
quotedTable,
|
||||||
strings.Join(setColumns, ","),
|
strings.Join(assignColumns, ","),
|
||||||
strings.Join(setBinVars, ","),
|
strings.Join(binVars, ","),
|
||||||
scope.Dialect().SelectFromDummyTable(),
|
scope.Dialect().SelectFromDummyTable(),
|
||||||
quotedTable,
|
quotedTable,
|
||||||
strings.Join(queryConditions, " AND "),
|
strings.Join(conditions, " AND "),
|
||||||
)
|
)
|
||||||
|
|
||||||
return db.Exec(sql, values...).Error
|
return db.Exec(sql, values...).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
||||||
// return db.Table(jt.Table(db)).Delete("").Error
|
var conditions []string
|
||||||
return nil
|
var values []interface{}
|
||||||
|
|
||||||
|
for key, value := range s.GetSearchMap(db, sources...) {
|
||||||
|
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
|
||||||
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB {
|
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||||
return db
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
|
||||||
|
quotedTable := s.Table(db)
|
||||||
|
|
||||||
|
scope := db.NewScope(source)
|
||||||
|
modelType := scope.GetModelStruct().ModelType
|
||||||
|
var joinConditions []string
|
||||||
|
var queryConditions []string
|
||||||
|
var values []interface{}
|
||||||
|
if s.Source.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
|
||||||
|
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
|
||||||
|
}
|
||||||
|
} else if s.Destination.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
|
||||||
|
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", strings.Join(joinConditions, " AND "))).
|
||||||
|
Where(strings.Join(queryConditions, " AND "), values...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,16 +15,13 @@ type Person struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type PersonAddress struct {
|
type PersonAddress struct {
|
||||||
|
gorm.JoinTableHandler
|
||||||
PersonID int
|
PersonID int
|
||||||
AddressID int
|
AddressID int
|
||||||
DeletedAt time.Time
|
DeletedAt time.Time
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string {
|
|
||||||
return relationship.JoinTable
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
|
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
|
||||||
return db.Where(map[string]interface{}{
|
return db.Where(map[string]interface{}{
|
||||||
relationship.ForeignDBName: foreignValue,
|
relationship.ForeignDBName: foreignValue,
|
||||||
|
@ -41,14 +38,14 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
|
func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
|
||||||
table := pa.Table(db, relationship)
|
table := pa.Table(db)
|
||||||
return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJoinTable(t *testing.T) {
|
func TestJoinTable(t *testing.T) {
|
||||||
DB.Exec("drop table person_addresses;")
|
DB.Exec("drop table person_addresses;")
|
||||||
DB.AutoMigrate(&Person{})
|
DB.AutoMigrate(&Person{})
|
||||||
DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses")
|
// DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses")
|
||||||
|
|
||||||
address1 := &Address{Address1: "address 1"}
|
address1 := &Address{Address1: "address 1"}
|
||||||
address2 := &Address{Address1: "address 2"}
|
address2 := &Address{Address1: "address 2"}
|
||||||
|
|
|
@ -437,7 +437,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
|
|
||||||
func (scope *Scope) createJoinTable(field *StructField) {
|
func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTable := relationship.JoinTableHandler.Table(scope.db)
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||||
|
@ -447,7 +448,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
|
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
|
||||||
).Error)
|
).Error)
|
||||||
}
|
}
|
||||||
scope.NewDB().Table(joinTable).AutoMigrate()
|
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue