Fix Scanner tests on mssql

This commit is contained in:
Jinzhu 2020-05-31 01:21:16 +08:00
parent 7b6b9c4d22
commit 2b56fa0472
6 changed files with 27 additions and 13 deletions

View File

@ -87,7 +87,7 @@ func Create(db *gorm.DB) {
if rows.Next() {
db.RowsAffected++
err = rows.Scan(values)
db.AddError(rows.Scan(values))
}
}
} else {

View File

@ -2,6 +2,7 @@ package mssql
import (
"database/sql"
"database/sql/driver"
"fmt"
"regexp"
"strconv"
@ -80,6 +81,15 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
for idx, v := range vars {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
if v, ok := v.(bool); ok {
vars[idx] = strconv.FormatBool(v)
}
}
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
@ -103,7 +113,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
}
return sqlType
case schema.Float:
return "decimal"
return "float"
case schema.String:
size := field.Size
if field.PrimaryKey && size == 0 {
@ -116,7 +126,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
case schema.Time:
return "datetimeoffset"
case schema.Bytes:
return "binary"
return "varbinary(MAX)"
}
return ""

2
go.mod
View File

@ -3,7 +3,7 @@ module github.com/jinzhu/gorm
go 1.14
require (
github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
github.com/go-sql-driver/mysql v1.5.0
github.com/jinzhu/inflection v1.0.0

14
scan.go
View File

@ -20,7 +20,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() {
db.RowsAffected++
rows.Scan(values...)
db.AddError(rows.Scan(values...))
}
mapValue, ok := dest.(map[string]interface{})
@ -41,7 +41,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
rows.Scan(values...)
db.AddError(rows.Scan(values...))
v := map[string]interface{}{}
for idx, column := range columns {
@ -53,7 +53,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
rows.Scan(dest)
db.AddError(rows.Scan(dest))
}
default:
switch db.Statement.ReflectValue.Kind() {
@ -96,9 +96,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
db.AddError(rows.Scan(values...))
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
@ -130,9 +128,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() {
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
db.AddError(rows.Scan(values...))
}
}
}

View File

@ -103,6 +103,9 @@ func (data *EncryptedData) Scan(value interface{}) error {
*data = b[3:]
return nil
} else if s, ok := value.(string); ok {
*data = []byte(s)[3:]
return nil
}
return errors.New("Bytes expected")

View File

@ -121,6 +121,11 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
}
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
isEqual()