From 101a7c789fa2c41f409da439056806756fd8ce22 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 22 Sep 2022 15:51:47 +0800 Subject: [PATCH] fix: scan array (#5624) Co-authored-by: Jinzhu --- scan.go | 22 +++++++++++++++------- tests/query_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/scan.go b/scan.go index df5a3714..70cd4284 100644 --- a/scan.go +++ b/scan.go @@ -243,15 +243,18 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - var elem reflect.Value - recyclableStruct := reflect.New(reflectValueType) + var ( + elem reflect.Value + recyclableStruct = reflect.New(reflectValueType) + isArrayKind = reflectValue.Kind() == reflect.Array + ) if !update || reflectValue.Len() == 0 { update = false // if the slice cap is externally initialized, the externally initialized slice is directly used here if reflectValue.Cap() == 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - } else { + } else if !isArrayKind { reflectValue.SetLen(0) db.Statement.ReflectValue.Set(reflectValue) } @@ -285,10 +288,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) { db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) + reflectValue = reflect.Append(reflectValue, elem) } } } @@ -312,4 +320,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } -} +} \ No newline at end of file diff --git a/tests/query_test.go b/tests/query_test.go index 4569fe1a..eccf0133 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -216,6 +216,30 @@ func TestFind(t *testing.T) { } } + // test array + var models2 [3]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models2[idx], user) + }) + } + } + + // test smaller array + var models3 [2]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) + } else { + for idx, user := range users[:2] { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models3[idx], user) + }) + } + } + var none []User if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))