diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index b17a2227..c85997fb 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -87,7 +87,7 @@ func Create(db *gorm.DB) { if rows.Next() { db.RowsAffected++ - err = rows.Scan(values) + db.AddError(rows.Scan(values)) } } } else { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de82f375..8e309faf 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -2,6 +2,7 @@ package mssql import ( "database/sql" + "database/sql/driver" "fmt" "regexp" "strconv" @@ -80,6 +81,15 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { var numericPlaceholder = regexp.MustCompile("@p(\\d+)") func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + for idx, v := range vars { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + if v, ok := v.(bool); ok { + vars[idx] = strconv.FormatBool(v) + } + } return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) } @@ -103,7 +113,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return sqlType case schema.Float: - return "decimal" + return "float" case schema.String: size := field.Size if field.PrimaryKey && size == 0 { @@ -116,7 +126,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Time: return "datetimeoffset" case schema.Bytes: - return "binary" + return "varbinary(MAX)" } return "" diff --git a/go.mod b/go.mod index 45bcf69c..7dabdd39 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/jinzhu/gorm go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd + github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 github.com/go-sql-driver/mysql v1.5.0 github.com/jinzhu/inflection v1.0.0 diff --git a/scan.go b/scan.go index c223f6eb..66cb0b94 100644 --- a/scan.go +++ b/scan.go @@ -20,7 +20,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if initialized || rows.Next() { db.RowsAffected++ - rows.Scan(values...) + db.AddError(rows.Scan(values...)) } mapValue, ok := dest.(map[string]interface{}) @@ -41,7 +41,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for initialized || rows.Next() { initialized = false db.RowsAffected++ - rows.Scan(values...) + db.AddError(rows.Scan(values...)) v := map[string]interface{}{} for idx, column := range columns { @@ -53,7 +53,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for initialized || rows.Next() { initialized = false db.RowsAffected++ - rows.Scan(dest) + db.AddError(rows.Scan(dest)) } default: switch db.Statement.ReflectValue.Kind() { @@ -96,9 +96,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } db.RowsAffected++ - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } + db.AddError(rows.Scan(values...)) if isPtr { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) @@ -130,9 +128,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if initialized || rows.Next() { db.RowsAffected++ - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } + db.AddError(rows.Scan(values...)) } } } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 38ffc919..88e7e12e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -103,6 +103,9 @@ func (data *EncryptedData) Scan(value interface{}) error { *data = b[3:] return nil + } else if s, ok := value.(string); ok { + *data = []byte(s)[3:] + return nil } return errors.New("Bytes expected") diff --git a/tests/utils.go b/tests/utils.go index 0a33edee..0add8143 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -121,6 +121,11 @@ func AssertEqual(t *testing.T, got, expect interface{}) { expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() } + if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() isEqual()