diff --git a/anonymous_struct_test.go b/anonymous_struct_test.go index 75bc3051..4bf1374d 100644 --- a/anonymous_struct_test.go +++ b/anonymous_struct_test.go @@ -3,29 +3,30 @@ package gorm_test import "testing" type BasePost struct { - Id int64 Title string Url string } type HNPost struct { + Id int64 BasePost `gorm:"embedded"` Upvotes int32 } type EngadgetPost struct { + Id int64 BasePost ImageUrl string } func TestSaveAndQueryEmbeddedStruct(t *testing.T) { - DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}}) + DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) var news HNPost if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else { - if news.BasePost.Title == "hn_news" { + if news.BasePost.Title != "hn_news" { t.Errorf("embedded struct's value should be scanned correctly") } } diff --git a/callback_query.go b/callback_query.go index c1dba707..7eaf5573 100644 --- a/callback_query.go +++ b/callback_query.go @@ -3,7 +3,6 @@ package gorm import ( "fmt" "reflect" - "strings" ) func Query(scope *Scope) { @@ -57,10 +56,10 @@ func Query(scope *Scope) { columns, _ := rows.Columns() var values []interface{} + fields := scope.New(elem.Addr().Interface()).Fields() for _, value := range columns { - field := elem.FieldByName(SnakeToUpperCamel(strings.ToLower(value))) - if field.IsValid() { - values = append(values, field.Addr().Interface()) + if field, ok := fields[value]; ok { + values = append(values, field.Field.Addr().Interface()) } else { var ignore interface{} values = append(values, &ignore) diff --git a/field.go b/field.go index f49bc7fe..9e9e6660 100644 --- a/field.go +++ b/field.go @@ -16,6 +16,7 @@ type relationship struct { type Field struct { Name string DBName string + Field reflect.Value Value interface{} Tag reflect.StructTag SqlTag string diff --git a/migration_test.go b/migration_test.go index 4e3bfb5a..71fd831c 100644 --- a/migration_test.go +++ b/migration_test.go @@ -29,7 +29,7 @@ func runMigration() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } - if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil { + if err := DB.AutoMigrate(&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } } diff --git a/scope.go b/scope.go index 96d280d1..52cd7d87 100644 --- a/scope.go +++ b/scope.go @@ -33,6 +33,10 @@ func (scope *Scope) IndirectValue() reflect.Value { // NewScope create scope for callbacks, including DB's search information func (db *DB) NewScope(value interface{}) *Scope { + // reflectKind := reflect.ValueOf(value).Kind() + // if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) { + // fmt.Printf("%v %v\n", fileWithLineNum(), "using unaddressable value") + // } db.Value = value return &Scope{db: db, Search: db.search, Value: value} } @@ -252,6 +256,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { value := scope.IndirectValue().FieldByName(fieldStruct.Name) indirectValue := reflect.Indirect(value) + field.Field = value field.Value = value.Interface() field.IsBlank = isBlank(value) @@ -315,7 +320,12 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { case reflect.Struct: embedded := settings["EMBEDDED"] if embedded != "" { - return scope.New(field.Value).Fields() + var fields []*Field + for _, field := range scope.New(field.Field.Addr().Interface()).Fields() { + field.DBName = prefix + field.DBName + fields = append(fields, field) + } + return fields } else if !field.IsTime() && !field.IsScanner() { if foreignKey == "" && scope.HasColumn(field.Name+"Id") { field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"} @@ -336,7 +346,8 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { } // Fields get value's fields -func (scope *Scope) Fields() (fields []*Field) { +func (scope *Scope) Fields() map[string]*Field { + var fields = map[string]*Field{} if scope.IndirectValue().IsValid() { scopeTyp := scope.IndirectValue().Type() for i := 0; i < scopeTyp.NumField(); i++ { @@ -344,10 +355,12 @@ func (scope *Scope) Fields() (fields []*Field) { if !ast.IsExported(fieldStruct.Name) { continue } - fields = append(fields, scope.fieldFromStruct(fieldStruct)...) + for _, field := range scope.fieldFromStruct(fieldStruct) { + fields[field.DBName] = field + } } } - return + return fields } // Raw set sql