Test embedded struct implements Scan & Value interface

This commit is contained in:
Jinzhu 2020-06-18 09:15:23 +08:00
parent 6b2f37189e
commit 96368eb967
5 changed files with 58 additions and 21 deletions

View File

@ -44,10 +44,6 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
} }
func (m Migrator) DataTypeOf(field *schema.Field) string { func (m Migrator) DataTypeOf(field *schema.Field) string {
if field.DBDataType != "" {
return field.DBDataType
}
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" {
@ -155,7 +151,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName] field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += fmt.Sprintf("? ?") createTableSQL += fmt.Sprintf("? ?")
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field)) values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field))
createTableSQL += "," createTableSQL += ","
} }

View File

@ -38,7 +38,6 @@ type Field struct {
DBName string DBName string
BindNames []string BindNames []string
DataType DataType DataType DataType
DBDataType string
PrimaryKey bool PrimaryKey bool
AutoIncrement bool AutoIncrement bool
Creatable bool Creatable bool
@ -104,7 +103,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first fields as data type
if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer {
var overrideFieldValue bool var overrideFieldValue bool
if v, err := valuer.Value(); v != nil && err == nil { if v, err := valuer.Value(); v != nil && err == nil {
overrideFieldValue = true overrideFieldValue = true
@ -176,10 +176,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Comment = val field.Comment = val
} }
if val, ok := field.TagSettings["TYPE"]; ok {
field.DBDataType = val
}
switch reflect.Indirect(fieldValue).Kind() { switch reflect.Indirect(fieldValue).Kind() {
case reflect.Bool: case reflect.Bool:
field.DataType = Bool field.DataType = Bool
@ -227,6 +223,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DataType = DataType(dataTyper.GormDataType()) field.DataType = DataType(dataTyper.GormDataType())
} }
if val, ok := field.TagSettings["TYPE"]; ok {
field.DataType = DataType(val)
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" { if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond field.AutoCreateTime = UnixNanosecond
@ -256,10 +256,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if field.DataType == "" && field.DBDataType != "" {
field.DataType = String
}
// setup permission // setup permission
if _, ok := field.TagSettings["-"]; ok { if _, ok := field.TagSettings["-"]; ok {
field.Creatable = false field.Creatable = false
@ -293,7 +289,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) {
var err error var err error
field.Creatable = false field.Creatable = false
field.Updatable = false field.Updatable = false

View File

@ -52,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
if parsedField, ok := s.FieldsByName[f.Name]; !ok { if parsedField, ok := s.FieldsByName[f.Name]; !ok {
t.Errorf("schema %v failed to look up field with name %v", s, f.Name) t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
} else { } else {
tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
if f.DBName != "" { if f.DBName != "" {
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {

View File

@ -1,6 +1,9 @@
package tests_test package tests_test
import ( import (
"database/sql/driver"
"encoding/json"
"errors"
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
@ -102,3 +105,45 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
t.Errorf("Should find correct value for embedded pointer type") t.Errorf("Should find correct value for embedded pointer type")
} }
} }
type Content struct {
Content interface{} `gorm:"type:string"`
}
func (c Content) Value() (driver.Value, error) {
return json.Marshal(c)
}
func (c *Content) Scan(src interface{}) error {
b, ok := src.([]byte)
if !ok {
return errors.New("Embedded.Scan byte assertion failed")
}
var value Content
if err := json.Unmarshal(b, &value); err != nil {
return err
}
*c = value
return nil
}
func TestEmbeddedScanValuer(t *testing.T) {
type HNPost struct {
gorm.Model
Content
}
DB.Migrator().DropTable(&HNPost{})
if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
hnPost := HNPost{Content: Content{Content: "hello world"}}
if err := DB.Create(&hnPost).Error; err != nil {
t.Errorf("Failed to create got error %v", err)
}
}

View File

@ -6,10 +6,10 @@ require (
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 gorm.io/driver/mysql v0.2.0
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/postgres v0.2.0
gorm.io/driver/sqlite v1.0.0 gorm.io/driver/sqlite v1.0.2
gorm.io/driver/sqlserver v0.0.0-20200610080012-25da0c25e81d gorm.io/driver/sqlserver v0.2.0
gorm.io/gorm v0.0.0-00010101000000-000000000000 gorm.io/gorm v0.0.0-00010101000000-000000000000
) )