From 5d692a6bf2f16bd2161bf6f764d5739a786b8e54 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 15 Feb 2015 23:01:09 +0800 Subject: [PATCH] Try to cache struct fields --- field.go | 61 +++++--------- scope.go | 179 ++------------------------------------- scope_private.go | 26 ++---- struct_field.go | 212 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 244 insertions(+), 234 deletions(-) create mode 100644 struct_field.go diff --git a/field.go b/field.go index b1deab9b..4245b6cb 100644 --- a/field.go +++ b/field.go @@ -4,51 +4,11 @@ import ( "database/sql" "errors" "reflect" - "time" ) -type relationship struct { - JoinTable string - ForeignKey string - ForeignType string - AssociationForeignKey string - Kind string -} - -// FIXME -func (r relationship) ForeignDBName() string { - return ToSnake(r.ForeignKey) -} - -func (r relationship) AssociationForeignDBName(name string) string { - return ToSnake(r.AssociationForeignKey) -} - type Field struct { - Name string - DBName string - Field reflect.Value - Tag reflect.StructTag - Relationship *relationship - IsNormal bool - IsBlank bool - IsIgnored bool - IsPrimaryKey bool - DefaultValue interface{} -} - -func (field *Field) IsScanner() bool { - _, isScanner := reflect.New(field.Field.Type()).Interface().(sql.Scanner) - return isScanner -} - -func (field *Field) IsTime() bool { - reflectValue := field.Field - if reflectValue.Kind() == reflect.Ptr { - reflectValue = reflect.New(reflectValue.Type().Elem()).Elem() - } - _, isTime := reflectValue.Interface().(time.Time) - return isTime + *StructField + Field reflect.Value } func (field *Field) Set(value interface{}) (err error) { @@ -76,3 +36,20 @@ func (field *Field) Set(value interface{}) (err error) { return } + +type relationship struct { + JoinTable string + ForeignKey string + ForeignType string + AssociationForeignKey string + Kind string +} + +// FIXME +func (r relationship) ForeignDBName() string { + return ToSnake(r.ForeignKey) +} + +func (r relationship) AssociationForeignDBName(name string) string { + return ToSnake(r.AssociationForeignKey) +} diff --git a/scope.go b/scope.go index d7db341e..5850a517 100644 --- a/scope.go +++ b/scope.go @@ -3,7 +3,6 @@ package gorm import ( "errors" "fmt" - "go/ast" "strings" "time" @@ -154,7 +153,7 @@ func (scope *Scope) HasColumn(column string) bool { dbName := ToSnake(column) - field, hasColumn := clone.Fields(false)[dbName] + field, hasColumn := clone.Fields()[dbName] return hasColumn && !field.IsIgnored } @@ -298,177 +297,13 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { return nil, false } -func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelation bool) []*Field { - var field Field - field.Name = fieldStruct.Name - - value := scope.IndirectValue().FieldByName(fieldStruct.Name) - indirectValue := reflect.Indirect(value) - field.Field = value - field.IsBlank = isBlank(value) - - // Search for primary key tag identifier - settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) - if _, ok := settings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - } - - if def, ok := parseTagSetting(fieldStruct.Tag.Get("sql"))["DEFAULT"]; ok { - field.DefaultValue = def - } - - field.Tag = fieldStruct.Tag - - if value, ok := settings["COLUMN"]; ok { - field.DBName = value - } else { - field.DBName = ToSnake(fieldStruct.Name) - } - - tagIdentifier := "sql" - if scope.db != nil { - tagIdentifier = scope.db.parent.tagIdentifier - } - if fieldStruct.Tag.Get(tagIdentifier) == "-" { - field.IsIgnored = true - } - - if !field.IsIgnored { - // parse association - if !indirectValue.IsValid() { - indirectValue = reflect.New(value.Type()) - } - typ := indirectValue.Type() - scopeTyp := scope.IndirectValue().Type() - - foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) - foreignType := SnakeToUpperCamel(settings["FOREIGNTYPE"]) - associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) - many2many := settings["MANY2MANY"] - polymorphic := SnakeToUpperCamel(settings["POLYMORPHIC"]) - - if polymorphic != "" { - foreignKey = polymorphic + "Id" - foreignType = polymorphic + "Type" - } - - switch indirectValue.Kind() { - case reflect.Slice: - typ = typ.Elem() - - if field.IsScanner() { - field.IsNormal = true - } else if (typ.Kind() == reflect.Struct) && withRelation { - if foreignKey == "" { - foreignKey = scopeTyp.Name() + "Id" - } - if associationForeignKey == "" { - associationForeignKey = typ.Name() + "Id" - } - - // if not many to many, foreign key could be null - if many2many == "" { - if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - foreignKey = "" - } - } - - field.Relationship = &relationship{ - JoinTable: many2many, - ForeignKey: foreignKey, - ForeignType: foreignType, - AssociationForeignKey: associationForeignKey, - Kind: "has_many", - } - - if many2many != "" { - field.Relationship.Kind = "many_to_many" - } - } else { - field.IsNormal = true - } - case reflect.Struct: - if field.IsTime() || field.IsScanner() { - field.IsNormal = true - } else if _, ok := settings["EMBEDDED"]; ok || fieldStruct.Anonymous { - var fields []*Field - if field.Field.CanAddr() { - for _, field := range scope.New(field.Field.Addr().Interface()).Fields() { - field.DBName = field.DBName - fields = append(fields, field) - } - } - return fields - } else if withRelation { - var belongsToForeignKey, hasOneForeignKey, kind string - - if foreignKey == "" { - belongsToForeignKey = field.Name + "Id" - hasOneForeignKey = scopeTyp.Name() + "Id" - } else { - belongsToForeignKey = foreignKey - hasOneForeignKey = foreignKey - } - - if scope.HasColumn(belongsToForeignKey) { - foreignKey = belongsToForeignKey - kind = "belongs_to" - } else { - foreignKey = hasOneForeignKey - kind = "has_one" - } - - field.Relationship = &relationship{ForeignKey: foreignKey, ForeignType: foreignType, Kind: kind} - } - default: - field.IsNormal = true - } - } - return []*Field{&field} -} - // Fields get value's fields -func (scope *Scope) Fields(noRelations ...bool) map[string]*Field { - var withRelation = len(noRelations) == 0 - - if withRelation && scope.fields != nil { - return scope.fields - } - - var fields = map[string]*Field{} - if scope.IndirectValue().IsValid() && scope.IndirectValue().Kind() == reflect.Struct { - scopeTyp := scope.IndirectValue().Type() - var hasPrimaryKey = false - for i := 0; i < scopeTyp.NumField(); i++ { - fieldStruct := scopeTyp.Field(i) - if !ast.IsExported(fieldStruct.Name) { - continue - } - for _, field := range scope.fieldFromStruct(fieldStruct, withRelation) { - if field.IsPrimaryKey { - hasPrimaryKey = true - } - if value, ok := fields[field.DBName]; ok { - if value.IsIgnored { - fields[field.DBName] = field - } else { - panic(fmt.Sprintf("Duplicated column name for %v (%v)\n", scope.typeName(), fileWithLineNum())) - } - } else { - fields[field.DBName] = field - } - } - } - - if !hasPrimaryKey { - if field, ok := fields["id"]; ok { - field.IsPrimaryKey = true - } - } - } - - if withRelation { - scope.fields = fields +func (scope *Scope) Fields() map[string]*Field { + fields := map[string]*Field{} + structFields := scope.GetStructFields() + for _, structField := range structFields { + field := Field{StructField: structField} + fields[field.DBName] = &field } return fields diff --git a/scope_private.go b/scope_private.go index 5670b8a1..2e44c7b0 100644 --- a/scope_private.go +++ b/scope_private.go @@ -5,7 +5,6 @@ import ( "database/sql/driver" "errors" "fmt" - "go/ast" "reflect" "regexp" "strconv" @@ -403,7 +402,7 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) { return typ + " " + additionalType } case reflect.Struct: - if field.IsScanner() { + if field.IsScanner { var getScannerValue func(reflect.Value) getScannerValue = func(value reflect.Value) { reflectValue = value @@ -412,7 +411,7 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) { } } getScannerValue(reflectValue.Field(0)) - } else if !field.IsTime() { + } else if !field.IsTime { return typ + " " + additionalType } } @@ -578,24 +577,11 @@ func (scope *Scope) createJoinTable(field *Field) { func (scope *Scope) createTable() *Scope { var sqls []string - fields := scope.Fields() - scopeType := scope.IndirectValue().Type() - for i := 0; i < scopeType.NumField(); i++ { - if !ast.IsExported(scopeType.Field(i).Name) { - continue - } - for _, field := range scope.fieldFromStruct(scopeType.Field(i), false) { - name := field.Name - for _, field := range fields { - if field.Name == name { - if field.IsNormal { - sqlTag := scope.sqlTagForField(field) - sqls = append(sqls, scope.Quote(field.DBName)+" "+sqlTag) - } - scope.createJoinTable(field) - } - } + for _, structField := range scope.GetStructFields() { + if structField.IsNormal { + sqls = append(sqls, scope.Quote(structField.DBName)+" "+structField.SqlTag) } + scope.createJoinTable(structField) } scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec() return scope diff --git a/struct_field.go b/struct_field.go new file mode 100644 index 00000000..4957dd1d --- /dev/null +++ b/struct_field.go @@ -0,0 +1,212 @@ +package gorm + +import ( + "database/sql" + "go/ast" + "reflect" + "strconv" + "time" +) + +type StructField struct { + Name string + DBName string + IsBlank bool + IsPrimaryKey bool + IsScanner bool + IsTime bool + IsNormal bool + IsIgnored bool + DefaultValue *string + SqlTag string + Relationship *relationship +} + +func (scope *Scope) GetStructFields() (fields []*StructField) { + reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) + if reflectValue.Kind() == reflect.Slice { + reflectValue = reflect.Indirect(reflect.New(reflectValue.Elem().Type())) + } + + scopeTyp := reflectValue.Type() + hasPrimaryKey := false + for i := 0; i < scopeTyp.NumField(); i++ { + fieldStruct := scopeTyp.Field(i) + if !ast.IsExported(fieldStruct.Name) { + continue + } + var field *StructField + + if fieldStruct.Tag.Get("sql") == "-" { + field.IsIgnored = true + } else { + sqlSettings := parseTagSetting(fieldStruct.Tag.Get("sql")) + settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) + if _, ok := settings["PRIMARY_KEY"]; ok { + field.IsPrimaryKey = true + hasPrimaryKey = true + } + + if value, ok := sqlSettings["DEFAULT"]; ok { + field.DefaultValue = &value + } + + if value, ok := settings["COLUMN"]; ok { + field.DBName = value + } else { + field.DBName = ToSnake(fieldStruct.Name) + } + + fieldType, indirectType := fieldStruct.Type, fieldStruct.Type + if indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() + } + + if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { + field.IsScanner, field.IsNormal = true, true + } + + if _, isTime := reflect.New(indirectType).Interface().(time.Time); isTime { + field.IsTime, field.IsNormal = true, true + } + + many2many := settings["MANY2MANY"] + foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) + foreignType := SnakeToUpperCamel(settings["FOREIGNTYPE"]) + associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) + if polymorphic := SnakeToUpperCamel(settings["POLYMORPHIC"]); polymorphic != "" { + foreignKey = polymorphic + "Id" + foreignType = polymorphic + "Type" + } + + if !field.IsNormal { + switch indirectType.Kind() { + case reflect.Slice: + typ := indirectType.Elem() + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + if typ.Kind() == reflect.Struct { + kind := "has_many" + + if foreignKey == "" { + foreignKey = indirectType.Name() + "Id" + } + + if associationForeignKey == "" { + associationForeignKey = typ.Name() + "Id" + } + + if many2many != "" { + kind = "many_to_many" + } else if !reflect.New(typ).FieldByName(foreignKey).IsValid() { + foreignKey = "" + } + + field.Relationship = &relationship{ + JoinTable: many2many, + ForeignKey: foreignKey, + ForeignType: foreignType, + AssociationForeignKey: associationForeignKey, + Kind: kind, + } + } else { + field.IsNormal = true + } + case reflect.Struct: + if _, ok := settings["EMBEDDED"]; ok || fieldStruct.Anonymous { + for _, field := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() { + fields = append(fields, field) + } + break + } else { + var belongsToForeignKey, hasOneForeignKey, kind string + + if foreignKey == "" { + belongsToForeignKey = indirectType.Name() + "Id" + hasOneForeignKey = scopeTyp.Name() + "Id" + } else { + belongsToForeignKey = foreignKey + hasOneForeignKey = foreignKey + } + + if _, ok := scopeTyp.FieldByName(belongsToForeignKey); ok { + foreignKey = belongsToForeignKey + kind = "belongs_to" + } else { + foreignKey = hasOneForeignKey + kind = "has_one" + } + + field.Relationship = &relationship{ForeignKey: foreignKey, ForeignType: foreignType, Kind: kind} + } + + default: + field.IsNormal = true + } + } + } + fields = append(fields, field) + } + + if !hasPrimaryKey { + for _, field := range fields { + if field.DBName == "id" { + field.IsPrimaryKey = true + } + } + } + + for _, field := range fields { + var sqlType string + size := 255 + sqlTag := field.Tag.Get("sql") + sqlSetting = parseTagSetting(sqlTag) + + if value, ok := sqlSetting["SIZE"]; ok { + if i, err := strconv.Atoi(value); err == nil { + size = i + } else { + size = 0 + } + } + + if value, ok := sqlSetting["TYPE"]; ok { + typ = value + } + + additionalType := sqlSetting["NOT NULL"] + " " + sqlSetting["UNIQUE"] + if value, ok := sqlSetting["DEFAULT"]; ok { + additionalType = additionalType + "DEFAULT " + value + } + + if field.IsScanner { + var getScannerValue func(reflect.Value) + getScannerValue = func(reflectValue reflect.Value) { + if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner { + getScannerValue(reflectValue.Field(0)) + } + } + getScannerValue(reflectValue.Field(0)) + } + if field.IsNormal { + typ + " " + additionalType + } + } else if !field.IsTime { + return typ + " " + additionalType + } + } + + if len(typ) == 0 { + if field.IsPrimaryKey { + typ = scope.Dialect().PrimaryKeyTag(reflectValue, size) + } else { + typ = scope.Dialect().SqlTag(reflectValue, size) + } + } + + return typ + " " + additionalType + } + return +}