mirror of https://github.com/go-gorm/gorm.git
Add Fields for embedded struct
This commit is contained in:
parent
ecc23e4432
commit
b2360c11da
|
@ -9,7 +9,7 @@ type BasePost struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type HNPost struct {
|
type HNPost struct {
|
||||||
BasePost
|
BasePost `gorm:"embedded"`
|
||||||
Upvotes int32
|
Upvotes int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,11 +18,15 @@ type EngadgetPost struct {
|
||||||
ImageUrl string
|
ImageUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymousStruct(t *testing.T) {
|
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||||
hn := HNPost{}
|
DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||||
hn.Title = "hn_news"
|
|
||||||
DB.Debug().Save(hn)
|
|
||||||
|
|
||||||
var news HNPost
|
var news HNPost
|
||||||
DB.Debug().First(&news)
|
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
|
} else {
|
||||||
|
if news.BasePost.Title == "hn_news" {
|
||||||
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
9
main.go
9
main.go
|
@ -394,14 +394,13 @@ func (s *DB) Association(column string) *Association {
|
||||||
}
|
}
|
||||||
|
|
||||||
var field *Field
|
var field *Field
|
||||||
scopeType := scope.IndirectValue().Type()
|
var ok bool
|
||||||
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok {
|
if field, ok = scope.FieldByName(SnakeToUpperCamel(column)); ok {
|
||||||
field = scope.fieldFromStruct(f)
|
|
||||||
if field.Relationship == nil || field.Relationship.ForeignKey == "" {
|
if field.Relationship == nil || field.Relationship.ForeignKey == "" {
|
||||||
scope.Err(fmt.Errorf("invalid association %v for %v", column, scopeType))
|
scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
scope.Err(fmt.Errorf("%v doesn't have column %v", scopeType, column))
|
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}
|
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}
|
||||||
|
|
|
@ -29,11 +29,9 @@ func runMigration() {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}).Error; err != nil {
|
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.AutoMigrate(HNPost{}, EngadgetPost{})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIndexes(t *testing.T) {
|
func TestIndexes(t *testing.T) {
|
||||||
|
|
20
scope.go
20
scope.go
|
@ -235,20 +235,20 @@ func (scope *Scope) CombinedConditionSql() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||||
var f reflect.StructField
|
|
||||||
if scope.Value != nil {
|
if scope.Value != nil {
|
||||||
if scope.IndirectValue().Kind() == reflect.Struct {
|
if scope.IndirectValue().Kind() == reflect.Struct {
|
||||||
if f, ok = scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
|
if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
|
||||||
field = scope.fieldFromStruct(f)
|
return scope.fieldFromStruct(f)[0], true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||||
var field Field
|
var field Field
|
||||||
field.Name = fieldStruct.Name
|
field.Name = fieldStruct.Name
|
||||||
|
field.DBName = ToSnake(fieldStruct.Name)
|
||||||
|
|
||||||
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
|
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
|
||||||
indirectValue := reflect.Indirect(value)
|
indirectValue := reflect.Indirect(value)
|
||||||
|
@ -262,7 +262,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
||||||
if prefix == "-" {
|
if prefix == "-" {
|
||||||
prefix = ""
|
prefix = ""
|
||||||
}
|
}
|
||||||
field.DBName = prefix + ToSnake(fieldStruct.Name)
|
|
||||||
|
|
||||||
if scope.PrimaryKey() == field.DBName {
|
if scope.PrimaryKey() == field.DBName {
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
|
@ -314,7 +313,10 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if !field.IsTime() && !field.IsScanner() {
|
embedded := settings["EMBEDDED"]
|
||||||
|
if embedded != "" {
|
||||||
|
return scope.New(field.Value).Fields()
|
||||||
|
} else if !field.IsTime() && !field.IsScanner() {
|
||||||
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
||||||
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
||||||
} else if scope.HasColumn(foreignKey) {
|
} else if scope.HasColumn(foreignKey) {
|
||||||
|
@ -330,7 +332,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &field
|
return []*Field{&field}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fields get value's fields
|
// Fields get value's fields
|
||||||
|
@ -342,7 +344,7 @@ func (scope *Scope) Fields() (fields []*Field) {
|
||||||
if !ast.IsExported(fieldStruct.Name) {
|
if !ast.IsExported(fieldStruct.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fields = append(fields, scope.fieldFromStruct(fieldStruct))
|
fields = append(fields, scope.fieldFromStruct(fieldStruct)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue