Implement parse many2many relation

This commit is contained in:
Jinzhu 2020-02-01 18:02:19 +08:00
parent a9c20291e4
commit fd9b688084
5 changed files with 133 additions and 69 deletions

View File

@ -103,11 +103,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DBName = dbName
}
if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && checkTruth(val) {
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && checkTruth(val) {
field.PrimaryKey = true
}
if val, ok := field.TagSettings["AUTO_INCREMENT"]; ok && checkTruth(val) {
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && checkTruth(val) {
field.AutoIncrement = true
field.HasDefaultValue = true
}
@ -180,7 +180,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
for _, ef := range field.EmbeddedSchema.Fields {
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
ef.DBName = prefix + ef.DBName
}

View File

@ -13,7 +13,6 @@ type Namer interface {
TableName(table string) string
ColumnName(column string) string
JoinTableName(table string) string
JoinTableColumnName(table, column string) string
}
// NamingStrategy tables, columns naming strategy
@ -40,11 +39,6 @@ func (ns NamingStrategy) JoinTableName(str string) string {
return ns.TablePrefix + toDBName(str)
}
// JoinTableColumnName convert string to join table column name
func (ns NamingStrategy) JoinTableColumnName(referenceTable, referenceColumn string) string {
return inflection.Singular(toDBName(referenceTable)) + toDBName(referenceColumn)
}
var (
smap sync.Map
// https://github.com/golang/lint/blob/master/lint.go#L770

View File

@ -57,7 +57,7 @@ func (schema *Schema) parseRelation(field *Field) {
Field: field,
Schema: schema,
ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
PrimaryKeys: toColumns(field.TagSettings["PRIMARYKEY"]),
PrimaryKeys: toColumns(field.TagSettings["REFERENCES"]),
}
)
@ -65,63 +65,13 @@ func (schema *Schema) parseRelation(field *Field) {
return
}
// Parse Polymorphic relations
//
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
}
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
relation.References = append(relation.References, Reference{
PriamryValue: relation.Polymorphic.Value,
ForeignKey: relation.Polymorphic.PolymorphicType,
})
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.ForeignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name)
}
}
relation.References = append(relation.References, Reference{
PriamryKey: primaryKeyField,
ForeignKey: relation.Polymorphic.PolymorphicType,
OwnPriamryKey: true,
})
}
relation.Type = "has"
schema.buildPolymorphicRelation(relation, field, polymorphic)
} else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else {
switch field.FieldType.Kind() {
case reflect.Struct:
schema.guessRelation(relation, field, true)
case reflect.Slice:
case reflect.Struct, reflect.Slice:
schema.guessRelation(relation, field, true)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
@ -138,6 +88,102 @@ func (schema *Schema) parseRelation(field *Field) {
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
}
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
relation.References = append(relation.References, Reference{
PriamryValue: relation.Polymorphic.Value,
ForeignKey: relation.Polymorphic.PolymorphicType,
})
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.ForeignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name)
}
}
relation.References = append(relation.References, Reference{
PriamryKey: primaryKeyField,
ForeignKey: relation.Polymorphic.PolymorphicType,
OwnPriamryKey: true,
})
}
relation.Type = "has"
}
func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
relation.Type = Many2Many
var (
joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{}
)
for _, s := range []*Schema{schema, relation.Schema} {
for _, primaryField := range s.PrimaryFields {
fieldName := s.Name + primaryField.Name
if _, ok := fieldsMap[fieldName]; ok {
if field.Name != s.Name {
fieldName = field.Name + primaryField.Name
} else {
fieldName = s.Name + primaryField.Name + "Reference"
}
}
fieldsMap[fieldName] = primaryField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: fieldName,
PkgPath: primaryField.StructField.PkgPath,
Type: primaryField.StructField.Type,
Tag: removeSettingFromTag(primaryField.StructField.Tag, "column"),
})
}
}
relation.JoinTable, schema.err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer)
relation.JoinTable.Name = many2many
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
// build references
for _, f := range relation.JoinTable.Fields {
relation.References = append(relation.References, Reference{
PriamryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPriamryKey: schema == fieldsMap[f.Name].Schema,
})
}
return
}
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
var (
primaryFields, foreignFields []*Field
@ -214,10 +260,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
if guessHas {
relation.Type = "has"
} else {
relation.Type = "belongs_to"
relation.Type = BelongsTo
}
}
func (schema *Schema) parseMany2ManyRelation(relation *Relationship, field *Field) error {
return nil
}

View File

@ -2,6 +2,7 @@ package schema
import (
"reflect"
"regexp"
"strings"
)
@ -38,3 +39,7 @@ func toColumns(val string) (results []string) {
}
return
}
func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag {
return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}"))
}

23
schema/utils_test.go Normal file
View File

@ -0,0 +1,23 @@
package schema
import (
"reflect"
"testing"
)
func TestRemoveSettingFromTag(t *testing.T) {
tags := map[string]string{
`gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`,
`gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`,
`gorm:"before:value;column:db" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`,
`gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`,
`gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`,
`gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`,
}
for k, v := range tags {
if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v {
t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column"))
}
}
}