From dbaa6b0ec3f451903c2983fd091c52e5efc60669 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 16:14:26 +0800 Subject: [PATCH] Fix Scan struct with primary key, close #3357 --- callbacks.go | 2 ++ callbacks/row.go | 2 +- finisher_api.go | 19 ++++++++++++++----- logger/sql.go | 3 ++- migrator.go | 2 +- tests/scan_test.go | 18 +++++++++++++++--- 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/callbacks.go b/callbacks.go index baeb6c09..eace06ca 100644 --- a/callbacks.go +++ b/callbacks.go @@ -79,6 +79,8 @@ func (p *processor) Execute(db *DB) { if stmt.Model == nil { stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model } if stmt.Model != nil { diff --git a/callbacks/row.go b/callbacks/row.go index 7e70382e..a36c0116 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { } if !db.DryRun { - if _, ok := db.Get("rows"); ok { + if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index a205b859..1d5ef5fc 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -331,13 +331,13 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance() + tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Row) } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.Set("rows", true) + tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Rows), tx.Error } @@ -345,8 +345,14 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) + if rows, err := tx.Rows(); err != nil { + tx.AddError(err) + } else { + defer rows.Close() + if rows.Next() { + tx.ScanRows(rows, dest) + } + } return } @@ -379,7 +385,10 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() tx.Error = tx.Statement.Parse(dest) tx.Statement.Dest = dest - tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + } Scan(rows, tx, true) return tx.Error } diff --git a/logger/sql.go b/logger/sql.go index 0efc0971..80645b0c 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -3,13 +3,14 @@ package logger import ( "database/sql/driver" "fmt" - "gorm.io/gorm/utils" "reflect" "regexp" "strconv" "strings" "time" "unicode" + + "gorm.io/gorm/utils" ) func isPrintable(s []byte) bool { diff --git a/migrator.go b/migrator.go index ed8a8e26..162fe680 100644 --- a/migrator.go +++ b/migrator.go @@ -9,7 +9,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db) + return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) } // AutoMigrate run auto migration for given models diff --git a/tests/scan_test.go b/tests/scan_test.go index d6a372bb..3e66a25a 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -16,14 +17,25 @@ func TestScan(t *testing.T) { DB.Save(&user1).Save(&user2).Save(&user3) type result struct { + ID uint Name string Age int } var res result - DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) - if res.Name != user3.Name || res.Age != int(user3.Age) { - t.Errorf("Scan into struct should work") + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) + if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) + } + + DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } var doubleAgeRes = &result{}