From c13e2f18f8d6f571e4ab4229dbc782cccc2f4125 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 18 Mar 2015 11:47:11 +0800 Subject: [PATCH] New JoinTableHandler --- association.go | 8 ++--- callback_shared.go | 7 ++-- join_table.go | 86 ++++++++++++++++++++++++++++++++-------------- main.go | 27 --------------- model_struct.go | 4 +-- scope_private.go | 22 ++++-------- 6 files changed, 76 insertions(+), 78 deletions(-) diff --git a/association.go b/association.go index b011971a..60763f8c 100644 --- a/association.go +++ b/association.go @@ -77,7 +77,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) @@ -132,7 +132,7 @@ func (association *Association) Replace(values ...interface{}) *Association { sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) - association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) } else { association.setErr(errors.New("replace only support many to many")) } @@ -145,7 +145,7 @@ func (association *Association) Clear() *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey) - if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -165,7 +165,7 @@ func (association *Association) Count() int { if relationship.Kind == "many_to_many" { query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) - scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count) + relationship.JoinTableHandler.JoinWith(query, association.Scope.Value).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/callback_shared.go b/callback_shared.go index 99ad8f50..ae75c250 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -38,7 +38,7 @@ func SaveAfterAssociations(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTable == "" && relationship.ForeignFieldName != "" { + if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" { scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) } @@ -48,9 +48,8 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newDB.Save(elem).Error) - if joinTable := relationship.JoinTable; joinTable != "" { - scope.Err(scope.db.GetJoinTableHandler(joinTable). - Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue())) + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value)) } } default: diff --git a/join_table.go b/join_table.go index 3ffa4f87..b6b33d9a 100644 --- a/join_table.go +++ b/join_table.go @@ -5,44 +5,78 @@ import ( "strings" ) -type JoinTableHandler interface { - Table(*DB, *Relationship) string - Add(*DB, *Relationship, interface{}, interface{}) error - Delete(*DB, *Relationship) error - Scope(*DB, *Relationship) *DB +type JoinTableHandlerInterface interface { + Table(db *DB) string + Add(db *DB, source1 interface{}, source2 interface{}) error + Delete(db *DB, sources ...interface{}) error + JoinWith(db *DB, source interface{}) *DB } -type defaultJoinTableHandler struct{} - -func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string { - return relationship.JoinTable +type JoinTableSource struct { + ForeignKey string + ForeignKeyPrefix string + ModelStruct } -func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error { +type JoinTableHandler struct { + TableName string + Source1 JoinTableSource + Source2 JoinTableSource +} + +func (jt JoinTableHandler) Table(*DB) string { + return jt.TableName +} + +func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} { + values := map[string]interface{}{} + for _, source := range sources { + scope := db.NewScope(source) + for _, primaryField := range scope.GetModelStruct().PrimaryFields { + if field, ok := scope.Fields()[primaryField.DBName]; ok { + values[primaryField.DBName] = field.Field.Interface() + } + } + } + return values +} + +func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") - quotedForeignDBName := scope.Quote(relationship.ForeignDBName) - quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName) - table := s.Table(db, relationship) + valueMap := jt.GetValueMap(db, source1, source2) + var setColumns, setBinVars, queryConditions []string + var values []interface{} + for key, value := range valueMap { + setColumns = append(setColumns, key) + setBinVars = append(setBinVars, `?`) + queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key))) + values = append(values, value) + } + + for _, value := range valueMap { + values = append(values, value) + } + + quotedTable := jt.Table(db) sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);", - scope.Quote(table), - strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","), + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);", + quotedTable, + strings.Join(setColumns, ","), + strings.Join(setBinVars, ","), scope.Dialect().SelectFromDummyTable(), - scope.Quote(table), - quotedForeignDBName, - quotedAssociationDBName, + quotedTable, + strings.Join(queryConditions, " AND "), ) - return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error + return db.Exec(sql, values...).Error } -func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error { - return db.Table(s.Table(db, relationship)).Delete("").Error +func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error { + // return db.Table(jt.Table(db)).Delete("").Error + return nil } -func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB { - return db.Table(s.Table(db, relationship)) +func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB { + return db } - -var DefaultJoinTableHandler = &defaultJoinTableHandler{} diff --git a/main.go b/main.go index 87fecd59..377ef582 100644 --- a/main.go +++ b/main.go @@ -469,30 +469,3 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { value, ok = s.values[name] return } - -func (s *DB) GetJoinTableHandler(table string) JoinTableHandler { - if s.parent.joinTableHandlers != nil { - if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok { - return joinTableHandler - } - if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok { - return joinTableHandler - } - } - return DefaultJoinTableHandler -} - -func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) { - if s.parent.joinTableHandlers == nil { - s.parent.joinTableHandlers = map[string]JoinTableHandler{} - } - - if len(tables) > 0 { - for _, table := range tables { - s.parent.joinTableHandlers[table] = joinTableHandler - s.Table(table).AutoMigrate(joinTableHandler) - } - } else { - s.parent.joinTableHandlers["*"] = joinTableHandler - } -} diff --git a/model_struct.go b/model_struct.go index 02e23dfe..cce28330 100644 --- a/model_struct.go +++ b/model_struct.go @@ -60,7 +60,7 @@ type Relationship struct { ForeignDBName string AssociationForeignFieldName string AssociationForeignDBName string - JoinTable string + JoinTableHandler JoinTableHandlerInterface } var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} @@ -205,7 +205,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := gormSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - relationship.JoinTable = many2many + relationship.JoinTableHandler = JoinTableHandler{} associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] if associationForeignKey == "" { diff --git a/scope_private.go b/scope_private.go index 6d700cb9..5755a60c 100644 --- a/scope_private.go +++ b/scope_private.go @@ -402,18 +402,11 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if fromField != nil { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { - joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) - quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship)) - - joinSql := fmt.Sprintf( - "INNER JOIN %v ON %v.%v = %v.%v", - quotedJoinTable, - quotedJoinTable, - scope.Quote(relationship.AssociationForeignDBName), - toScope.QuotedTableName(), - scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)) - scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error) + joinTableHandler := relationship.JoinTableHandler + quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db)) + scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value). + Where(fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)), scope.PrimaryKeyValue()). + Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() @@ -443,9 +436,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" { - joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) - joinTable := joinTableHandler.Table(scope.db, relationship) + if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { + joinTable := relationship.JoinTableHandler.Table(scope.db) if !scope.Dialect().HasTable(scope, joinTable) { primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",