diff --git a/clause/clause.go b/clause/clause.go index b0507f44..1b4a7e85 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -59,7 +59,7 @@ type OverrideNameInterface interface { type Where struct { AndConditions AddConditions ORConditions []ORConditions - Builders []Expression + builders []Expression } func (where Where) Name() string { @@ -74,8 +74,8 @@ func (where Where) Build(builder Builder) { where.AndConditions.Build(builder) } - if len(where.Builders) > 0 { - for _, b := range where.Builders { + if len(where.builders) > 0 { + for _, b := range where.builders { if withConditions { builder.Write(" AND ") } @@ -122,9 +122,9 @@ func (where Where) MergeExpression(expr Expression) { if w, ok := expr.(Where); ok { where.AndConditions = append(where.AndConditions, w.AndConditions...) where.ORConditions = append(where.ORConditions, w.ORConditions...) - where.Builders = append(where.Builders, w.Builders...) + where.builders = append(where.builders, w.builders...) } else { - where.Builders = append(where.Builders, expr) + where.builders = append(where.builders, expr) } } @@ -135,6 +135,22 @@ type Select struct { // Join join clause type Join struct { + Table string + Type string // left join books on + ON []Expression + builders []Expression +} + +func (join Join) Build(builder Builder) { + // TODO +} + +func (join Join) MergeExpression(expr Expression) { + if j, ok := expr.(Join); ok { + join.builders = append(join.builders, j.builders...) + } else { + join.builders = append(join.builders, expr) + } } // GroupBy group by clause diff --git a/clause/query.go b/clause/query.go index 949678d9..7b5491e5 100644 --- a/clause/query.go +++ b/clause/query.go @@ -2,6 +2,12 @@ package clause import "strings" +// Column quote with name +type Column struct { + Table string + Name string +} + //////////////////////////////////////////////////////////////////////////////// // Query Expressions //////////////////////////////////////////////////////////////////////////////// diff --git a/schema/field.go b/schema/field.go index 88a0d3fb..005fd4e3 100644 --- a/schema/field.go +++ b/schema/field.go @@ -8,23 +8,23 @@ import ( "time" ) -type FieldType string +type DataType string const ( - Bool FieldType = "bool" - Int = "int" - Uint = "uint" - Float = "float" - String = "string" - Time = "time" - Bytes = "bytes" + Bool DataType = "bool" + Int = "int" + Uint = "uint" + Float = "float" + String = "string" + Time = "time" + Bytes = "bytes" ) type Field struct { Name string DBName string BindNames []string - DataType FieldType + DataType DataType DBDataType string PrimaryKey bool AutoIncrement bool @@ -42,8 +42,7 @@ type Field struct { Tag reflect.StructTag TagSettings map[string]string Schema *Schema - EmbeddedbSchema *Schema - Relationship string + EmbeddedSchema *Schema } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -177,8 +176,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - field.EmbeddedbSchema = Parse(fieldValue, sync.Map{}, schema.namer) - for _, ef := range field.EmbeddedbSchema.Fields { + field.EmbeddedSchema, schema.err = Parse(fieldValue, sync.Map{}, schema.namer) + for _, ef := range field.EmbeddedSchema.Fields { ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { @@ -189,13 +188,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - } else { - switch fieldValue.Kind() { - case reflect.Struct: - field.Relationship = "one" - case reflect.Slice: - field.Relationship = "many" - } } return field diff --git a/schema/naming.go b/schema/naming.go index 1baa8558..6df80d2a 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -10,8 +10,10 @@ import ( // Namer namer interface type Namer interface { - TableName(string) string - ColumnName(string) string + TableName(table string) string + ColumnName(column string) string + JoinTableName(table string) string + JoinTableColumnName(table, column string) string } // NamingStrategy tables, columns naming strategy @@ -33,6 +35,16 @@ func (ns NamingStrategy) ColumnName(str string) string { return toDBName(str) } +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + return ns.TablePrefix + toDBName(str) +} + +// JoinTableColumnName convert string to join table column name +func (ns NamingStrategy) JoinTableColumnName(referenceTable, referenceColumn string) string { + return inflection.Singular(toDBName(referenceTable)) + toDBName(referenceColumn) +} + var ( smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 diff --git a/schema/relationship.go b/schema/relationship.go index b0c630be..95f56f6d 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -1,43 +1,143 @@ package schema +import ( + "fmt" + "reflect" + "strings" +) + // RelationshipType relationship type type RelationshipType string const ( - HasOneRel RelationshipType = "has_one" // HasOneRel has one relationship - HasManyRel RelationshipType = "has_many" // HasManyRel has many relationship - BelongsToRel RelationshipType = "belongs_to" // BelongsToRel belongs to relationship - Many2ManyRel RelationshipType = "many_to_many" // Many2ManyRel many to many relationship + HasOne RelationshipType = "has_one" // HasOneRel has one relationship + HasMany RelationshipType = "has_many" // HasManyRel has many relationship + BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship ) type Relationships struct { - HasOne map[string]*Relationship - BelongsTo map[string]*Relationship - HasMany map[string]*Relationship - Many2Many map[string]*Relationship + HasOne []*Relationship + BelongsTo []*Relationship + HasMany []*Relationship + Many2Many []*Relationship + Relations map[string]*Relationship } type Relationship struct { - Type RelationshipType - ForeignKeys []*RelationField // self - AssociationForeignKeys []*RelationField // association - JoinTable *JoinTable + Name string + Type RelationshipType + Field *Field + Polymorphic *Polymorphic + References []Reference + Schema *Schema + FieldSchema *Schema + JoinTable *Schema + ForeignKeys, AssociationForeignKeys []string } -type RelationField struct { - *Field - PolymorphicField *Field - PolymorphicValue string +type Polymorphic struct { + PolymorphicID *Field + PolymorphicType *Field + Value string } -type JoinTable struct { - Table string - ForeignKeys []*RelationField - AssociationForeignKeys []*RelationField +type Reference struct { + PriamryKey *Field + PriamryValue string + ForeignKey *Field + OwnPriamryKey bool } -func (schema *Schema) buildToOneRel(field *Field) { +func (schema *Schema) parseRelation(field *Field) { + var ( + fieldValue = reflect.New(field.FieldType).Interface() + relation = &Relationship{ + Name: field.Name, + Field: field, + Schema: schema, + Type: RelationshipType(strings.ToLower(strings.TrimSpace(field.TagSettings["REL"]))), + ForeignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + AssociationForeignKeys: toColumns(field.TagSettings["ASSOCIATION_FOREIGNKEY"]), + } + ) + + if relation.FieldSchema, schema.err = Parse(fieldValue, schema.cacheStore, schema.namer); schema.err != nil { + return + } + + // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` + // type User struct { + // Toys []Toy `gorm:"polymorphic:Owner;"` + // } + // type Pet struct { + // Toy Toy `gorm:"polymorphic:Owner;"` + // } + // type Toy struct { + // OwnerID int + // OwnerType string + // } + if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type: %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type: %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, Reference{ + PriamryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.ForeignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.ForeignKeys[0]); primaryKeyField == nil || len(relation.ForeignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign key: %+v for %v on field %v", relation.ForeignKeys, schema, field.Name) + } + } + relation.References = append(relation.References, Reference{ + PriamryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicType, + OwnPriamryKey: true, + }) + } + + switch field.FieldType.Kind() { + case reflect.Struct: + relation.Type = HasOne + case reflect.Slice: + relation.Type = HasMany + } + return + } + + switch field.FieldType.Kind() { + case reflect.Struct: + schema.parseStructRelation(relation, field) + case reflect.Slice: + schema.parseSliceRelation(relation, field) + default: + schema.err = fmt.Errorf("unsupported data type: %v (in %v#%v ", field.FieldType.PkgPath(), schema, field.Name) + } } -func (schema *Schema) buildToManyRel(field *Field) { +func (schema *Schema) parseStructRelation(relation *Relationship, field *Field) error { + return nil +} + +func (schema *Schema) parseSliceRelation(relation *Relationship, field *Field) error { + return nil } diff --git a/schema/schema.go b/schema/schema.go index 5069bb44..f18cb7a6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -1,6 +1,7 @@ package schema import ( + "fmt" "go/ast" "reflect" "strings" @@ -8,6 +9,7 @@ import ( ) type Schema struct { + Name string ModelType reflect.Type Table string PrioritizedPrimaryField *Field @@ -16,42 +18,64 @@ type Schema struct { FieldsByName map[string]*Field FieldsByDBName map[string]*Field Relationships Relationships + err error namer Namer + cacheStore sync.Map +} + +func (schema Schema) String() string { + return schema.ModelType.PkgPath() +} + +func (schema Schema) LookUpField(name string) *Field { + if field, ok := schema.FieldsByDBName[name]; ok { + return field + } + if field, ok := schema.FieldsByName[name]; ok { + return field + } + return nil } // get data type from dialector -func Parse(dest interface{}, cacheStore sync.Map, namer Namer) *Schema { +func Parse(dest interface{}, cacheStore sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - return nil + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("unsupported data %+v when parsing model", dest) + } + return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath()) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema) + return v.(*Schema), nil } schema := &Schema{ + Name: modelType.Name(), ModelType: modelType, Table: namer.TableName(modelType.Name()), FieldsByName: map[string]*Field{}, FieldsByDBName: map[string]*Field{}, + cacheStore: cacheStore, } - for i := 0; i < modelType.NumField(); i++ { - fieldStruct := modelType.Field(i) - if !ast.IsExported(fieldStruct.Name) { - continue + defer func() { + if schema.err != nil { + cacheStore.Delete(modelType) } + }() - field := schema.ParseField(fieldStruct) - schema.Fields = append(schema.Fields, field) - if field.EmbeddedbSchema != nil { - for _, f := range field.EmbeddedbSchema.Fields { - schema.Fields = append(schema.Fields, f) + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + field := schema.ParseField(fieldStruct) + schema.Fields = append(schema.Fields, field) + if field.EmbeddedSchema != nil { + schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) } } } @@ -85,7 +109,12 @@ func Parse(dest interface{}, cacheStore sync.Map, namer Namer) *Schema { } schema.PrimaryFields = append(schema.PrimaryFields, field) } + + if field.DataType == "" { + defer schema.parseRelation(field) + } } - return schema + cacheStore.Store(modelType, schema) + return schema, schema.err } diff --git a/schema/utils.go b/schema/utils.go index 1b0f5eac..4f4bfa50 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -29,3 +29,12 @@ func checkTruth(val string) bool { } return true } + +func toColumns(val string) (results []string) { + if val != "" { + for _, v := range strings.Split(val, ",") { + results = append(results, strings.TrimSpace(v)) + } + } + return +}