mirror of https://github.com/go-gorm/gorm.git
fix(scan): update Scan function to reset structs to zero values for each scan (#7061)
Co-authored-by: waleed.masoom <waleed.masoom@wheniwork.com>
This commit is contained in:
parent
05167fd591
commit
73a988ceb2
3
scan.go
3
scan.go
|
@ -331,6 +331,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||||
}
|
}
|
||||||
case reflect.Struct, reflect.Ptr:
|
case reflect.Struct, reflect.Ptr:
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
|
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
|
||||||
|
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||||
|
}
|
||||||
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
@ -126,7 +127,7 @@ func TestScanRows(t *testing.T) {
|
||||||
|
|
||||||
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Not error should happen, got %v", err)
|
t.Errorf("No error should happen, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Result struct {
|
type Result struct {
|
||||||
|
@ -148,7 +149,7 @@ func TestScanRows(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||||
t.Errorf("Should find expected results")
|
t.Errorf("Should find expected results, got %+v", results)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ages int
|
var ages int
|
||||||
|
@ -158,7 +159,105 @@ func TestScanRows(t *testing.T) {
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
|
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
|
||||||
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
|
t.Fatalf("failed to scan name, got error %v, name: %v", err, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) {
|
||||||
|
DB.Save(&User{})
|
||||||
|
|
||||||
|
rows, err := DB.Table("users").
|
||||||
|
Select(`
|
||||||
|
NULL AS bool_field,
|
||||||
|
NULL AS int_field,
|
||||||
|
NULL AS int8_field,
|
||||||
|
NULL AS int16_field,
|
||||||
|
NULL AS int32_field,
|
||||||
|
NULL AS int64_field,
|
||||||
|
NULL AS uint_field,
|
||||||
|
NULL AS uint8_field,
|
||||||
|
NULL AS uint16_field,
|
||||||
|
NULL AS uint32_field,
|
||||||
|
NULL AS uint64_field,
|
||||||
|
NULL AS float32_field,
|
||||||
|
NULL AS float64_field,
|
||||||
|
NULL AS string_field,
|
||||||
|
NULL AS time_field,
|
||||||
|
NULL AS time_ptr_field,
|
||||||
|
NULL AS embedded_int_field,
|
||||||
|
NULL AS nested_embedded_int_field,
|
||||||
|
NULL AS embedded_ptr_int_field
|
||||||
|
`).Rows()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("No error should happen, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type NestedEmbeddedStruct struct {
|
||||||
|
NestedEmbeddedIntField int
|
||||||
|
NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedStruct struct {
|
||||||
|
EmbeddedIntField int
|
||||||
|
NestedEmbeddedStruct `gorm:"embedded"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedPtrStruct struct {
|
||||||
|
EmbeddedPtrIntField int
|
||||||
|
*NestedEmbeddedStruct `gorm:"embedded"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
BoolField bool
|
||||||
|
IntField int
|
||||||
|
Int8Field int8
|
||||||
|
Int16Field int16
|
||||||
|
Int32Field int32
|
||||||
|
Int64Field int64
|
||||||
|
UIntField uint
|
||||||
|
UInt8Field uint8
|
||||||
|
UInt16Field uint16
|
||||||
|
UInt32Field uint32
|
||||||
|
UInt64Field uint64
|
||||||
|
Float32Field float32
|
||||||
|
Float64Field float64
|
||||||
|
StringField string
|
||||||
|
TimeField time.Time
|
||||||
|
TimePtrField *time.Time
|
||||||
|
EmbeddedStruct `gorm:"embedded"`
|
||||||
|
*EmbeddedPtrStruct `gorm:"embedded"`
|
||||||
|
}
|
||||||
|
|
||||||
|
currTime := time.Now()
|
||||||
|
reusedVar := Result{
|
||||||
|
BoolField: true,
|
||||||
|
IntField: 1,
|
||||||
|
Int8Field: 1,
|
||||||
|
Int16Field: 1,
|
||||||
|
Int32Field: 1,
|
||||||
|
Int64Field: 1,
|
||||||
|
UIntField: 1,
|
||||||
|
UInt8Field: 1,
|
||||||
|
UInt16Field: 1,
|
||||||
|
UInt32Field: 1,
|
||||||
|
UInt64Field: 1,
|
||||||
|
Float32Field: 1.1,
|
||||||
|
Float64Field: 1.1,
|
||||||
|
StringField: "hello",
|
||||||
|
TimeField: currTime,
|
||||||
|
TimePtrField: &currTime,
|
||||||
|
EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
|
||||||
|
EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err := DB.ScanRows(rows, &reusedVar); err != nil {
|
||||||
|
t.Errorf("should get no error, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(reusedVar, Result{}) {
|
||||||
|
t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue