Add Fields for embedded struct

This commit is contained in:
Jinzhu 2014-08-28 17:21:43 +08:00
parent ecc23e4432
commit b2360c11da
4 changed files with 27 additions and 24 deletions

View File

@ -9,8 +9,8 @@ type BasePost struct {
}
type HNPost struct {
BasePost
Upvotes int32
BasePost `gorm:"embedded"`
Upvotes int32
}
type EngadgetPost struct {
@ -18,11 +18,15 @@ type EngadgetPost struct {
ImageUrl string
}
func TestAnonymousStruct(t *testing.T) {
hn := HNPost{}
hn.Title = "hn_news"
DB.Debug().Save(hn)
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
DB.Debug().First(&news)
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else {
if news.BasePost.Title == "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
}
}

View File

@ -394,14 +394,13 @@ func (s *DB) Association(column string) *Association {
}
var field *Field
scopeType := scope.IndirectValue().Type()
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok {
field = scope.fieldFromStruct(f)
var ok bool
if field, ok = scope.FieldByName(SnakeToUpperCamel(column)); ok {
if field.Relationship == nil || field.Relationship.ForeignKey == "" {
scope.Err(fmt.Errorf("invalid association %v for %v", column, scopeType))
scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()))
}
} else {
scope.Err(fmt.Errorf("%v doesn't have column %v", scopeType, column))
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
}
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}

View File

@ -29,11 +29,9 @@ func runMigration() {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}).Error; err != nil {
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
DB.AutoMigrate(HNPost{}, EngadgetPost{})
}
func TestIndexes(t *testing.T) {

View File

@ -235,20 +235,20 @@ func (scope *Scope) CombinedConditionSql() string {
}
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var f reflect.StructField
if scope.Value != nil {
if scope.IndirectValue().Kind() == reflect.Struct {
if f, ok = scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
field = scope.fieldFromStruct(f)
if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
return scope.fieldFromStruct(f)[0], true
}
}
}
return
return nil, false
}
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
var field Field
field.Name = fieldStruct.Name
field.DBName = ToSnake(fieldStruct.Name)
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
indirectValue := reflect.Indirect(value)
@ -262,7 +262,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
if prefix == "-" {
prefix = ""
}
field.DBName = prefix + ToSnake(fieldStruct.Name)
if scope.PrimaryKey() == field.DBName {
field.IsPrimaryKey = true
@ -314,7 +313,10 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
}
}
case reflect.Struct:
if !field.IsTime() && !field.IsScanner() {
embedded := settings["EMBEDDED"]
if embedded != "" {
return scope.New(field.Value).Fields()
} else if !field.IsTime() && !field.IsScanner() {
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
} else if scope.HasColumn(foreignKey) {
@ -330,7 +332,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
}
}
}
return &field
return []*Field{&field}
}
// Fields get value's fields
@ -342,7 +344,7 @@ func (scope *Scope) Fields() (fields []*Field) {
if !ast.IsExported(fieldStruct.Name) {
continue
}
fields = append(fields, scope.fieldFromStruct(fieldStruct))
fields = append(fields, scope.fieldFromStruct(fieldStruct)...)
}
}
return