From 5d2b9bfe3420c95932d1ee0f3ff274c3efd71637 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 09:46:07 +0800 Subject: [PATCH 01/10] Refactor GetModelStruct --- model_struct.go | 49 ++++++++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/model_struct.go b/model_struct.go index 89e7a169..7a47540e 100644 --- a/model_struct.go +++ b/model_struct.go @@ -97,48 +97,43 @@ type Relationship struct { func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct - - reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) - if !reflectValue.IsValid() { + // Scope value can't be nil + if scope.Value == nil { return &modelStruct } - if reflectValue.Kind() == reflect.Slice { - reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) + reflectType := reflect.ValueOf(scope.Value).Type() + for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() } - scopeType := reflectValue.Type() - - if scopeType.Kind() == reflect.Ptr { - scopeType = scopeType.Elem() + // Scope value need to be a struct + if reflectType.Kind() != reflect.Struct { + return &modelStruct } - if value := modelStructsMap.Get(scopeType); value != nil { + // Get Cached model struct + if value := modelStructsMap.Get(reflectType); value != nil { return value } - modelStruct.ModelType = scopeType - if scopeType.Kind() != reflect.Struct { - return &modelStruct - } + modelStruct.ModelType = reflectType - if tabler, ok := reflect.New(scopeType).Interface().(interface { - TableName() string - }); ok { + // Set default table name + if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok { modelStruct.defaultTableName = tabler.TableName() } else { - name := ToDBName(scopeType.Name()) + tableName := ToDBName(reflectType.Name()) if scope.db == nil || !scope.db.parent.singularTable { - name = inflection.Plural(name) + tableName = inflection.Plural(tableName) } - - modelStruct.defaultTableName = name + modelStruct.defaultTableName = tableName } // Get all fields fields := []*StructField{} - for i := 0; i < scopeType.NumField(); i++ { - if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { + for i := 0; i < reflectType.NumField(); i++ { + if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { field := &StructField{ Struct: fieldStruct, Name: fieldStruct.Name, @@ -244,7 +239,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if field, ok := scope.FieldByName(foreignKey); ok { relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) - joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName + joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } } @@ -268,7 +263,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, scopeType, elemType) + joinTableHandler.Setup(relationship, many2many, reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -276,7 +271,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if len(foreignKeys) == 0 { for _, field := range scope.PrimaryFields() { - if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { + if foreignField := getForeignField(reflectType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) @@ -386,7 +381,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { finished <- true }(finished) - modelStructsMap.Set(scopeType, &modelStruct) + modelStructsMap.Set(reflectType, &modelStruct) <-finished modelStruct.cached = true From 19b85b1f1756187bc2a3dfa4b15e559d423a38a7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 09:52:27 +0800 Subject: [PATCH 02/10] Compatible with both gorm, sql tag --- main.go | 2 +- model_struct.go | 29 +++++++++++++++-------------- scope_private.go | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/main.go b/main.go index 8e5ed1ac..f3e86506 100644 --- a/main.go +++ b/main.go @@ -512,7 +512,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { + if many2many := parseTagSetting(field.Tag)["MANY2MANY"]; many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) diff --git a/model_struct.go b/model_struct.go index 7a47540e..c4ad313b 100644 --- a/model_struct.go +++ b/model_struct.go @@ -145,14 +145,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.IsIgnored = true } - sqlSettings := parseTagSetting(field.Tag.Get("sql")) - gormSettings := parseTagSetting(field.Tag.Get("gorm")) + gormSettings := parseTagSetting(field.Tag) if _, ok := gormSettings["PRIMARY_KEY"]; ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := sqlSettings["DEFAULT"]; ok { + if _, ok := gormSettings["DEFAULT"]; ok { field.HasDefaultValue = true } @@ -185,7 +184,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if !field.IsNormal { - gormSettings := parseTagSetting(field.Tag.Get("gorm")) + gormSettings := parseTagSetting(field.Tag) toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) getForeignField := func(column string, fields []*StructField) *StructField { @@ -400,7 +399,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string { structType = structType.Elem() } reflectValue := reflect.Indirect(reflect.New(structType)) - sqlSettings := parseTagSetting(field.Tag.Get("sql")) + sqlSettings := parseTagSetting(field.Tag) if value, ok := sqlSettings["TYPE"]; ok { sqlType = value @@ -447,16 +446,18 @@ func (scope *Scope) generateSqlTag(field *StructField) string { } } -func parseTagSetting(str string) map[string]string { - tags := strings.Split(str, ";") +func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k + for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { + tags := strings.Split(str, ";") + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if len(v) >= 2 { + setting[k] = strings.Join(v[1:], ":") + } else { + setting[k] = k + } } } return setting diff --git a/scope_private.go b/scope_private.go index eddcfcc3..d301e80e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -630,7 +630,7 @@ func (scope *Scope) autoIndex() *Scope { var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - sqlSettings := parseTagSetting(field.Tag.Get("sql")) + sqlSettings := parseTagSetting(field.Tag) if name, ok := sqlSettings["INDEX"]; ok { if name == "INDEX" { name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) From 4e45e6dc2dc20c152a420fb40250e651bda981a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 10:00:18 +0800 Subject: [PATCH 03/10] Use field.TagSettings --- main.go | 2 +- model_struct.go | 41 ++++++++++++++++++++--------------------- scope_private.go | 5 ++--- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index f3e86506..6db37dab 100644 --- a/main.go +++ b/main.go @@ -512,7 +512,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := parseTagSetting(field.Tag)["MANY2MANY"]; many2many != "" { + if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) diff --git a/model_struct.go b/model_struct.go index c4ad313b..5bbf2e42 100644 --- a/model_struct.go +++ b/model_struct.go @@ -62,6 +62,7 @@ type StructField struct { IsScanner bool HasDefaultValue bool Tag reflect.StructTag + TagSettings map[string]string Struct reflect.StructField IsForeignKey bool Relationship *Relationship @@ -135,27 +136,27 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for i := 0; i < reflectType.NumField(); i++ { if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, + Struct: fieldStruct, + Name: fieldStruct.Name, + Names: []string{fieldStruct.Name}, + Tag: fieldStruct.Tag, + TagSettings: parseTagSetting(fieldStruct.Tag), } if fieldStruct.Tag.Get("sql") == "-" { field.IsIgnored = true } - gormSettings := parseTagSetting(field.Tag) - if _, ok := gormSettings["PRIMARY_KEY"]; ok { + if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := gormSettings["DEFAULT"]; ok { + if _, ok := field.TagSettings["DEFAULT"]; ok { field.HasDefaultValue = true } - if value, ok := gormSettings["COLUMN"]; ok { + if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { field.DBName = ToDBName(fieldStruct.Name) @@ -184,7 +185,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if !field.IsNormal { - gormSettings := parseTagSetting(field.Tag) toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) getForeignField := func(column string, fields []*StructField) *StructField { @@ -198,7 +198,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var relationship = &Relationship{} - if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { relationship.ForeignFieldNames = []string{polymorphicField.Name} @@ -214,7 +214,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } var foreignKeys []string - if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { + if foreignKey, ok := field.TagSettings["FOREIGNKEY"]; ok { foreignKeys = append(foreignKeys, foreignKey) } switch indirectType.Kind() { @@ -225,7 +225,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if many2many := gormSettings["MANY2MANY"]; many2many != "" { + if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" // foreign keys @@ -245,8 +245,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // association foreign keys var associationForeignKeys []string - if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]} + if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = []string{foreignKey} } else { for _, field := range toScope.PrimaryFields() { associationForeignKeys = append(associationForeignKeys, field.DBName) @@ -298,7 +298,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.IsNormal = true } case reflect.Struct: - if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { for _, toField := range toScope.GetStructFields() { toField = toField.clone() toField.Names = append([]string{fieldStruct.Name}, toField.Names...) @@ -399,14 +399,13 @@ func (scope *Scope) generateSqlTag(field *StructField) string { structType = structType.Elem() } reflectValue := reflect.Indirect(reflect.New(structType)) - sqlSettings := parseTagSetting(field.Tag) - if value, ok := sqlSettings["TYPE"]; ok { + if value, ok := field.TagSettings["TYPE"]; ok { sqlType = value } - additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] - if value, ok := sqlSettings["DEFAULT"]; ok { + additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] + if value, ok := field.TagSettings["DEFAULT"]; ok { additionalType = additionalType + " DEFAULT " + value } @@ -424,11 +423,11 @@ func (scope *Scope) generateSqlTag(field *StructField) string { if sqlType == "" { var size = 255 - if value, ok := sqlSettings["SIZE"]; ok { + if value, ok := field.TagSettings["SIZE"]; ok { size, _ = strconv.Atoi(value) } - v, autoIncrease := sqlSettings["AUTO_INCREMENT"] + v, autoIncrease := field.TagSettings["AUTO_INCREMENT"] if field.IsPrimaryKey { autoIncrease = true } diff --git a/scope_private.go b/scope_private.go index d301e80e..cd90c8c2 100644 --- a/scope_private.go +++ b/scope_private.go @@ -630,15 +630,14 @@ func (scope *Scope) autoIndex() *Scope { var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - sqlSettings := parseTagSetting(field.Tag) - if name, ok := sqlSettings["INDEX"]; ok { + if name, ok := field.TagSettings["INDEX"]; ok { if name == "INDEX" { name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) } indexes[name] = append(indexes[name], field.DBName) } - if name, ok := sqlSettings["UNIQUE_INDEX"]; ok { + if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { if name == "UNIQUE_INDEX" { name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) } From 6a5a2dbc55710c50e86e26341aa98988b60d870c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 11:11:30 +0800 Subject: [PATCH 04/10] Refactor GetModelStruct --- field.go | 4 +- model_struct.go | 351 ++++++++++++++++++++++-------------------------- 2 files changed, 164 insertions(+), 191 deletions(-) diff --git a/field.go b/field.go index 6bfac0bb..7151f468 100644 --- a/field.go +++ b/field.go @@ -75,9 +75,7 @@ func (scope *Scope) Fields() map[string]*Field { } } - if modelStruct.cached { - scope.fields = fields - } + scope.fields = fields return fields } return scope.fields diff --git a/model_struct.go b/model_struct.go index 5bbf2e42..1683d09b 100644 --- a/model_struct.go +++ b/model_struct.go @@ -45,7 +45,6 @@ type ModelStruct struct { StructFields []*StructField ModelType reflect.Type defaultTableName string - cached bool } func (s *ModelStruct) TableName(db *DB) string { @@ -96,6 +95,15 @@ type Relationship struct { JoinTableHandler JoinTableHandlerInterface } +func getForeignField(column string, fields []*StructField) *StructField { + for _, field := range fields { + if field.Name == column || field.DBName == ToDBName(column) { + return field + } + } + return nil +} + func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct // Scope value can't be nil @@ -132,7 +140,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get all fields - fields := []*StructField{} for i := 0; i < reflectType.NumField(); i++ { if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { field := &StructField{ @@ -145,189 +152,155 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if fieldStruct.Tag.Get("sql") == "-" { field.IsIgnored = true - } - - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettings["DEFAULT"]; ok { - field.HasDefaultValue = true - } - - if value, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) - } - - fields = append(fields, field) - } - } - - var finished = make(chan bool) - go func(finished chan bool) { - for _, field := range fields { - if !field.IsIgnored { - fieldStruct := field.Struct - indirectType := fieldStruct.Type - if indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() + if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, isScanner := reflect.New(indirectType).Interface().(sql.Scanner); isScanner { + if _, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + } + + fieldValue := reflect.New(fieldStruct.Type).Interface() + if _, isScanner := fieldValue.(sql.Scanner); isScanner { + // is scanner field.IsScanner, field.IsNormal = true, true - } - - if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { + } else if _, isTime := fieldValue.(*time.Time); isTime { + // is time field.IsNormal = true - } - - if !field.IsNormal { - toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) - - getForeignField := func(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == ToDBName(column) { - return field - } + } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + // embedded struct + for _, subField := range scope.New(fieldValue).GetStructFields() { + subField = subField.clone() + subField.Names = append([]string{fieldStruct.Name}, subField.Names...) + if subField.IsPrimaryKey { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) } - return nil + modelStruct.StructFields = append(modelStruct.StructFields, subField) + } + continue + } else { + indirectType := fieldStruct.Type + for indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() } - var relationship = &Relationship{} - - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldNames = []string{polymorphicField.Name} - relationship.ForeignDBNames = []string{polymorphicField.DBName} - relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} - relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - polymorphicField.IsForeignKey = true - } - } - } - - var foreignKeys []string - if foreignKey, ok := field.TagSettings["FOREIGNKEY"]; ok { - foreignKeys = append(foreignKeys, foreignKey) - } switch indirectType.Kind() { case reflect.Slice: - elemType := indirectType.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } + defer func(field *StructField) { + var ( + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") + associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + elemType = field.Struct.Type + ) - if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" + for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } - // foreign keys - if len(foreignKeys) == 0 { - for _, field := range scope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.DBName) - } - } + if elemType.Kind() == reflect.Struct { + if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + relationship.Kind = "many_to_many" - for _, foreignKey := range foreignKeys { - if field, ok := scope.FieldByName(foreignKey); ok { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) - joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } - - // association foreign keys - var associationForeignKeys []string - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = []string{foreignKey} - } else { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - relationship.Kind = "has_many" - - if len(foreignKeys) == 0 { - for _, field := range scope.PrimaryFields() { - if foreignField := getForeignField(reflectType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + // if no foreign keys + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { // FIXME + foreignKeys = append(foreignKeys, field.DBName) } } + + for _, foreignKey := range foreignKeys { + if field, ok := scope.FieldByName(foreignKey); ok { // FIXME + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + // join table foreign keys for source + joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) + } + } + + // if no association foreign keys + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for _, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } + + joinTableHandler := JoinTableHandler{} + joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + relationship.JoinTableHandler = &joinTableHandler + field.Relationship = relationship } else { + relationship.Kind = "has_many" + + // if no foreign keys + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { // FIXME + foreignKeys = append(foreignKeys, reflectType.Name()+field.Name) + } + } + for _, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + // source foreign keys + foreignField.IsForeignKey = true + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + + // association foreign keys relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true } } - } - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - case reflect.Struct: - if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - for _, toField := range toScope.GetStructFields() { - toField = toField.clone() - toField.Names = append([]string{fieldStruct.Name}, toField.Names...) - modelStruct.StructFields = append(modelStruct.StructFields, toField) - if toField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) - } - } - continue - } else { - if len(foreignKeys) == 0 { - for _, f := range scope.PrimaryFields() { - if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + if len(relationship.ForeignFieldNames) != 0 { + field.Relationship = relationship } } } else { - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true - } + field.IsNormal = true + } + }(field) + case reflect.Struct: + defer func(field *StructField) { + var ( + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") + foreignKeys = tagForeignKeys + ) + + if len(tagForeignKeys) == 0 { + for _, primaryField := range scope.PrimaryFields() { + foreignKeys = append(foreignKeys, modelStruct.ModelType.Name()+primaryField.Name) + } + } + + // if has one + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + foreignField.IsForeignKey = true + // source foreign keys + scopeField := getForeignField(foreignKey, modelStruct.StructFields) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) + + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) } } @@ -335,25 +308,23 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.Kind = "has_one" field.Relationship = relationship } else { + foreignKeys = tagForeignKeys if len(foreignKeys) == 0 { for _, f := range toScope.PrimaryFields() { - if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true - } + foreignKeys = append(foreignKeys, field.Name+f.Name) } - } else { - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, fields); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true - } + } + + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // association foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + + // source foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true } } @@ -362,29 +333,33 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.Relationship = relationship } } - } + }(field) default: field.IsNormal = true } } - - if field.IsNormal { - if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } } + + // Even it is ignored, also possible to decode db value into the field + if value, ok := field.TagSettings["COLUMN"]; ok { + field.DBName = value + } else { + field.DBName = ToDBName(fieldStruct.Name) + } + modelStruct.StructFields = append(modelStruct.StructFields, field) } - finished <- true - }(finished) + } + + if len(modelStruct.PrimaryFields) == 0 { + if field := getForeignField("id", modelStruct.StructFields); field != nil { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } + } modelStructsMap.Set(reflectType, &modelStruct) - <-finished - modelStruct.cached = true - return &modelStruct } From 4bc06a21c1a809122652d42022170b15642f4a11 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 14:04:59 +0800 Subject: [PATCH 05/10] Refactor GetModelStruct --- model_struct.go | 140 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 98 insertions(+), 42 deletions(-) diff --git a/model_struct.go b/model_struct.go index 1683d09b..13d2696b 100644 --- a/model_struct.go +++ b/model_struct.go @@ -192,11 +192,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var ( relationship = &Relationship{} toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + foreignKeys []string + associationForeignKeys []string elemType = field.Struct.Type ) + if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") + } + + if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + } + for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { elemType = elemType.Elem() } @@ -248,21 +256,33 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if no foreign keys if len(foreignKeys) == 0 { - for _, field := range scope.PrimaryFields() { // FIXME - foreignKeys = append(foreignKeys, reflectType.Name()+field.Name) + if len(associationForeignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, reflectType.Name()+field.Name) + associationForeignKeys = append(associationForeignKeys, field.Name) + } + } else { + for _, associationForeignKey := range associationForeignKeys { + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, reflectType.Name()+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } } } - for _, foreignKey := range foreignKeys { + for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - // source foreign keys - foreignField.IsForeignKey = true - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { + // source foreign keys + foreignField.IsForeignKey = true + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } } } @@ -277,30 +297,53 @@ func (scope *Scope) GetModelStruct() *ModelStruct { case reflect.Struct: defer func(field *StructField) { var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") - foreignKeys = tagForeignKeys + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + tagForeignKeys []string + tagAssociationForeignKeys []string ) - if len(tagForeignKeys) == 0 { - for _, primaryField := range scope.PrimaryFields() { - foreignKeys = append(foreignKeys, modelStruct.ModelType.Name()+primaryField.Name) - } + if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") } - // if has one - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - foreignField.IsForeignKey = true - // source foreign keys - scopeField := getForeignField(foreignKey, modelStruct.StructFields) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) + if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + } - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + // Has One + { + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys + if len(foreignKeys) == 0 { + if len(associationForeignKeys) == 0 { + for _, primaryField := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, modelStruct.ModelType.Name()+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) + } + } else { + for _, associationForeignKey := range tagAssociationForeignKeys { + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, modelStruct.ModelType.Name()+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + foreignField.IsForeignKey = true + // source foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) + + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } + } } } @@ -308,23 +351,36 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.Kind = "has_one" field.Relationship = relationship } else { - foreignKeys = tagForeignKeys + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys if len(foreignKeys) == 0 { - for _, f := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+f.Name) + if len(associationForeignKeys) == 0 { + for _, f := range toScope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.Name+f.Name) + associationForeignKeys = append(associationForeignKeys, f.Name) + } + } else { + for _, associationForeignKey := range associationForeignKeys { + if foreignField := getForeignField(associationForeignKey, toScope.GetStructFields()); foreignField != nil { + foreignKeys = append(foreignKeys, field.Name+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } } } - for _, foreignKey := range foreignKeys { + for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + if associationField := getForeignField(associationForeignKeys[idx], toScope.GetStructFields()); associationField != nil { + // association foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + // source foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } } From 8a0c77e5fcbaf11a940e3f9eeb3535e99f53dfcf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 14:16:58 +0800 Subject: [PATCH 06/10] Fix all tests for GetModelStruct --- model_struct.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/model_struct.go b/model_struct.go index 13d2696b..b4ab67ca 100644 --- a/model_struct.go +++ b/model_struct.go @@ -254,6 +254,21 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else { relationship.Kind = "has_many" + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphicField := getForeignField(polymorphic+"ID", toScope.GetStructFields()); polymorphicField != nil { + if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { + relationship.ForeignFieldNames = []string{polymorphicField.Name} + relationship.ForeignDBNames = []string{polymorphicField.DBName} + relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} + relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + polymorphicField.IsForeignKey = true + } + } + } + // if no foreign keys if len(foreignKeys) == 0 { if len(associationForeignKeys) == 0 { @@ -311,6 +326,21 @@ func (scope *Scope) GetModelStruct() *ModelStruct { tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") } + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphicField := getForeignField(polymorphic+"ID", toScope.GetStructFields()); polymorphicField != nil { + if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { + relationship.ForeignFieldNames = []string{polymorphicField.Name} + relationship.ForeignDBNames = []string{polymorphicField.DBName} + relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} + relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + polymorphicField.IsForeignKey = true + } + } + } + // Has One { var foreignKeys = tagForeignKeys From f53af2a236c6ad3d67104868f5fddb5128132dfa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 14:21:21 +0800 Subject: [PATCH 07/10] Don't preload if has any error --- preload.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preload.go b/preload.go index 2c981a79..9eaf0ab0 100644 --- a/preload.go +++ b/preload.go @@ -31,7 +31,7 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { - if scope.Search.preload == nil { + if scope.Search.preload == nil || scope.HasError() { return } From 0f5055471af45dbdd67e2a9dd34255e9a1079004 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 15:22:35 +0800 Subject: [PATCH 08/10] Keep refactoring get model struct --- model_struct.go | 143 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 99 insertions(+), 44 deletions(-) diff --git a/model_struct.go b/model_struct.go index b4ab67ca..15df1082 100644 --- a/model_struct.go +++ b/model_struct.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "errors" "fmt" "go/ast" "reflect" @@ -84,6 +85,7 @@ func (structField *StructField) clone() *StructField { } } +// Relationship described the relationship between models type Relationship struct { Kind string PolymorphicType string @@ -104,6 +106,7 @@ func getForeignField(column string, fields []*StructField) *StructField { return nil } +// GetModelStruct generate model struct & relationships based on struct and tag definition func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct // Scope value can't be nil @@ -150,6 +153,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { TagSettings: parseTagSetting(fieldStruct.Tag), } + // is ignored field if fieldStruct.Tag.Get("sql") == "-" { field.IsIgnored = true } else { @@ -170,7 +174,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // is time field.IsNormal = true } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - // embedded struct + // is embedded struct for _, subField := range scope.New(fieldValue).GetStructFields() { subField = subField.clone() subField.Names = append([]string{fieldStruct.Name}, subField.Names...) @@ -181,6 +185,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } continue } else { + // build relationships indirectType := fieldStruct.Type for indirectType.Kind() == reflect.Ptr { indirectType = indirectType.Elem() @@ -213,24 +218,24 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - // if no foreign keys + // if no foreign keys defined with tag if len(foreignKeys) == 0 { - for _, field := range scope.PrimaryFields() { // FIXME + for _, field := range modelStruct.PrimaryFields { foreignKeys = append(foreignKeys, field.DBName) } } for _, foreignKey := range foreignKeys { - if field, ok := scope.FieldByName(foreignKey); ok { // FIXME + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + field.DBName + joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } } - // if no association foreign keys + // if no association foreign keys defined with tag if len(associationForeignKeys) == 0 { for _, field := range toScope.PrimaryFields() { associationForeignKeys = append(associationForeignKeys, field.DBName) @@ -252,42 +257,57 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { + // User has many comments, associationType is User, comment use UserID as foreign key + var associationType = reflectType.Name() + var toFields = toScope.GetStructFields() relationship.Kind = "has_many" if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"ID", toScope.GetStructFields()); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldNames = []string{polymorphicField.Name} - relationship.ForeignDBNames = []string{polymorphicField.DBName} - relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} - relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - polymorphicField.IsForeignKey = true - } + // Dog has many toys, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('dogs') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true } } - // if no foreign keys + // if no foreign keys defined with tag if len(foreignKeys) == 0 { + // if no association foreign keys defined with tag if len(associationForeignKeys) == 0 { for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, reflectType.Name()+field.Name) + foreignKeys = append(foreignKeys, associationType+field.Name) associationForeignKeys = append(associationForeignKeys, field.Name) } } else { - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, reflectType.Name()+foreignField.Name) + // generate foreign keys from defined association foreign keys + for _, scopeFieldName := range associationForeignKeys { + if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, associationType+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } } } + } else { + // generate association foreign keys from foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, associationType) { + associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + } else { + scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, associationType)) + } + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } } for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { // source foreign keys foreignField.IsForeignKey = true @@ -312,8 +332,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { case reflect.Struct: defer func(field *StructField) { var ( + // user has one profile, associationType is User, profile use UserID as foreign key + // user belongs to profile, associationType is Profile, user use ProfileID as foreign key + associationType = reflectType.Name() relationship = &Relationship{} toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + toFields = toScope.GetStructFields() tagForeignKeys []string tagAssociationForeignKeys []string ) @@ -327,17 +351,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"ID", toScope.GetStructFields()); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldNames = []string{polymorphicField.Name} - relationship.ForeignDBNames = []string{polymorphicField.DBName} - relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} - relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - polymorphicField.IsForeignKey = true - } + // Cat has one toy, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('cats') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true } } @@ -345,24 +365,41 @@ func (scope *Scope) GetModelStruct() *ModelStruct { { var foreignKeys = tagForeignKeys var associationForeignKeys = tagAssociationForeignKeys + // if no foreign keys defined with tag if len(foreignKeys) == 0 { + // if no association foreign keys defined with tag if len(associationForeignKeys) == 0 { for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, modelStruct.ModelType.Name()+primaryField.Name) + foreignKeys = append(foreignKeys, associationType+primaryField.Name) associationForeignKeys = append(associationForeignKeys, primaryField.Name) } } else { + // generate foreign keys form association foreign keys for _, associationForeignKey := range tagAssociationForeignKeys { if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, modelStruct.ModelType.Name()+foreignField.Name) + foreignKeys = append(foreignKeys, associationType+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } } } + } else { + // generate association foreign keys from foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, associationType) { + associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + } else { + scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, associationType)) + } + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } } for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { foreignField.IsForeignKey = true // source foreign keys @@ -383,25 +420,44 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else { var foreignKeys = tagForeignKeys var associationForeignKeys = tagAssociationForeignKeys + if len(foreignKeys) == 0 { + // generate foreign keys & association foreign keys if len(associationForeignKeys) == 0 { - for _, f := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+f.Name) - associationForeignKeys = append(associationForeignKeys, f.Name) + for _, primaryField := range toScope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.Name+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) } } else { + // generate foreign keys with association foreign keys for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toScope.GetStructFields()); foreignField != nil { + if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { foreignKeys = append(foreignKeys, field.Name+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } } } + } else { + // generate foreign keys & association foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, field.Name) { + associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name)) + } else { + scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, field.Name)) + } + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } } for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toScope.GetStructFields()); associationField != nil { + if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + foreignField.IsForeignKey = true + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) @@ -409,7 +465,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // source foreign keys relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true } } } From b907500a477e6040523c942451b874606d09a86b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 16:18:51 +0800 Subject: [PATCH 09/10] Add test for many2many relations with customized foreign keys --- multi_primary_keys_test.go | 129 +++++++++++++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 5 deletions(-) diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index 061822ef..27997fb7 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -8,17 +8,19 @@ import ( ) type Blog struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Subject string - Body string - Tags []Tag `gorm:"many2many:blog_tags;"` + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` } type Tag struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Value string + Blogs []*Blog `gorm:"many2many:"blogs_tags` } func compareTags(tags []Tag, contents []string) bool { @@ -114,3 +116,120 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } } } + +func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("shared_blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + SharedTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "ZH", + } + DB.Save(&blog2) + + if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { + t.Errorf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) + if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("SharedTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + var blog1 Blog + DB.Preload("SharedTags").Find(&blog1) + if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("SharedTags").Append(tag4) + + DB.Model(&blog).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Errorf("Should find 3 tags with Related") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "SharedTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + DB.Model(&blog2).Related(&tags2, "SharedTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 2 { + t.Errorf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("SharedTags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Related(&tags3, "SharedTags") + if !compareTags(tags3, []string{"tag6"}) { + t.Errorf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 1 { + t.Errorf("Blog should has three tags after Delete") + } + + DB.Model(&blog2).Association("SharedTags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Related(&tags4, "SharedTags") + if !compareTags(tags4, []string{"tag6"}) { + t.Errorf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog2).Association("SharedTags").Clear() + if DB.Model(&blog).Association("SharedTags").Count() != 0 { + t.Errorf("All tags should be cleared") + } + } +} From aa55bd3fd2591965994281e1d5e50ef305af6f30 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 3 Jan 2016 17:20:24 +0800 Subject: [PATCH 10/10] Add more tests for customized foreign keys for many2many relations --- association.go | 22 ++++-- main.go | 4 + model_struct.go | 8 +- multi_primary_keys_test.go | 150 ++++++++++++++++++++++++++++++++++++- preload.go | 10 +-- 5 files changed, 175 insertions(+), 19 deletions(-) diff --git a/association.go b/association.go index 45d7367a..1ba1d519 100644 --- a/association.go +++ b/association.go @@ -126,14 +126,6 @@ func (association *Association) Replace(values ...interface{}) *Association { } } else { // Relations - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - if relationship.PolymorphicDBName != "" { newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } @@ -164,8 +156,22 @@ func (association *Association) Replace(values ...interface{}) *Association { } if relationship.Kind == "many_to_many" { + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + var foreignKeyMap = map[string]interface{}{} + for idx, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + fieldValue := reflect.New(association.Field.Field.Type()).Interface() association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) } diff --git a/main.go b/main.go index 6db37dab..9fe6cf4e 100644 --- a/main.go +++ b/main.go @@ -384,6 +384,10 @@ func (s *DB) CreateTable(values ...interface{}) *DB { func (s *DB) DropTable(values ...interface{}) *DB { db := s.clone() for _, value := range values { + if tableName, ok := value.(string); ok { + db = db.Table(tableName) + } + db = db.NewScope(value).dropTable().db } return db diff --git a/model_struct.go b/model_struct.go index 15df1082..204da5e7 100644 --- a/model_struct.go +++ b/model_struct.go @@ -203,11 +203,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") + foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") } if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -343,11 +343,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ";") + tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") } if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ";") + tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index 27997fb7..ea80326e 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -14,6 +14,7 @@ type Blog struct { Body string Tags []Tag `gorm:"many2many:blog_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` } type Tag struct { @@ -135,9 +136,9 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { blog2 := Blog{ ID: blog.ID, - Locale: "ZH", + Locale: "EN", } - DB.Save(&blog2) + DB.Create(&blog2) if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { t.Errorf("Blog should has two tags") @@ -233,3 +234,148 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } } } + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("locale_blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if len(tags) != 0 { + t.Errorf("Should find 0 tags with Related for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related for EN Blog") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag4"}) { + t.Errorf("Should find 1 tags with Related for EN Blog") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Errorf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared") + } + } +} diff --git a/preload.go b/preload.go index 9eaf0ab0..d12995f3 100644 --- a/preload.go +++ b/preload.go @@ -277,10 +277,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } - var associationForeignStructFieldNames []string - for _, dbName := range relation.AssociationForeignFieldNames { + var foreignFieldNames []string + for _, dbName := range relation.ForeignFieldNames { if field, ok := scope.FieldByName(dbName); ok { - associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name) + foreignFieldNames = append(foreignFieldNames, field.Name) } } @@ -288,7 +288,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - source := getRealValue(object, associationForeignStructFieldNames) + source := getRealValue(object, foreignFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link)) @@ -296,7 +296,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } else { if object := scope.IndirectValue(); object.IsValid() { - source := getRealValue(object, associationForeignStructFieldNames) + source := getRealValue(object, foreignFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link))