forked from mirror/gorm
Add ScanRows support
This commit is contained in:
parent
51c5be0503
commit
5be642a435
|
@ -105,7 +105,7 @@ func Query(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
Scan(rows, db)
|
gorm.Scan(rows, db, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Preload(db *gorm.DB) {
|
func Preload(db *gorm.DB) {
|
||||||
|
|
|
@ -186,8 +186,13 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||||
return nil
|
tx := db.getInstance()
|
||||||
|
tx.Error = tx.Statement.Parse(dest)
|
||||||
|
tx.Statement.Dest = dest
|
||||||
|
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest))
|
||||||
|
Scan(rows, tx, true)
|
||||||
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
package callbacks
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Scan(rows *sql.Rows, db *gorm.DB) {
|
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
values := make([]interface{}, len(columns))
|
values := make([]interface{}, len(columns))
|
||||||
|
|
||||||
|
@ -19,7 +18,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
values[idx] = new(interface{})
|
values[idx] = new(interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.Next() {
|
if initialized || rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
rows.Scan(values...)
|
rows.Scan(values...)
|
||||||
}
|
}
|
||||||
|
@ -39,7 +38,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
values[idx] = new(interface{})
|
values[idx] = new(interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for initialized || rows.Next() {
|
||||||
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
rows.Scan(values...)
|
rows.Scan(values...)
|
||||||
|
|
||||||
|
@ -50,7 +50,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
*dest = append(*dest, v)
|
*dest = append(*dest, v)
|
||||||
}
|
}
|
||||||
case *int, *int64, *uint, *uint64:
|
case *int, *int64, *uint, *uint64:
|
||||||
for rows.Next() {
|
for initialized || rows.Next() {
|
||||||
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
rows.Scan(dest)
|
rows.Scan(dest)
|
||||||
}
|
}
|
||||||
|
@ -78,7 +79,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for initialized || rows.Next() {
|
||||||
|
initialized = false
|
||||||
elem := reflect.New(db.Statement.Schema.ModelType).Elem()
|
elem := reflect.New(db.Statement.Schema.ModelType).Elem()
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field != nil {
|
if field != nil {
|
||||||
|
@ -118,7 +120,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.Next() {
|
if initialized || rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
if err := rows.Scan(values...); err != nil {
|
if err := rows.Scan(values...); err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
|
@ -128,6 +130,6 @@ func Scan(rows *sql.Rows, db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
|
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
|
||||||
db.AddError(gorm.ErrRecordNotFound)
|
db.AddError(ErrRecordNotFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,6 +1,9 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/jinzhu/gorm/tests"
|
. "github.com/jinzhu/gorm/tests"
|
||||||
|
@ -24,7 +27,7 @@ func TestScan(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var doubleAgeRes = &result{}
|
var doubleAgeRes = &result{}
|
||||||
if err := DB.Debug().Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil {
|
if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil {
|
||||||
t.Errorf("Scan to pointer of pointer")
|
t.Errorf("Scan to pointer of pointer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,9 +35,44 @@ func TestScan(t *testing.T) {
|
||||||
t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age)
|
t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ress []result
|
var results []result
|
||||||
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&ress)
|
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results)
|
||||||
if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
|
|
||||||
|
sort.Slice(results, func(i, j int) bool {
|
||||||
|
return strings.Compare(results[i].Name, results[j].Name) < -1
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
|
||||||
t.Errorf("Scan into struct map")
|
t.Errorf("Scan into struct map")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScanRows(t *testing.T) {
|
||||||
|
user1 := User{Name: "ScanRowsUser1", Age: 1}
|
||||||
|
user2 := User{Name: "ScanRowsUser2", Age: 10}
|
||||||
|
user3 := User{Name: "ScanRowsUser3", Age: 20}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Not error should happen, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []Result
|
||||||
|
for rows.Next() {
|
||||||
|
var result Result
|
||||||
|
if err := DB.ScanRows(rows, &result); err != nil {
|
||||||
|
t.Errorf("should get no error, but got %v", err)
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||||
|
t.Errorf("Should find expected results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue