Pass anonymous struct tests

This commit is contained in:
Jinzhu 2014-08-28 18:25:05 +08:00
parent b2360c11da
commit d806b7084f
5 changed files with 26 additions and 12 deletions

View File

@ -3,29 +3,30 @@ package gorm_test
import "testing" import "testing"
type BasePost struct { type BasePost struct {
Id int64
Title string Title string
Url string Url string
} }
type HNPost struct { type HNPost struct {
Id int64
BasePost `gorm:"embedded"` BasePost `gorm:"embedded"`
Upvotes int32 Upvotes int32
} }
type EngadgetPost struct { type EngadgetPost struct {
Id int64
BasePost BasePost
ImageUrl string ImageUrl string
} }
func TestSaveAndQueryEmbeddedStruct(t *testing.T) { func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}}) DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err) t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else { } else {
if news.BasePost.Title == "hn_news" { if news.BasePost.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
} }

View File

@ -3,7 +3,6 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
) )
func Query(scope *Scope) { func Query(scope *Scope) {
@ -57,10 +56,10 @@ func Query(scope *Scope) {
columns, _ := rows.Columns() columns, _ := rows.Columns()
var values []interface{} var values []interface{}
fields := scope.New(elem.Addr().Interface()).Fields()
for _, value := range columns { for _, value := range columns {
field := elem.FieldByName(SnakeToUpperCamel(strings.ToLower(value))) if field, ok := fields[value]; ok {
if field.IsValid() { values = append(values, field.Field.Addr().Interface())
values = append(values, field.Addr().Interface())
} else { } else {
var ignore interface{} var ignore interface{}
values = append(values, &ignore) values = append(values, &ignore)

View File

@ -16,6 +16,7 @@ type relationship struct {
type Field struct { type Field struct {
Name string Name string
DBName string DBName string
Field reflect.Value
Value interface{} Value interface{}
Tag reflect.StructTag Tag reflect.StructTag
SqlTag string SqlTag string

View File

@ -29,7 +29,7 @@ func runMigration() {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil { if err := DB.AutoMigrate(&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
} }

View File

@ -33,6 +33,10 @@ func (scope *Scope) IndirectValue() reflect.Value {
// NewScope create scope for callbacks, including DB's search information // NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope { func (db *DB) NewScope(value interface{}) *Scope {
// reflectKind := reflect.ValueOf(value).Kind()
// if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
// fmt.Printf("%v %v\n", fileWithLineNum(), "using unaddressable value")
// }
db.Value = value db.Value = value
return &Scope{db: db, Search: db.search, Value: value} return &Scope{db: db, Search: db.search, Value: value}
} }
@ -252,6 +256,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
value := scope.IndirectValue().FieldByName(fieldStruct.Name) value := scope.IndirectValue().FieldByName(fieldStruct.Name)
indirectValue := reflect.Indirect(value) indirectValue := reflect.Indirect(value)
field.Field = value
field.Value = value.Interface() field.Value = value.Interface()
field.IsBlank = isBlank(value) field.IsBlank = isBlank(value)
@ -315,7 +320,12 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
case reflect.Struct: case reflect.Struct:
embedded := settings["EMBEDDED"] embedded := settings["EMBEDDED"]
if embedded != "" { if embedded != "" {
return scope.New(field.Value).Fields() var fields []*Field
for _, field := range scope.New(field.Field.Addr().Interface()).Fields() {
field.DBName = prefix + field.DBName
fields = append(fields, field)
}
return fields
} else if !field.IsTime() && !field.IsScanner() { } else if !field.IsTime() && !field.IsScanner() {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") { if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"} field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
@ -336,7 +346,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
} }
// Fields get value's fields // Fields get value's fields
func (scope *Scope) Fields() (fields []*Field) { func (scope *Scope) Fields() map[string]*Field {
var fields = map[string]*Field{}
if scope.IndirectValue().IsValid() { if scope.IndirectValue().IsValid() {
scopeTyp := scope.IndirectValue().Type() scopeTyp := scope.IndirectValue().Type()
for i := 0; i < scopeTyp.NumField(); i++ { for i := 0; i < scopeTyp.NumField(); i++ {
@ -344,10 +355,12 @@ func (scope *Scope) Fields() (fields []*Field) {
if !ast.IsExported(fieldStruct.Name) { if !ast.IsExported(fieldStruct.Name) {
continue continue
} }
fields = append(fields, scope.fieldFromStruct(fieldStruct)...) for _, field := range scope.fieldFromStruct(fieldStruct) {
fields[field.DBName] = field
}
} }
} }
return return fields
} }
// Raw set sql // Raw set sql