diff --git a/schema/field.go b/schema/field.go index 4b8a5a2a..ce2808a8 100644 --- a/schema/field.go +++ b/schema/field.go @@ -326,7 +326,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } diff --git a/schema/naming.go b/schema/naming.go index ecdab791..af753ce5 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -14,7 +14,7 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string - JoinTableName(table string) string + JoinTableName(joinTable string) string RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string diff --git a/schema/naming_test.go b/schema/naming_test.go index 96b83ced..a4600ceb 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,6 +1,7 @@ package schema import ( + "strings" "testing" ) @@ -32,3 +33,28 @@ func TestToDBName(t *testing.T) { } } } + +type NewNamingStrategy struct { + NamingStrategy +} + +func (ns NewNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} diff --git a/schema/schema.go b/schema/schema.go index c3d3f6e0..cffc19a7 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } schema := &Schema{ Name: modelType.Name(), diff --git a/schema/schema_test.go b/schema/schema_test.go index 6ca5b269..a426cd90 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "strings" "sync" "testing" @@ -227,3 +228,72 @@ func TestEmbeddedStruct(t *testing.T) { }) } } + +type CustomizedNamingStrategy struct { + schema.NamingStrategy +} + +func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} + +func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} diff --git a/schema/utils.go b/schema/utils.go index 41bd9d60..55cbdeb4 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa return columns, queryValues } } + +type embeddedNamer struct { + Table string + Namer +} diff --git a/tests/go.mod b/tests/go.mod index 0db87934..c92fa0cf 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.0 + gorm.io/driver/postgres v1.0.1 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 gorm.io/gorm v1.20.1