From ba253982bf558543187f3eb88295b88610cdc83b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 20:08:24 +0800 Subject: [PATCH] Fix Pluck with Time and Scanner --- scan.go | 13 +++++++++++-- schema/field.go | 6 ++++-- tests/query_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/scan.go b/scan.go index be8782ed..d7cddbe6 100644 --- a/scan.go +++ b/scan.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "reflect" "strings" + "time" "gorm.io/gorm/schema" ) @@ -82,7 +83,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64, *string: + case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: for initialized || rows.Next() { initialized = false db.RowsAffected++ @@ -134,7 +135,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } // pluck values into slice of data - isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct + isPluck := false + if len(fields) == 1 { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { + isPluck = true + } else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { + isPluck = true + } + } + for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/schema/field.go b/schema/field.go index ce2808a8..db516c33 100644 --- a/schema/field.go +++ b/schema/field.go @@ -18,6 +18,8 @@ type DataType string type TimeType int64 +var TimeReflectType = reflect.TypeOf(time.Time{}) + const ( UnixSecond TimeType = 1 UnixMillisecond TimeType = 2 @@ -102,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { for i := 0; i < rv.Type().NumField(); i++ { newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { @@ -221,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time diff --git a/tests/query_test.go b/tests/query_test.go index 9c9ad9f2..431ccce2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "fmt" "reflect" "regexp" @@ -431,6 +432,33 @@ func TestPluck(t *testing.T) { t.Errorf("Unexpected result on pluck id, got %+v", ids) } } + + var times []time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range times { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var ptrtimes []*time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range ptrtimes { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var nulltimes []sql.NullTime + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range nulltimes { + AssertEqual(t, tv.Time, users[idx].CreatedAt) + } } func TestSelect(t *testing.T) {