From 96368eb967bbfbab8ef0bdef2e9ff1fcbdee6710 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Jun 2020 09:15:23 +0800 Subject: [PATCH] Test embedded struct implements Scan & Value interface --- migrator/migrator.go | 6 +---- schema/field.go | 18 ++++++-------- schema/schema_helper_test.go | 2 +- tests/embedded_struct_test.go | 45 +++++++++++++++++++++++++++++++++++ tests/go.mod | 8 +++---- 5 files changed, 58 insertions(+), 21 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 955cc6bb..8f872ee4 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -44,10 +44,6 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error } func (m Migrator) DataTypeOf(field *schema.Field) string { - if field.DBDataType != "" { - return field.DBDataType - } - fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { @@ -155,7 +151,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field)) createTableSQL += "," } diff --git a/schema/field.go b/schema/field.go index ea6dcd25..8bfa3b22 100644 --- a/schema/field.go +++ b/schema/field.go @@ -38,7 +38,6 @@ type Field struct { DBName string BindNames []string DataType DataType - DBDataType string PrimaryKey bool AutoIncrement bool Creatable bool @@ -104,7 +103,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // if field is valuer, used its value or first fields as data type - if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { + valuer, isValuer := fieldValue.Interface().(driver.Valuer) + if isValuer { var overrideFieldValue bool if v, err := valuer.Value(); v != nil && err == nil { overrideFieldValue = true @@ -176,10 +176,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } - if val, ok := field.TagSettings["TYPE"]; ok { - field.DBDataType = val - } - switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool @@ -227,6 +223,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = DataType(dataTyper.GormDataType()) } + if val, ok := field.TagSettings["TYPE"]; ok { + field.DataType = DataType(val) + } + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond @@ -256,10 +256,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if field.DataType == "" && field.DBDataType != "" { - field.DataType = String - } - // setup permission if _, ok := field.TagSettings["-"]; ok { field.Creatable = false @@ -293,7 +289,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { var err error field.Creatable = false field.Updatable = false diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index d2e68536..f202b487 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -52,7 +52,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if parsedField, ok := s.FieldsByName[f.Name]; !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") if f.DBName != "" { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 9a1436fe..5f06f63c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -1,6 +1,9 @@ package tests_test import ( + "database/sql/driver" + "encoding/json" + "errors" "testing" "gorm.io/gorm" @@ -102,3 +105,45 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { t.Errorf("Should find correct value for embedded pointer type") } } + +type Content struct { + Content interface{} `gorm:"type:string"` +} + +func (c Content) Value() (driver.Value, error) { + return json.Marshal(c) +} + +func (c *Content) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + + var value Content + if err := json.Unmarshal(b, &value); err != nil { + return err + } + + *c = value + + return nil +} + +func TestEmbeddedScanValuer(t *testing.T) { + type HNPost struct { + gorm.Model + Content + } + + DB.Migrator().DropTable(&HNPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + hnPost := HNPost{Content: Content{Content: "hello world"}} + + if err := DB.Create(&hnPost).Error; err != nil { + t.Errorf("Failed to create got error %v", err) + } +} diff --git a/tests/go.mod b/tests/go.mod index e500edd7..07ec6be2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.0.0-20200609004954-b8310c61c3f2 - gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 - gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200610080012-25da0c25e81d + gorm.io/driver/mysql v0.2.0 + gorm.io/driver/postgres v0.2.0 + gorm.io/driver/sqlite v1.0.2 + gorm.io/driver/sqlserver v0.2.0 gorm.io/gorm v0.0.0-00010101000000-000000000000 )