From 0d3085393e977f036f73521d88fdf9c3fb869818 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 30 Jul 2014 14:58:00 +0800 Subject: [PATCH] Add IndirectValue for Scope --- association_test.go | 8 +- callback_query.go | 2 +- scope.go | 198 +++++++++++++++++++++++--------------------- scope_private.go | 6 +- 4 files changed, 112 insertions(+), 102 deletions(-) diff --git a/association_test.go b/association_test.go index 650fb6a9..e2dd8142 100644 --- a/association_test.go +++ b/association_test.go @@ -2,7 +2,7 @@ package gorm_test import "testing" -func TestSubStruct(t *testing.T) { +func TestHasOneAndHasManyAssociation(t *testing.T) { db.DropTable(Category{}) db.DropTable(Post{}) db.DropTable(Comment{}) @@ -115,8 +115,8 @@ func TestRelated(t *testing.T) { var creditcard CreditCard var user3 User - db.Debug().First(&creditcard, "number = ?", "1234567890") - db.Debug().Model(&creditcard).Related(&user3) + db.First(&creditcard, "number = ?", "1234567890") + db.Model(&creditcard).Related(&user3) if user3.Id != user.Id || user3.Name != user.Name { t.Errorf("Should get user from credit card correctly") } @@ -126,7 +126,7 @@ func TestRelated(t *testing.T) { } } -func TestQueryManyToManyWithRelated(t *testing.T) { +func TestManyToMany(t *testing.T) { var languages = []Language{{Name: "ZH"}, {Name: "EN"}, {Name: "DE"}} user := User{Name: "Many2Many", Languages: languages} db.Save(&user) diff --git a/callback_query.go b/callback_query.go index 6e592ac9..acdbde20 100644 --- a/callback_query.go +++ b/callback_query.go @@ -16,7 +16,7 @@ func Query(scope *Scope) { destType reflect.Type ) - var dest = reflect.Indirect(reflect.ValueOf(scope.Value)) + var dest = scope.IndirectValue() if value, ok := scope.Get("gorm:query_destination"); ok { dest = reflect.Indirect(reflect.ValueOf(value)) } diff --git a/scope.go b/scope.go index 3856a2fc..287a803a 100644 --- a/scope.go +++ b/scope.go @@ -12,14 +12,23 @@ import ( ) type Scope struct { - Value interface{} - Search *search - Sql string - SqlVars []interface{} - db *DB - _values map[string]interface{} - skipLeft bool - primaryKey string + Value interface{} + indirectValue *reflect.Value + Search *search + Sql string + SqlVars []interface{} + db *DB + _values map[string]interface{} + skipLeft bool + primaryKey string +} + +func (scope *Scope) IndirectValue() reflect.Value { + if scope.indirectValue == nil { + value := reflect.Indirect(reflect.ValueOf(scope.Value)) + scope.indirectValue = &value + } + return *scope.indirectValue } // NewScope create scope for callbacks, including DB's search information @@ -93,10 +102,8 @@ func (scope *Scope) PrimaryKeyZero() bool { // PrimaryKeyValue get the primary key's value func (scope *Scope) PrimaryKeyValue() interface{} { - data := reflect.Indirect(reflect.ValueOf(scope.Value)) - - if data.Kind() == reflect.Struct { - if field := data.FieldByName(SnakeToUpperCamel(scope.PrimaryKey())); field.IsValid() { + if scope.IndirectValue().Kind() == reflect.Struct { + if field := scope.IndirectValue().FieldByName(SnakeToUpperCamel(scope.PrimaryKey())); field.IsValid() { return field.Interface() } } @@ -120,8 +127,7 @@ func (scope *Scope) SetColumn(column string, value interface{}) { return } - data := reflect.Indirect(reflect.ValueOf(scope.Value)) - setFieldValue(data.FieldByName(SnakeToUpperCamel(column)), value) + setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value) } // CallMethod invoke method with necessary argument @@ -151,7 +157,7 @@ func (scope *Scope) CallMethod(name string) { } } - if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice { + if values := scope.IndirectValue(); values.Kind() == reflect.Slice { for i := 0; i < values.Len(); i++ { call(values.Index(i).Addr().Interface()) } @@ -178,8 +184,8 @@ func (scope *Scope) TableName() string { scope.Err(errors.New("can't get table name")) return "" } - data := reflect.Indirect(reflect.ValueOf(scope.Value)) + data := scope.IndirectValue() if data.Kind() == reflect.Slice { elem := data.Type().Elem() if elem.Kind() == reflect.Ptr { @@ -228,9 +234,89 @@ func (scope *Scope) CombinedConditionSql() string { scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() } +func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { + var field Field + field.Name = fieldStruct.Name + field.DBName = ToSnake(fieldStruct.Name) + + value := scope.IndirectValue().FieldByName(fieldStruct.Name) + indirectValue := reflect.Indirect(value) + field.Value = value.Interface() + field.IsBlank = isBlank(value) + + // Search for primary key tag identifier + settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) + if _, ok := settings["PRIMARY_KEY"]; scope.PrimaryKey() == field.DBName || ok { + field.isPrimaryKey = true + } + + if field.isPrimaryKey { + scope.primaryKey = field.DBName + } + + if scope.db != nil { + field.Tag = fieldStruct.Tag + field.SqlTag = scope.sqlTagForField(&field) + + // parse association + typ := indirectValue.Type() + foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) + associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) + many2many := settings["MANY2MANY"] + scopeTyp := scope.IndirectValue().Type() + + switch indirectValue.Kind() { + case reflect.Slice: + typ = typ.Elem() + + if typ.Kind() == reflect.Struct { + if foreignKey == "" { + foreignKey = scopeTyp.Name() + "Id" + } + if associationForeignKey == "" { + associationForeignKey = typ.Name() + "Id" + } + + // if not many to many, foreign key could be null + if many2many == "" { + if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + foreignKey = "" + } + } + + field.AfterAssociation = true + field.JoinTable = &joinTable{ + joinTable: many2many, + foreignKey: foreignKey, + associationForeignKey: associationForeignKey, + } + } + case reflect.Struct: + if !field.IsTime() && !field.IsScanner() { + if foreignKey == "" && scope.HasColumn(field.Name+"Id") { + field.JoinTable = &joinTable{foreignKey: field.Name + "Id"} + field.BeforeAssociation = true + } else if scope.HasColumn(foreignKey) { + field.JoinTable = &joinTable{foreignKey: foreignKey} + field.BeforeAssociation = true + } else { + if foreignKey == "" { + foreignKey = scopeTyp.Name() + "Id" + } + if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + field.JoinTable = &joinTable{foreignKey: foreignKey} + } + field.AfterAssociation = true + } + } + } + } + return &field +} + // Fields get value's fields func (scope *Scope) Fields() []*Field { - indirectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) + indirectValue := scope.IndirectValue() fields := []*Field{} if !indirectValue.IsValid() { @@ -243,83 +329,7 @@ func (scope *Scope) Fields() []*Field { if !ast.IsExported(fieldStruct.Name) { continue } - - var field Field - field.Name = fieldStruct.Name - field.DBName = ToSnake(fieldStruct.Name) - - value := indirectValue.FieldByName(fieldStruct.Name) - field.Value = value.Interface() - field.IsBlank = isBlank(value) - - // Search for primary key tag identifier - settings := parseTagSetting(fieldStruct.Tag.Get("gorm")) - if _, ok := settings["PRIMARY_KEY"]; scope.PrimaryKey() == field.DBName || ok { - field.isPrimaryKey = true - } - - if field.isPrimaryKey { - scope.primaryKey = field.DBName - } - - if scope.db != nil { - indirectValue := reflect.Indirect(value) - field.Tag = fieldStruct.Tag - field.SqlTag = scope.sqlTagForField(&field) - - // parse association - typ := indirectValue.Type() - foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"]) - associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"]) - many2many := settings["MANY2MANY"] - - switch indirectValue.Kind() { - case reflect.Slice: - typ = typ.Elem() - - if typ.Kind() == reflect.Struct { - if foreignKey == "" { - foreignKey = scopeTyp.Name() + "Id" - } - if associationForeignKey == "" { - associationForeignKey = typ.Name() + "Id" - } - - // if not many to many, foreign key could be null - if many2many == "" { - if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - foreignKey = "" - } - } - - field.AfterAssociation = true - field.JoinTable = &joinTable{ - joinTable: many2many, - foreignKey: foreignKey, - associationForeignKey: associationForeignKey, - } - } - case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { - if foreignKey == "" && scope.HasColumn(field.Name+"Id") { - field.JoinTable = &joinTable{foreignKey: field.Name + "Id"} - field.BeforeAssociation = true - } else if scope.HasColumn(foreignKey) { - field.JoinTable = &joinTable{foreignKey: foreignKey} - field.BeforeAssociation = true - } else { - if foreignKey == "" { - foreignKey = scopeTyp.Name() + "Id" - } - if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - field.JoinTable = &joinTable{foreignKey: foreignKey} - } - field.AfterAssociation = true - } - } - } - } - fields = append(fields, &field) + fields = append(fields, scope.fieldFromStruct(fieldStruct)) } return fields diff --git a/scope_private.go b/scope_private.go index e2be0d01..ae87dbcd 100644 --- a/scope_private.go +++ b/scope_private.go @@ -259,7 +259,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { } func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { - data := reflect.Indirect(reflect.ValueOf(scope.Value)) + data := scope.IndirectValue() if !data.CanAddr() { return values, true } @@ -381,7 +381,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) scope.Search = scope.Search.clone().selects(column) if dest.Kind() != reflect.Slice { - scope.Err(errors.New("Results should be a slice")) + scope.Err(errors.New("results should be a slice")) return scope } @@ -404,7 +404,7 @@ func (scope *Scope) count(value interface{}) *Scope { } func (scope *Scope) typeName() string { - value := reflect.Indirect(reflect.ValueOf(scope.Value)) + value := scope.IndirectValue() if value.Kind() == reflect.Slice { return value.Type().Elem().Name() } else {