Add IndirectValue for Scope

This commit is contained in:
Jinzhu 2014-07-30 14:58:00 +08:00
parent ba95de5c50
commit 0d3085393e
4 changed files with 112 additions and 102 deletions

View File

@ -2,7 +2,7 @@ package gorm_test
import "testing" import "testing"
func TestSubStruct(t *testing.T) { func TestHasOneAndHasManyAssociation(t *testing.T) {
db.DropTable(Category{}) db.DropTable(Category{})
db.DropTable(Post{}) db.DropTable(Post{})
db.DropTable(Comment{}) db.DropTable(Comment{})
@ -115,8 +115,8 @@ func TestRelated(t *testing.T) {
var creditcard CreditCard var creditcard CreditCard
var user3 User var user3 User
db.Debug().First(&creditcard, "number = ?", "1234567890") db.First(&creditcard, "number = ?", "1234567890")
db.Debug().Model(&creditcard).Related(&user3) db.Model(&creditcard).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name { if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly") t.Errorf("Should get user from credit card correctly")
} }
@ -126,7 +126,7 @@ func TestRelated(t *testing.T) {
} }
} }
func TestQueryManyToManyWithRelated(t *testing.T) { func TestManyToMany(t *testing.T) {
var languages = []Language{{Name: "ZH"}, {Name: "EN"}, {Name: "DE"}} var languages = []Language{{Name: "ZH"}, {Name: "EN"}, {Name: "DE"}}
user := User{Name: "Many2Many", Languages: languages} user := User{Name: "Many2Many", Languages: languages}
db.Save(&user) db.Save(&user)

View File

@ -16,7 +16,7 @@ func Query(scope *Scope) {
destType reflect.Type destType reflect.Type
) )
var dest = reflect.Indirect(reflect.ValueOf(scope.Value)) var dest = scope.IndirectValue()
if value, ok := scope.Get("gorm:query_destination"); ok { if value, ok := scope.Get("gorm:query_destination"); ok {
dest = reflect.Indirect(reflect.ValueOf(value)) dest = reflect.Indirect(reflect.ValueOf(value))
} }

198
scope.go
View File

@ -12,14 +12,23 @@ import (
) )
type Scope struct { type Scope struct {
Value interface{} Value interface{}
Search *search indirectValue *reflect.Value
Sql string Search *search
SqlVars []interface{} Sql string
db *DB SqlVars []interface{}
_values map[string]interface{} db *DB
skipLeft bool _values map[string]interface{}
primaryKey string skipLeft bool
primaryKey string
}
func (scope *Scope) IndirectValue() reflect.Value {
if scope.indirectValue == nil {
value := reflect.Indirect(reflect.ValueOf(scope.Value))
scope.indirectValue = &value
}
return *scope.indirectValue
} }
// NewScope create scope for callbacks, including DB's search information // NewScope create scope for callbacks, including DB's search information
@ -93,10 +102,8 @@ func (scope *Scope) PrimaryKeyZero() bool {
// PrimaryKeyValue get the primary key's value // PrimaryKeyValue get the primary key's value
func (scope *Scope) PrimaryKeyValue() interface{} { func (scope *Scope) PrimaryKeyValue() interface{} {
data := reflect.Indirect(reflect.ValueOf(scope.Value)) if scope.IndirectValue().Kind() == reflect.Struct {
if field := scope.IndirectValue().FieldByName(SnakeToUpperCamel(scope.PrimaryKey())); field.IsValid() {
if data.Kind() == reflect.Struct {
if field := data.FieldByName(SnakeToUpperCamel(scope.PrimaryKey())); field.IsValid() {
return field.Interface() return field.Interface()
} }
} }
@ -120,8 +127,7 @@ func (scope *Scope) SetColumn(column string, value interface{}) {
return return
} }
data := reflect.Indirect(reflect.ValueOf(scope.Value)) setFieldValue(scope.IndirectValue().FieldByName(SnakeToUpperCamel(column)), value)
setFieldValue(data.FieldByName(SnakeToUpperCamel(column)), value)
} }
// CallMethod invoke method with necessary argument // CallMethod invoke method with necessary argument
@ -151,7 +157,7 @@ func (scope *Scope) CallMethod(name string) {
} }
} }
if values := reflect.Indirect(reflect.ValueOf(scope.Value)); values.Kind() == reflect.Slice { if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
call(values.Index(i).Addr().Interface()) call(values.Index(i).Addr().Interface())
} }
@ -178,8 +184,8 @@ func (scope *Scope) TableName() string {
scope.Err(errors.New("can't get table name")) scope.Err(errors.New("can't get table name"))
return "" return ""
} }
data := reflect.Indirect(reflect.ValueOf(scope.Value))
data := scope.IndirectValue()
if data.Kind() == reflect.Slice { if data.Kind() == reflect.Slice {
elem := data.Type().Elem() elem := data.Type().Elem()
if elem.Kind() == reflect.Ptr { if elem.Kind() == reflect.Ptr {
@ -228,9 +234,89 @@ func (scope *Scope) CombinedConditionSql() string {
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
} }
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
var field Field
field.Name = fieldStruct.Name
field.DBName = ToSnake(fieldStruct.Name)
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
indirectValue := reflect.Indirect(value)
field.Value = value.Interface()
field.IsBlank = isBlank(value)
// Search for primary key tag identifier
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
if _, ok := settings["PRIMARY_KEY"]; scope.PrimaryKey() == field.DBName || ok {
field.isPrimaryKey = true
}
if field.isPrimaryKey {
scope.primaryKey = field.DBName
}
if scope.db != nil {
field.Tag = fieldStruct.Tag
field.SqlTag = scope.sqlTagForField(&field)
// parse association
typ := indirectValue.Type()
foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"])
associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"])
many2many := settings["MANY2MANY"]
scopeTyp := scope.IndirectValue().Type()
switch indirectValue.Kind() {
case reflect.Slice:
typ = typ.Elem()
if typ.Kind() == reflect.Struct {
if foreignKey == "" {
foreignKey = scopeTyp.Name() + "Id"
}
if associationForeignKey == "" {
associationForeignKey = typ.Name() + "Id"
}
// if not many to many, foreign key could be null
if many2many == "" {
if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
foreignKey = ""
}
}
field.AfterAssociation = true
field.JoinTable = &joinTable{
joinTable: many2many,
foreignKey: foreignKey,
associationForeignKey: associationForeignKey,
}
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.JoinTable = &joinTable{foreignKey: field.Name + "Id"}
field.BeforeAssociation = true
} else if scope.HasColumn(foreignKey) {
field.JoinTable = &joinTable{foreignKey: foreignKey}
field.BeforeAssociation = true
} else {
if foreignKey == "" {
foreignKey = scopeTyp.Name() + "Id"
}
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.JoinTable = &joinTable{foreignKey: foreignKey}
}
field.AfterAssociation = true
}
}
}
}
return &field
}
// Fields get value's fields // Fields get value's fields
func (scope *Scope) Fields() []*Field { func (scope *Scope) Fields() []*Field {
indirectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) indirectValue := scope.IndirectValue()
fields := []*Field{} fields := []*Field{}
if !indirectValue.IsValid() { if !indirectValue.IsValid() {
@ -243,83 +329,7 @@ func (scope *Scope) Fields() []*Field {
if !ast.IsExported(fieldStruct.Name) { if !ast.IsExported(fieldStruct.Name) {
continue continue
} }
fields = append(fields, scope.fieldFromStruct(fieldStruct))
var field Field
field.Name = fieldStruct.Name
field.DBName = ToSnake(fieldStruct.Name)
value := indirectValue.FieldByName(fieldStruct.Name)
field.Value = value.Interface()
field.IsBlank = isBlank(value)
// Search for primary key tag identifier
settings := parseTagSetting(fieldStruct.Tag.Get("gorm"))
if _, ok := settings["PRIMARY_KEY"]; scope.PrimaryKey() == field.DBName || ok {
field.isPrimaryKey = true
}
if field.isPrimaryKey {
scope.primaryKey = field.DBName
}
if scope.db != nil {
indirectValue := reflect.Indirect(value)
field.Tag = fieldStruct.Tag
field.SqlTag = scope.sqlTagForField(&field)
// parse association
typ := indirectValue.Type()
foreignKey := SnakeToUpperCamel(settings["FOREIGNKEY"])
associationForeignKey := SnakeToUpperCamel(settings["ASSOCIATIONFOREIGNKEY"])
many2many := settings["MANY2MANY"]
switch indirectValue.Kind() {
case reflect.Slice:
typ = typ.Elem()
if typ.Kind() == reflect.Struct {
if foreignKey == "" {
foreignKey = scopeTyp.Name() + "Id"
}
if associationForeignKey == "" {
associationForeignKey = typ.Name() + "Id"
}
// if not many to many, foreign key could be null
if many2many == "" {
if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
foreignKey = ""
}
}
field.AfterAssociation = true
field.JoinTable = &joinTable{
joinTable: many2many,
foreignKey: foreignKey,
associationForeignKey: associationForeignKey,
}
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.JoinTable = &joinTable{foreignKey: field.Name + "Id"}
field.BeforeAssociation = true
} else if scope.HasColumn(foreignKey) {
field.JoinTable = &joinTable{foreignKey: foreignKey}
field.BeforeAssociation = true
} else {
if foreignKey == "" {
foreignKey = scopeTyp.Name() + "Id"
}
if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() {
field.JoinTable = &joinTable{foreignKey: foreignKey}
}
field.AfterAssociation = true
}
}
}
}
fields = append(fields, &field)
} }
return fields return fields

View File

@ -259,7 +259,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
} }
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
data := reflect.Indirect(reflect.ValueOf(scope.Value)) data := scope.IndirectValue()
if !data.CanAddr() { if !data.CanAddr() {
return values, true return values, true
} }
@ -381,7 +381,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value)) dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search = scope.Search.clone().selects(column) scope.Search = scope.Search.clone().selects(column)
if dest.Kind() != reflect.Slice { if dest.Kind() != reflect.Slice {
scope.Err(errors.New("Results should be a slice")) scope.Err(errors.New("results should be a slice"))
return scope return scope
} }
@ -404,7 +404,7 @@ func (scope *Scope) count(value interface{}) *Scope {
} }
func (scope *Scope) typeName() string { func (scope *Scope) typeName() string {
value := reflect.Indirect(reflect.ValueOf(scope.Value)) value := scope.IndirectValue()
if value.Kind() == reflect.Slice { if value.Kind() == reflect.Slice {
return value.Type().Elem().Name() return value.Type().Elem().Name()
} else { } else {