From 3cbd233758499f55bebf640264a2158aafe07096 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 2 Feb 2020 00:03:56 +0800 Subject: [PATCH] Add more tests for parse schema relations --- schema/field.go | 2 + schema/naming.go | 6 +-- schema/relationship.go | 31 ++++++----- schema/schema.go | 5 +- schema/schema_helper_test.go | 100 +++++++++++++++++++++++++++++++---- schema/schema_test.go | 55 ++++++++++++++++++- tests/model.go | 4 +- 7 files changed, 172 insertions(+), 31 deletions(-) diff --git a/schema/field.go b/schema/field.go index f1cd022b..570b3c50 100644 --- a/schema/field.go +++ b/schema/field.go @@ -55,6 +55,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Updatable: true, Tag: fieldStruct.Tag, TagSettings: ParseTagSetting(fieldStruct.Tag), + Schema: schema, } for field.FieldType.Kind() == reflect.Ptr { @@ -183,6 +184,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { diff --git a/schema/naming.go b/schema/naming.go index 5a2311b6..e6a5625e 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -11,7 +11,7 @@ import ( // Namer namer interface type Namer interface { TableName(table string) string - ColumnName(column string) string + ColumnName(table, column string) string JoinTableName(table string) string } @@ -30,13 +30,13 @@ func (ns NamingStrategy) TableName(str string) string { } // ColumnName convert string to column name -func (ns NamingStrategy) ColumnName(str string) string { +func (ns NamingStrategy) ColumnName(table, str string) string { return toDBName(str) } // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { - return ns.TablePrefix + toDBName(str) + return ns.TablePrefix + inflection.Plural(toDBName(str)) } var ( diff --git a/schema/relationship.go b/schema/relationship.go index 358d13e7..b6aaefbd 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -4,6 +4,8 @@ import ( "fmt" "reflect" "strings" + + "github.com/jinzhu/inflection" ) // RelationshipType relationship type @@ -43,10 +45,10 @@ type Polymorphic struct { } type Reference struct { - PriamryKey *Field - PriamryValue string + PrimaryKey *Field + PrimaryValue string ForeignKey *Field - OwnPriamryKey bool + OwnPrimaryKey bool } func (schema *Schema) parseRelation(field *Field) { @@ -136,7 +138,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi if schema.err == nil { relation.References = append(relation.References, Reference{ - PriamryValue: relation.Polymorphic.Value, + PrimaryValue: relation.Polymorphic.Value, ForeignKey: relation.Polymorphic.PolymorphicType, }) @@ -147,9 +149,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } } relation.References = append(relation.References, Reference{ - PriamryKey: primaryKeyField, - ForeignKey: relation.Polymorphic.PolymorphicType, - OwnPriamryKey: true, + PrimaryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicID, + OwnPrimaryKey: true, }) } @@ -163,17 +165,20 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} + ownFieldsMap = map[string]bool{} // fix self join many2many ) - for _, s := range []*Schema{schema, relation.Schema} { + for _, s := range []*Schema{schema, relation.FieldSchema} { for _, primaryField := range s.PrimaryFields { fieldName := s.Name + primaryField.Name if _, ok := fieldsMap[fieldName]; ok { if field.Name != s.Name { - fieldName = field.Name + primaryField.Name + fieldName = inflection.Singular(field.Name) + primaryField.Name } else { fieldName = s.Name + primaryField.Name + "Reference" } + } else { + ownFieldsMap[fieldName] = true } fieldsMap[fieldName] = primaryField @@ -195,9 +200,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { relation.References = append(relation.References, Reference{ - PriamryKey: fieldsMap[f.Name], + PrimaryKey: fieldsMap[f.Name], ForeignKey: f, - OwnPriamryKey: schema == fieldsMap[f.Name].Schema, + OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], }) } return @@ -275,9 +280,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH // build references for idx, foreignField := range foreignFields { relation.References = append(relation.References, Reference{ - PriamryKey: primaryFields[idx], + PrimaryKey: primaryFields[idx], ForeignKey: foreignField, - OwnPriamryKey: schema == primarySchema, + OwnPrimaryKey: schema == primarySchema && guessHas, }) } diff --git a/schema/schema.go b/schema/schema.go index d3404312..5cd6146b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -25,6 +25,9 @@ type Schema struct { } func (schema Schema) String() string { + if schema.ModelType.Name() == "" { + return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) + } return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) } @@ -86,7 +89,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) for _, field := range schema.Fields { if field.DBName == "" { - field.DBName = namer.ColumnName(field.Name) + field.DBName = namer.ColumnName(schema.Table, field.Name) } if field.DBName != "" { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index eb0085c2..ce91d8d1 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,7 +1,9 @@ package schema_test import ( + "fmt" "reflect" + "strings" "testing" "github.com/jinzhu/gorm/schema" @@ -90,14 +92,25 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } type Relation struct { - Name string - Type schema.RelationshipType - Polymorphic schema.Polymorphic - Schema string - FieldSchema string - JoinTable string - JoinTableFields []schema.Field - References []Reference + Name string + Type schema.RelationshipType + Schema string + FieldSchema string + Polymorphic Polymorphic + JoinTable JoinTable + References []Reference +} + +type Polymorphic struct { + ID string + Type string + Value string +} + +type JoinTable struct { + Name string + Table string + Fields []schema.Field } type Reference struct { @@ -105,17 +118,82 @@ type Reference struct { PrimarySchema string ForeignKey string ForeignSchema string - OwnPriamryKey bool + PrimaryValue string + OwnPrimaryKey bool } func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { if r, ok := s.Relationships.Relations[relation.Name]; ok { if r.Name != relation.Name { - t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Name, r.Name) + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) } if r.Type != relation.Type { - t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Type, r.Type) + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) + } + + if r.Schema.Name != relation.Schema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + } + + if r.FieldSchema.Name != relation.FieldSchema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + } + + if r.Polymorphic != nil { + if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { + t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + } + + if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { + t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + } + + if r.Polymorphic.Value != relation.Polymorphic.Value { + t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) + } + } + + if r.JoinTable != nil { + if r.JoinTable.Name != relation.JoinTable.Name { + t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + } + + if r.JoinTable.Table != relation.JoinTable.Table { + t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + } + + for _, f := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &f, nil) + } + } + + if len(relation.References) != len(r.References) { + t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) + } + + for _, ref := range relation.References { + var found bool + for _, rf := range r.References { + if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { + found = true + } + } + + if !found { + var refs []string + for _, rf := range r.References { + var primaryKey, primaryKeySchema string + if rf.PrimaryKey != nil { + primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + } + refs = append(refs, fmt.Sprintf( + "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", + primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, + )) + } + t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + } } } else { t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) diff --git a/schema/schema_test.go b/schema/schema_test.go index 8ea219e1..526a98bd 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -41,8 +41,61 @@ func TestParseSchema(t *testing.T) { // check relations relations := []Relation{ - {Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", References: []Reference{{"ID", "User", "UserID", "Pet", true}}}, + { + Name: "Account", Type: schema.HasOne, Schema: "User", FieldSchema: "Account", + References: []Reference{{"ID", "User", "UserID", "Account", "", true}}, + }, + { + Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", + References: []Reference{{"ID", "User", "UserID", "Pet", "", true}}, + }, + { + Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{{"ID", "User", "OwnerID", "Toy", "", true}, {"", "", "OwnerType", "Toy", "users", false}}, + }, + { + Name: "Company", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Company", + References: []Reference{{"ID", "Company", "CompanyID", "User", "", false}}, + }, + { + Name: "Manager", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", false}}, + }, + { + Name: "Team", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", true}}, + }, + { + Name: "Languages", Type: schema.Many2Many, Schema: "User", FieldSchema: "Language", + JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + { + Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "UserSpeak", "", true}, {"Code", "Language", "LanguageCode", "UserSpeak", "", false}}, + }, + { + Name: "Friends", Type: schema.Many2Many, Schema: "User", FieldSchema: "User", + JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + { + Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, + }, } + for _, relation := range relations { checkSchemaRelation(t, user, relation) } diff --git a/tests/model.go b/tests/model.go index e2b69abc..62000352 100644 --- a/tests/model.go +++ b/tests/model.go @@ -24,8 +24,8 @@ type User struct { ManagerID uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` + Languages []Language `gorm:"many2many:UserSpeak"` Friends []*User `gorm:"many2many:user_friends"` - Languages []Language `gorm:"many2many:user_speaks"` } type Account struct { @@ -53,6 +53,6 @@ type Company struct { } type Language struct { - Code string `gorm:primarykey` + Code string `gorm:"primarykey"` Name string }