mirror of https://github.com/go-gorm/gorm.git
Add Scanner, Valuer tests
This commit is contained in:
parent
c422d75f4b
commit
c291c2f42c
|
@ -1,6 +1,9 @@
|
|||
package clause
|
||||
|
||||
import "reflect"
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Expression expression interface
|
||||
type Expression interface {
|
||||
|
@ -28,16 +31,20 @@ func (expr Expr) Build(builder Builder) {
|
|||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '?' {
|
||||
if afterParenthesis {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
|
|
|
@ -57,6 +57,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
|||
vars[idx] = "NULL"
|
||||
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
vars[idx] = "NULL"
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
convertParams(v, idx)
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else {
|
||||
|
@ -74,10 +77,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
|
|||
}
|
||||
|
||||
for idx, v := range vars {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
convertParams(v, idx)
|
||||
}
|
||||
|
||||
|
|
|
@ -207,7 +207,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
field.DBDataType = val
|
||||
}
|
||||
|
||||
switch fieldValue.Elem().Kind() {
|
||||
switch reflect.Indirect(fieldValue).Kind() {
|
||||
case reflect.Bool:
|
||||
field.DataType = Bool
|
||||
if field.HasDefaultValue {
|
||||
|
|
|
@ -146,6 +146,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||
case clause.Expr:
|
||||
writer.WriteString(v.SQL)
|
||||
stmt.Vars = append(stmt.Vars, v.Vars...)
|
||||
case driver.Valuer:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
case []interface{}:
|
||||
if len(v) > 0 {
|
||||
writer.WriteByte('(')
|
||||
|
|
|
@ -0,0 +1,175 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestScannerValuer(t *testing.T) {
|
||||
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
||||
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
||||
t.Errorf("no error should happen when migrate scanner, valuer struct")
|
||||
}
|
||||
|
||||
data := ScannerValuerStruct{
|
||||
Name: sql.NullString{String: "name", Valid: true},
|
||||
Gender: &sql.NullString{String: "M", Valid: true},
|
||||
Age: sql.NullInt64{Int64: 18, Valid: true},
|
||||
Male: sql.NullBool{Bool: true, Valid: true},
|
||||
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
|
||||
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
Password: EncryptedData("pass1"),
|
||||
Num: 18,
|
||||
Strings: StringsSlice{"a", "b", "c"},
|
||||
Structs: StructsSlice{
|
||||
{"name1", "value1"},
|
||||
{"name2", "value2"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&data).Error; err != nil {
|
||||
t.Errorf("No error should happend when create scanner valuer struct, but got %v", err)
|
||||
}
|
||||
|
||||
var result ScannerValuerStruct
|
||||
|
||||
if err := DB.Find(&result).Error; err != nil {
|
||||
t.Errorf("no error should happen when query scanner, valuer struct, but got %v", err)
|
||||
}
|
||||
|
||||
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Num", "Strings", "Structs")
|
||||
}
|
||||
|
||||
func TestInvalidValuer(t *testing.T) {
|
||||
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
||||
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
||||
t.Errorf("no error should happen when migrate scanner, valuer struct")
|
||||
}
|
||||
|
||||
data := ScannerValuerStruct{
|
||||
Password: EncryptedData("xpass1"),
|
||||
}
|
||||
|
||||
if err := DB.Create(&data).Error; err == nil {
|
||||
t.Errorf("Should failed to create data with invalid data")
|
||||
}
|
||||
|
||||
data.Password = EncryptedData("pass1")
|
||||
if err := DB.Create(&data).Error; err != nil {
|
||||
t.Errorf("Should got no error when creating data, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil {
|
||||
t.Errorf("Should failed to update data with invalid data")
|
||||
}
|
||||
|
||||
if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil {
|
||||
t.Errorf("Should got no error update data with valid data, but got %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, data.Password, EncryptedData("newpass"))
|
||||
}
|
||||
|
||||
type ScannerValuerStruct struct {
|
||||
gorm.Model
|
||||
Name sql.NullString
|
||||
Gender *sql.NullString
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
Birthday sql.NullTime
|
||||
Password EncryptedData
|
||||
Num Num
|
||||
Strings StringsSlice
|
||||
Structs StructsSlice
|
||||
}
|
||||
|
||||
type EncryptedData []byte
|
||||
|
||||
func (data *EncryptedData) Scan(value interface{}) error {
|
||||
if b, ok := value.([]byte); ok {
|
||||
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
|
||||
return errors.New("Too short")
|
||||
}
|
||||
|
||||
*data = b[3:]
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("Bytes expected")
|
||||
}
|
||||
|
||||
func (data EncryptedData) Value() (driver.Value, error) {
|
||||
if len(data) > 0 && data[0] == 'x' {
|
||||
//needed to test failures
|
||||
return nil, errors.New("Should not start with 'x'")
|
||||
}
|
||||
|
||||
//prepend asterisks
|
||||
return append([]byte("***"), data...), nil
|
||||
}
|
||||
|
||||
type Num int64
|
||||
|
||||
func (i *Num) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
n, _ := strconv.Atoi(string(s))
|
||||
*i = Num(n)
|
||||
case int64:
|
||||
*i = Num(s)
|
||||
default:
|
||||
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type StringsSlice []string
|
||||
|
||||
func (l StringsSlice) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(l)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
func (l *StringsSlice) Scan(input interface{}) error {
|
||||
switch value := input.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), l)
|
||||
case []byte:
|
||||
return json.Unmarshal(value, l)
|
||||
default:
|
||||
return errors.New("not supported")
|
||||
}
|
||||
}
|
||||
|
||||
type ExampleStruct struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
type StructsSlice []ExampleStruct
|
||||
|
||||
func (l StructsSlice) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(l)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
func (l *StructsSlice) Scan(input interface{}) error {
|
||||
switch value := input.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), l)
|
||||
case []byte:
|
||||
return json.Unmarshal(value, l)
|
||||
default:
|
||||
return errors.New("not supported")
|
||||
}
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -89,12 +91,12 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||
if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) {
|
||||
t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format))
|
||||
}
|
||||
} else if got != expect {
|
||||
} else if fmt.Sprint(got) != fmt.Sprint(expect) {
|
||||
t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got)
|
||||
}
|
||||
}
|
||||
|
||||
if got == expect {
|
||||
if fmt.Sprint(got) == fmt.Sprint(expect) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -103,6 +105,14 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
if valuer, ok := got.(driver.Valuer); ok {
|
||||
got, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if valuer, ok := expect.(driver.Valuer); ok {
|
||||
expect, _ = valuer.Value()
|
||||
}
|
||||
|
||||
if got != nil {
|
||||
got = reflect.Indirect(reflect.ValueOf(got)).Interface()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue