diff --git a/callbacks.go b/callbacks.go index 3aed2d37..db8261c4 100644 --- a/callbacks.go +++ b/callbacks.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "reflect" "time" "github.com/jinzhu/gorm/logger" @@ -77,12 +78,11 @@ func (p *processor) Execute(db *DB) { } if stmt.Model != nil { - err := stmt.Parse(stmt.Model) - - if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { db.AddError(err) } } + stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest)) } for _, f := range p.fns { diff --git a/logger/sql.go b/logger/sql.go index eec72d47..cb50ccf6 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -84,7 +84,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { - sql = strings.Replace(sql, "$"+strconv.Itoa(idx)+"$", v.(string), 1) + sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1) } } diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index 34c0e687..720c9a5b 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -19,7 +19,7 @@ func (UserWithCallback) AfterCreate(*gorm.DB) { } func TestCallback(t *testing.T) { - user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user with callback, got error %v", err) } diff --git a/schema/check_test.go b/schema/check_test.go index f0ba553c..e4bc9ebe 100644 --- a/schema/check_test.go +++ b/schema/check_test.go @@ -15,7 +15,7 @@ type UserCheck struct { } func TestParseCheck(t *testing.T) { - user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user check, got error %v", err) } diff --git a/schema/field.go b/schema/field.go index f640ec3b..ea4e6a40 100644 --- a/schema/field.go +++ b/schema/field.go @@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field.Creatable = false field.Updatable = false - if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { diff --git a/schema/field_test.go b/schema/field_test.go index 02e6aec0..15dfa41d 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -14,8 +14,8 @@ import ( func TestFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) - user = tests.User{ + userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user = tests.User{ Model: gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) - name = "pointer_field_valuer_and_setter" - age uint = 18 - active = true - user = User{ + userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { var ( - userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) - name = "advanced_data_type_valuer_and_setter" - deletedAt = mytime(time.Now()) - isAdmin = mybool(false) - user = AdvancedDataTypeUser{ + userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ ID: sql.NullInt64{Int64: 10, Valid: true}, Name: &sql.NullString{String: name, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, diff --git a/schema/index_test.go b/schema/index_test.go index 03d75b97..d0e8dfe0 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -19,7 +19,7 @@ type UserIndex struct { } func TestParseIndex(t *testing.T) { - user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user index, got error %v", err) } diff --git a/schema/relationship.go b/schema/relationship.go index 3b9d692a..4ffea8b3 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) { } ) - if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = err return } @@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many diff --git a/schema/schema.go b/schema/schema.go index c56932ad..2ac6d312 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -53,22 +53,21 @@ func (schema Schema) LookUpField(name string) *Field { } // get data type from dialector -func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) { - reflectValue := reflect.ValueOf(dest) - modelType := reflectValue.Type() +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { - return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { - return v.(*Schema), reflectValue, nil + return v.(*Schema), nil } schema := &Schema{ @@ -167,6 +166,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec } } + reflectValue := reflect.Indirect(reflect.New(modelType)) callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} for _, name := range callbacks { if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { @@ -185,10 +185,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec for _, field := range schema.Fields { if field.DataType == "" && field.Creatable { if schema.parseRelation(field); schema.err != nil { - return schema, reflectValue, schema.err + return schema, schema.err } } } - return schema, reflectValue, schema.err + return schema, schema.err } diff --git a/schema/schema_test.go b/schema/schema_test.go index 04cd9d82..ce225010 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -9,7 +9,7 @@ import ( ) func TestParseSchema(t *testing.T) { - user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } @@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) { } func TestParseSchemaWithPointerFields(t *testing.T) { - user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } @@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { } func TestParseSchemaWithAdvancedDataType(t *testing.T) { - user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } diff --git a/statement.go b/statement.go index 91f45b2b..ad30ed08 100644 --- a/statement.go +++ b/statement.go @@ -274,12 +274,8 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { - stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue) - - if stmt.Table == "" { - stmt.Table = stmt.Schema.Table - } + if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + stmt.Table = stmt.Schema.Table } return err }