From a4a0895a8589acc0116fc84eb4ce0139f52917a7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Feb 2020 21:48:06 +0800 Subject: [PATCH] Test parse schema relations --- logger/logger.go | 8 +-- schema/field.go | 7 +- schema/relationship.go | 58 ++++++++++++----- schema/schema.go | 62 ++++++++++++------ schema/schema_helper_test.go | 123 +++++++++++++++++++++++++++++++++++ schema/schema_test.go | 77 +++++++--------------- tests/model.go | 2 +- 7 files changed, 242 insertions(+), 95 deletions(-) create mode 100644 schema/schema_helper_test.go diff --git a/logger/logger.go b/logger/logger.go index 9d6e70bf..cad9be16 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -8,7 +8,7 @@ import ( type LogLevel int -var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", 0)} +var Default Interface = Logger{Writer: log.New(os.Stdout, "\r\n", log.LstdFlags)} const ( Info LogLevel = iota + 1 @@ -40,21 +40,21 @@ func (logger Logger) LogMode(level LogLevel) Interface { // Info print info func (logger Logger) Info(msg string, data ...interface{}) { - if logger.logLevel >= Info { + if logger.logLevel <= Info { logger.Print("[info] " + fmt.Sprintf(msg, data...)) } } // Warn print warn messages func (logger Logger) Warn(msg string, data ...interface{}) { - if logger.logLevel >= Warn { + if logger.logLevel <= Warn { logger.Print("[warn] " + fmt.Sprintf(msg, data...)) } } // Error print error messages func (logger Logger) Error(msg string, data ...interface{}) { - if logger.logLevel >= Error { + if logger.logLevel <= Error { logger.Print("[error] " + fmt.Sprintf(msg, data...)) } } diff --git a/schema/field.go b/schema/field.go index 47250aa8..f1cd022b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -176,7 +176,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - field.EmbeddedSchema, schema.err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer) + var err error + field.Creatable = false + field.Updatable = false + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + schema.err = err + } for _, ef := range field.EmbeddedSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) diff --git a/schema/relationship.go b/schema/relationship.go index 5195589d..358d13e7 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -33,7 +33,7 @@ type Relationship struct { Schema *Schema FieldSchema *Schema JoinTable *Schema - ForeignKeys, PrimaryKeys []string + foreignKeys, primaryKeys []string } type Polymorphic struct { @@ -51,17 +51,19 @@ type Reference struct { func (schema *Schema) parseRelation(field *Field) { var ( + err error fieldValue = reflect.New(field.FieldType).Interface() relation = &Relationship{ Name: field.Name, Field: field, Schema: schema, - ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), - PrimaryKeys: toColumns(field.TagSettings["REFERENCES"]), + foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + primaryKeys: toColumns(field.TagSettings["REFERENCES"]), } ) - if relation.FieldSchema, schema.err = Parse(fieldValue, schema.cacheStore, schema.namer); schema.err != nil { + if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + schema.err = err return } @@ -86,6 +88,20 @@ func (schema *Schema) parseRelation(field *Field) { relation.Type = HasMany } } + + if schema.err == nil { + schema.Relationships.Relations[relation.Name] = relation + switch relation.Type { + case HasOne: + schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) + case HasMany: + schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation) + case BelongsTo: + schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation) + case Many2Many: + schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) + } + } } // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` @@ -125,9 +141,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi }) primaryKeyField := schema.PrioritizedPrimaryField - if len(relation.ForeignKeys) > 0 { - if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) + if len(relation.foreignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) } } relation.References = append(relation.References, Reference{ @@ -144,6 +160,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel relation.Type = Many2Many var ( + err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} ) @@ -169,7 +186,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - relation.JoinTable, schema.err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer) + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + schema.err = err + } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) @@ -202,18 +221,23 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } } - if len(relation.ForeignKeys) > 0 { - for _, foreignKey := range relation.ForeignKeys { + if len(relation.foreignKeys) > 0 { + for _, foreignKey := range relation.foreignKeys { if f := foreignSchema.LookUpField(foreignKey); f != nil { foreignFields = append(foreignFields, f) } else { - reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.ForeignKeys) + reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys) return } } } else { for _, primaryField := range primarySchema.PrimaryFields { - if f := foreignSchema.LookUpField(field.Name + primaryField.Name); f != nil { + lookUpName := schema.Name + primaryField.Name + if !guessHas { + lookUpName = field.Name + primaryField.Name + } + + if f := foreignSchema.LookUpField(lookUpName); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) } @@ -221,19 +245,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v", relation.FieldSchema, schema, field.Name) + reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) return - } else if len(relation.PrimaryKeys) > 0 { - for idx, primaryKey := range relation.PrimaryKeys { + } else if len(relation.primaryKeys) > 0 { + for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { primaryFields = append(primaryFields, f) } else if f != primaryFields[idx] { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys) + reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) return } } else { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.PrimaryKeys) + reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) return } } diff --git a/schema/schema.go b/schema/schema.go index 0b5548e3..d3404312 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -4,7 +4,6 @@ import ( "fmt" "go/ast" "reflect" - "strings" "sync" "github.com/jinzhu/gorm/logger" @@ -26,7 +25,7 @@ type Schema struct { } func (schema Schema) String() string { - return schema.ModelType.PkgPath() + return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) } func (schema Schema) LookUpField(name string) *Field { @@ -63,6 +62,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) Table: namer.TableName(modelType.Name()), FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, } @@ -76,10 +76,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { - field := schema.ParseField(fieldStruct) - schema.Fields = append(schema.Fields, field) - if field.EmbeddedSchema != nil { + if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) + } else { + schema.Fields = append(schema.Fields, field) } } } @@ -94,6 +94,27 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field + + if v != nil && v.PrimaryKey { + if schema.PrioritizedPrimaryField == v { + schema.PrioritizedPrimaryField = nil + } + + for idx, f := range schema.PrimaryFields { + if f == v { + schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) + } else if schema.PrioritizedPrimaryField == nil { + schema.PrioritizedPrimaryField = f + } + } + } + + if field.PrimaryKey { + if schema.PrioritizedPrimaryField == nil { + schema.PrioritizedPrimaryField = field + } + schema.PrimaryFields = append(schema.PrimaryFields, field) + } } } @@ -102,23 +123,26 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - for db, field := range schema.FieldsByDBName { - if strings.ToLower(db) == "id" { - schema.PrioritizedPrimaryField = field - } - - if field.PrimaryKey { - if schema.PrioritizedPrimaryField == nil { - schema.PrioritizedPrimaryField = field - } - schema.PrimaryFields = append(schema.PrimaryFields, field) - } - - if field.DataType == "" { - defer schema.parseRelation(field) + if f := schema.LookUpField("id"); f != nil { + if f.PrimaryKey { + schema.PrioritizedPrimaryField = f + } else if len(schema.PrimaryFields) == 0 { + f.PrimaryKey = true + schema.PrioritizedPrimaryField = f + schema.PrimaryFields = append(schema.PrimaryFields, f) } } cacheStore.Store(modelType, schema) + + // parse relations for unidentified fields + for _, field := range schema.Fields { + if field.DataType == "" && field.Creatable { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } + } + } + return schema, schema.err } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go new file mode 100644 index 00000000..eb0085c2 --- /dev/null +++ b/schema/schema_helper_test.go @@ -0,0 +1,123 @@ +package schema_test + +import ( + "reflect" + "testing" + + "github.com/jinzhu/gorm/schema" +) + +func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { + equalFieldNames := []string{"Name", "Table"} + + for _, name := range equalFieldNames { + got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(v).FieldByName(name).Interface() + if !reflect.DeepEqual(got, expects) { + t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) + } + } + + for idx, field := range primaryFields { + var found bool + for _, f := range s.PrimaryFields { + if f.Name == field { + found = true + } + } + + if idx == 0 { + if field != s.PrioritizedPrimaryField.Name { + t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) + } + } + + if !found { + t.Errorf("schema %v failed to found priamry key: %v", s, field) + } + } +} + +func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { + if fc != nil { + fc(f) + } + + if f.TagSettings == nil { + if f.Tag != "" { + f.TagSettings = schema.ParseTagSetting(f.Tag) + } else { + f.TagSettings = map[string]string{} + } + } + + if parsedField, ok := s.FieldsByName[f.Name]; !ok { + t.Errorf("schema %v failed to look up field with name %v", s, f.Name) + } else { + equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} + + for _, name := range equalFieldNames { + got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() + expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() + if !reflect.DeepEqual(got, expects) { + t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) + } + } + + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + + for _, name := range []string{f.DBName, f.Name} { + if field := s.LookUpField(name); field == nil || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + + if f.PrimaryKey { + var found bool + for _, primaryField := range s.PrimaryFields { + if primaryField == parsedField { + found = true + } + } + + if !found { + t.Errorf("schema %v doesn't include field %v", s, f.Name) + } + } + } +} + +type Relation struct { + Name string + Type schema.RelationshipType + Polymorphic schema.Polymorphic + Schema string + FieldSchema string + JoinTable string + JoinTableFields []schema.Field + References []Reference +} + +type Reference struct { + PrimaryKey string + PrimarySchema string + ForeignKey string + ForeignSchema string + OwnPriamryKey bool +} + +func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { + if r, ok := s.Relationships.Relations[relation.Name]; ok { + if r.Name != relation.Name { + t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Name, r.Name) + } + + if r.Type != relation.Type { + t.Errorf("schema %v relation name expects %v, but got %v", s, relation.Type, r.Type) + } + } else { + t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) + } +} diff --git a/schema/schema_test.go b/schema/schema_test.go index eefac98b..8ea219e1 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,7 +1,6 @@ package schema_test import ( - "reflect" "sync" "testing" @@ -11,68 +10,40 @@ import ( func TestParseSchema(t *testing.T) { cacheMap := sync.Map{} - user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } - checkSchemaFields(t, user) -} + // check schema + checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) -func checkSchemaFields(t *testing.T, s *schema.Schema) { + // check fields fields := []schema.Field{ - schema.Field{ - Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, - PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, - }, - schema.Field{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, - schema.Field{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, - schema.Field{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, - schema.Field{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, - schema.Field{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, - schema.Field{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, - schema.Field{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, - schema.Field{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, + {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, } for _, f := range fields { - f.Creatable = true - f.Updatable = true - if f.TagSettings == nil { - if f.Tag != "" { - f.TagSettings = schema.ParseTagSetting(f.Tag) - } else { - f.TagSettings = map[string]string{} - } - } + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + }) + } - if foundField, ok := s.FieldsByName[f.Name]; !ok { - t.Errorf("schema %v failed to look up field with name %v", s, f.Name) - } else { - checkSchemaField(t, foundField, f) - - if field, ok := s.FieldsByDBName[f.DBName]; !ok || foundField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - - for _, name := range []string{f.DBName, f.Name} { - if field := s.LookUpField(name); field == nil || foundField != field { - t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) - } - } - } - } -} - -func checkSchemaField(t *testing.T, parsedField *schema.Field, field schema.Field) { - equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(field).FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) - } + // check relations + relations := []Relation{ + {Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", References: []Reference{{"ID", "User", "UserID", "Pet", true}}}, + } + for _, relation := range relations { + checkSchemaRelation(t, user, relation) } } diff --git a/tests/model.go b/tests/model.go index 0be3e97a..e2b69abc 100644 --- a/tests/model.go +++ b/tests/model.go @@ -23,7 +23,7 @@ type User struct { Company Company ManagerID uint Manager *User - Team []User `foreignkey:ManagerID` + Team []User `gorm:"foreignkey:ManagerID"` Friends []*User `gorm:"many2many:user_friends"` Languages []Language `gorm:"many2many:user_speaks"` }