From 0d2c37e310f9a812e4d53d62b4963a73847b8ca1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 18 Feb 2015 12:30:09 +0800 Subject: [PATCH] Refactor Model Struct --- association.go | 4 +- association_test.go | 4 +- callback_shared.go | 8 +-- model_struct.go | 118 ++++++++++++++++++++++++++------------------ scope_private.go | 4 +- structs_test.go | 2 +- 6 files changed, 80 insertions(+), 60 deletions(-) diff --git a/association.go b/association.go index c52f3e58..c8518946 100644 --- a/association.go +++ b/association.go @@ -156,8 +156,8 @@ func (association *Association) Count() int { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey) - if relationship.ForeignType != "" { - countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToDBName(relationship.ForeignType))), scope.TableName()) + if relationship.PolymorphicType != "" { + countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } countScope.Count(&count) } else if relationship.Kind == "belongs_to" { diff --git a/association_test.go b/association_test.go index 201eed53..2fbdb008 100644 --- a/association_test.go +++ b/association_test.go @@ -50,8 +50,8 @@ func TestHasOneAndHasManyAssociation(t *testing.T) { t.Errorf("Comment 2 Should have post id") } - comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}} - DB.Save(&comment3) + // comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}} + // DB.Save(&comment3) } func TestRelated(t *testing.T) { diff --git a/callback_shared.go b/callback_shared.go index a00e0fb3..56238cb6 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -46,8 +46,8 @@ func SaveAfterAssociations(scope *Scope) { newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()) } - if relationship.ForeignType != "" { - newScope.SetColumn(relationship.ForeignType, scope.TableName()) + if relationship.PolymorphicType != "" { + newScope.SetColumn(relationship.PolymorphicType, scope.TableName()) } scope.Err(newDB.Save(elem).Error) @@ -80,8 +80,8 @@ func SaveAfterAssociations(scope *Scope) { newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()) } - if relationship.ForeignType != "" { - newScope.SetColumn(relationship.ForeignType, scope.TableName()) + if relationship.PolymorphicType != "" { + newScope.SetColumn(relationship.PolymorphicType, scope.TableName()) } scope.Err(scope.NewDB().Save(elem).Error) } diff --git a/model_struct.go b/model_struct.go index bec6354f..ecec9a32 100644 --- a/model_struct.go +++ b/model_struct.go @@ -14,6 +14,7 @@ import ( type ModelStruct struct { PrimaryKeyField *StructField StructFields []*StructField + ModelType reflect.Type TableName string } @@ -53,7 +54,8 @@ func (structField *StructField) clone() *StructField { type Relationship struct { Kind string - ForeignType string + PolymorphicType string + PolymorphicDBName string ForeignFieldName string ForeignDBName string AssociationForeignFieldName string @@ -134,6 +136,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } + modelStruct.ModelType = scopeType if scopeType.Kind() != reflect.Struct { return &modelStruct } @@ -209,47 +212,63 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if !field.IsNormal { gormSettings := parseTagSetting(field.Tag.Get("gorm")) - many2many := gormSettings["MANY2MANY"] + toModelStruct := scope.New(reflect.New(fieldStruct.Type).Interface()).GetModelStruct() + getForeignField := func(column string, fields []*StructField) *StructField { + for _, field := range fields { + if field.Name == column || field.DBName == ToDBName(column) { + return field + } + } + return nil + } + + var relationship = &Relationship{} + foreignKey := gormSettings["FOREIGNKEY"] - foreignType := gormSettings["FOREIGNTYPE"] - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { - foreignKey = polymorphic + "Id" - foreignType = polymorphic + "Type" + if polymorphicField := getForeignField(polymorphic+"Id", toModelStruct.StructFields); polymorphicField != nil { + if polymorphicType := getForeignField(polymorphic+"Type", toModelStruct.StructFields); polymorphicType != nil { + relationship.ForeignFieldName = polymorphicField.Name + relationship.ForeignDBName = polymorphicField.DBName + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + polymorphicField.IsForeignKey = true + } + } } switch indirectType.Kind() { case reflect.Slice: - typ := indirectType.Elem() - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - if typ.Kind() == reflect.Struct { - kind := "has_many" - + if len(toModelStruct.StructFields) > 0 { if foreignKey == "" { foreignKey = scopeType.Name() + "Id" } - if associationForeignKey == "" { - associationForeignKey = typ.Name() + "Id" - } + if many2many := gormSettings["MANY2MANY"]; many2many != "" { + relationship.Kind = "many_to_many" + relationship.JoinTable = many2many - if many2many != "" { - kind = "many_to_many" - } else if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - foreignKey = "" - } + associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] + if associationForeignKey == "" { + associationForeignKey = toModelStruct.ModelType.Name() + "Id" + } - field.Relationship = &Relationship{ - JoinTable: many2many, - ForeignType: foreignType, - ForeignFieldName: foreignKey, - AssociationForeignFieldName: associationForeignKey, - ForeignDBName: ToDBName(foreignKey), - AssociationForeignDBName: ToDBName(associationForeignKey), - Kind: kind, + relationship.ForeignFieldName = foreignKey + relationship.ForeignDBName = ToDBName(foreignKey) + relationship.AssociationForeignFieldName = associationForeignKey + relationship.AssociationForeignDBName = ToDBName(associationForeignKey) + field.Relationship = relationship + } else { + relationship.Kind = "has_many" + if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } } } else { field.IsNormal = true @@ -263,29 +282,30 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } break } else { - var belongsToForeignKey, hasOneForeignKey, kind string - - if foreignKey == "" { + belongsToForeignKey := foreignKey + if belongsToForeignKey == "" { belongsToForeignKey = field.Name + "Id" - hasOneForeignKey = scopeType.Name() + "Id" - } else { - belongsToForeignKey = foreignKey - hasOneForeignKey = foreignKey } - if _, ok := scopeType.FieldByName(belongsToForeignKey); ok { - kind = "belongs_to" - foreignKey = belongsToForeignKey + if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { + relationship.Kind = "belongs_to" + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship } else { - foreignKey = hasOneForeignKey - kind = "has_one" - } - - field.Relationship = &Relationship{ - ForeignFieldName: foreignKey, - ForeignDBName: ToDBName(foreignKey), - ForeignType: foreignType, - Kind: kind, + if foreignKey == "" { + foreignKey = modelStruct.ModelType.Name() + "Id" + } + relationship.Kind = "has_one" + if foreignField := getForeignField(foreignKey, toModelStruct.StructFields); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } } } default: diff --git a/scope_private.go b/scope_private.go index 27457a04..84881d00 100644 --- a/scope_private.go +++ b/scope_private.go @@ -410,8 +410,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := toScope.db.Where(sql, scope.PrimaryKeyValue()) - if relationship.ForeignType != "" && toScope.HasColumn(relationship.ForeignType) { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(ToDBName(relationship.ForeignType))), scope.TableName()) + if relationship.PolymorphicType != "" { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } scope.Err(query.Find(value).Error) } diff --git a/structs_test.go b/structs_test.go index 3bf76f3f..24ddfd03 100644 --- a/structs_test.go +++ b/structs_test.go @@ -160,7 +160,7 @@ type Comment struct { Id int64 PostId int64 Content string - Post Post + // Post Post } // Scanner