mirror of https://github.com/go-gorm/gorm.git
fix(scan): update Scan function to reset structs to zero values for each scan (#7061)
Co-authored-by: waleed.masoom <waleed.masoom@wheniwork.com>
This commit is contained in:
parent
05167fd591
commit
73a988ceb2
3
scan.go
3
scan.go
|
@ -331,6 +331,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if initialized || rows.Next() {
|
||||
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
}
|
||||
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
|
@ -126,7 +127,7 @@ func TestScanRows(t *testing.T) {
|
|||
|
||||
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||
if err != nil {
|
||||
t.Errorf("Not error should happen, got %v", err)
|
||||
t.Errorf("No error should happen, got %v", err)
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
|
@ -148,7 +149,7 @@ func TestScanRows(t *testing.T) {
|
|||
})
|
||||
|
||||
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||
t.Errorf("Should find expected results")
|
||||
t.Errorf("Should find expected results, got %+v", results)
|
||||
}
|
||||
|
||||
var ages int
|
||||
|
@ -158,7 +159,105 @@ func TestScanRows(t *testing.T) {
|
|||
|
||||
var name string
|
||||
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
|
||||
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
|
||||
t.Fatalf("failed to scan name, got error %v, name: %v", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) {
|
||||
DB.Save(&User{})
|
||||
|
||||
rows, err := DB.Table("users").
|
||||
Select(`
|
||||
NULL AS bool_field,
|
||||
NULL AS int_field,
|
||||
NULL AS int8_field,
|
||||
NULL AS int16_field,
|
||||
NULL AS int32_field,
|
||||
NULL AS int64_field,
|
||||
NULL AS uint_field,
|
||||
NULL AS uint8_field,
|
||||
NULL AS uint16_field,
|
||||
NULL AS uint32_field,
|
||||
NULL AS uint64_field,
|
||||
NULL AS float32_field,
|
||||
NULL AS float64_field,
|
||||
NULL AS string_field,
|
||||
NULL AS time_field,
|
||||
NULL AS time_ptr_field,
|
||||
NULL AS embedded_int_field,
|
||||
NULL AS nested_embedded_int_field,
|
||||
NULL AS embedded_ptr_int_field
|
||||
`).Rows()
|
||||
if err != nil {
|
||||
t.Errorf("No error should happen, got %v", err)
|
||||
}
|
||||
|
||||
type NestedEmbeddedStruct struct {
|
||||
NestedEmbeddedIntField int
|
||||
NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"`
|
||||
}
|
||||
|
||||
type EmbeddedStruct struct {
|
||||
EmbeddedIntField int
|
||||
NestedEmbeddedStruct `gorm:"embedded"`
|
||||
}
|
||||
|
||||
type EmbeddedPtrStruct struct {
|
||||
EmbeddedPtrIntField int
|
||||
*NestedEmbeddedStruct `gorm:"embedded"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
BoolField bool
|
||||
IntField int
|
||||
Int8Field int8
|
||||
Int16Field int16
|
||||
Int32Field int32
|
||||
Int64Field int64
|
||||
UIntField uint
|
||||
UInt8Field uint8
|
||||
UInt16Field uint16
|
||||
UInt32Field uint32
|
||||
UInt64Field uint64
|
||||
Float32Field float32
|
||||
Float64Field float64
|
||||
StringField string
|
||||
TimeField time.Time
|
||||
TimePtrField *time.Time
|
||||
EmbeddedStruct `gorm:"embedded"`
|
||||
*EmbeddedPtrStruct `gorm:"embedded"`
|
||||
}
|
||||
|
||||
currTime := time.Now()
|
||||
reusedVar := Result{
|
||||
BoolField: true,
|
||||
IntField: 1,
|
||||
Int8Field: 1,
|
||||
Int16Field: 1,
|
||||
Int32Field: 1,
|
||||
Int64Field: 1,
|
||||
UIntField: 1,
|
||||
UInt8Field: 1,
|
||||
UInt16Field: 1,
|
||||
UInt32Field: 1,
|
||||
UInt64Field: 1,
|
||||
Float32Field: 1.1,
|
||||
Float64Field: 1.1,
|
||||
StringField: "hello",
|
||||
TimeField: currTime,
|
||||
TimePtrField: &currTime,
|
||||
EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
|
||||
EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
if err := DB.ScanRows(rows, &reusedVar); err != nil {
|
||||
t.Errorf("should get no error, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(reusedVar, Result{}) {
|
||||
t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue