Fix Scanner tests on mssql

This commit is contained in:
Jinzhu 2020-05-31 01:21:16 +08:00
parent 7b6b9c4d22
commit 2b56fa0472
6 changed files with 27 additions and 13 deletions

View File

@ -87,7 +87,7 @@ func Create(db *gorm.DB) {
if rows.Next() { if rows.Next() {
db.RowsAffected++ db.RowsAffected++
err = rows.Scan(values) db.AddError(rows.Scan(values))
} }
} }
} else { } else {

View File

@ -2,6 +2,7 @@ package mssql
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
@ -80,6 +81,15 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
var numericPlaceholder = regexp.MustCompile("@p(\\d+)") var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
func (dialector Dialector) Explain(sql string, vars ...interface{}) string { func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
for idx, v := range vars {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
if v, ok := v.(bool); ok {
vars[idx] = strconv.FormatBool(v)
}
}
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
} }
@ -103,7 +113,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
} }
return sqlType return sqlType
case schema.Float: case schema.Float:
return "decimal" return "float"
case schema.String: case schema.String:
size := field.Size size := field.Size
if field.PrimaryKey && size == 0 { if field.PrimaryKey && size == 0 {
@ -116,7 +126,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
case schema.Time: case schema.Time:
return "datetimeoffset" return "datetimeoffset"
case schema.Bytes: case schema.Bytes:
return "binary" return "varbinary(MAX)"
} }
return "" return ""

2
go.mod
View File

@ -3,7 +3,7 @@ module github.com/jinzhu/gorm
go 1.14 go 1.14
require ( require (
github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
github.com/go-sql-driver/mysql v1.5.0 github.com/go-sql-driver/mysql v1.5.0
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0

14
scan.go
View File

@ -20,7 +20,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() { if initialized || rows.Next() {
db.RowsAffected++ db.RowsAffected++
rows.Scan(values...) db.AddError(rows.Scan(values...))
} }
mapValue, ok := dest.(map[string]interface{}) mapValue, ok := dest.(map[string]interface{})
@ -41,7 +41,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
rows.Scan(values...) db.AddError(rows.Scan(values...))
v := map[string]interface{}{} v := map[string]interface{}{}
for idx, column := range columns { for idx, column := range columns {
@ -53,7 +53,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
rows.Scan(dest) db.AddError(rows.Scan(dest))
} }
default: default:
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
@ -96,9 +96,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
db.RowsAffected++ db.RowsAffected++
if err := rows.Scan(values...); err != nil { db.AddError(rows.Scan(values...))
db.AddError(err)
}
if isPtr { if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
@ -130,9 +128,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() { if initialized || rows.Next() {
db.RowsAffected++ db.RowsAffected++
if err := rows.Scan(values...); err != nil { db.AddError(rows.Scan(values...))
db.AddError(err)
}
} }
} }
} }

View File

@ -103,6 +103,9 @@ func (data *EncryptedData) Scan(value interface{}) error {
*data = b[3:] *data = b[3:]
return nil return nil
} else if s, ok := value.(string); ok {
*data = []byte(s)[3:]
return nil
} }
return errors.New("Bytes expected") return errors.New("Bytes expected")

View File

@ -121,6 +121,11 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
} }
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
isEqual() isEqual()