gorm/join_table_handler.go

181 lines
5.8 KiB
Go
Raw Normal View History

2015-02-28 06:48:18 +03:00
package gorm
import (
2015-03-20 06:11:30 +03:00
"errors"
2015-02-28 06:48:18 +03:00
"fmt"
2015-03-18 13:14:28 +03:00
"reflect"
2015-02-28 06:48:18 +03:00
"strings"
)
2015-03-18 06:47:11 +03:00
type JoinTableHandlerInterface interface {
2015-03-20 06:11:30 +03:00
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
2015-03-18 06:47:11 +03:00
Table(db *DB) string
2015-06-19 06:32:11 +03:00
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
2015-06-29 13:04:15 +03:00
SourceForeignKeys() []JoinTableForeignKey
DestinationForeignKeys() []JoinTableForeignKey
2015-02-28 06:48:18 +03:00
}
2015-03-19 10:02:15 +03:00
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
2015-03-19 13:23:54 +03:00
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}
2015-06-29 13:04:15 +03:00
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}
2015-03-20 06:11:30 +03:00
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
2015-03-19 10:02:15 +03:00
2015-03-20 06:11:30 +03:00
s.Source = JoinTableSource{ModelType: source}
2015-07-30 09:26:48 +03:00
for idx, dbName := range relationship.ForeignFieldNames {
2015-03-20 06:11:30 +03:00
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
2015-07-30 09:26:48 +03:00
DBName: relationship.ForeignDBNames[idx],
AssociationDBName: dbName,
2015-03-19 10:02:15 +03:00
})
}
2015-03-20 06:11:30 +03:00
s.Destination = JoinTableSource{ModelType: destination}
2015-07-30 09:26:48 +03:00
for idx, dbName := range relationship.AssociationForeignFieldNames {
2015-03-20 06:11:30 +03:00
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
2015-07-30 09:26:48 +03:00
DBName: relationship.AssociationForeignDBNames[idx],
AssociationDBName: dbName,
2015-03-19 10:02:15 +03:00
})
}
}
2015-06-19 06:32:11 +03:00
func (s JoinTableHandler) Table(db *DB) string {
2015-06-23 06:31:56 +03:00
return s.TableName
2015-03-18 06:47:11 +03:00
}
2015-02-28 06:48:18 +03:00
2015-03-18 13:14:28 +03:00
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
2015-03-18 06:47:11 +03:00
values := map[string]interface{}{}
2015-03-18 13:14:28 +03:00
2015-03-18 06:47:11 +03:00
for _, source := range sources {
scope := db.NewScope(source)
2015-03-18 13:14:28 +03:00
modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
2015-03-18 06:47:11 +03:00
}
}
}
return values
2015-03-04 07:16:16 +03:00
}
2015-06-19 06:32:11 +03:00
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
2015-02-28 06:48:18 +03:00
scope := db.NewScope("")
2015-03-18 13:14:28 +03:00
searchMap := s.GetSearchMap(db, source1, source2)
2015-03-18 06:47:11 +03:00
2015-03-18 13:14:28 +03:00
var assignColumns, binVars, conditions []string
2015-03-18 06:47:11 +03:00
var values []interface{}
2015-03-18 13:14:28 +03:00
for key, value := range searchMap {
assignColumns = append(assignColumns, scope.Quote(key))
2015-03-18 13:14:28 +03:00
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
2015-03-18 06:47:11 +03:00
values = append(values, value)
}
2015-02-28 06:48:18 +03:00
2015-03-19 13:23:54 +03:00
for _, value := range values {
2015-03-18 06:47:11 +03:00
values = append(values, value)
}
quotedTable := scope.Quote(handler.Table(db))
2015-02-28 06:48:18 +03:00
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
2015-03-18 06:47:11 +03:00
quotedTable,
2015-03-18 13:14:28 +03:00
strings.Join(assignColumns, ","),
strings.Join(binVars, ","),
2015-02-28 06:48:18 +03:00
scope.Dialect().SelectFromDummyTable(),
2015-03-18 06:47:11 +03:00
quotedTable,
2015-03-18 13:14:28 +03:00
strings.Join(conditions, " AND "),
2015-02-28 06:48:18 +03:00
)
2015-03-18 06:47:11 +03:00
return db.Exec(sql, values...).Error
2015-02-28 06:48:18 +03:00
}
2015-06-19 06:32:11 +03:00
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
conditions []string
values []interface{}
)
2015-03-18 13:14:28 +03:00
for key, value := range s.GetSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
2015-03-18 13:14:28 +03:00
values = append(values, value)
}
2015-06-19 06:32:11 +03:00
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
2015-02-28 06:48:18 +03:00
}
2015-06-19 06:32:11 +03:00
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var (
scope = db.NewScope(source)
tableName = handler.Table(db)
quotedTableName = scope.Quote(tableName)
joinConditions []string
values []interface{}
)
2015-03-18 13:14:28 +03:00
if s.Source.ModelType == scope.GetModelStruct().ModelType {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
2015-08-16 10:10:11 +03:00
for _, foreignKey := range s.Destination.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
2015-08-16 10:10:11 +03:00
}
var foreignDBNames []string
var foreignFieldNames []string
2015-08-16 10:10:11 +03:00
for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name)
}
2015-08-16 10:10:11 +03:00
2016-01-15 17:53:09 +03:00
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
2015-08-16 10:10:11 +03:00
var condString string
if len(foreignFieldValues) > 0 {
var quotedForeignDBNames []string
for _, dbName := range foreignDBNames {
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
}
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
2016-01-15 17:53:09 +03:00
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
values = append(values, toQueryValues(keys))
} else {
condString = fmt.Sprintf("1 <> 1")
}
2015-08-16 10:10:11 +03:00
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, toQueryValues(foreignFieldValues)...)
2015-08-16 10:10:11 +03:00
}
2016-01-15 16:03:35 +03:00
db.Error = errors.New("wrong source type for join table handler")
return db
2015-08-16 10:10:11 +03:00
}