mirror of https://github.com/go-gorm/gorm.git
Fix parse stmt ReflectValue
This commit is contained in:
parent
5ccd76f76c
commit
04adbaf7f6
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
|
@ -77,12 +78,11 @@ func (p *processor) Execute(db *DB) {
|
|||
}
|
||||
|
||||
if stmt.Model != nil {
|
||||
err := stmt.Parse(stmt.Model)
|
||||
|
||||
if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
|
||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest))
|
||||
}
|
||||
|
||||
for _, f := range p.fns {
|
||||
|
|
|
@ -84,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
|||
} else {
|
||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||
for idx, v := range vars {
|
||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1)
|
||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ func (UserWithCallback) AfterCreate(*gorm.DB) {
|
|||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user with callback, got error %v", err)
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ type UserCheck struct {
|
|||
}
|
||||
|
||||
func TestParseCheck(t *testing.T) {
|
||||
user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user check, got error %v", err)
|
||||
}
|
||||
|
|
|
@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
var err error
|
||||
field.Creatable = false
|
||||
field.Updatable = false
|
||||
if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
for _, ef := range field.EmbeddedSchema.Fields {
|
||||
|
|
|
@ -14,8 +14,8 @@ import (
|
|||
|
||||
func TestFieldValuerAndSetter(t *testing.T) {
|
||||
var (
|
||||
userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user = tests.User{
|
||||
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user = tests.User{
|
||||
Model: gorm.Model{
|
||||
ID: 10,
|
||||
CreatedAt: time.Now(),
|
||||
|
@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
|
|||
|
||||
func TestPointerFieldValuerAndSetter(t *testing.T) {
|
||||
var (
|
||||
userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
name = "pointer_field_valuer_and_setter"
|
||||
age uint = 18
|
||||
active = true
|
||||
user = User{
|
||||
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
name = "pointer_field_valuer_and_setter"
|
||||
age uint = 18
|
||||
active = true
|
||||
user = User{
|
||||
Model: &gorm.Model{
|
||||
ID: 10,
|
||||
CreatedAt: time.Now(),
|
||||
|
@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
|
|||
|
||||
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
|
||||
var (
|
||||
userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
name = "advanced_data_type_valuer_and_setter"
|
||||
deletedAt = mytime(time.Now())
|
||||
isAdmin = mybool(false)
|
||||
user = AdvancedDataTypeUser{
|
||||
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
name = "advanced_data_type_valuer_and_setter"
|
||||
deletedAt = mytime(time.Now())
|
||||
isAdmin = mybool(false)
|
||||
user = AdvancedDataTypeUser{
|
||||
ID: sql.NullInt64{Int64: 10, Valid: true},
|
||||
Name: &sql.NullString{String: name, Valid: true},
|
||||
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
|
|
|
@ -19,7 +19,7 @@ type UserIndex struct {
|
|||
}
|
||||
|
||||
func TestParseIndex(t *testing.T) {
|
||||
user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
|
|||
}
|
||||
)
|
||||
|
||||
if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||
}
|
||||
}
|
||||
|
||||
if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
|
|
|
@ -53,22 +53,21 @@ func (schema Schema) LookUpField(name string) *Field {
|
|||
}
|
||||
|
||||
// get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
|
||||
reflectValue := reflect.ValueOf(dest)
|
||||
modelType := reflectValue.Type()
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
if modelType.PkgPath() == "" {
|
||||
return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
if v, ok := cacheStore.Load(modelType); ok {
|
||||
return v.(*Schema), reflectValue, nil
|
||||
return v.(*Schema), nil
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
|
@ -167,6 +166,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
|
|||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.New(modelType))
|
||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
||||
for _, name := range callbacks {
|
||||
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
|
||||
|
@ -185,10 +185,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
|
|||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.Creatable {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, reflectValue, schema.err
|
||||
return schema, schema.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schema, reflectValue, schema.err
|
||||
return schema, schema.err
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func TestParseSchema(t *testing.T) {
|
||||
user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user, got error %v", err)
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestParseSchemaWithPointerFields(t *testing.T) {
|
||||
user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
|||
}
|
||||
|
||||
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||
user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||
}
|
||||
|
|
|
@ -274,12 +274,8 @@ func (stmt *Statement) Build(clauses ...string) {
|
|||
}
|
||||
|
||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
|
||||
|
||||
if stmt.Table == "" {
|
||||
stmt.Table = stmt.Schema.Table
|
||||
}
|
||||
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
|
||||
stmt.Table = stmt.Schema.Table
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue