forked from mirror/gorm
Add Fields for embedded struct
This commit is contained in:
parent
ecc23e4432
commit
b2360c11da
|
@ -9,8 +9,8 @@ type BasePost struct {
|
|||
}
|
||||
|
||||
type HNPost struct {
|
||||
BasePost
|
||||
Upvotes int32
|
||||
BasePost `gorm:"embedded"`
|
||||
Upvotes int32
|
||||
}
|
||||
|
||||
type EngadgetPost struct {
|
||||
|
@ -18,11 +18,15 @@ type EngadgetPost struct {
|
|||
ImageUrl string
|
||||
}
|
||||
|
||||
func TestAnonymousStruct(t *testing.T) {
|
||||
hn := HNPost{}
|
||||
hn.Title = "hn_news"
|
||||
DB.Debug().Save(hn)
|
||||
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||
DB.Save(HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||
|
||||
var news HNPost
|
||||
DB.Debug().First(&news)
|
||||
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||
} else {
|
||||
if news.BasePost.Title == "hn_news" {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
9
main.go
9
main.go
|
@ -394,14 +394,13 @@ func (s *DB) Association(column string) *Association {
|
|||
}
|
||||
|
||||
var field *Field
|
||||
scopeType := scope.IndirectValue().Type()
|
||||
if f, ok := scopeType.FieldByName(SnakeToUpperCamel(column)); ok {
|
||||
field = scope.fieldFromStruct(f)
|
||||
var ok bool
|
||||
if field, ok = scope.FieldByName(SnakeToUpperCamel(column)); ok {
|
||||
if field.Relationship == nil || field.Relationship.ForeignKey == "" {
|
||||
scope.Err(fmt.Errorf("invalid association %v for %v", column, scopeType))
|
||||
scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()))
|
||||
}
|
||||
} else {
|
||||
scope.Err(fmt.Errorf("%v doesn't have column %v", scopeType, column))
|
||||
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
|
||||
}
|
||||
|
||||
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}
|
||||
|
|
|
@ -29,11 +29,9 @@ func runMigration() {
|
|||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}).Error; err != nil {
|
||||
if err := DB.AutoMigrate(&Product{}, Email{}, Address{}, CreditCard{}, Company{}, Role{}, Language{}, HNPost{}, EngadgetPost{}).Error; err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||
}
|
||||
|
||||
DB.AutoMigrate(HNPost{}, EngadgetPost{})
|
||||
}
|
||||
|
||||
func TestIndexes(t *testing.T) {
|
||||
|
|
20
scope.go
20
scope.go
|
@ -235,20 +235,20 @@ func (scope *Scope) CombinedConditionSql() string {
|
|||
}
|
||||
|
||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||
var f reflect.StructField
|
||||
if scope.Value != nil {
|
||||
if scope.IndirectValue().Kind() == reflect.Struct {
|
||||
if f, ok = scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
|
||||
field = scope.fieldFromStruct(f)
|
||||
if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
|
||||
return scope.fieldFromStruct(f)[0], true
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
||||
func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) []*Field {
|
||||
var field Field
|
||||
field.Name = fieldStruct.Name
|
||||
field.DBName = ToSnake(fieldStruct.Name)
|
||||
|
||||
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
|
||||
indirectValue := reflect.Indirect(value)
|
||||
|
@ -262,7 +262,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
|||
if prefix == "-" {
|
||||
prefix = ""
|
||||
}
|
||||
field.DBName = prefix + ToSnake(fieldStruct.Name)
|
||||
|
||||
if scope.PrimaryKey() == field.DBName {
|
||||
field.IsPrimaryKey = true
|
||||
|
@ -314,7 +313,10 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
|||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !field.IsTime() && !field.IsScanner() {
|
||||
embedded := settings["EMBEDDED"]
|
||||
if embedded != "" {
|
||||
return scope.New(field.Value).Fields()
|
||||
} else if !field.IsTime() && !field.IsScanner() {
|
||||
if foreignKey == "" && scope.HasColumn(field.Name+"Id") {
|
||||
field.Relationship = &relationship{ForeignKey: field.Name + "Id", Kind: "belongs_to"}
|
||||
} else if scope.HasColumn(foreignKey) {
|
||||
|
@ -330,7 +332,7 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField) *Field {
|
|||
}
|
||||
}
|
||||
}
|
||||
return &field
|
||||
return []*Field{&field}
|
||||
}
|
||||
|
||||
// Fields get value's fields
|
||||
|
@ -342,7 +344,7 @@ func (scope *Scope) Fields() (fields []*Field) {
|
|||
if !ast.IsExported(fieldStruct.Name) {
|
||||
continue
|
||||
}
|
||||
fields = append(fields, scope.fieldFromStruct(fieldStruct))
|
||||
fields = append(fields, scope.fieldFromStruct(fieldStruct)...)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue