forked from mirror/gorm
Fix Scanner tests on mssql
This commit is contained in:
parent
7b6b9c4d22
commit
2b56fa0472
|
@ -87,7 +87,7 @@ func Create(db *gorm.DB) {
|
|||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
err = rows.Scan(values)
|
||||
db.AddError(rows.Scan(values))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -2,6 +2,7 @@ package mssql
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
@ -80,6 +81,15 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
|||
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
for idx, v := range vars {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if v, ok := v.(bool); ok {
|
||||
vars[idx] = strconv.FormatBool(v)
|
||||
}
|
||||
}
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||
}
|
||||
|
||||
|
@ -103,7 +113,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
|||
}
|
||||
return sqlType
|
||||
case schema.Float:
|
||||
return "decimal"
|
||||
return "float"
|
||||
case schema.String:
|
||||
size := field.Size
|
||||
if field.PrimaryKey && size == 0 {
|
||||
|
@ -116,7 +126,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
|||
case schema.Time:
|
||||
return "datetimeoffset"
|
||||
case schema.Bytes:
|
||||
return "binary"
|
||||
return "varbinary(MAX)"
|
||||
}
|
||||
|
||||
return ""
|
||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/jinzhu/gorm
|
|||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd
|
||||
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
|
||||
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
|
|
14
scan.go
14
scan.go
|
@ -20,7 +20,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
|
||||
if initialized || rows.Next() {
|
||||
db.RowsAffected++
|
||||
rows.Scan(values...)
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
|
||||
mapValue, ok := dest.(map[string]interface{})
|
||||
|
@ -41,7 +41,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
rows.Scan(values...)
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
v := map[string]interface{}{}
|
||||
for idx, column := range columns {
|
||||
|
@ -53,7 +53,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
rows.Scan(dest)
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
default:
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
|
@ -96,9 +96,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
|
||||
db.RowsAffected++
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
if isPtr {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
||||
|
@ -130,9 +128,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
|
||||
if initialized || rows.Next() {
|
||||
db.RowsAffected++
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,6 +103,9 @@ func (data *EncryptedData) Scan(value interface{}) error {
|
|||
|
||||
*data = b[3:]
|
||||
return nil
|
||||
} else if s, ok := value.(string); ok {
|
||||
*data = []byte(s)[3:]
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("Bytes expected")
|
||||
|
|
|
@ -121,6 +121,11 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
|
||||
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
||||
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
||||
isEqual()
|
||||
|
|
Loading…
Reference in New Issue