mirror of https://github.com/go-gorm/gorm.git
383 lines
10 KiB
Go
383 lines
10 KiB
Go
package tests_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
. "gorm.io/gorm/utils/tests"
|
|
)
|
|
|
|
func TestScannerValuer(t *testing.T) {
|
|
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
|
}
|
|
|
|
data := ScannerValuerStruct{
|
|
Name: sql.NullString{String: "name", Valid: true},
|
|
Gender: &sql.NullString{String: "M", Valid: true},
|
|
Age: sql.NullInt64{Int64: 18, Valid: true},
|
|
Male: sql.NullBool{Bool: true, Valid: true},
|
|
Height: sql.NullFloat64{Float64: 1.8888, Valid: true},
|
|
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
|
Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}},
|
|
Password: EncryptedData("pass1"),
|
|
Bytes: []byte("byte"),
|
|
Num: 18,
|
|
Strings: StringsSlice{"a", "b", "c"},
|
|
Structs: StructsSlice{
|
|
{"name1", "value1"},
|
|
{"name2", "value2"},
|
|
},
|
|
Role: Role{Name: "admin"},
|
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
|
}
|
|
|
|
if err := DB.Create(&data).Error; err != nil {
|
|
t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err)
|
|
}
|
|
|
|
var result ScannerValuerStruct
|
|
|
|
if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
|
|
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
|
}
|
|
|
|
if result.ExampleStructPtr.Val != "value2" {
|
|
t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val)
|
|
}
|
|
|
|
if result.ExampleStruct.Val != "value1" {
|
|
t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct)
|
|
}
|
|
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
|
}
|
|
|
|
func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
|
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
|
t.Errorf("no error should happen when migrate scanner, valuer struct")
|
|
}
|
|
|
|
data := ScannerValuerStruct{
|
|
Name: sql.NullString{String: "name", Valid: true},
|
|
Gender: &sql.NullString{String: "M", Valid: true},
|
|
Age: sql.NullInt64{Int64: 18, Valid: true},
|
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
|
}
|
|
|
|
var result ScannerValuerStruct
|
|
tx := DB.Where(data).FirstOrCreate(&result)
|
|
|
|
if tx.RowsAffected != 1 {
|
|
t.Errorf("RowsAffected should be 1 after create some record")
|
|
}
|
|
|
|
if tx.Error != nil {
|
|
t.Errorf("Should not raise any error, but got %v", tx.Error)
|
|
}
|
|
|
|
AssertObjEqual(t, result, data, "Name", "Gender", "Age")
|
|
|
|
if err := DB.Where(data).Assign(ScannerValuerStruct{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&result).Error; err != nil {
|
|
t.Errorf("Should not raise any error, but got %v", err)
|
|
}
|
|
|
|
if result.Age.Int64 != 18 {
|
|
t.Errorf("should update age to 18")
|
|
}
|
|
|
|
var result2 ScannerValuerStruct
|
|
if err := DB.First(&result2, result.ID).Error; err != nil {
|
|
t.Errorf("got error %v when query with %v", err, result.ID)
|
|
}
|
|
|
|
AssertObjEqual(t, result2, result, "ID", "CreatedAt", "UpdatedAt", "Name", "Gender", "Age")
|
|
}
|
|
|
|
func TestInvalidValuer(t *testing.T) {
|
|
DB.Migrator().DropTable(&ScannerValuerStruct{})
|
|
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
|
|
t.Errorf("no error should happen when migrate scanner, valuer struct")
|
|
}
|
|
|
|
data := ScannerValuerStruct{
|
|
Password: EncryptedData("xpass1"),
|
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
|
}
|
|
|
|
if err := DB.Create(&data).Error; err == nil {
|
|
t.Errorf("Should failed to create data with invalid data")
|
|
}
|
|
|
|
data.Password = EncryptedData("pass1")
|
|
if err := DB.Create(&data).Error; err != nil {
|
|
t.Errorf("Should got no error when creating data, but got %v", err)
|
|
}
|
|
|
|
if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil {
|
|
t.Errorf("Should failed to update data with invalid data")
|
|
}
|
|
|
|
if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil {
|
|
t.Errorf("Should got no error update data with valid data, but got %v", err)
|
|
}
|
|
|
|
AssertEqual(t, data.Password, EncryptedData("newpass"))
|
|
}
|
|
|
|
type ScannerValuerStruct struct {
|
|
gorm.Model
|
|
Name sql.NullString
|
|
Gender *sql.NullString
|
|
Age sql.NullInt64
|
|
Male sql.NullBool
|
|
Height sql.NullFloat64
|
|
Birthday sql.NullTime
|
|
Allergen NullString
|
|
Password EncryptedData
|
|
Bytes []byte
|
|
Num Num
|
|
Strings StringsSlice
|
|
Structs StructsSlice
|
|
Role Role
|
|
UserID *sql.NullInt64
|
|
User User
|
|
EmptyTime EmptyTime
|
|
ExampleStruct ExampleStruct
|
|
ExampleStructPtr *ExampleStruct
|
|
}
|
|
|
|
type EncryptedData []byte
|
|
|
|
func (data *EncryptedData) Scan(value interface{}) error {
|
|
if b, ok := value.([]byte); ok {
|
|
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
|
|
return errors.New("Too short")
|
|
}
|
|
|
|
*data = b[3:]
|
|
return nil
|
|
} else if s, ok := value.(string); ok {
|
|
*data = []byte(s)[3:]
|
|
return nil
|
|
}
|
|
|
|
return errors.New("Bytes expected")
|
|
}
|
|
|
|
func (data EncryptedData) Value() (driver.Value, error) {
|
|
if len(data) > 0 && data[0] == 'x' {
|
|
//needed to test failures
|
|
return nil, errors.New("Should not start with 'x'")
|
|
}
|
|
|
|
//prepend asterisks
|
|
return append([]byte("***"), data...), nil
|
|
}
|
|
|
|
type Num int64
|
|
|
|
func (i *Num) Scan(src interface{}) error {
|
|
switch s := src.(type) {
|
|
case []byte:
|
|
n, _ := strconv.Atoi(string(s))
|
|
*i = Num(n)
|
|
case int64:
|
|
*i = Num(s)
|
|
default:
|
|
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type StringsSlice []string
|
|
|
|
func (l StringsSlice) Value() (driver.Value, error) {
|
|
bytes, err := json.Marshal(l)
|
|
return string(bytes), err
|
|
}
|
|
|
|
func (l *StringsSlice) Scan(input interface{}) error {
|
|
switch value := input.(type) {
|
|
case string:
|
|
return json.Unmarshal([]byte(value), l)
|
|
case []byte:
|
|
return json.Unmarshal(value, l)
|
|
default:
|
|
return errors.New("not supported")
|
|
}
|
|
}
|
|
|
|
type ExampleStruct struct {
|
|
Name string
|
|
Val string
|
|
}
|
|
|
|
func (ExampleStruct) GormDataType() string {
|
|
return "bytes"
|
|
}
|
|
|
|
func (s ExampleStruct) Value() (driver.Value, error) {
|
|
if len(s.Name) == 0 {
|
|
return nil, nil
|
|
}
|
|
// for test, has no practical meaning
|
|
s.Name = ""
|
|
return json.Marshal(s)
|
|
}
|
|
|
|
func (s *ExampleStruct) Scan(src interface{}) error {
|
|
switch value := src.(type) {
|
|
case string:
|
|
return json.Unmarshal([]byte(value), s)
|
|
case []byte:
|
|
return json.Unmarshal(value, s)
|
|
default:
|
|
return errors.New("not supported")
|
|
}
|
|
}
|
|
|
|
type StructsSlice []ExampleStruct
|
|
|
|
func (l StructsSlice) Value() (driver.Value, error) {
|
|
bytes, err := json.Marshal(l)
|
|
return string(bytes), err
|
|
}
|
|
|
|
func (l *StructsSlice) Scan(input interface{}) error {
|
|
switch value := input.(type) {
|
|
case string:
|
|
return json.Unmarshal([]byte(value), l)
|
|
case []byte:
|
|
return json.Unmarshal(value, l)
|
|
default:
|
|
return errors.New("not supported")
|
|
}
|
|
}
|
|
|
|
type Role struct {
|
|
Name string `gorm:"size:256"`
|
|
}
|
|
|
|
func (role *Role) Scan(value interface{}) error {
|
|
if b, ok := value.([]uint8); ok {
|
|
role.Name = string(b)
|
|
} else {
|
|
role.Name = value.(string)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (role Role) Value() (driver.Value, error) {
|
|
return role.Name, nil
|
|
}
|
|
|
|
func (role Role) IsAdmin() bool {
|
|
return role.Name == "admin"
|
|
}
|
|
|
|
type EmptyTime struct {
|
|
time.Time
|
|
}
|
|
|
|
func (t *EmptyTime) Scan(v interface{}) error {
|
|
nullTime := sql.NullTime{}
|
|
err := nullTime.Scan(v)
|
|
t.Time = nullTime.Time
|
|
return err
|
|
}
|
|
|
|
func (t EmptyTime) Value() (driver.Value, error) {
|
|
return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil
|
|
}
|
|
|
|
type NullString struct {
|
|
sql.NullString
|
|
}
|
|
|
|
type Point struct {
|
|
X, Y int
|
|
}
|
|
|
|
func (point *Point) Scan(v interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (point Point) GormDataType() string {
|
|
return "geo"
|
|
}
|
|
|
|
func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
|
return clause.Expr{
|
|
SQL: "ST_PointFromText(?)",
|
|
Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)},
|
|
}
|
|
}
|
|
|
|
func TestGORMValuer(t *testing.T) {
|
|
type UserWithPoint struct {
|
|
Name string
|
|
Point Point
|
|
}
|
|
|
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
|
|
|
stmt := dryRunDB.Create(&UserWithPoint{
|
|
Name: "jinzhu",
|
|
Point: Point{X: 100, Y: 100},
|
|
}).Statement
|
|
|
|
if stmt.SQL.String() == "" || len(stmt.Vars) != 2 {
|
|
t.Errorf("Failed to generate sql, got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
|
}
|
|
|
|
stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{
|
|
"Name": "jinzhu",
|
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
|
}).Statement
|
|
|
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
|
}
|
|
|
|
stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{
|
|
"Name": "jinzhu",
|
|
"Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}},
|
|
}).Statement
|
|
|
|
if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) {
|
|
t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String())
|
|
}
|
|
|
|
if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) {
|
|
t.Errorf("generated vars is not equal, got %v", stmt.Vars)
|
|
}
|
|
}
|