mirror of https://github.com/go-gorm/gorm.git
Pass anonymous struct tests
This commit is contained in:
parent
b2360c11da
commit
d806b7084f
|
@ -3,29 +3,30 @@ package gorm_test
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
type BasePost struct {
|
type BasePost struct {
|
||||||
Id int64
|
|
||||||
Title string
|
Title string
|
||||||
Url string
|
Url string
|
||||||
}
|
}
|
||||||
|
|
||||||
type HNPost struct {
|
type HNPost struct {
|
||||||
|
Id int64
|
||||||
BasePost `gorm:"embedded"`
|
BasePost `gorm:"embedded"`
|
||||||
Upvotes int32
|
Upvotes int32
|
||||||
}
|
}
|
||||||
|
|
||||||
type EngadgetPost struct {
|
type EngadgetPost struct {
|
||||||
|
Id int64
|
||||||
BasePost
|
BasePost
|
||||||
ImageUrl string
|
ImageUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||||
DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}})
|
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||||
|
|
||||||
var news HNPost
|
var news HNPost
|
||||||
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
||||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
} else {
|
} else {
|
||||||
if news.BasePost.Title == "hn_news" {
|
if news.BasePost.Title != "hn_news" {
|
||||||
t.Errorf("embedded struct's value should be scanned correctly")
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(scope *Scope) {
|
func Query(scope *Scope) {
|
||||||
|
@ -57,10 +56,10 @@ func Query(scope *Scope) {
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
var values []interface{}
|
var values []interface{}
|
||||||
|
fields := scope.New(elem.Addr().Interface()).Fields()
|
||||||
for _, value := range columns {
|
for _, value := range columns {
|
||||||
field := elem.FieldByName(SnakeToUpperCamel(strings.ToLower(value)))
|
if field, ok := fields[value]; ok {
|
||||||
if field.IsValid() {
|
values = append(values, field.Field.Addr().Interface())
|
||||||
values = append(values, field.Addr().Interface())
|
|
||||||
} else {
|
} else {
|
||||||
var ignore interface{}
|
var ignore interface{}
|
||||||
values = append(values, &ignore)
|
values = append(values, &ignore)
|
||||||
|
|
1
field.go
1
field.go
|
@ -16,6 +16,7 @@ type relationship struct {
|
||||||
type Field struct {
|
type Field struct {
|
||||||
Name string
|
Name string
|
||||||
DBName string
|
DBName string
|
||||||
|
Field reflect.Value
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Tag reflect.StructTag
|
Tag reflect.StructTag
|
||||||
SqlTag string
|
SqlTag string
|
||||||
|
|
|
@ -29,7 +29,7 @@ func runMigration() {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil {
|
if err := DB.AutoMigrate(&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}).Error; err != nil {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
21
scope.go
21
scope.go
|
@ -33,6 +33,10 @@ func (scope *Scope) IndirectValue() reflect.Value {
|
||||||
|
|
||||||
// NewScope create scope for callbacks, including DB's search information
|
// NewScope create scope for callbacks, including DB's search information
|
||||||
func (db *DB) NewScope(value interface{}) *Scope {
|
func (db *DB) NewScope(value interface{}) *Scope {
|
||||||
|
// reflectKind := reflect.ValueOf(value).Kind()
|
||||||
|
// if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
|
||||||
|
// fmt.Printf("%v %v\n", fileWithLineNum(), "using unaddressable value")
|
||||||
|
// }
|
||||||
db.Value = value
|
db.Value = value
|
||||||
return &Scope{db: db, Search: db.search, Value: value}
|
return &Scope{db: db, Search: db.search, Value: value}
|
||||||
}
|
}
|
||||||
|
@ -252,6 +256,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
|
|
||||||
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
|
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
|
||||||
indirectValue := reflect.Indirect(value)
|
indirectValue := reflect.Indirect(value)
|
||||||
|
field.Field = value
|
||||||
field.Value = value.Interface()
|
field.Value = value.Interface()
|
||||||
field.IsBlank = isBlank(value)
|
field.IsBlank = isBlank(value)
|
||||||
|
|
||||||
|
@ -315,7 +320,12 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
embedded := settings["EMBEDDED"]
|
embedded := settings["EMBEDDED"]
|
||||||
if embedded != "" {
|
if embedded != "" {
|
||||||
return scope.New(field.Value).Fields()
|
var fields []*Field
|
||||||
|
for _, field := range scope.New(field.Field.Addr().Interface()).Fields() {
|
||||||
|
field.DBName = prefix + field.DBName
|
||||||
|
fields = append(fields, field)
|
||||||
|
}
|
||||||
|
return fields
|
||||||
} else if !field.IsTime() && !field.IsScanner() {
|
} else if !field.IsTime() && !field.IsScanner() {
|
||||||
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
||||||
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
||||||
|
@ -336,7 +346,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fields get value's fields
|
// Fields get value's fields
|
||||||
func (scope *Scope) Fields() (fields []*Field) {
|
func (scope *Scope) Fields() map[string]*Field {
|
||||||
|
var fields = map[string]*Field{}
|
||||||
if scope.IndirectValue().IsValid() {
|
if scope.IndirectValue().IsValid() {
|
||||||
scopeTyp := scope.IndirectValue().Type()
|
scopeTyp := scope.IndirectValue().Type()
|
||||||
for i := 0; i < scopeTyp.NumField(); i++ {
|
for i := 0; i < scopeTyp.NumField(); i++ {
|
||||||
|
@ -344,10 +355,12 @@ func (scope *Scope) Fields() (fields []*Field) {
|
||||||
if !ast.IsExported(fieldStruct.Name) {
|
if !ast.IsExported(fieldStruct.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fields = append(fields, scope.fieldFromStruct(fieldStruct)...)
|
for _, field := range scope.fieldFromStruct(fieldStruct) {
|
||||||
|
fields[field.DBName] = field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
}
|
||||||
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
// Raw set sql
|
// Raw set sql
|
||||||
|
|
Loading…
Reference in New Issue