Better support for scan into map, fix unfriendly data type for interface, close #3351

This commit is contained in:
Jinzhu 2020-09-01 16:55:30 +08:00
parent e6f4b711a7
commit e73147fa8e
2 changed files with 149 additions and 27 deletions

72
scan.go
View File

@ -2,12 +2,52 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"reflect" "reflect"
"strings" "strings"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func Scan(rows *sql.Rows, db *DB, initialized bool) { func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns() columns, _ := rows.Columns()
values := make([]interface{}, len(columns)) values := make([]interface{}, len(columns))
@ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}: case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() { if initialized || rows.Next() {
for idx := range columns { columnTypes, _ := rows.ColumnTypes()
values[idx] = new(interface{}) prepareValues(values, db, columnTypes, columns)
}
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
@ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
mapValue = *v mapValue = *v
} }
} }
scanIntoMap(mapValue, values, columns)
for idx, column := range columns {
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
} }
case *[]map[string]interface{}: case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() { for initialized || rows.Next() {
for idx := range columns { prepareValues(values, db, columnTypes, columns)
values[idx] = new(interface{})
}
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{} mapValue := map[string]interface{}{}
for idx, column := range columns { scanIntoMap(mapValue, values, columns)
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
*dest = append(*dest, mapValue) *dest = append(*dest, mapValue)
} }
case *int, *int64, *uint, *uint64, *float32, *float64: case *int, *int64, *uint, *uint64, *float32, *float64:

View File

@ -6,6 +6,7 @@ import (
"regexp" "regexp"
"sort" "sort"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -61,6 +62,54 @@ func TestFind(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} { for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name) dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := first[dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Age":
if _, ok := first[dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
case "Birthday":
if _, ok := first[dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
})
}
}
})
t.Run("FirstMapWithTable", func(t *testing.T) {
var first = map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
t.Errorf("errors happened when query first: %v", err)
} else {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(first[dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
}) })
@ -86,13 +135,29 @@ func TestFind(t *testing.T) {
t.Run("FirstSliceOfMap", func(t *testing.T) { t.Run("FirstSliceOfMap", func(t *testing.T) {
var allMap = []map[string]interface{}{} var allMap = []map[string]interface{}{}
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query first: %v", err) t.Errorf("errors happened when query find: %v", err)
} else { } else {
for idx, user := range users { for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} { for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name) dbName := DB.NamingStrategy.ColumnName("", name)
switch name {
case "Name":
if _, ok := allMap[idx][dbName].(string); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Age":
if _, ok := allMap[idx][dbName].(uint); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
case "Birthday":
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user)) reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
}) })
@ -101,6 +166,43 @@ func TestFind(t *testing.T) {
} }
} }
}) })
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
var allMap = []map[string]interface{}{}
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
t.Errorf("errors happened when query find: %v", err)
} else {
for idx, user := range users {
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
for _, name := range []string{"Name", "Age", "Birthday"} {
t.Run(name, func(t *testing.T) {
dbName := DB.NamingStrategy.ColumnName("", name)
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
switch name {
case "Name":
if !strings.Contains(resultType, "string") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Age":
if !strings.Contains(resultType, "int") {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
case "Birthday":
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
}
}
reflectValue := reflect.Indirect(reflect.ValueOf(user))
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
})
}
})
}
}
})
} }
func TestQueryWithAssociation(t *testing.T) { func TestQueryWithAssociation(t *testing.T) {