mirror of https://github.com/go-gorm/gorm.git
Fix parse stmt ReflectValue
This commit is contained in:
parent
5ccd76f76c
commit
04adbaf7f6
|
@ -3,6 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/logger"
|
"github.com/jinzhu/gorm/logger"
|
||||||
|
@ -77,12 +78,11 @@ func (p *processor) Execute(db *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.Model != nil {
|
if stmt.Model != nil {
|
||||||
err := stmt.Parse(stmt.Model)
|
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
|
||||||
|
|
||||||
if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
|
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range p.fns {
|
for _, f := range p.fns {
|
||||||
|
|
|
@ -84,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
||||||
} else {
|
} else {
|
||||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||||
for idx, v := range vars {
|
for idx, v := range vars {
|
||||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1)
|
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ func (UserWithCallback) AfterCreate(*gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCallback(t *testing.T) {
|
func TestCallback(t *testing.T) {
|
||||||
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
|
user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse user with callback, got error %v", err)
|
t.Fatalf("failed to parse user with callback, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ type UserCheck struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseCheck(t *testing.T) {
|
func TestParseCheck(t *testing.T) {
|
||||||
user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
|
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse user check, got error %v", err)
|
t.Fatalf("failed to parse user check, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
var err error
|
var err error
|
||||||
field.Creatable = false
|
field.Creatable = false
|
||||||
field.Updatable = false
|
field.Updatable = false
|
||||||
if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
for _, ef := range field.EmbeddedSchema.Fields {
|
for _, ef := range field.EmbeddedSchema.Fields {
|
||||||
|
|
|
@ -14,8 +14,8 @@ import (
|
||||||
|
|
||||||
func TestFieldValuerAndSetter(t *testing.T) {
|
func TestFieldValuerAndSetter(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
user = tests.User{
|
user = tests.User{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
ID: 10,
|
ID: 10,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
|
@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) {
|
||||||
|
|
||||||
func TestPointerFieldValuerAndSetter(t *testing.T) {
|
func TestPointerFieldValuerAndSetter(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
name = "pointer_field_valuer_and_setter"
|
name = "pointer_field_valuer_and_setter"
|
||||||
age uint = 18
|
age uint = 18
|
||||||
active = true
|
active = true
|
||||||
user = User{
|
user = User{
|
||||||
Model: &gorm.Model{
|
Model: &gorm.Model{
|
||||||
ID: 10,
|
ID: 10,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
|
@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
|
||||||
|
|
||||||
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
|
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
name = "advanced_data_type_valuer_and_setter"
|
name = "advanced_data_type_valuer_and_setter"
|
||||||
deletedAt = mytime(time.Now())
|
deletedAt = mytime(time.Now())
|
||||||
isAdmin = mybool(false)
|
isAdmin = mybool(false)
|
||||||
user = AdvancedDataTypeUser{
|
user = AdvancedDataTypeUser{
|
||||||
ID: sql.NullInt64{Int64: 10, Valid: true},
|
ID: sql.NullInt64{Int64: 10, Valid: true},
|
||||||
Name: &sql.NullString{String: name, Valid: true},
|
Name: &sql.NullString{String: name, Valid: true},
|
||||||
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
||||||
|
|
|
@ -19,7 +19,7 @@ type UserIndex struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseIndex(t *testing.T) {
|
func TestParseIndex(t *testing.T) {
|
||||||
user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
|
user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse user index, got error %v", err)
|
t.Fatalf("failed to parse user index, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
relation.JoinTable.Name = many2many
|
relation.JoinTable.Name = many2many
|
||||||
|
|
|
@ -53,22 +53,21 @@ func (schema Schema) LookUpField(name string) *Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
// get data type from dialector
|
// get data type from dialector
|
||||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
|
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||||
reflectValue := reflect.ValueOf(dest)
|
modelType := reflect.ValueOf(dest).Type()
|
||||||
modelType := reflectValue.Type()
|
|
||||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
if modelType.PkgPath() == "" {
|
if modelType.PkgPath() == "" {
|
||||||
return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||||
}
|
}
|
||||||
return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := cacheStore.Load(modelType); ok {
|
if v, ok := cacheStore.Load(modelType); ok {
|
||||||
return v.(*Schema), reflectValue, nil
|
return v.(*Schema), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
schema := &Schema{
|
schema := &Schema{
|
||||||
|
@ -167,6 +166,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reflectValue := reflect.Indirect(reflect.New(modelType))
|
||||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
||||||
for _, name := range callbacks {
|
for _, name := range callbacks {
|
||||||
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
|
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
|
||||||
|
@ -185,10 +185,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.DataType == "" && field.Creatable {
|
if field.DataType == "" && field.Creatable {
|
||||||
if schema.parseRelation(field); schema.err != nil {
|
if schema.parseRelation(field); schema.err != nil {
|
||||||
return schema, reflectValue, schema.err
|
return schema, schema.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema, reflectValue, schema.err
|
return schema, schema.err
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSchema(t *testing.T) {
|
func TestParseSchema(t *testing.T) {
|
||||||
user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse user, got error %v", err)
|
t.Fatalf("failed to parse user, got error %v", err)
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseSchemaWithPointerFields(t *testing.T) {
|
func TestParseSchemaWithPointerFields(t *testing.T) {
|
||||||
user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse pointer user, got error %v", err)
|
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||||
user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse pointer user, got error %v", err)
|
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -274,12 +274,8 @@ func (stmt *Statement) Build(clauses ...string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||||
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
|
||||||
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)
|
stmt.Table = stmt.Schema.Table
|
||||||
|
|
||||||
if stmt.Table == "" {
|
|
||||||
stmt.Table = stmt.Schema.Table
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue