From b2360c11da818551df4a2898fb1c0538c5961d10 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 Aug 2014 17:21:43 +0800 Subject: [PATCH] Add Fields for embedded struct --- anonymous_struct_test.go | 18 +++++++++++------- main.go | 9 ++++----- migration_test.go | 4 +--- scope.go | 20 +++++++++++--------- 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/anonymous_struct_test.go b/anonymous_struct_test.go index 20ed7864..75bc3051 100644 --- a/anonymous_struct_test.go +++ b/anonymous_struct_test.go @@ -9,8 +9,8 @@ type BasePost struct { } type HNPost struct { - BasePost - Upvotes int32 + BasePost `gorm:"embedded"` + Upvotes int32 } type EngadgetPost struct { @@ -18,11 +18,15 @@ type EngadgetPost struct { ImageUrl string } -func TestAnonymousStruct(t *testing.T) { - hn := HNPost{} - hn.Title = "hn_news" - DB.Debug().Save(hn) +func TestSaveAndQueryEmbeddedStruct(t *testing.T) { + DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}}) var news HNPost - DB.Debug().First(&news) + if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else { + if news.BasePost.Title == "hn_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } + } } diff --git a/main.go b/main.go index 3ab03aaf..af2c26c4 100644 --- a/main.go +++ b/main.go @@ -394,14 +394,13 @@ func (s *DB) Association(column string) *Association { } var field *Field - scopeType := scope.IndirectValue().Type() - if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok { - field = scope.fieldFromStruct(f) + var ok bool + if field, ok = scope.FieldByName(SnakeToUpperCamel(column)); ok { if field.Relationship == nil || field.Relationship.ForeignKey == "" { - scope.Err(fmt.Errorf("invalid association %v for %v", column, scopeType)) + scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())) } } else { - scope.Err(fmt.Errorf("%v doesn't have column %v", scopeType, column)) + scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)) } return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field} diff --git a/migration_test.go b/migration_test.go index 211abd6d..4e3bfb5a 100644 --- a/migration_test.go +++ b/migration_test.go @@ -29,11 +29,9 @@ func runMigration() { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } - if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}).Error; err != nil { + if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } - - DB.AutoMigrate(HNPost{}, EngadgetPost{}) } func TestIndexes(t *testing.T) { diff --git a/scope.go b/scope.go index b2973dd6..96d280d1 100644 --- a/scope.go +++ b/scope.go @@ -235,20 +235,20 @@ func (scope *Scope) CombinedConditionSql() string { } func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var f reflect.StructField if scope.Value != nil { if scope.IndirectValue().Kind() == reflect.Struct { - if f, ok = scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok { - field = scope.fieldFromStruct(f) + if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok { + return scope.fieldFromStruct(f)[0], true } } } - return + return nil, false } -func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { +func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field { var field Field field.Name = fieldStruct.Name + field.DBName = ToSnake(fieldStruct.Name) value := scope.IndirectValue().FieldByName(fieldStruct.Name) indirectValue := reflect.Indirect(value) @@ -262,7 +262,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { if prefix == "-" { prefix = "" } - field.DBName = prefix + ToSnake(fieldStruct.Name) if scope.PrimaryKey() == field.DBName { field.IsPrimaryKey = true @@ -314,7 +313,10 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { } } case reflect.Struct: - if !field.IsTime() && !field.IsScanner() { + embedded := settings["EMBEDDED"] + if embedded != "" { + return scope.New(field.Value).Fields() + } else if !field.IsTime() && !field.IsScanner() { if foreignKey == "" && scope.HasColumn(field.Name+"Id") { field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"} } else if scope.HasColumn(foreignKey) { @@ -330,7 +332,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { } } } - return &field + return []*Field{&field} } // Fields get value's fields @@ -342,7 +344,7 @@ func (scope *Scope) Fields() (fields []*Field) { if !ast.IsExported(fieldStruct.Name) { continue } - fields = append(fields, scope.fieldFromStruct(fieldStruct)) + fields = append(fields, scope.fieldFromStruct(fieldStruct)...) } } return