Pass anonymous struct tests

This commit is contained in:
Jinzhu 2014-08-28 18:25:05 +08:00
parent b2360c11da
commit d806b7084f
5 changed files with 26 additions and 12 deletions

View File

@ -3,29 +3,30 @@ package gorm_test
import "testing"
type BasePost struct {
Id int64
Title string
Url string
}
type HNPost struct {
Id int64
BasePost `gorm:"embedded"`
Upvotes int32
}
type EngadgetPost struct {
Id int64
BasePost
ImageUrl string
}
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else {
if news.BasePost.Title == "hn_news" {
if news.BasePost.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
}

View File

@ -3,7 +3,6 @@ package gorm
import (
"fmt"
"reflect"
"strings"
)
func Query(scope *Scope) {
@ -57,10 +56,10 @@ func Query(scope *Scope) {
columns, _ := rows.Columns()
var values []interface{}
fields := scope.New(elem.Addr().Interface()).Fields()
for _, value := range columns {
field := elem.FieldByName(SnakeToUpperCamel(strings.ToLower(value)))
if field.IsValid() {
values = append(values, field.Addr().Interface())
if field, ok := fields[value]; ok {
values = append(values, field.Field.Addr().Interface())
} else {
var ignore interface{}
values = append(values, &ignore)

View File

@ -16,6 +16,7 @@ type relationship struct {
type Field struct {
Name string
DBName string
Field reflect.Value
Value interface{}
Tag reflect.StructTag
SqlTag string

View File

@ -29,7 +29,7 @@ func runMigration() {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil {
if err := DB.AutoMigrate(&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
}

View File

@ -33,6 +33,10 @@ func (scope *Scope) IndirectValue() reflect.Value {
// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
// reflectKind := reflect.ValueOf(value).Kind()
// if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
// fmt.Printf("%v %v\n", fileWithLineNum(), "using unaddressable value")
// }
db.Value = value
return &Scope{db: db, Search: db.search, Value: value}
}
@ -252,6 +256,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
indirectValue := reflect.Indirect(value)
field.Field = value
field.Value = value.Interface()
field.IsBlank = isBlank(value)
@ -315,7 +320,12 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
case reflect.Struct:
embedded := settings["EMBEDDED"]
if embedded != "" {
return scope.New(field.Value).Fields()
var fields []*Field
for _, field := range scope.New(field.Field.Addr().Interface()).Fields() {
field.DBName = prefix + field.DBName
fields = append(fields, field)
}
return fields
} else if !field.IsTime() && !field.IsScanner() {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
@ -336,7 +346,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
}
// Fields get value's fields
func (scope *Scope) Fields() (fields []*Field) {
func (scope *Scope) Fields() map[string]*Field {
var fields = map[string]*Field{}
if scope.IndirectValue().IsValid() {
scopeTyp := scope.IndirectValue().Type()
for i := 0; i < scopeTyp.NumField(); i++ {
@ -344,10 +355,12 @@ func (scope *Scope) Fields() (fields []*Field) {
if !ast.IsExported(fieldStruct.Name) {
continue
}
fields = append(fields, scope.fieldFromStruct(fieldStruct)...)
for _, field := range scope.fieldFromStruct(fieldStruct) {
fields[field.DBName] = field
}
}
return
}
return fields
}
// Raw set sql