From 6ba0c1661f356dea7e2c97aa0bda83362275a797 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 18 Mar 2015 18:14:28 +0800 Subject: [PATCH] Refactor JoinTableHandler --- join_table.go | 103 ++++++++++++++++++++++++++++++++------------- join_table_test.go | 9 ++-- scope_private.go | 5 ++- 3 files changed, 80 insertions(+), 37 deletions(-) diff --git a/join_table.go b/join_table.go index b6b33d9a..2aeb1c4a 100644 --- a/join_table.go +++ b/join_table.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "reflect" "strings" ) @@ -13,70 +14,114 @@ type JoinTableHandlerInterface interface { } type JoinTableSource struct { - ForeignKey string - ForeignKeyPrefix string - ModelStruct + ModelType reflect.Type + ForeignKeys []struct { + DBName string + AssociationDBName string + } } type JoinTableHandler struct { - TableName string - Source1 JoinTableSource - Source2 JoinTableSource + TableName string `sql:"-"` + Source JoinTableSource `sql:"-"` + Destination JoinTableSource `sql:"-"` } -func (jt JoinTableHandler) Table(*DB) string { - return jt.TableName +func (s JoinTableHandler) Table(*DB) string { + return s.TableName } -func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} { +func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { values := map[string]interface{}{} + for _, source := range sources { scope := db.NewScope(source) - for _, primaryField := range scope.GetModelStruct().PrimaryFields { - if field, ok := scope.Fields()[primaryField.DBName]; ok { - values[primaryField.DBName] = field.Field.Interface() + modelType := scope.GetModelStruct().ModelType + + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Source.ForeignKeys { + values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + } + } else if s.Destination.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() } } } return values } -func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { +func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") - valueMap := jt.GetValueMap(db, source1, source2) + searchMap := s.GetSearchMap(db, source1, source2) - var setColumns, setBinVars, queryConditions []string + var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range valueMap { - setColumns = append(setColumns, key) - setBinVars = append(setBinVars, `?`) - queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key))) + for key, value := range searchMap { + assignColumns = append(assignColumns, key) + binVars = append(binVars, `?`) + conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } - for _, value := range valueMap { + for _, value := range searchMap { values = append(values, value) } - quotedTable := jt.Table(db) + quotedTable := s.Table(db) sql := fmt.Sprintf( "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);", quotedTable, - strings.Join(setColumns, ","), - strings.Join(setBinVars, ","), + strings.Join(assignColumns, ","), + strings.Join(binVars, ","), scope.Dialect().SelectFromDummyTable(), quotedTable, - strings.Join(queryConditions, " AND "), + strings.Join(conditions, " AND "), ) return db.Exec(sql, values...).Error } -func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error { - // return db.Table(jt.Table(db)).Delete("").Error - return nil +func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { + var conditions []string + var values []interface{} + + for key, value := range s.GetSearchMap(db, sources...) { + conditions = append(conditions, fmt.Sprintf("%v = ?", key)) + values = append(values, value) + } + + return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } -func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB { - return db +func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { + quotedTable := s.Table(db) + + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + var joinConditions []string + var queryConditions []string + var values []interface{} + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Source.ForeignKeys { + queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) + values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) + } + } else if s.Destination.ModelType == modelType { + for _, foreignKey := range s.Source.ForeignKeys { + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Destination.ForeignKeys { + queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) + values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) + } + } + + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) } diff --git a/join_table_test.go b/join_table_test.go index 2624fdb2..429e46e1 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -15,16 +15,13 @@ type Person struct { } type PersonAddress struct { + gorm.JoinTableHandler PersonID int AddressID int DeletedAt time.Time CreatedAt time.Time } -func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string { - return relationship.JoinTable -} - func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ relationship.ForeignDBName: foreignValue, @@ -41,14 +38,14 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error } func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { - table := pa.Table(db, relationship) + table := pa.Table(db) return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { DB.Exec("drop table person_addresses;") DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") + // DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") address1 := &Address{Address1: "address 1"} address2 := &Address{Address1: "address 2"} diff --git a/scope_private.go b/scope_private.go index 5755a60c..da78d3f2 100644 --- a/scope_private.go +++ b/scope_private.go @@ -437,7 +437,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTable := relationship.JoinTableHandler.Table(scope.db) + joinTableHandler := relationship.JoinTableHandler + joinTable := joinTableHandler.Table(scope.db) if !scope.Dialect().HasTable(scope, joinTable) { primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", @@ -447,7 +448,7 @@ func (scope *Scope) createJoinTable(field *StructField) { scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), ).Error) } - scope.NewDB().Table(joinTable).AutoMigrate() + scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } }