mirror of https://github.com/go-gorm/gorm.git
154 lines
5.1 KiB
Go
154 lines
5.1 KiB
Go
package gorm
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
type JoinTableHandlerInterface interface {
|
|
Table(db *DB) string
|
|
Add(db *DB, source1 interface{}, source2 interface{}) error
|
|
Delete(db *DB, sources ...interface{}) error
|
|
JoinWith(db *DB, source interface{}) *DB
|
|
}
|
|
|
|
type JoinTableForeignKey struct {
|
|
DBName string
|
|
AssociationDBName string
|
|
}
|
|
|
|
type JoinTableSource struct {
|
|
ModelType reflect.Type
|
|
ForeignKeys []JoinTableForeignKey
|
|
}
|
|
|
|
type JoinTableHandler struct {
|
|
TableName string `sql:"-"`
|
|
Source JoinTableSource `sql:"-"`
|
|
Destination JoinTableSource `sql:"-"`
|
|
}
|
|
|
|
func updateJoinTableHandler(relationship *Relationship) {
|
|
handler := relationship.JoinTableHandler.(*JoinTableHandler)
|
|
|
|
destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()}
|
|
for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields {
|
|
db := relationship.AssociationForeignDBName
|
|
handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{
|
|
DBName: db,
|
|
AssociationDBName: primaryField.DBName,
|
|
})
|
|
}
|
|
|
|
sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()}
|
|
for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields {
|
|
db := relationship.ForeignDBName
|
|
handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{
|
|
DBName: db,
|
|
AssociationDBName: primaryField.DBName,
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s JoinTableHandler) Table(*DB) string {
|
|
return s.TableName
|
|
}
|
|
|
|
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
|
values := map[string]interface{}{}
|
|
|
|
for _, source := range sources {
|
|
scope := db.NewScope(source)
|
|
modelType := scope.GetModelStruct().ModelType
|
|
|
|
if s.Source.ModelType == modelType {
|
|
for _, foreignKey := range s.Source.ForeignKeys {
|
|
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
|
}
|
|
} else if s.Destination.ModelType == modelType {
|
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
|
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
|
}
|
|
}
|
|
}
|
|
return values
|
|
}
|
|
|
|
func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
|
|
scope := db.NewScope("")
|
|
searchMap := s.GetSearchMap(db, source1, source2)
|
|
|
|
var assignColumns, binVars, conditions []string
|
|
var values []interface{}
|
|
for key, value := range searchMap {
|
|
assignColumns = append(assignColumns, key)
|
|
binVars = append(binVars, `?`)
|
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
|
values = append(values, value)
|
|
}
|
|
|
|
for _, value := range values {
|
|
values = append(values, value)
|
|
}
|
|
|
|
quotedTable := s.Table(db)
|
|
sql := fmt.Sprintf(
|
|
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
|
|
quotedTable,
|
|
strings.Join(assignColumns, ","),
|
|
strings.Join(binVars, ","),
|
|
scope.Dialect().SelectFromDummyTable(),
|
|
quotedTable,
|
|
strings.Join(conditions, " AND "),
|
|
)
|
|
|
|
return db.Exec(sql, values...).Error
|
|
}
|
|
|
|
func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
|
var conditions []string
|
|
var values []interface{}
|
|
|
|
for key, value := range s.GetSearchMap(db, sources...) {
|
|
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
|
|
values = append(values, value)
|
|
}
|
|
|
|
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
|
}
|
|
|
|
func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
|
|
quotedTable := s.Table(db)
|
|
|
|
scope := db.NewScope(source)
|
|
modelType := scope.GetModelStruct().ModelType
|
|
var joinConditions []string
|
|
var queryConditions []string
|
|
var values []interface{}
|
|
if s.Source.ModelType == modelType {
|
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
|
destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
|
}
|
|
|
|
for _, foreignKey := range s.Source.ForeignKeys {
|
|
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
|
|
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
|
|
}
|
|
} else if s.Destination.ModelType == modelType {
|
|
for _, foreignKey := range s.Source.ForeignKeys {
|
|
sourceTableName := scope.New(reflect.New(s.Source.ModelType).Interface()).QuotedTableName()
|
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), sourceTableName, scope.Quote(foreignKey.AssociationDBName)))
|
|
}
|
|
|
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
|
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
|
|
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
|
|
}
|
|
}
|
|
|
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
|
Where(strings.Join(queryConditions, " AND "), values...)
|
|
}
|