fix: scan array (#5624)

Co-authored-by: Jinzhu <wosmvp@gmail.com>
This commit is contained in:
Cr 2022-09-22 15:51:47 +08:00 committed by GitHub
parent 3a72ba102e
commit 101a7c789f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 7 deletions

22
scan.go
View File

@ -243,15 +243,18 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var elem reflect.Value var (
recyclableStruct := reflect.New(reflectValueType) elem reflect.Value
recyclableStruct = reflect.New(reflectValueType)
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 { if !update || reflectValue.Len() == 0 {
update = false update = false
// if the slice cap is externally initialized, the externally initialized slice is directly used here // if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 { if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else { } else if !isArrayKind {
reflectValue.SetLen(0) reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue) db.Statement.ReflectValue.Set(reflectValue)
} }
@ -285,10 +288,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
db.scanIntoStruct(rows, elem, values, fields, joinFields) db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update { if !update {
if isPtr { if !isPtr {
reflectValue = reflect.Append(reflectValue, elem) elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else { } else {
reflectValue = reflect.Append(reflectValue, elem.Elem()) reflectValue = reflect.Append(reflectValue, elem)
} }
} }
} }
@ -312,4 +320,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
db.AddError(ErrRecordNotFound) db.AddError(ErrRecordNotFound)
} }
} }

View File

@ -216,6 +216,30 @@ func TestFind(t *testing.T) {
} }
} }
// test array
var models2 [3]User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2))
} else {
for idx, user := range users {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models2[idx], user)
})
}
}
// test smaller array
var models3 [2]User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3))
} else {
for idx, user := range users[:2] {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models3[idx], user)
})
}
}
var none []User var none []User
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))