Better support for scan into map, fix unfriendly data type for interface, close #3351

This commit is contained in:
Jinzhu 2020-09-01 16:55:30 +08:00
parent e6f4b711a7
commit e73147fa8e
2 changed files with 149 additions and 27 deletions

72
scan.go
View File

@ -2,12 +2,52 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"reflect"
"strings"
"gorm.io/gorm/schema"
)
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
@ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++
db.AddError(rows.Scan(values...))
@ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
mapValue = *v
}
}
for idx, column := range columns {
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
prepareValues(values, db, columnTypes, columns)
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
for idx, column := range columns {
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
}
case *int, *int64, *uint, *uint64, *float32, *float64:

View File

@ -6,6 +6,7 @@ import (
"regexp"
"sort"
"strconv"
"strings"
"testing"
"time"
@ -61,6 +62,54 @@ func TestFind(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := first[dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Age":
if _, ok := first[dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Birthday":
if _, ok := first[dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
})
}
}
})
t.Run("FirstMapWithTable", func(t *testing.T) {
var first = map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(first[dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
})
@ -86,13 +135,29 @@ func TestFind(t *testing.T) {
t.Run("FirstSliceOfMap", func(t *testing.T) {
var allMap = []map[string]interface{}{}
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
t.Errorf("errors happened when query find: %v", err)
} else {
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := allMap[idx][dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Age":
if _, ok := allMap[idx][dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Birthday":
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
@ -101,6 +166,43 @@ func TestFind(t *testing.T) {
}
}
})
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
var allMap = []map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query find: %v", err)
} else {
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
}
})
}
}
})
}
func TestQueryWithAssociation(t *testing.T) {