Fix parse stmt ReflectValue

This commit is contained in:
Jinzhu 2020-02-24 08:51:35 +08:00
parent 5ccd76f76c
commit 04adbaf7f6
11 changed files with 35 additions and 39 deletions

View File

@ -3,6 +3,7 @@ package gorm
import (
"errors"
"fmt"
"reflect"
"time"
"github.com/jinzhu/gorm/logger"
@ -77,12 +78,11 @@ func (p *processor) Execute(db *DB) {
}
if stmt.Model != nil {
err := stmt.Parse(stmt.Model)
if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
db.AddError(err)
}
}
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest))
}
for _, f := range p.fns {

View File

@ -84,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1)
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
}
}

View File

@ -19,7 +19,7 @@ func (UserWithCallback) AfterCreate(*gorm.DB) {
}
func TestCallback(t *testing.T) {
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with callback, got error %v", err)
}

View File

@ -15,7 +15,7 @@ type UserCheck struct {
}
func TestParseCheck(t *testing.T) {
user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user check, got error %v", err)
}

View File

@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error
field.Creatable = false
field.Updatable = false
if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
schema.err = err
}
for _, ef := range field.EmbeddedSchema.Fields {

View File

@ -14,7 +14,7 @@ import (
func TestFieldValuerAndSetter(t *testing.T) {
var (
userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user = tests.User{
Model: gorm.Model{
ID: 10,
@ -81,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
func TestPointerFieldValuerAndSetter(t *testing.T) {
var (
userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter"
age uint = 18
active = true
@ -151,7 +151,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
var (
userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
name = "advanced_data_type_valuer_and_setter"
deletedAt = mytime(time.Now())
isAdmin = mybool(false)

View File

@ -19,7 +19,7 @@ type UserIndex struct {
}
func TestParseIndex(t *testing.T) {
user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}

View File

@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
}
)
if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
}
@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}
if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many

View File

@ -53,22 +53,21 @@ func (schema Schema) LookUpField(name string) *Field {
}
// get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
reflectValue := reflect.ValueOf(dest)
modelType := reflectValue.Type()
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), reflectValue, nil
return v.(*Schema), nil
}
schema := &Schema{
@ -167,6 +166,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
}
}
reflectValue := reflect.Indirect(reflect.New(modelType))
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
@ -185,10 +185,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable {
if schema.parseRelation(field); schema.err != nil {
return schema, reflectValue, schema.err
return schema, schema.err
}
}
}
return schema, reflectValue, schema.err
return schema, schema.err
}

View File

@ -9,7 +9,7 @@ import (
)
func TestParseSchema(t *testing.T) {
user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user, got error %v", err)
}
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
}
func TestParseSchemaWithPointerFields(t *testing.T) {
user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
}
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}

View File

@ -274,12 +274,8 @@ func (stmt *Statement) Build(clauses ...string) {
}
func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
if stmt.Table == "" {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
stmt.Table = stmt.Schema.Table
}
}
return err
}