diff --git a/association.go b/association.go index 45d7367a..1ba1d519 100644 --- a/association.go +++ b/association.go @@ -126,14 +126,6 @@ func (association *Association) Replace(values ...interface{}) *Association { } } else { // Relations - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - if relationship.PolymorphicDBName != "" { newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } @@ -164,8 +156,22 @@ func (association *Association) Replace(values ...interface{}) *Association { } if relationship.Kind == "many_to_many" { + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + var foreignKeyMap = map[string]interface{}{} + for idx, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + fieldValue := reflect.New(association.Field.Field.Type()).Interface() association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) } diff --git a/field.go b/field.go index 6bfac0bb..7151f468 100644 --- a/field.go +++ b/field.go @@ -75,9 +75,7 @@ func (scope *Scope) Fields() map[string]*Field { } } - if modelStruct.cached { - scope.fields = fields - } + scope.fields = fields return fields } return scope.fields diff --git a/main.go b/main.go index 8e5ed1ac..9fe6cf4e 100644 --- a/main.go +++ b/main.go @@ -384,6 +384,10 @@ func (s *DB) CreateTable(values ...interface{}) *DB { func (s *DB) DropTable(values ...interface{}) *DB { db := s.clone() for _, value := range values { + if tableName, ok := value.(string); ok { + db = db.Table(tableName) + } + db = db.NewScope(value).dropTable().db } return db @@ -512,7 +516,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { + if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) diff --git a/model_struct.go b/model_struct.go index 89e7a169..204da5e7 100644 --- a/model_struct.go +++ b/model_struct.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "errors" "fmt" "go/ast" "reflect" @@ -45,7 +46,6 @@ type ModelStruct struct { StructFields []*StructField ModelType reflect.Type defaultTableName string - cached bool } func (s *ModelStruct) TableName(db *DB) string { @@ -62,6 +62,7 @@ type StructField struct { IsScanner bool HasDefaultValue bool Tag reflect.StructTag + TagSettings map[string]string Struct reflect.StructField IsForeignKey bool Relationship *Relationship @@ -84,6 +85,7 @@ func (structField *StructField) clone() *StructField { } } +// Relationship described the relationship between models type Relationship struct { Kind string PolymorphicType string @@ -95,244 +97,319 @@ type Relationship struct { JoinTableHandler JoinTableHandlerInterface } +func getForeignField(column string, fields []*StructField) *StructField { + for _, field := range fields { + if field.Name == column || field.DBName == ToDBName(column) { + return field + } + } + return nil +} + +// GetModelStruct generate model struct & relationships based on struct and tag definition func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct - - reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) - if !reflectValue.IsValid() { + // Scope value can't be nil + if scope.Value == nil { return &modelStruct } - if reflectValue.Kind() == reflect.Slice { - reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) + reflectType := reflect.ValueOf(scope.Value).Type() + for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() } - scopeType := reflectValue.Type() - - if scopeType.Kind() == reflect.Ptr { - scopeType = scopeType.Elem() + // Scope value need to be a struct + if reflectType.Kind() != reflect.Struct { + return &modelStruct } - if value := modelStructsMap.Get(scopeType); value != nil { + // Get Cached model struct + if value := modelStructsMap.Get(reflectType); value != nil { return value } - modelStruct.ModelType = scopeType - if scopeType.Kind() != reflect.Struct { - return &modelStruct - } + modelStruct.ModelType = reflectType - if tabler, ok := reflect.New(scopeType).Interface().(interface { - TableName() string - }); ok { + // Set default table name + if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok { modelStruct.defaultTableName = tabler.TableName() } else { - name := ToDBName(scopeType.Name()) + tableName := ToDBName(reflectType.Name()) if scope.db == nil || !scope.db.parent.singularTable { - name = inflection.Plural(name) + tableName = inflection.Plural(tableName) } - - modelStruct.defaultTableName = name + modelStruct.defaultTableName = tableName } // Get all fields - fields := []*StructField{} - for i := 0; i < scopeType.NumField(); i++ { - if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) { + for i := 0; i < reflectType.NumField(); i++ { + if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, + Struct: fieldStruct, + Name: fieldStruct.Name, + Names: []string{fieldStruct.Name}, + Tag: fieldStruct.Tag, + TagSettings: parseTagSetting(fieldStruct.Tag), } + // is ignored field if fieldStruct.Tag.Get("sql") == "-" { field.IsIgnored = true - } - - sqlSettings := parseTagSetting(field.Tag.Get("sql")) - gormSettings := parseTagSetting(field.Tag.Get("gorm")) - if _, ok := gormSettings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := sqlSettings["DEFAULT"]; ok { - field.HasDefaultValue = true - } - - if value, ok := gormSettings["COLUMN"]; ok { - field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) - } - - fields = append(fields, field) - } - } - - var finished = make(chan bool) - go func(finished chan bool) { - for _, field := range fields { - if !field.IsIgnored { - fieldStruct := field.Struct - indirectType := fieldStruct.Type - if indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() + if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, isScanner := reflect.New(indirectType).Interface().(sql.Scanner); isScanner { + if _, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + } + + fieldValue := reflect.New(fieldStruct.Type).Interface() + if _, isScanner := fieldValue.(sql.Scanner); isScanner { + // is scanner field.IsScanner, field.IsNormal = true, true - } - - if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { + } else if _, isTime := fieldValue.(*time.Time); isTime { + // is time field.IsNormal = true - } - - if !field.IsNormal { - gormSettings := parseTagSetting(field.Tag.Get("gorm")) - toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) - - getForeignField := func(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == ToDBName(column) { - return field - } + } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + // is embedded struct + for _, subField := range scope.New(fieldValue).GetStructFields() { + subField = subField.clone() + subField.Names = append([]string{fieldStruct.Name}, subField.Names...) + if subField.IsPrimaryKey { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) } - return nil + modelStruct.StructFields = append(modelStruct.StructFields, subField) + } + continue + } else { + // build relationships + indirectType := fieldStruct.Type + for indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() } - var relationship = &Relationship{} - - if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldNames = []string{polymorphicField.Name} - relationship.ForeignDBNames = []string{polymorphicField.DBName} - relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} - relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - polymorphicField.IsForeignKey = true - } - } - } - - var foreignKeys []string - if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { - foreignKeys = append(foreignKeys, foreignKey) - } switch indirectType.Kind() { case reflect.Slice: - elemType := indirectType.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } + defer func(field *StructField) { + var ( + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + foreignKeys []string + associationForeignKeys []string + elemType = field.Struct.Type + ) - if elemType.Kind() == reflect.Struct { - if many2many := gormSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" + if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + } - // foreign keys - if len(foreignKeys) == 0 { - for _, field := range scope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.DBName) - } - } + if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + } - for _, foreignKey := range foreignKeys { - if field, ok := scope.FieldByName(foreignKey); ok { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) - joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } + for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } - // association foreign keys - var associationForeignKeys []string - if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]} - } else { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } + if elemType.Kind() == reflect.Struct { + if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + relationship.Kind = "many_to_many" - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, scopeType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - relationship.Kind = "has_many" - - if len(foreignKeys) == 0 { - for _, field := range scope.PrimaryFields() { - if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) } } - } else { + for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + // join table foreign keys for source + joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) } } - } - if len(relationship.ForeignFieldNames) != 0 { + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for _, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } + + joinTableHandler := JoinTableHandler{} + joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - case reflect.Struct: - if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - for _, toField := range toScope.GetStructFields() { - toField = toField.clone() - toField.Names = append([]string{fieldStruct.Name}, toField.Names...) - modelStruct.StructFields = append(modelStruct.StructFields, toField) - if toField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) - } - } - continue - } else { - if len(foreignKeys) == 0 { - for _, f := range scope.PrimaryFields() { - if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + } else { + // User has many comments, associationType is User, comment use UserID as foreign key + var associationType = reflectType.Name() + var toFields = toScope.GetStructFields() + relationship.Kind = "has_many" + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + // Dog has many toys, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('dogs') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + } + } + + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, associationType+field.Name) + associationForeignKeys = append(associationForeignKeys, field.Name) + } + } else { + // generate foreign keys from defined association foreign keys + for _, scopeFieldName := range associationForeignKeys { + if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, associationType+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } else { + // generate association foreign keys from foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, associationType) { + associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + } else { + scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, associationType)) + } + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { + if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { + // source foreign keys + foreignField.IsForeignKey = true + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) + + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } + } + } + + if len(relationship.ForeignFieldNames) != 0 { + field.Relationship = relationship } } } else { - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + field.IsNormal = true + } + }(field) + case reflect.Struct: + defer func(field *StructField) { + var ( + // user has one profile, associationType is User, profile use UserID as foreign key + // user belongs to profile, associationType is Profile, user use ProfileID as foreign key + associationType = reflectType.Name() + relationship = &Relationship{} + toScope = scope.New(reflect.New(field.Struct.Type).Interface()) + toFields = toScope.GetStructFields() + tagForeignKeys []string + tagAssociationForeignKeys []string + ) + + if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { + tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + } + + if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + } + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + // Cat has one toy, tag polymorphic is Owner, then associationType is Owner + // Toy use OwnerID, OwnerType ('cats') as foreign key + if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { + associationType = polymorphic + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + } + } + + // Has One + { + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, primaryField := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, associationType+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) + } + } else { + // generate foreign keys form association foreign keys + for _, associationForeignKey := range tagAssociationForeignKeys { + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + foreignKeys = append(foreignKeys, associationType+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } + } + } + } else { + // generate association foreign keys from foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, associationType) { + associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + } else { + scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, associationType)) + } + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { + if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + foreignField.IsForeignKey = true + // source foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) + + // association foreign keys + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + } } } } @@ -341,24 +418,53 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.Kind = "has_one" field.Relationship = relationship } else { + var foreignKeys = tagForeignKeys + var associationForeignKeys = tagAssociationForeignKeys + if len(foreignKeys) == 0 { - for _, f := range toScope.PrimaryFields() { - if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + // generate foreign keys & association foreign keys + if len(associationForeignKeys) == 0 { + for _, primaryField := range toScope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.Name+primaryField.Name) + associationForeignKeys = append(associationForeignKeys, primaryField.Name) + } + } else { + // generate foreign keys with association foreign keys + for _, associationForeignKey := range associationForeignKeys { + if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { + foreignKeys = append(foreignKeys, field.Name+foreignField.Name) + associationForeignKeys = append(associationForeignKeys, foreignField.Name) + } } } } else { - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, fields); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + // generate foreign keys & association foreign keys + if len(associationForeignKeys) == 0 { + for _, foreignKey := range foreignKeys { + if strings.HasPrefix(foreignKey, field.Name) { + associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name)) + } else { + scope.Err(fmt.Errorf("invalid foreign keys, foreign key %v should start with %v", foreignKey, field.Name)) + } + } + } else if len(foreignKeys) != len(associationForeignKeys) { + scope.Err(errors.New("invalid foreign keys, should have same length")) + return + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + foreignField.IsForeignKey = true + + // association foreign keys + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) + + // source foreign keys relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true } } } @@ -368,28 +474,32 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.Relationship = relationship } } - } + }(field) default: field.IsNormal = true } } - - if field.IsNormal { - if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } } + + // Even it is ignored, also possible to decode db value into the field + if value, ok := field.TagSettings["COLUMN"]; ok { + field.DBName = value + } else { + field.DBName = ToDBName(fieldStruct.Name) + } + modelStruct.StructFields = append(modelStruct.StructFields, field) } - finished <- true - }(finished) + } - modelStructsMap.Set(scopeType, &modelStruct) + if len(modelStruct.PrimaryFields) == 0 { + if field := getForeignField("id", modelStruct.StructFields); field != nil { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } + } - <-finished - modelStruct.cached = true + modelStructsMap.Set(reflectType, &modelStruct) return &modelStruct } @@ -405,14 +515,13 @@ func (scope *Scope) generateSqlTag(field *StructField) string { structType = structType.Elem() } reflectValue := reflect.Indirect(reflect.New(structType)) - sqlSettings := parseTagSetting(field.Tag.Get("sql")) - if value, ok := sqlSettings["TYPE"]; ok { + if value, ok := field.TagSettings["TYPE"]; ok { sqlType = value } - additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] - if value, ok := sqlSettings["DEFAULT"]; ok { + additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] + if value, ok := field.TagSettings["DEFAULT"]; ok { additionalType = additionalType + " DEFAULT " + value } @@ -430,11 +539,11 @@ func (scope *Scope) generateSqlTag(field *StructField) string { if sqlType == "" { var size = 255 - if value, ok := sqlSettings["SIZE"]; ok { + if value, ok := field.TagSettings["SIZE"]; ok { size, _ = strconv.Atoi(value) } - v, autoIncrease := sqlSettings["AUTO_INCREMENT"] + v, autoIncrease := field.TagSettings["AUTO_INCREMENT"] if field.IsPrimaryKey { autoIncrease = true } @@ -452,16 +561,18 @@ func (scope *Scope) generateSqlTag(field *StructField) string { } } -func parseTagSetting(str string) map[string]string { - tags := strings.Split(str, ";") +func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k + for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { + tags := strings.Split(str, ";") + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if len(v) >= 2 { + setting[k] = strings.Join(v[1:], ":") + } else { + setting[k] = k + } } } return setting diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index 061822ef..ea80326e 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -8,17 +8,20 @@ import ( ) type Blog struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Subject string - Body string - Tags []Tag `gorm:"many2many:blog_tags;"` + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` } type Tag struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Value string + Blogs []*Blog `gorm:"many2many:"blogs_tags` } func compareTags(tags []Tag, contents []string) bool { @@ -114,3 +117,265 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } } } + +func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("shared_blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + SharedTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { + t.Errorf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) + if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("SharedTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + var blog1 Blog + DB.Preload("SharedTags").Find(&blog1) + if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("SharedTags").Append(tag4) + + DB.Model(&blog).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "SharedTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Errorf("Should find 3 tags with Related") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "SharedTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + DB.Model(&blog2).Related(&tags2, "SharedTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 2 { + t.Errorf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("SharedTags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Related(&tags3, "SharedTags") + if !compareTags(tags3, []string{"tag6"}) { + t.Errorf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 1 { + t.Errorf("Blog should has three tags after Delete") + } + + DB.Model(&blog2).Association("SharedTags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Related(&tags4, "SharedTags") + if !compareTags(tags4, []string{"tag6"}) { + t.Errorf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog2).Association("SharedTags").Clear() + if DB.Model(&blog).Association("SharedTags").Count() != 0 { + t.Errorf("All tags should be cleared") + } + } +} + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.DropTable(&Blog{}, &Tag{}) + DB.DropTable("locale_blog_tags") + DB.CreateTable(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if len(tags) != 0 { + t.Errorf("Should find 0 tags with Related for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("Should find 3 tags with Related for EN Blog") + } + + DB.Model(&blog2).Related(&tags, "LocaleTags") + if !compareTags(tags, []string{"tag4"}) { + t.Errorf("Should find 1 tags with Related for EN Blog") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Related(&tags2, "LocaleTags") + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Errorf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Errorf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Errorf("EN Blog's tags should be cleared") + } + } +} diff --git a/preload.go b/preload.go index 2c981a79..d12995f3 100644 --- a/preload.go +++ b/preload.go @@ -31,7 +31,7 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { - if scope.Search.preload == nil { + if scope.Search.preload == nil || scope.HasError() { return } @@ -277,10 +277,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } - var associationForeignStructFieldNames []string - for _, dbName := range relation.AssociationForeignFieldNames { + var foreignFieldNames []string + for _, dbName := range relation.ForeignFieldNames { if field, ok := scope.FieldByName(dbName); ok { - associationForeignStructFieldNames = append(associationForeignStructFieldNames, field.Name) + foreignFieldNames = append(foreignFieldNames, field.Name) } } @@ -288,7 +288,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - source := getRealValue(object, associationForeignStructFieldNames) + source := getRealValue(object, foreignFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link)) @@ -296,7 +296,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } else { if object := scope.IndirectValue(); object.IsValid() { - source := getRealValue(object, associationForeignStructFieldNames) + source := getRealValue(object, foreignFieldNames) field := object.FieldByName(field.Name) for _, link := range linkHash[toString(source)] { field.Set(reflect.Append(field, link)) diff --git a/scope_private.go b/scope_private.go index eddcfcc3..cd90c8c2 100644 --- a/scope_private.go +++ b/scope_private.go @@ -630,15 +630,14 @@ func (scope *Scope) autoIndex() *Scope { var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - sqlSettings := parseTagSetting(field.Tag.Get("sql")) - if name, ok := sqlSettings["INDEX"]; ok { + if name, ok := field.TagSettings["INDEX"]; ok { if name == "INDEX" { name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) } indexes[name] = append(indexes[name], field.DBName) } - if name, ok := sqlSettings["UNIQUE_INDEX"]; ok { + if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { if name == "UNIQUE_INDEX" { name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) }