mirror of https://github.com/go-gorm/gorm.git
Fix Scanner tests on mssql
This commit is contained in:
parent
7b6b9c4d22
commit
2b56fa0472
|
@ -87,7 +87,7 @@ func Create(db *gorm.DB) {
|
||||||
|
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
err = rows.Scan(values)
|
db.AddError(rows.Scan(values))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -80,6 +81,15 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||||
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
|
var numericPlaceholder = regexp.MustCompile("@p(\\d+)")
|
||||||
|
|
||||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||||
|
for idx, v := range vars {
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
v, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := v.(bool); ok {
|
||||||
|
vars[idx] = strconv.FormatBool(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,7 +113,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||||
}
|
}
|
||||||
return sqlType
|
return sqlType
|
||||||
case schema.Float:
|
case schema.Float:
|
||||||
return "decimal"
|
return "float"
|
||||||
case schema.String:
|
case schema.String:
|
||||||
size := field.Size
|
size := field.Size
|
||||||
if field.PrimaryKey && size == 0 {
|
if field.PrimaryKey && size == 0 {
|
||||||
|
@ -116,7 +126,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||||
case schema.Time:
|
case schema.Time:
|
||||||
return "datetimeoffset"
|
return "datetimeoffset"
|
||||||
case schema.Bytes:
|
case schema.Bytes:
|
||||||
return "binary"
|
return "varbinary(MAX)"
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/jinzhu/gorm
|
||||||
go 1.14
|
go 1.14
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd
|
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
|
||||||
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
|
github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5
|
||||||
github.com/go-sql-driver/mysql v1.5.0
|
github.com/go-sql-driver/mysql v1.5.0
|
||||||
github.com/jinzhu/inflection v1.0.0
|
github.com/jinzhu/inflection v1.0.0
|
||||||
|
|
14
scan.go
14
scan.go
|
@ -20,7 +20,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
rows.Scan(values...)
|
db.AddError(rows.Scan(values...))
|
||||||
}
|
}
|
||||||
|
|
||||||
mapValue, ok := dest.(map[string]interface{})
|
mapValue, ok := dest.(map[string]interface{})
|
||||||
|
@ -41,7 +41,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
rows.Scan(values...)
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
v := map[string]interface{}{}
|
v := map[string]interface{}{}
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
|
@ -53,7 +53,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
rows.Scan(dest)
|
db.AddError(rows.Scan(dest))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
@ -96,9 +96,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
if err := rows.Scan(values...); err != nil {
|
db.AddError(rows.Scan(values...))
|
||||||
db.AddError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isPtr {
|
if isPtr {
|
||||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
||||||
|
@ -130,9 +128,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
if err := rows.Scan(values...); err != nil {
|
db.AddError(rows.Scan(values...))
|
||||||
db.AddError(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,6 +103,9 @@ func (data *EncryptedData) Scan(value interface{}) error {
|
||||||
|
|
||||||
*data = b[3:]
|
*data = b[3:]
|
||||||
return nil
|
return nil
|
||||||
|
} else if s, ok := value.(string); ok {
|
||||||
|
*data = []byte(s)[3:]
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.New("Bytes expected")
|
return errors.New("Bytes expected")
|
||||||
|
|
|
@ -121,6 +121,11 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
||||||
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
|
expect = reflect.Indirect(reflect.ValueOf(expect)).Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() {
|
||||||
|
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
||||||
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
||||||
isEqual()
|
isEqual()
|
||||||
|
|
Loading…
Reference in New Issue