diff --git a/field.go b/field.go index 459d6677..b69dfceb 100644 --- a/field.go +++ b/field.go @@ -52,6 +52,12 @@ func (scope *Scope) Fields() map[string]*Field { func (scope *Scope) getField(structField *StructField) *Field { field := Field{StructField: structField} - field.Field = scope.IndirectValue().FieldByName(structField.Name) + indirectValue := scope.IndirectValue() + if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct { + for _, name := range structField.Names { + indirectValue = reflect.Indirect(indirectValue.FieldByName(name)) + } + field.Field = indirectValue + } return &field } diff --git a/model_struct.go b/model_struct.go index 9229c540..01e8f95d 100644 --- a/model_struct.go +++ b/model_struct.go @@ -18,8 +18,9 @@ type ModelStruct struct { } type StructField struct { - Name string DBName string + Name string + Names []string IsPrimaryKey bool IsScanner bool IsTime bool @@ -96,7 +97,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value)) if reflectValue.Kind() == reflect.Slice { - reflectValue = reflect.Indirect(reflect.New(reflectValue.Elem().Type())) + reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem())) } scopeTyp := reflectValue.Type() @@ -125,7 +126,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { continue } - field := &StructField{Struct: fieldStruct} + field := &StructField{Struct: fieldStruct, Name: fieldStruct.Name, Names: []string{fieldStruct.Name}} if fieldStruct.Tag.Get("sql") == "-" { field.IsIgnored = true } else { @@ -206,6 +207,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { case reflect.Struct: if _, ok := field.GormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { for _, field := range scope.New(reflect.New(indirectType).Interface()).GetStructFields() { + field.Names = append([]string{fieldStruct.Name}, field.Names...) modelStruct.StructFields = append(modelStruct.StructFields, field) } break diff --git a/scope.go b/scope.go index ed1c97b8..7e91eaf4 100644 --- a/scope.go +++ b/scope.go @@ -95,7 +95,10 @@ func (scope *Scope) HasError() bool { } func (scope *Scope) PrimaryKeyField() *Field { - return scope.getField(scope.GetModelStruct().PrimaryKeyField) + if field := scope.GetModelStruct().PrimaryKeyField; field != nil { + return scope.getField(field) + } + return nil } // PrimaryKey get the primary key's column name