forked from mirror/gorm
Fix Pluck with Time and Scanner
This commit is contained in:
parent
c0de3c5051
commit
ba253982bf
13
scan.go
13
scan.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
@ -82,7 +83,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
scanIntoMap(mapValue, values, columns)
|
scanIntoMap(mapValue, values, columns)
|
||||||
*dest = append(*dest, mapValue)
|
*dest = append(*dest, mapValue)
|
||||||
}
|
}
|
||||||
case *int, *int64, *uint, *uint64, *float32, *float64, *string:
|
case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time:
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
|
@ -134,7 +135,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// pluck values into slice of data
|
// pluck values into slice of data
|
||||||
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
|
isPluck := false
|
||||||
|
if len(fields) == 1 {
|
||||||
|
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok {
|
||||||
|
isPluck = true
|
||||||
|
} else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) {
|
||||||
|
isPluck = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
|
|
|
@ -18,6 +18,8 @@ type DataType string
|
||||||
|
|
||||||
type TimeType int64
|
type TimeType int64
|
||||||
|
|
||||||
|
var TimeReflectType = reflect.TypeOf(time.Time{})
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UnixSecond TimeType = 1
|
UnixSecond TimeType = 1
|
||||||
UnixMillisecond TimeType = 2
|
UnixMillisecond TimeType = 2
|
||||||
|
@ -102,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
var getRealFieldValue func(reflect.Value)
|
var getRealFieldValue func(reflect.Value)
|
||||||
getRealFieldValue = func(v reflect.Value) {
|
getRealFieldValue = func(v reflect.Value) {
|
||||||
rv := reflect.Indirect(v)
|
rv := reflect.Indirect(v)
|
||||||
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
|
||||||
for i := 0; i < rv.Type().NumField(); i++ {
|
for i := 0; i < rv.Type().NumField(); i++ {
|
||||||
newFieldType := rv.Type().Field(i).Type
|
newFieldType := rv.Type().Field(i).Type
|
||||||
for newFieldType.Kind() == reflect.Ptr {
|
for newFieldType.Kind() == reflect.Ptr {
|
||||||
|
@ -221,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
|
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -431,6 +432,33 @@ func TestPluck(t *testing.T) {
|
||||||
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var times []time.Time
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, tv := range times {
|
||||||
|
AssertEqual(t, tv, users[idx].CreatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ptrtimes []*time.Time
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, tv := range ptrtimes {
|
||||||
|
AssertEqual(t, tv, users[idx].CreatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nulltimes []sql.NullTime
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, tv := range nulltimes {
|
||||||
|
AssertEqual(t, tv.Time, users[idx].CreatedAt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelect(t *testing.T) {
|
func TestSelect(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue