Add Scanner, Valuer tests

This commit is contained in:
Jinzhu 2020-05-30 21:05:27 +08:00
parent c422d75f4b
commit c291c2f42c
6 changed files with 211 additions and 17 deletions

View File

@ -1,6 +1,9 @@
package clause
import "reflect"
import (
"database/sql/driver"
"reflect"
)
// Expression expression interface
type Expression interface {
@ -28,16 +31,20 @@ func (expr Expr) Build(builder Builder) {
for _, v := range []byte(expr.SQL) {
if v == '?' {
if afterParenthesis {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
default:
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])

View File

@ -57,6 +57,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
vars[idx] = "NULL"
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = "NULL"
} else if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else {
@ -74,10 +77,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
}
for idx, v := range vars {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
convertParams(v, idx)
}

View File

@ -207,7 +207,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DBDataType = val
}
switch fieldValue.Elem().Kind() {
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Bool:
field.DataType = Bool
if field.HasDefaultValue {

View File

@ -146,6 +146,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
case clause.Expr:
writer.WriteString(v.SQL)
stmt.Vars = append(stmt.Vars, v.Vars...)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []interface{}:
if len(v) > 0 {
writer.WriteByte('(')

View File

@ -0,0 +1,175 @@
package tests_test
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"reflect"
"strconv"
"testing"
"time"
"github.com/jinzhu/gorm"
. "github.com/jinzhu/gorm/tests"
)
func TestScannerValuer(t *testing.T) {
DB.Migrator().DropTable(&ScannerValuerStruct{})
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
t.Errorf("no error should happen when migrate scanner, valuer struct")
}
data := ScannerValuerStruct{
Name: sql.NullString{String: "name", Valid: true},
Gender: &sql.NullString{String: "M", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: true},
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
Password: EncryptedData("pass1"),
Num: 18,
Strings: StringsSlice{"a", "b", "c"},
Structs: StructsSlice{
{"name1", "value1"},
{"name2", "value2"},
},
}
if err := DB.Create(&data).Error; err != nil {
t.Errorf("No error should happend when create scanner valuer struct, but got %v", err)
}
var result ScannerValuerStruct
if err := DB.Find(&result).Error; err != nil {
t.Errorf("no error should happen when query scanner, valuer struct, but got %v", err)
}
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs")
}
func TestInvalidValuer(t *testing.T) {
DB.Migrator().DropTable(&ScannerValuerStruct{})
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
t.Errorf("no error should happen when migrate scanner, valuer struct")
}
data := ScannerValuerStruct{
Password: EncryptedData("xpass1"),
}
if err := DB.Create(&data).Error; err == nil {
t.Errorf("Should failed to create data with invalid data")
}
data.Password = EncryptedData("pass1")
if err := DB.Create(&data).Error; err != nil {
t.Errorf("Should got no error when creating data, but got %v", err)
}
if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil {
t.Errorf("Should failed to update data with invalid data")
}
if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil {
t.Errorf("Should got no error update data with valid data, but got %v", err)
}
AssertEqual(t, data.Password, EncryptedData("newpass"))
}
type ScannerValuerStruct struct {
gorm.Model
Name sql.NullString
Gender *sql.NullString
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
Birthday sql.NullTime
Password EncryptedData
Num Num
Strings StringsSlice
Structs StructsSlice
}
type EncryptedData []byte
func (data *EncryptedData) Scan(value interface{}) error {
if b, ok := value.([]byte); ok {
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
return errors.New("Too short")
}
*data = b[3:]
return nil
}
return errors.New("Bytes expected")
}
func (data EncryptedData) Value() (driver.Value, error) {
if len(data) > 0 && data[0] == 'x' {
//needed to test failures
return nil, errors.New("Should not start with 'x'")
}
//prepend asterisks
return append([]byte("***"), data...), nil
}
type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
n, _ := strconv.Atoi(string(s))
*i = Num(n)
case int64:
*i = Num(s)
default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
}
return nil
}
type StringsSlice []string
func (l StringsSlice) Value() (driver.Value, error) {
bytes, err := json.Marshal(l)
return string(bytes), err
}
func (l *StringsSlice) Scan(input interface{}) error {
switch value := input.(type) {
case string:
return json.Unmarshal([]byte(value), l)
case []byte:
return json.Unmarshal(value, l)
default:
return errors.New("not supported")
}
}
type ExampleStruct struct {
Name string
Value string
}
type StructsSlice []ExampleStruct
func (l StructsSlice) Value() (driver.Value, error) {
bytes, err := json.Marshal(l)
return string(bytes), err
}
func (l *StructsSlice) Scan(input interface{}) error {
switch value := input.(type) {
case string:
return json.Unmarshal([]byte(value), l)
case []byte:
return json.Unmarshal(value, l)
default:
return errors.New("not supported")
}
}

View File

@ -1,6 +1,8 @@
package tests
import (
"database/sql/driver"
"fmt"
"reflect"
"sort"
"strconv"
@ -89,12 +91,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) {
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format))
}
} else if got != expect {
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
}
}
if got == expect {
if fmt.Sprint(got) == fmt.Sprint(expect) {
return
}
@ -103,6 +105,14 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
return
}
if valuer, ok := got.(driver.Valuer); ok {
got, _ = valuer.Value()
}
if valuer, ok := expect.(driver.Valuer); ok {
expect, _ = valuer.Value()
}
if got != nil {
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
}