forked from mirror/gorm
Better support for scan into map, fix unfriendly data type for interface, close #3351
This commit is contained in:
parent
e6f4b711a7
commit
e73147fa8e
72
scan.go
72
scan.go
|
@ -2,12 +2,52 @@ package gorm
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
} else if len(columnTypes) > 0 {
|
||||
for idx, columnType := range columnTypes {
|
||||
if columnType.ScanType() != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||
} else {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
||||
for idx, column := range columns {
|
||||
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
||||
mapValue[column] = reflectValue.Interface()
|
||||
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
||||
mapValue[column], _ = valuer.Value()
|
||||
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
|
||||
mapValue[column] = string(b)
|
||||
}
|
||||
} else {
|
||||
mapValue[column] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
columns, _ := rows.Columns()
|
||||
values := make([]interface{}, len(columns))
|
||||
|
@ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
if initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
@ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
mapValue = *v
|
||||
}
|
||||
}
|
||||
|
||||
for idx, column := range columns {
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if v == nil {
|
||||
mapValue[column] = nil
|
||||
} else {
|
||||
mapValue[column] = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
}
|
||||
case *[]map[string]interface{}:
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
for initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue := map[string]interface{}{}
|
||||
for idx, column := range columns {
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if v == nil {
|
||||
mapValue[column] = nil
|
||||
} else {
|
||||
mapValue[column] = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
*dest = append(*dest, mapValue)
|
||||
}
|
||||
case *int, *int64, *uint, *uint64, *float32, *float64:
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -61,6 +62,54 @@ func TestFind(t *testing.T) {
|
|||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if _, ok := first[dbName].(string); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||
}
|
||||
case "Age":
|
||||
if _, ok := first[dbName].(uint); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if _, ok := first[dbName].(*time.Time); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FirstMapWithTable", func(t *testing.T) {
|
||||
var first = map[string]interface{}{}
|
||||
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
|
||||
t.Errorf("errors happened when query first: %v", err)
|
||||
} else {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
resultType := reflect.ValueOf(first[dbName]).Type().Name()
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if !strings.Contains(resultType, "string") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||
}
|
||||
case "Age":
|
||||
if !strings.Contains(resultType, "int") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
|
@ -86,13 +135,29 @@ func TestFind(t *testing.T) {
|
|||
t.Run("FirstSliceOfMap", func(t *testing.T) {
|
||||
var allMap = []map[string]interface{}{}
|
||||
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||
t.Errorf("errors happened when query first: %v", err)
|
||||
t.Errorf("errors happened when query find: %v", err)
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if _, ok := allMap[idx][dbName].(string); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||
}
|
||||
case "Age":
|
||||
if _, ok := allMap[idx][dbName].(uint); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
|
||||
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
|
@ -101,6 +166,43 @@ func TestFind(t *testing.T) {
|
|||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
|
||||
var allMap = []map[string]interface{}{}
|
||||
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||
t.Errorf("errors happened when query find: %v", err)
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
|
||||
|
||||
switch name {
|
||||
case "Name":
|
||||
if !strings.Contains(resultType, "string") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||
}
|
||||
case "Age":
|
||||
if !strings.Contains(resultType, "int") {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||
}
|
||||
case "Birthday":
|
||||
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestQueryWithAssociation(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue