mirror of https://github.com/go-gorm/gorm.git
Better support for scan into map, fix unfriendly data type for interface, close #3351
This commit is contained in:
parent
e6f4b711a7
commit
e73147fa8e
72
scan.go
72
scan.go
|
@ -2,12 +2,52 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
for idx, name := range columns {
|
||||||
|
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||||
|
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
} else if len(columnTypes) > 0 {
|
||||||
|
for idx, columnType := range columnTypes {
|
||||||
|
if columnType.ScanType() != nil {
|
||||||
|
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||||
|
} else {
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for idx := range columns {
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
||||||
|
for idx, column := range columns {
|
||||||
|
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
||||||
|
mapValue[column] = reflectValue.Interface()
|
||||||
|
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
||||||
|
mapValue[column], _ = valuer.Value()
|
||||||
|
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
|
||||||
|
mapValue[column] = string(b)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mapValue[column] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
values := make([]interface{}, len(columns))
|
values := make([]interface{}, len(columns))
|
||||||
|
@ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
switch dest := db.Statement.Dest.(type) {
|
switch dest := db.Statement.Dest.(type) {
|
||||||
case map[string]interface{}, *map[string]interface{}:
|
case map[string]interface{}, *map[string]interface{}:
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
for idx := range columns {
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
values[idx] = new(interface{})
|
prepareValues(values, db, columnTypes, columns)
|
||||||
}
|
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
@ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
mapValue = *v
|
mapValue = *v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
scanIntoMap(mapValue, values, columns)
|
||||||
for idx, column := range columns {
|
|
||||||
if v, ok := values[idx].(*interface{}); ok {
|
|
||||||
if v == nil {
|
|
||||||
mapValue[column] = nil
|
|
||||||
} else {
|
|
||||||
mapValue[column] = *v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case *[]map[string]interface{}:
|
case *[]map[string]interface{}:
|
||||||
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
for idx := range columns {
|
prepareValues(values, db, columnTypes, columns)
|
||||||
values[idx] = new(interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
mapValue := map[string]interface{}{}
|
mapValue := map[string]interface{}{}
|
||||||
for idx, column := range columns {
|
scanIntoMap(mapValue, values, columns)
|
||||||
if v, ok := values[idx].(*interface{}); ok {
|
|
||||||
if v == nil {
|
|
||||||
mapValue[column] = nil
|
|
||||||
} else {
|
|
||||||
mapValue[column] = *v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*dest = append(*dest, mapValue)
|
*dest = append(*dest, mapValue)
|
||||||
}
|
}
|
||||||
case *int, *int64, *uint, *uint64, *float32, *float64:
|
case *int, *int64, *uint, *uint64, *float32, *float64:
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -61,6 +62,54 @@ func TestFind(t *testing.T) {
|
||||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if _, ok := first[dbName].(string); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if _, ok := first[dbName].(uint); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if _, ok := first[dbName].(*time.Time); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||||
|
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FirstMapWithTable", func(t *testing.T) {
|
||||||
|
var first = map[string]interface{}{}
|
||||||
|
if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query first: %v", err)
|
||||||
|
} else {
|
||||||
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
resultType := reflect.ValueOf(first[dbName]).Type().Name()
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if !strings.Contains(resultType, "string") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if !strings.Contains(resultType, "int") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
reflectValue := reflect.Indirect(reflect.ValueOf(users[0]))
|
||||||
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface())
|
||||||
})
|
})
|
||||||
|
@ -86,13 +135,29 @@ func TestFind(t *testing.T) {
|
||||||
t.Run("FirstSliceOfMap", func(t *testing.T) {
|
t.Run("FirstSliceOfMap", func(t *testing.T) {
|
||||||
var allMap = []map[string]interface{}{}
|
var allMap = []map[string]interface{}{}
|
||||||
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||||
t.Errorf("errors happened when query first: %v", err)
|
t.Errorf("errors happened when query find: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for idx, user := range users {
|
for idx, user := range users {
|
||||||
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
for _, name := range []string{"Name", "Age", "Birthday"} {
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
dbName := DB.NamingStrategy.ColumnName("", name)
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if _, ok := allMap[idx][dbName].(string); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if _, ok := allMap[idx][dbName].(uint); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if _, ok := allMap[idx][dbName].(*time.Time); !ok {
|
||||||
|
t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||||
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||||
})
|
})
|
||||||
|
@ -101,6 +166,43 @@ func TestFind(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("FindSliceOfMapWithTable", func(t *testing.T) {
|
||||||
|
var allMap = []map[string]interface{}{}
|
||||||
|
if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query find: %v", err)
|
||||||
|
} else {
|
||||||
|
for idx, user := range users {
|
||||||
|
t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
|
for _, name := range []string{"Name", "Age", "Birthday"} {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
dbName := DB.NamingStrategy.ColumnName("", name)
|
||||||
|
resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name()
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "Name":
|
||||||
|
if !strings.Contains(resultType, "string") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Age":
|
||||||
|
if !strings.Contains(resultType, "int") {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
case "Birthday":
|
||||||
|
if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) {
|
||||||
|
t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(user))
|
||||||
|
AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryWithAssociation(t *testing.T) {
|
func TestQueryWithAssociation(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue