Fix Pluck with Time and Scanner

This commit is contained in:
Jinzhu 2020-09-24 20:08:24 +08:00
parent c0de3c5051
commit ba253982bf
3 changed files with 43 additions and 4 deletions

13
scan.go
View File

@ -5,6 +5,7 @@ import (
"database/sql/driver" "database/sql/driver"
"reflect" "reflect"
"strings" "strings"
"time"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -82,7 +83,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
scanIntoMap(mapValue, values, columns) scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue) *dest = append(*dest, mapValue)
} }
case *int, *int64, *uint, *uint64, *float32, *float64, *string: case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time:
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
@ -134,7 +135,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
// pluck values into slice of data // pluck values into slice of data
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok {
isPluck = true
} else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) {
isPluck = true
}
}
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++

View File

@ -18,6 +18,8 @@ type DataType string
type TimeType int64 type TimeType int64
var TimeReflectType = reflect.TypeOf(time.Time{})
const ( const (
UnixSecond TimeType = 1 UnixSecond TimeType = 1
UnixMillisecond TimeType = 2 UnixMillisecond TimeType = 2
@ -102,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var getRealFieldValue func(reflect.Value) var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) { getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v) rv := reflect.Indirect(v)
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
for i := 0; i < rv.Type().NumField(); i++ { for i := 0; i < rv.Type().NumField(); i++ {
newFieldType := rv.Type().Field(i).Type newFieldType := rv.Type().Field(i).Type
for newFieldType.Kind() == reflect.Ptr { for newFieldType.Kind() == reflect.Ptr {
@ -221,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
case reflect.Struct: case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok { if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { } else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
field.DataType = Time field.DataType = Time

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -431,6 +432,33 @@ func TestPluck(t *testing.T) {
t.Errorf("Unexpected result on pluck id, got %+v", ids) t.Errorf("Unexpected result on pluck id, got %+v", ids)
} }
} }
var times []time.Time
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &times).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range times {
AssertEqual(t, tv, users[idx].CreatedAt)
}
var ptrtimes []*time.Time
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range ptrtimes {
AssertEqual(t, tv, users[idx].CreatedAt)
}
var nulltimes []sql.NullTime
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)
}
for idx, tv := range nulltimes {
AssertEqual(t, tv.Time, users[idx].CreatedAt)
}
} }
func TestSelect(t *testing.T) { func TestSelect(t *testing.T) {