From fd9b688084d3021927721b8925a655d19762918f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Feb 2020 18:02:19 +0800 Subject: [PATCH] Implement parse many2many relation --- schema/field.go | 6 +- schema/naming.go | 6 -- schema/relationship.go | 162 ++++++++++++++++++++++++++--------------- schema/utils.go | 5 ++ schema/utils_test.go | 23 ++++++ 5 files changed, 133 insertions(+), 69 deletions(-) create mode 100644 schema/utils_test.go diff --git a/schema/field.go b/schema/field.go index 005fd4e3..d2747100 100644 --- a/schema/field.go +++ b/schema/field.go @@ -103,11 +103,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBName = dbName } - if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) { field.PrimaryKey = true } - if val, ok := field.TagSettings["AUTO_INCREMENT"]; ok && checkTruth(val) { + if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) { field.AutoIncrement = true field.HasDefaultValue = true } @@ -180,7 +180,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { for _, ef := range field.EmbeddedSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) - if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { ef.DBName = prefix + ef.DBName } diff --git a/schema/naming.go b/schema/naming.go index 6df80d2a..5a2311b6 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -13,7 +13,6 @@ type Namer interface { TableName(table string) string ColumnName(column string) string JoinTableName(table string) string - JoinTableColumnName(table, column string) string } // NamingStrategy tables, columns naming strategy @@ -40,11 +39,6 @@ func (ns NamingStrategy) JoinTableName(str string) string { return ns.TablePrefix + toDBName(str) } -// JoinTableColumnName convert string to join table column name -func (ns NamingStrategy) JoinTableColumnName(referenceTable, referenceColumn string) string { - return inflection.Singular(toDBName(referenceTable)) + toDBName(referenceColumn) -} - var ( smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 diff --git a/schema/relationship.go b/schema/relationship.go index 5081d540..5195589d 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -57,7 +57,7 @@ func (schema *Schema) parseRelation(field *Field) { Field: field, Schema: schema, ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), - PrimaryKeys: toColumns(field.TagSettings["PRIMARYKEY"]), + PrimaryKeys: toColumns(field.TagSettings["REFERENCES"]), } ) @@ -65,63 +65,13 @@ func (schema *Schema) parseRelation(field *Field) { return } - // Parse Polymorphic relations - // - // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` - // type User struct { - // Toys []Toy `gorm:"polymorphic:Owner;"` - // } - // type Pet struct { - // Toy Toy `gorm:"polymorphic:Owner;"` - // } - // type Toy struct { - // OwnerID int - // OwnerType string - // } if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - relation.Polymorphic = &Polymorphic{ - Value: schema.Table, - PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], - PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], - } - - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { - relation.Polymorphic.Value = strings.TrimSpace(value) - } - - if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") - } - - if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") - } - - if schema.err == nil { - relation.References = append(relation.References, Reference{ - PriamryValue: relation.Polymorphic.Value, - ForeignKey: relation.Polymorphic.PolymorphicType, - }) - - primaryKeyField := schema.PrioritizedPrimaryField - if len(relation.ForeignKeys) > 0 { - if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) - } - } - relation.References = append(relation.References, Reference{ - PriamryKey: primaryKeyField, - ForeignKey: relation.Polymorphic.PolymorphicType, - OwnPriamryKey: true, - }) - } - - relation.Type = "has" + schema.buildPolymorphicRelation(relation, field, polymorphic) + } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { + schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.FieldType.Kind() { - case reflect.Struct: - schema.guessRelation(relation, field, true) - case reflect.Slice: + case reflect.Struct, reflect.Slice: schema.guessRelation(relation, field, true) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) @@ -138,6 +88,102 @@ func (schema *Schema) parseRelation(field *Field) { } } +// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, Reference{ + PriamryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.ForeignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) + } + } + relation.References = append(relation.References, Reference{ + PriamryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicType, + OwnPriamryKey: true, + }) + } + + relation.Type = "has" +} + +func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { + relation.Type = Many2Many + + var ( + joinTableFields []reflect.StructField + fieldsMap = map[string]*Field{} + ) + + for _, s := range []*Schema{schema, relation.Schema} { + for _, primaryField := range s.PrimaryFields { + fieldName := s.Name + primaryField.Name + if _, ok := fieldsMap[fieldName]; ok { + if field.Name != s.Name { + fieldName = field.Name + primaryField.Name + } else { + fieldName = s.Name + primaryField.Name + "Reference" + } + } + + fieldsMap[fieldName] = primaryField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: fieldName, + PkgPath: primaryField.StructField.PkgPath, + Type: primaryField.StructField.Type, + Tag: removeSettingFromTag(primaryField.StructField.Tag, "column"), + }) + } + } + + relation.JoinTable, schema.err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer) + relation.JoinTable.Name = many2many + relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + + // build references + for _, f := range relation.JoinTable.Fields { + relation.References = append(relation.References, Reference{ + PriamryKey: fieldsMap[f.Name], + ForeignKey: f, + OwnPriamryKey: schema == fieldsMap[f.Name].Schema, + }) + } + return +} + func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { var ( primaryFields, foreignFields []*Field @@ -214,10 +260,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH if guessHas { relation.Type = "has" } else { - relation.Type = "belongs_to" + relation.Type = BelongsTo } } - -func (schema *Schema) parseMany2ManyRelation(relation *Relationship, field *Field) error { - return nil -} diff --git a/schema/utils.go b/schema/utils.go index 4f4bfa50..f2dd90af 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -2,6 +2,7 @@ package schema import ( "reflect" + "regexp" "strings" ) @@ -38,3 +39,7 @@ func toColumns(val string) (results []string) { } return } + +func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { + return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}")) +} diff --git a/schema/utils_test.go b/schema/utils_test.go new file mode 100644 index 00000000..e70169bf --- /dev/null +++ b/schema/utils_test.go @@ -0,0 +1,23 @@ +package schema + +import ( + "reflect" + "testing" +) + +func TestRemoveSettingFromTag(t *testing.T) { + tags := map[string]string{ + `gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + } + + for k, v := range tags { + if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v { + t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column")) + } + } +}