Fix Scan struct with primary key, close #3357

This commit is contained in:
Jinzhu 2020-09-02 16:14:26 +08:00
parent 9a101c8a08
commit dbaa6b0ec3
6 changed files with 35 additions and 11 deletions

View File

@ -79,6 +79,8 @@ func (p *processor) Execute(db *DB) {
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
} }
if stmt.Model != nil { if stmt.Model != nil {

View File

@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) {
} }
if !db.DryRun { if !db.DryRun {
if _, ok := db.Get("rows"); ok { if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else { } else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

View File

@ -331,13 +331,13 @@ func (db *DB) Count(count *int64) (tx *DB) {
} }
func (db *DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
tx := db.getInstance() tx := db.getInstance().InstanceSet("rows", false)
tx.callbacks.Row().Execute(tx) tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Row) return tx.Statement.Dest.(*sql.Row)
} }
func (db *DB) Rows() (*sql.Rows, error) { func (db *DB) Rows() (*sql.Rows, error) {
tx := db.Set("rows", true) tx := db.getInstance().InstanceSet("rows", true)
tx.callbacks.Row().Execute(tx) tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Rows), tx.Error return tx.Statement.Dest.(*sql.Rows), tx.Error
} }
@ -345,8 +345,14 @@ func (db *DB) Rows() (*sql.Rows, error) {
// Scan scan value to a struct // Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = dest if rows, err := tx.Rows(); err != nil {
tx.callbacks.Query().Execute(tx) tx.AddError(err)
} else {
defer rows.Close()
if rows.Next() {
tx.ScanRows(rows, dest)
}
}
return return
} }
@ -379,7 +385,10 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance() tx := db.getInstance()
tx.Error = tx.Statement.Parse(dest) tx.Error = tx.Statement.Parse(dest)
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
}
Scan(rows, tx, true) Scan(rows, tx, true)
return tx.Error return tx.Error
} }

View File

@ -3,13 +3,14 @@ package logger
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"gorm.io/gorm/utils"
"reflect" "reflect"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unicode" "unicode"
"gorm.io/gorm/utils"
) )
func isPrintable(s []byte) bool { func isPrintable(s []byte) bool {

View File

@ -9,7 +9,7 @@ import (
// Migrator returns migrator // Migrator returns migrator
func (db *DB) Migrator() Migrator { func (db *DB) Migrator() Migrator {
return db.Dialector.Migrator(db) return db.Dialector.Migrator(db.Session(&Session{WithConditions: true}))
} }
// AutoMigrate run auto migration for given models // AutoMigrate run auto migration for given models

View File

@ -6,6 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -16,14 +17,25 @@ func TestScan(t *testing.T) {
DB.Save(&user1).Save(&user2).Save(&user3) DB.Save(&user1).Save(&user2).Save(&user3)
type result struct { type result struct {
ID uint
Name string Name string
Age int Age int
} }
var res result var res result
DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res)
if res.Name != user3.Name || res.Age != int(user3.Age) { if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
t.Errorf("Scan into struct should work") t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
}
DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res)
if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2)
}
DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res)
if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) {
t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3)
} }
var doubleAgeRes = &result{} var doubleAgeRes = &result{}