mirror of https://github.com/go-gorm/gorm.git
Add more tests for setter, valuer
This commit is contained in:
parent
faee069a9f
commit
18236fa3d7
129
schema/field.go
129
schema/field.go
|
@ -164,6 +164,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
|
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
||||||
|
field.DataType = Time
|
||||||
}
|
}
|
||||||
case reflect.Array, reflect.Slice:
|
case reflect.Array, reflect.Slice:
|
||||||
if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) {
|
if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) {
|
||||||
|
@ -311,6 +313,24 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
||||||
|
reflectV := reflect.ValueOf(v)
|
||||||
|
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
|
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
||||||
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
if v, err = valuer.Value(); err == nil {
|
||||||
|
return setter(value, v)
|
||||||
|
}
|
||||||
|
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
||||||
|
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
||||||
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
|
return field.Setter(value, reflectV.Elem().Interface())
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Setter
|
// Setter
|
||||||
switch field.FieldType.Kind() {
|
switch field.FieldType.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
@ -321,17 +341,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
case *bool:
|
case *bool:
|
||||||
field.ReflectValuer(value).SetBool(*data)
|
field.ReflectValuer(value).SetBool(*data)
|
||||||
default:
|
default:
|
||||||
reflectV := reflect.ValueOf(v)
|
return recoverFunc(value, v, field.Setter)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
|
||||||
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
|
||||||
} else {
|
|
||||||
field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
field.Setter = func(value reflect.Value, v interface{}) error {
|
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
case int64:
|
case int64:
|
||||||
field.ReflectValuer(value).SetInt(data)
|
field.ReflectValuer(value).SetInt(data)
|
||||||
|
@ -366,19 +381,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
reflectV := reflect.ValueOf(v)
|
return recoverFunc(value, v, field.Setter)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
|
||||||
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
|
||||||
return field.Setter(value, reflectV.Elem().Interface())
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
|
||||||
}
|
}
|
||||||
}
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
field.Setter = func(value reflect.Value, v interface{}) error {
|
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
case uint64:
|
case uint64:
|
||||||
field.ReflectValuer(value).SetUint(data)
|
field.ReflectValuer(value).SetUint(data)
|
||||||
|
@ -413,19 +421,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
reflectV := reflect.ValueOf(v)
|
return recoverFunc(value, v, field.Setter)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
|
||||||
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
|
||||||
return field.Setter(value, reflectV.Elem().Interface())
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
|
||||||
}
|
}
|
||||||
}
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
field.Setter = func(value reflect.Value, v interface{}) error {
|
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
case float64:
|
case float64:
|
||||||
field.ReflectValuer(value).SetFloat(data)
|
field.ReflectValuer(value).SetFloat(data)
|
||||||
|
@ -460,19 +461,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
reflectV := reflect.ValueOf(v)
|
return recoverFunc(value, v, field.Setter)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
|
||||||
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
|
||||||
return field.Setter(value, reflectV.Elem().Interface())
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
|
||||||
}
|
}
|
||||||
}
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.Setter = func(value reflect.Value, v interface{}) error {
|
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
case string:
|
case string:
|
||||||
field.ReflectValuer(value).SetString(data)
|
field.ReflectValuer(value).SetString(data)
|
||||||
|
@ -483,16 +477,9 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
case float64, float32:
|
case float64, float32:
|
||||||
field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
||||||
default:
|
default:
|
||||||
reflectV := reflect.ValueOf(v)
|
return recoverFunc(value, v, field.Setter)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
|
||||||
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
|
||||||
return field.Setter(value, reflectV.Elem().Interface())
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
|
||||||
}
|
}
|
||||||
}
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
fieldValue := reflect.New(field.FieldType)
|
fieldValue := reflect.New(field.FieldType)
|
||||||
|
@ -511,7 +498,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
|
return recoverFunc(value, v, field.Setter)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -529,29 +516,18 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
|
return recoverFunc(value, v, field.Setter)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||||
|
// struct scanner
|
||||||
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
reflectV := reflect.ValueOf(v)
|
||||||
if v, err = valuer.Value(); err == nil {
|
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
|
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
||||||
}
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||||
} else {
|
|
||||||
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if fieldValue.CanAddr() {
|
|
||||||
if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
|
||||||
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
|
||||||
if v, err = valuer.Value(); err == nil {
|
if v, err = valuer.Value(); err == nil {
|
||||||
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
|
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
|
@ -560,31 +536,26 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||||
}
|
// pointer scanner
|
||||||
}
|
|
||||||
|
|
||||||
if field.FieldType.Kind() == reflect.Ptr {
|
|
||||||
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
||||||
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
||||||
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
||||||
} else {
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
if v, err = valuer.Value(); err == nil {
|
||||||
|
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
return nil
|
} else {
|
||||||
|
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
field.Setter = func(value reflect.Value, v interface{}) (err error) {
|
||||||
reflectV := reflect.ValueOf(v)
|
return recoverFunc(value, v, field.Setter)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
|
||||||
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package schema_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -13,8 +14,7 @@ import (
|
||||||
|
|
||||||
func TestFieldValuerAndSetter(t *testing.T) {
|
func TestFieldValuerAndSetter(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
cacheMap = sync.Map{}
|
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
|
|
||||||
user = tests.User{
|
user = tests.User{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
ID: 10,
|
ID: 10,
|
||||||
|
@ -54,18 +54,36 @@ func TestFieldValuerAndSetter(t *testing.T) {
|
||||||
|
|
||||||
for k, v := range newValues {
|
for k, v := range newValues {
|
||||||
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
||||||
t.Errorf("no error should happen when assign value to field %v", k)
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
checkField(t, userSchema, reflectValue, newValues)
|
checkField(t, userSchema, reflectValue, newValues)
|
||||||
|
|
||||||
|
// test valuer and other type
|
||||||
|
age := myint(10)
|
||||||
|
newValues2 := map[string]interface{}{
|
||||||
|
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
|
||||||
|
"id": &sql.NullInt64{Int64: 3, Valid: true},
|
||||||
|
"created_at": tests.Now(),
|
||||||
|
"deleted_at": time.Now(),
|
||||||
|
"age": &age,
|
||||||
|
"birthday": mytime(time.Now()),
|
||||||
|
"active": mybool(true),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range newValues2 {
|
||||||
|
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
||||||
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkField(t, userSchema, reflectValue, newValues2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPointerFieldValuerAndSetter(t *testing.T) {
|
func TestPointerFieldValuerAndSetter(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
cacheMap = sync.Map{}
|
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{})
|
|
||||||
name = "pointer_field_valuer_and_setter"
|
name = "pointer_field_valuer_and_setter"
|
||||||
age = 18
|
age uint = 18
|
||||||
active = true
|
active = true
|
||||||
user = User{
|
user = User{
|
||||||
Model: &gorm.Model{
|
Model: &gorm.Model{
|
||||||
|
@ -110,22 +128,91 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
checkField(t, userSchema, reflectValue, newValues)
|
checkField(t, userSchema, reflectValue, newValues)
|
||||||
|
|
||||||
|
// test valuer and other type
|
||||||
|
age2 := myint(10)
|
||||||
|
newValues2 := map[string]interface{}{
|
||||||
|
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
|
||||||
|
"id": &sql.NullInt64{Int64: 3, Valid: true},
|
||||||
|
"created_at": tests.Now(),
|
||||||
|
"deleted_at": time.Now(),
|
||||||
|
"age": &age2,
|
||||||
|
"birthday": mytime(time.Now()),
|
||||||
|
"active": mybool(true),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range newValues2 {
|
||||||
|
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
||||||
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkField(t, userSchema, reflectValue, newValues2)
|
||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
|
||||||
*gorm.Model
|
var (
|
||||||
Name *string
|
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
Age *int
|
name = "advanced_data_type_valuer_and_setter"
|
||||||
Birthday *time.Time
|
deletedAt = mytime(time.Now())
|
||||||
Account *tests.Account
|
isAdmin = mybool(false)
|
||||||
Pets []*tests.Pet
|
user = AdvancedDataTypeUser{
|
||||||
Toys []tests.Toy `gorm:"polymorphic:Owner"`
|
ID: sql.NullInt64{Int64: 10, Valid: true},
|
||||||
CompanyID *int
|
Name: &sql.NullString{String: name, Valid: true},
|
||||||
Company *tests.Company
|
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
|
||||||
ManagerID *int
|
RegisteredAt: mytime(time.Now()),
|
||||||
Manager *User
|
DeletedAt: &deletedAt,
|
||||||
Team []User `gorm:"foreignkey:ManagerID"`
|
Active: mybool(true),
|
||||||
Languages []tests.Language `gorm:"many2many:UserSpeak"`
|
Admin: &isAdmin,
|
||||||
Friends []*User `gorm:"many2many:user_friends"`
|
}
|
||||||
Active *bool
|
reflectValue = reflect.ValueOf(&user)
|
||||||
|
)
|
||||||
|
|
||||||
|
// test valuer
|
||||||
|
values := map[string]interface{}{
|
||||||
|
"id": user.ID,
|
||||||
|
"name": user.Name,
|
||||||
|
"birthday": user.Birthday,
|
||||||
|
"registered_at": user.RegisteredAt,
|
||||||
|
"deleted_at": user.DeletedAt,
|
||||||
|
"active": user.Active,
|
||||||
|
"admin": user.Admin,
|
||||||
|
}
|
||||||
|
checkField(t, userSchema, reflectValue, values)
|
||||||
|
|
||||||
|
// test setter
|
||||||
|
newDeletedAt := mytime(time.Now())
|
||||||
|
newIsAdmin := mybool(true)
|
||||||
|
newValues := map[string]interface{}{
|
||||||
|
"id": sql.NullInt64{Int64: 1, Valid: true},
|
||||||
|
"name": &sql.NullString{String: name + "rename", Valid: true},
|
||||||
|
"birthday": time.Now(),
|
||||||
|
"registered_at": mytime(time.Now()),
|
||||||
|
"deleted_at": &newDeletedAt,
|
||||||
|
"active": mybool(false),
|
||||||
|
"admin": &newIsAdmin,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range newValues {
|
||||||
|
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
||||||
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkField(t, userSchema, reflectValue, newValues)
|
||||||
|
|
||||||
|
newValues2 := map[string]interface{}{
|
||||||
|
"id": 5,
|
||||||
|
"name": name + "rename2",
|
||||||
|
"birthday": time.Now(),
|
||||||
|
"registered_at": time.Now(),
|
||||||
|
"deleted_at": time.Now(),
|
||||||
|
"active": true,
|
||||||
|
"admin": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range newValues2 {
|
||||||
|
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
||||||
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkField(t, userSchema, reflectValue, newValues2)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
package schema_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"github.com/jinzhu/gorm/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
*gorm.Model
|
||||||
|
Name *string
|
||||||
|
Age *uint
|
||||||
|
Birthday *time.Time
|
||||||
|
Account *tests.Account
|
||||||
|
Pets []*tests.Pet
|
||||||
|
Toys []*tests.Toy `gorm:"polymorphic:Owner"`
|
||||||
|
CompanyID *int
|
||||||
|
Company *tests.Company
|
||||||
|
ManagerID *int
|
||||||
|
Manager *User
|
||||||
|
Team []*User `gorm:"foreignkey:ManagerID"`
|
||||||
|
Languages []*tests.Language `gorm:"many2many:UserSpeak"`
|
||||||
|
Friends []*User `gorm:"many2many:user_friends"`
|
||||||
|
Active *bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type mytime time.Time
|
||||||
|
type myint int
|
||||||
|
type mybool = bool
|
||||||
|
|
||||||
|
type AdvancedDataTypeUser struct {
|
||||||
|
ID sql.NullInt64
|
||||||
|
Name *sql.NullString
|
||||||
|
Birthday sql.NullTime
|
||||||
|
RegisteredAt mytime
|
||||||
|
DeletedAt *mytime
|
||||||
|
Active mybool
|
||||||
|
Admin *mybool
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package schema_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -194,30 +195,39 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
||||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||||
for k, v := range values {
|
for k, v := range values {
|
||||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||||
field := s.FieldsByDBName[k]
|
var (
|
||||||
fv := field.ValueOf(value)
|
checker func(fv interface{}, v interface{})
|
||||||
|
field = s.FieldsByDBName[k]
|
||||||
|
fv = field.ValueOf(value)
|
||||||
|
)
|
||||||
|
|
||||||
if reflect.ValueOf(fv).Kind() == reflect.Ptr {
|
checker = func(fv interface{}, v interface{}) {
|
||||||
if reflect.ValueOf(v).Kind() == reflect.Ptr {
|
if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v {
|
||||||
if fv != v {
|
t.Errorf("expects: %p, but got %p", v, fv)
|
||||||
t.Errorf("pointer expects: %p, but got %p", v, fv)
|
} else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) {
|
||||||
|
if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv {
|
||||||
|
t.Errorf("expects: %p, but got %p", v, fv)
|
||||||
}
|
}
|
||||||
} else if fv == nil {
|
} else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) {
|
||||||
if v != nil {
|
if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v {
|
||||||
t.Errorf("expects: %+v, but got nil", v)
|
t.Errorf("expects: %p, but got %p", v, fv)
|
||||||
}
|
|
||||||
} else if reflect.ValueOf(fv).Elem().Interface() != v {
|
|
||||||
t.Errorf("expects: %+v, but got %+v", v, fv)
|
|
||||||
}
|
}
|
||||||
|
} else if valuer, isValuer := fv.(driver.Valuer); isValuer {
|
||||||
|
valuerv, _ := valuer.Value()
|
||||||
|
checker(valuerv, v)
|
||||||
|
} else if valuer, isValuer := v.(driver.Valuer); isValuer {
|
||||||
|
valuerv, _ := valuer.Value()
|
||||||
|
checker(fv, valuerv)
|
||||||
|
} else if reflect.ValueOf(fv).Kind() == reflect.Ptr {
|
||||||
|
checker(reflect.ValueOf(fv).Elem().Interface(), v)
|
||||||
} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
|
} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
|
||||||
if reflect.ValueOf(v).Elem().Interface() != fv {
|
checker(fv, reflect.ValueOf(v).Elem().Interface())
|
||||||
t.Errorf("expects: %+v, but got %+v", v, fv)
|
} else {
|
||||||
}
|
|
||||||
} else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) {
|
|
||||||
if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv {
|
|
||||||
t.Errorf("expects: %+v, but got %+v", v, fv)
|
t.Errorf("expects: %+v, but got %+v", v, fv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checker(fv, v)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,13 +9,24 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSchema(t *testing.T) {
|
func TestParseSchema(t *testing.T) {
|
||||||
cacheMap := sync.Map{}
|
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
|
||||||
user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse user, got error %v", err)
|
t.Fatalf("failed to parse user, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkUserSchema(t, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSchemaWithPointerFields(t *testing.T) {
|
||||||
|
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkUserSchema(t, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||||
// check schema
|
// check schema
|
||||||
checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"})
|
checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"})
|
||||||
|
|
||||||
|
@ -101,3 +112,31 @@ func TestParseSchema(t *testing.T) {
|
||||||
checkSchemaRelation(t, user, relation)
|
checkSchemaRelation(t, user, relation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||||
|
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse pointer user, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check schema
|
||||||
|
checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"})
|
||||||
|
|
||||||
|
// check fields
|
||||||
|
fields := []schema.Field{
|
||||||
|
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true},
|
||||||
|
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
||||||
|
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
||||||
|
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
|
||||||
|
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time},
|
||||||
|
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
||||||
|
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
||||||
|
f.Creatable = true
|
||||||
|
f.Updatable = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue