mirror of https://github.com/go-gorm/gorm.git
Pass anonymous struct tests
This commit is contained in:
parent
b2360c11da
commit
d806b7084f
|
@ -3,29 +3,30 @@ package gorm_test
|
|||
import "testing"
|
||||
|
||||
type BasePost struct {
|
||||
Id int64
|
||||
Title string
|
||||
Url string
|
||||
}
|
||||
|
||||
type HNPost struct {
|
||||
Id int64
|
||||
BasePost `gorm:"embedded"`
|
||||
Upvotes int32
|
||||
}
|
||||
|
||||
type EngadgetPost struct {
|
||||
Id int64
|
||||
BasePost
|
||||
ImageUrl string
|
||||
}
|
||||
|
||||
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||
DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||
|
||||
var news HNPost
|
||||
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||
} else {
|
||||
if news.BasePost.Title == "hn_news" {
|
||||
if news.BasePost.Title != "hn_news" {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Query(scope *Scope) {
|
||||
|
@ -57,10 +56,10 @@ func Query(scope *Scope) {
|
|||
|
||||
columns, _ := rows.Columns()
|
||||
var values []interface{}
|
||||
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||
for _, value := range columns {
|
||||
field := elem.FieldByName(SnakeToUpperCamel(strings.ToLower(value)))
|
||||
if field.IsValid() {
|
||||
values = append(values, field.Addr().Interface())
|
||||
if field, ok := fields[value]; ok {
|
||||
values = append(values, field.Field.Addr().Interface())
|
||||
} else {
|
||||
var ignore interface{}
|
||||
values = append(values, &ignore)
|
||||
|
|
1
field.go
1
field.go
|
@ -16,6 +16,7 @@ type relationship struct {
|
|||
type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
Field reflect.Value
|
||||
Value interface{}
|
||||
Tag reflect.StructTag
|
||||
SqlTag string
|
||||
|
|
|
@ -29,7 +29,7 @@ func runMigration() {
|
|||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil {
|
||||
if err := DB.AutoMigrate(&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
}
|
||||
|
|
21
scope.go
21
scope.go
|
@ -33,6 +33,10 @@ func (scope *Scope) IndirectValue() reflect.Value {
|
|||
|
||||
// NewScope create scope for callbacks, including DB's search information
|
||||
func (db *DB) NewScope(value interface{}) *Scope {
|
||||
// reflectKind := reflect.ValueOf(value).Kind()
|
||||
// if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
|
||||
// fmt.Printf("%v %v\n", fileWithLineNum(), "using unaddressable value")
|
||||
// }
|
||||
db.Value = value
|
||||
return &Scope{db: db, Search: db.search, Value: value}
|
||||
}
|
||||
|
@ -252,6 +256,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
|||
|
||||
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
|
||||
indirectValue := reflect.Indirect(value)
|
||||
field.Field = value
|
||||
field.Value = value.Interface()
|
||||
field.IsBlank = isBlank(value)
|
||||
|
||||
|
@ -315,7 +320,12 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
|||
case reflect.Struct:
|
||||
embedded := settings["EMBEDDED"]
|
||||
if embedded != "" {
|
||||
return scope.New(field.Value).Fields()
|
||||
var fields []*Field
|
||||
for _, field := range scope.New(field.Field.Addr().Interface()).Fields() {
|
||||
field.DBName = prefix + field.DBName
|
||||
fields = append(fields, field)
|
||||
}
|
||||
return fields
|
||||
} else if !field.IsTime() && !field.IsScanner() {
|
||||
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
||||
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
||||
|
@ -336,7 +346,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
|||
}
|
||||
|
||||
// Fields get value's fields
|
||||
func (scope *Scope) Fields() (fields []*Field) {
|
||||
func (scope *Scope) Fields() map[string]*Field {
|
||||
var fields = map[string]*Field{}
|
||||
if scope.IndirectValue().IsValid() {
|
||||
scopeTyp := scope.IndirectValue().Type()
|
||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||
|
@ -344,10 +355,12 @@ func (scope *Scope) Fields() (fields []*Field) {
|
|||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
fields = append(fields, scope.fieldFromStruct(fieldStruct)...)
|
||||
for _, field := range scope.fieldFromStruct(fieldStruct) {
|
||||
fields[field.DBName] = field
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// Raw set sql
|
||||
|
|
Loading…
Reference in New Issue