Add Fields for embedded struct

This commit is contained in:
Jinzhu 2014-08-28 17:21:43 +08:00
parent ecc23e4432
commit b2360c11da
4 changed files with 27 additions and 24 deletions

View File

@ -9,8 +9,8 @@ type BasePost struct {
} }
type HNPost struct { type HNPost struct {
BasePost BasePost `gorm:"embedded"`
Upvotes int32 Upvotes int32
} }
type EngadgetPost struct { type EngadgetPost struct {
@ -18,11 +18,15 @@ type EngadgetPost struct {
ImageUrl string ImageUrl string
} }
func TestAnonymousStruct(t *testing.T) { func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
hn := HNPost{} DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}})
hn.Title = "hn_news"
DB.Debug().Save(hn)
var news HNPost var news HNPost
DB.Debug().First(&news) if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else {
if news.BasePost.Title == "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
}
} }

View File

@ -394,14 +394,13 @@ func (s *DB) Association(column string) *Association {
} }
var field *Field var field *Field
scopeType := scope.IndirectValue().Type() var ok bool
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok { if field, ok = scope.FieldByName(SnakeToUpperCamel(column)); ok {
field = scope.fieldFromStruct(f)
if field.Relationship == nil || field.Relationship.ForeignKey == "" { if field.Relationship == nil || field.Relationship.ForeignKey == "" {
scope.Err(fmt.Errorf("invalid association %v for %v", column, scopeType)) scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()))
} }
} else { } else {
scope.Err(fmt.Errorf("%v doesn't have column %v", scopeType, column)) scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
} }
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field} return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}

View File

@ -29,11 +29,9 @@ func runMigration() {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}).Error; err != nil { if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
DB.AutoMigrate(HNPost{}, EngadgetPost{})
} }
func TestIndexes(t *testing.T) { func TestIndexes(t *testing.T) {

View File

@ -235,20 +235,20 @@ func (scope *Scope) CombinedConditionSql() string {
} }
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var f reflect.StructField
if scope.Value != nil { if scope.Value != nil {
if scope.IndirectValue().Kind() == reflect.Struct { if scope.IndirectValue().Kind() == reflect.Struct {
if f, ok = scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok { if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
field = scope.fieldFromStruct(f) return scope.fieldFromStruct(f)[0], true
} }
} }
} }
return return nil, false
} }
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field { func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
var field Field var field Field
field.Name = fieldStruct.Name field.Name = fieldStruct.Name
field.DBName = ToSnake(fieldStruct.Name)
value := scope.IndirectValue().FieldByName(fieldStruct.Name) value := scope.IndirectValue().FieldByName(fieldStruct.Name)
indirectValue := reflect.Indirect(value) indirectValue := reflect.Indirect(value)
@ -262,7 +262,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
if prefix == "-" { if prefix == "-" {
prefix = "" prefix = ""
} }
field.DBName = prefix + ToSnake(fieldStruct.Name)
if scope.PrimaryKey() == field.DBName { if scope.PrimaryKey() == field.DBName {
field.IsPrimaryKey = true field.IsPrimaryKey = true
@ -314,7 +313,10 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
} }
} }
case reflect.Struct: case reflect.Struct:
if !field.IsTime() && !field.IsScanner() { embedded := settings["EMBEDDED"]
if embedded != "" {
return scope.New(field.Value).Fields()
} else if !field.IsTime() && !field.IsScanner() {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") { if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"} field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
} else if scope.HasColumn(foreignKey) { } else if scope.HasColumn(foreignKey) {
@ -330,7 +332,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
} }
} }
} }
return &field return []*Field{&field}
} }
// Fields get value's fields // Fields get value's fields
@ -342,7 +344,7 @@ func (scope *Scope) Fields() (fields []*Field) {
if !ast.IsExported(fieldStruct.Name) { if !ast.IsExported(fieldStruct.Name) {
continue continue
} }
fields = append(fields, scope.fieldFromStruct(fieldStruct)) fields = append(fields, scope.fieldFromStruct(fieldStruct)...)
} }
} }
return return