mirror of https://github.com/go-gorm/gorm.git
New JoinTableHandler
This commit is contained in:
parent
17b20bcdd8
commit
c13e2f18f8
|
@ -77,7 +77,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
|
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
|
||||||
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
|
||||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||||
value := association.Field.Field.Index(i)
|
value := association.Field.Field.Index(i)
|
||||||
|
@ -132,7 +132,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
|
|
||||||
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
||||||
association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
|
association.setErr(relationship.JoinTableHandler.Delete(query, relationship))
|
||||||
} else {
|
} else {
|
||||||
association.setErr(errors.New("replace only support many to many"))
|
association.setErr(errors.New("replace only support many to many"))
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,7 @@ func (association *Association) Clear() *Association {
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
||||||
query := scope.NewDB().Where(sql, association.PrimaryKey)
|
query := scope.NewDB().Where(sql, association.PrimaryKey)
|
||||||
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||||
} else {
|
} else {
|
||||||
association.setErr(err)
|
association.setErr(err)
|
||||||
|
@ -165,7 +165,7 @@ func (association *Association) Count() int {
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
|
query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
|
||||||
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
|
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
|
||||||
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
|
relationship.JoinTableHandler.JoinWith(query, association.Scope.Value).Count(&count)
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||||
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
||||||
|
|
|
@ -38,7 +38,7 @@ func SaveAfterAssociations(scope *Scope) {
|
||||||
elem := value.Index(i).Addr().Interface()
|
elem := value.Index(i).Addr().Interface()
|
||||||
newScope := newDB.NewScope(elem)
|
newScope := newDB.NewScope(elem)
|
||||||
|
|
||||||
if relationship.JoinTable == "" && relationship.ForeignFieldName != "" {
|
if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" {
|
||||||
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
|
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,9 +48,8 @@ func SaveAfterAssociations(scope *Scope) {
|
||||||
|
|
||||||
scope.Err(newDB.Save(elem).Error)
|
scope.Err(newDB.Save(elem).Error)
|
||||||
|
|
||||||
if joinTable := relationship.JoinTable; joinTable != "" {
|
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||||
scope.Err(scope.db.GetJoinTableHandler(joinTable).
|
scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value))
|
||||||
Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -5,44 +5,78 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JoinTableHandler interface {
|
type JoinTableHandlerInterface interface {
|
||||||
Table(*DB, *Relationship) string
|
Table(db *DB) string
|
||||||
Add(*DB, *Relationship, interface{}, interface{}) error
|
Add(db *DB, source1 interface{}, source2 interface{}) error
|
||||||
Delete(*DB, *Relationship) error
|
Delete(db *DB, sources ...interface{}) error
|
||||||
Scope(*DB, *Relationship) *DB
|
JoinWith(db *DB, source interface{}) *DB
|
||||||
}
|
}
|
||||||
|
|
||||||
type defaultJoinTableHandler struct{}
|
type JoinTableSource struct {
|
||||||
|
ForeignKey string
|
||||||
func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string {
|
ForeignKeyPrefix string
|
||||||
return relationship.JoinTable
|
ModelStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
|
type JoinTableHandler struct {
|
||||||
|
TableName string
|
||||||
|
Source1 JoinTableSource
|
||||||
|
Source2 JoinTableSource
|
||||||
|
}
|
||||||
|
|
||||||
|
func (jt JoinTableHandler) Table(*DB) string {
|
||||||
|
return jt.TableName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||||
|
values := map[string]interface{}{}
|
||||||
|
for _, source := range sources {
|
||||||
|
scope := db.NewScope(source)
|
||||||
|
for _, primaryField := range scope.GetModelStruct().PrimaryFields {
|
||||||
|
if field, ok := scope.Fields()[primaryField.DBName]; ok {
|
||||||
|
values[primaryField.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
|
||||||
scope := db.NewScope("")
|
scope := db.NewScope("")
|
||||||
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
|
valueMap := jt.GetValueMap(db, source1, source2)
|
||||||
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
|
|
||||||
table := s.Table(db, relationship)
|
|
||||||
|
|
||||||
|
var setColumns, setBinVars, queryConditions []string
|
||||||
|
var values []interface{}
|
||||||
|
for key, value := range valueMap {
|
||||||
|
setColumns = append(setColumns, key)
|
||||||
|
setBinVars = append(setBinVars, `?`)
|
||||||
|
queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range valueMap {
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
quotedTable := jt.Table(db)
|
||||||
sql := fmt.Sprintf(
|
sql := fmt.Sprintf(
|
||||||
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
|
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
|
||||||
scope.Quote(table),
|
quotedTable,
|
||||||
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
|
strings.Join(setColumns, ","),
|
||||||
|
strings.Join(setBinVars, ","),
|
||||||
scope.Dialect().SelectFromDummyTable(),
|
scope.Dialect().SelectFromDummyTable(),
|
||||||
scope.Quote(table),
|
quotedTable,
|
||||||
quotedForeignDBName,
|
strings.Join(queryConditions, " AND "),
|
||||||
quotedAssociationDBName,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
|
return db.Exec(sql, values...).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
|
func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
||||||
return db.Table(s.Table(db, relationship)).Delete("").Error
|
// return db.Table(jt.Table(db)).Delete("").Error
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
|
func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB {
|
||||||
return db.Table(s.Table(db, relationship))
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultJoinTableHandler = &defaultJoinTableHandler{}
|
|
||||||
|
|
27
main.go
27
main.go
|
@ -469,30 +469,3 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||||
value, ok = s.values[name]
|
value, ok = s.values[name]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) GetJoinTableHandler(table string) JoinTableHandler {
|
|
||||||
if s.parent.joinTableHandlers != nil {
|
|
||||||
if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok {
|
|
||||||
return joinTableHandler
|
|
||||||
}
|
|
||||||
if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok {
|
|
||||||
return joinTableHandler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return DefaultJoinTableHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) {
|
|
||||||
if s.parent.joinTableHandlers == nil {
|
|
||||||
s.parent.joinTableHandlers = map[string]JoinTableHandler{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tables) > 0 {
|
|
||||||
for _, table := range tables {
|
|
||||||
s.parent.joinTableHandlers[table] = joinTableHandler
|
|
||||||
s.Table(table).AutoMigrate(joinTableHandler)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
s.parent.joinTableHandlers["*"] = joinTableHandler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ type Relationship struct {
|
||||||
ForeignDBName string
|
ForeignDBName string
|
||||||
AssociationForeignFieldName string
|
AssociationForeignFieldName string
|
||||||
AssociationForeignDBName string
|
AssociationForeignDBName string
|
||||||
JoinTable string
|
JoinTableHandler JoinTableHandlerInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
||||||
|
@ -205,7 +205,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
|
|
||||||
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
|
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
|
||||||
relationship.Kind = "many_to_many"
|
relationship.Kind = "many_to_many"
|
||||||
relationship.JoinTable = many2many
|
relationship.JoinTableHandler = JoinTableHandler{}
|
||||||
|
|
||||||
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
||||||
if associationForeignKey == "" {
|
if associationForeignKey == "" {
|
||||||
|
|
|
@ -402,18 +402,11 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
if fromField != nil {
|
if fromField != nil {
|
||||||
if relationship := fromField.Relationship; relationship != nil {
|
if relationship := fromField.Relationship; relationship != nil {
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship))
|
quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db))
|
||||||
|
scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).
|
||||||
joinSql := fmt.Sprintf(
|
Where(fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)), scope.PrimaryKeyValue()).
|
||||||
"INNER JOIN %v ON %v.%v = %v.%v",
|
Find(value).Error)
|
||||||
quotedJoinTable,
|
|
||||||
quotedJoinTable,
|
|
||||||
scope.Quote(relationship.AssociationForeignDBName),
|
|
||||||
toScope.QuotedTableName(),
|
|
||||||
scope.Quote(toScope.PrimaryKey()))
|
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName))
|
|
||||||
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
|
|
||||||
} else if relationship.Kind == "belongs_to" {
|
} else if relationship.Kind == "belongs_to" {
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||||
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
|
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
|
||||||
|
@ -443,9 +436,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) createJoinTable(field *StructField) {
|
func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
joinTable := relationship.JoinTableHandler.Table(scope.db)
|
||||||
joinTable := joinTableHandler.Table(scope.db, relationship)
|
|
||||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||||
|
|
Loading…
Reference in New Issue