diff --git a/scan.go b/scan.go index 0b199029..89d9a07a 100644 --- a/scan.go +++ b/scan.go @@ -2,12 +2,52 @@ package gorm import ( "database/sql" + "database/sql/driver" "reflect" "strings" "gorm.io/gorm/schema" ) +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) @@ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: if initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) db.RowsAffected++ db.AddError(rows.Scan(values...)) @@ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { mapValue = *v } } - - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } + scanIntoMap(mapValue, values, columns) } case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + prepareValues(values, db, columnTypes, columns) initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue := map[string]interface{}{} - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } - + scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } case *int, *int64, *uint, *uint64, *float32, *float64: diff --git a/tests/query_test.go b/tests/query_test.go index d71c813a..6bb68cd3 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -6,6 +6,7 @@ import ( "regexp" "sort" "strconv" + "strings" "testing" "time" @@ -61,6 +62,54 @@ func TestFind(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := first[dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Age": + if _, ok := first[dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Birthday": + if _, ok := first[dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstMapWithTable", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(first[dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) }) @@ -86,13 +135,29 @@ func TestFind(t *testing.T) { t.Run("FirstSliceOfMap", func(t *testing.T) { var allMap = []map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) + t.Errorf("errors happened when query find: %v", err) } else { for idx, user := range users { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := allMap[idx][dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Age": + if _, ok := allMap[idx][dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Birthday": + if _, ok := allMap[idx][dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(user)) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) }) @@ -101,6 +166,43 @@ func TestFind(t *testing.T) { } } }) + + t.Run("FindSliceOfMapWithTable", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + } func TestQueryWithAssociation(t *testing.T) {