diff --git a/join_table.go b/join_table_handler.go similarity index 74% rename from join_table.go rename to join_table_handler.go index 163bb4e2..21e88fe1 100644 --- a/join_table.go +++ b/join_table_handler.go @@ -1,14 +1,16 @@ package gorm import ( + "errors" "fmt" "reflect" "strings" ) type JoinTableHandlerInterface interface { + Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) Table(db *DB) string - Add(db *DB, source1 interface{}, source2 interface{}) error + Add(db *DB, source interface{}, destination interface{}) error Delete(db *DB, sources ...interface{}) error JoinWith(db *DB, source interface{}) *DB } @@ -29,22 +31,24 @@ type JoinTableHandler struct { Destination JoinTableSource `sql:"-"` } -func updateJoinTableHandler(relationship *Relationship) { - handler := relationship.JoinTableHandler.(*JoinTableHandler) +func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { + s.TableName = tableName - destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} - for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - db := relationship.AssociationForeignDBName - handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ + s.Source = JoinTableSource{ModelType: source} + sourceScope := &Scope{Value: reflect.New(source).Interface()} + for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + db := relationship.ForeignDBName + s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ DBName: db, AssociationDBName: primaryField.DBName, }) } - sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} - for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { - db := relationship.ForeignDBName - handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ + s.Destination = JoinTableSource{ModelType: destination} + destinationScope := &Scope{Value: reflect.New(destination).Interface()} + for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { + db := relationship.AssociationForeignDBName + s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ DBName: db, AssociationDBName: primaryField.DBName, }) @@ -136,18 +140,10 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - sourceTableName := scope.New(reflect.New(s.Source.ModelType).Interface()).QuotedTableName() - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), sourceTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - for _, foreignKey := range s.Destination.ForeignKeys { - queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) - values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) - } + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) + } else { + db.Error = errors.New("wrong source type for join table handler") + return db } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). - Where(strings.Join(queryConditions, " AND "), values...) } diff --git a/join_table_test.go b/join_table_test.go index 40f36799..f8b097b6 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -15,16 +15,13 @@ type Person struct { } type PersonAddress struct { + gorm.JoinTableHandler PersonID int AddressID int DeletedAt time.Time CreatedAt time.Time } -func (*PersonAddress) Table(db *gorm.DB) string { - return "person_addresses" -} - func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), @@ -32,7 +29,7 @@ func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValu }).Assign(map[string]interface{}{ "person_id": foreignValue, "address_id": associationValue, - "DeletedAt": gorm.Expr("NULL"), + "deleted_at": gorm.Expr("NULL"), }).FirstOrCreate(&PersonAddress{}).Error } diff --git a/main.go b/main.go index 5f7db05a..b66ceda3 100644 --- a/main.go +++ b/main.go @@ -473,8 +473,13 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { for _, field := range s.NewScope(source).GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - field.Relationship.JoinTableHandler = handler - s.Table(handler.Table(s)).AutoMigrate(handler) + if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { + source := (&Scope{Value: source}).GetModelStruct().ModelType + destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType + handler.Setup(field.Relationship, many2many, source, destination) + field.Relationship.JoinTableHandler = handler + s.Table(handler.Table(s)).AutoMigrate(handler) + } } } } diff --git a/model_struct.go b/model_struct.go index 50940472..b7c44414 100644 --- a/model_struct.go +++ b/model_struct.go @@ -215,13 +215,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.ForeignDBName = ToDBName(foreignKey) relationship.AssociationForeignFieldName = associationForeignKey relationship.AssociationForeignDBName = ToDBName(associationForeignKey) - relationship.JoinTableHandler = &JoinTableHandler{ - TableName: many2many, - Source: JoinTableSource{ModelType: scopeType}, - Destination: JoinTableSource{ModelType: elemType}, - } - updateJoinTableHandler(relationship) + joinTableHandler := JoinTableHandler{} + joinTableHandler.Setup(relationship, many2many, scopeType, elemType) + relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { relationship.Kind = "has_many"