diff --git a/join_table.go b/join_table.go index 2aeb1c4a..9b23f89f 100644 --- a/join_table.go +++ b/join_table.go @@ -13,12 +13,36 @@ type JoinTableHandlerInterface interface { JoinWith(db *DB, source interface{}) *DB } +type JoinTableForeignKey struct { + DBName string + AssociationDBName string +} + +func updateJoinTableHandler(relationship *Relationship) { + handler := relationship.JoinTableHandler.(*JoinTableHandler) + + destinationScope := &Scope{Value: reflect.New(handler.Destination.ModelType).Interface()} + for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { + db := relationship.ForeignDBName + handler.Destination.ForeignKeys = append(handler.Destination.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } + + sourceScope := &Scope{Value: reflect.New(handler.Source.ModelType).Interface()} + for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + db := relationship.AssociationForeignDBName + handler.Source.ForeignKeys = append(handler.Source.ForeignKeys, JoinTableForeignKey{ + DBName: db, + AssociationDBName: primaryField.DBName, + }) + } +} + type JoinTableSource struct { ModelType reflect.Type - ForeignKeys []struct { - DBName string - AssociationDBName string - } + ForeignKeys []JoinTableForeignKey } type JoinTableHandler struct { diff --git a/model_struct.go b/model_struct.go index cce28330..50940472 100644 --- a/model_struct.go +++ b/model_struct.go @@ -146,143 +146,150 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - for _, field := range fields { - if !field.IsIgnored { - fieldStruct := field.Struct - fieldType, indirectType := fieldStruct.Type, fieldStruct.Type - if indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { - field.IsScanner, field.IsNormal = true, true - } - - if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { - field.IsNormal = true - } - - if !field.IsNormal { - gormSettings := parseTagSetting(field.Tag.Get("gorm")) - toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) - - getForeignField := func(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == ToDBName(column) { - return field - } - } - return nil + defer func() { + for _, field := range fields { + if !field.IsIgnored { + fieldStruct := field.Struct + fieldType, indirectType := fieldStruct.Type, fieldStruct.Type + if indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() } - var relationship = &Relationship{} - - foreignKey := gormSettings["FOREIGNKEY"] - if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldName = polymorphicField.Name - relationship.ForeignDBName = polymorphicField.DBName - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - polymorphicField.IsForeignKey = true - } - } + if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { + field.IsScanner, field.IsNormal = true, true } - switch indirectType.Kind() { - case reflect.Slice: - elemType := indirectType.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if foreignKey == "" { - foreignKey = scopeType.Name() + "Id" - } - - if many2many := gormSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" - relationship.JoinTableHandler = JoinTableHandler{} - - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] - if associationForeignKey == "" { - associationForeignKey = elemType.Name() + "Id" - } - - relationship.ForeignFieldName = foreignKey - relationship.ForeignDBName = ToDBName(foreignKey) - relationship.AssociationForeignFieldName = associationForeignKey - relationship.AssociationForeignDBName = ToDBName(associationForeignKey) - field.Relationship = relationship - } else { - relationship.Kind = "has_many" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - case reflect.Struct: - if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - for _, toField := range toScope.GetStructFields() { - toField = toField.clone() - toField.Names = append([]string{fieldStruct.Name}, toField.Names...) - modelStruct.StructFields = append(modelStruct.StructFields, toField) - if toField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) - } - } - continue - } else { - belongsToForeignKey := foreignKey - if belongsToForeignKey == "" { - belongsToForeignKey = field.Name + "Id" - } - - if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { - relationship.Kind = "belongs_to" - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else { - if foreignKey == "" { - foreignKey = modelStruct.ModelType.Name() + "Id" - } - relationship.Kind = "has_one" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { - field.Relationship = relationship - } - } - } - default: + if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { field.IsNormal = true } - } - if field.IsNormal { - if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + if !field.IsNormal { + gormSettings := parseTagSetting(field.Tag.Get("gorm")) + toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) + + getForeignField := func(column string, fields []*StructField) *StructField { + for _, field := range fields { + if field.Name == column || field.DBName == ToDBName(column) { + return field + } + } + return nil + } + + var relationship = &Relationship{} + + foreignKey := gormSettings["FOREIGNKEY"] + if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { + if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { + relationship.ForeignFieldName = polymorphicField.Name + relationship.ForeignDBName = polymorphicField.DBName + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + polymorphicField.IsForeignKey = true + } + } + } + + switch indirectType.Kind() { + case reflect.Slice: + elemType := indirectType.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + + if elemType.Kind() == reflect.Struct { + if foreignKey == "" { + foreignKey = scopeType.Name() + "Id" + } + + if many2many := gormSettings["MANY2MANY"]; many2many != "" { + relationship.Kind = "many_to_many" + associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] + if associationForeignKey == "" { + associationForeignKey = elemType.Name() + "Id" + } + + relationship.ForeignFieldName = foreignKey + relationship.ForeignDBName = ToDBName(foreignKey) + relationship.AssociationForeignFieldName = associationForeignKey + relationship.AssociationForeignDBName = ToDBName(associationForeignKey) + relationship.JoinTableHandler = &JoinTableHandler{ + TableName: many2many, + Source: JoinTableSource{ModelType: scopeType}, + Destination: JoinTableSource{ModelType: elemType}, + } + updateJoinTableHandler(relationship) + + field.Relationship = relationship + } else { + relationship.Kind = "has_many" + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } + } + } else { + field.IsNormal = true + } + case reflect.Struct: + if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + for _, toField := range toScope.GetStructFields() { + toField = toField.clone() + toField.Names = append([]string{fieldStruct.Name}, toField.Names...) + modelStruct.StructFields = append(modelStruct.StructFields, toField) + if toField.IsPrimaryKey { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) + } + } + continue + } else { + belongsToForeignKey := foreignKey + if belongsToForeignKey == "" { + belongsToForeignKey = field.Name + "Id" + } + + if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { + relationship.Kind = "belongs_to" + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else { + if foreignKey == "" { + foreignKey = modelStruct.ModelType.Name() + "Id" + } + relationship.Kind = "has_one" + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } + } + } + default: + field.IsNormal = true + } + } + + if field.IsNormal { + if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } } } + modelStruct.StructFields = append(modelStruct.StructFields, field) } - modelStruct.StructFields = append(modelStruct.StructFields, field) - } + }() modelStructs[scopeType] = &modelStruct