Fix parse stmt ReflectValue

This commit is contained in:
Jinzhu 2020-02-24 08:51:35 +08:00
parent 5ccd76f76c
commit 04adbaf7f6
11 changed files with 35 additions and 39 deletions

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"reflect"
"time" "time"
"github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/logger"
@ -77,12 +78,11 @@ func (p *processor) Execute(db *DB) {
} }
if stmt.Model != nil { if stmt.Model != nil {
err := stmt.Parse(stmt.Model) if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
db.AddError(err) db.AddError(err)
} }
} }
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest))
} }
for _, f := range p.fns { for _, f := range p.fns {

View File

@ -84,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
} else { } else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars { for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1) sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
} }
} }

View File

@ -19,7 +19,7 @@ func (UserWithCallback) AfterCreate(*gorm.DB) {
} }
func TestCallback(t *testing.T) { func TestCallback(t *testing.T) {
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user with callback, got error %v", err) t.Fatalf("failed to parse user with callback, got error %v", err)
} }

View File

@ -15,7 +15,7 @@ type UserCheck struct {
} }
func TestParseCheck(t *testing.T) { func TestParseCheck(t *testing.T) {
user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user check, got error %v", err) t.Fatalf("failed to parse user check, got error %v", err)
} }

View File

@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error var err error
field.Creatable = false field.Creatable = false
field.Updatable = false field.Updatable = false
if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
schema.err = err schema.err = err
} }
for _, ef := range field.EmbeddedSchema.Fields { for _, ef := range field.EmbeddedSchema.Fields {

View File

@ -14,8 +14,8 @@ import (
func TestFieldValuerAndSetter(t *testing.T) { func TestFieldValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user = tests.User{ user = tests.User{
Model: gorm.Model{ Model: gorm.Model{
ID: 10, ID: 10,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
func TestPointerFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter" name = "pointer_field_valuer_and_setter"
age uint = 18 age uint = 18
active = true active = true
user = User{ user = User{
Model: &gorm.Model{ Model: &gorm.Model{
ID: 10, ID: 10,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
var ( var (
userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
name = "advanced_data_type_valuer_and_setter" name = "advanced_data_type_valuer_and_setter"
deletedAt = mytime(time.Now()) deletedAt = mytime(time.Now())
isAdmin = mybool(false) isAdmin = mybool(false)
user = AdvancedDataTypeUser{ user = AdvancedDataTypeUser{
ID: sql.NullInt64{Int64: 10, Valid: true}, ID: sql.NullInt64{Int64: 10, Valid: true},
Name: &sql.NullString{String: name, Valid: true}, Name: &sql.NullString{String: name, Valid: true},
Birthday: sql.NullTime{Time: time.Now(), Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true},

View File

@ -19,7 +19,7 @@ type UserIndex struct {
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user index, got error %v", err) t.Fatalf("failed to parse user index, got error %v", err)
} }

View File

@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
} }
) )
if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err schema.err = err
return return
} }
@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
} }
if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err schema.err = err
} }
relation.JoinTable.Name = many2many relation.JoinTable.Name = many2many

View File

@ -53,22 +53,21 @@ func (schema Schema) LookUpField(name string) *Field {
} }
// get data type from dialector // get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
reflectValue := reflect.ValueOf(dest) modelType := reflect.ValueOf(dest).Type()
modelType := reflectValue.Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" { if modelType.PkgPath() == "" {
return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
} }
if v, ok := cacheStore.Load(modelType); ok { if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), reflectValue, nil return v.(*Schema), nil
} }
schema := &Schema{ schema := &Schema{
@ -167,6 +166,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
} }
} }
reflectValue := reflect.Indirect(reflect.New(modelType))
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks { for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
@ -185,10 +185,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable { if field.DataType == "" && field.Creatable {
if schema.parseRelation(field); schema.err != nil { if schema.parseRelation(field); schema.err != nil {
return schema, reflectValue, schema.err return schema, schema.err
} }
} }
} }
return schema, reflectValue, schema.err return schema, schema.err
} }

View File

@ -9,7 +9,7 @@ import (
) )
func TestParseSchema(t *testing.T) { func TestParseSchema(t *testing.T) {
user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user, got error %v", err) t.Fatalf("failed to parse user, got error %v", err)
} }
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
} }
func TestParseSchemaWithPointerFields(t *testing.T) { func TestParseSchemaWithPointerFields(t *testing.T) {
user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err) t.Fatalf("failed to parse pointer user, got error %v", err)
} }
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
} }
func TestParseSchemaWithAdvancedDataType(t *testing.T) { func TestParseSchemaWithAdvancedDataType(t *testing.T) {
user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err) t.Fatalf("failed to parse pointer user, got error %v", err)
} }

View File

@ -274,12 +274,8 @@ func (stmt *Statement) Build(clauses ...string) {
} }
func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue) stmt.Table = stmt.Schema.Table
if stmt.Table == "" {
stmt.Table = stmt.Schema.Table
}
} }
return err return err
} }