diff --git a/schema/naming.go b/schema/naming.go index f6d15f5a..e10c9212 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -4,7 +4,6 @@ import ( "crypto/sha1" "fmt" "strings" - "sync" "unicode/utf8" "github.com/jinzhu/inflection" @@ -20,11 +19,17 @@ type Namer interface { IndexName(table, column string) string } +// Replacer replacer interface like strings.Replacer +type Replacer interface { + Replace(name string) string +} + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string SingularTable bool - NameReplacer *strings.Replacer + NameReplacer Replacer + NoLowerCase bool } // TableName convert string to table name @@ -42,7 +47,7 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { - if strings.ToLower(str) == str { + if !ns.NoLowerCase && strings.ToLower(str) == str { return ns.TablePrefix + str } @@ -81,7 +86,6 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { } var ( - smap sync.Map // https://github.com/golang/lint/blob/master/lint.go#L770 commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} commonInitialismsReplacer *strings.Replacer @@ -98,14 +102,16 @@ func init() { func (ns NamingStrategy) toDBName(name string) string { if name == "" { return "" - } else if v, ok := smap.Load(name); ok { - return v.(string) } if ns.NameReplacer != nil { name = ns.NameReplacer.Replace(name) } + if ns.NoLowerCase { + return name + } + var ( value = commonInitialismsReplacer.Replace(name) buf strings.Builder @@ -143,6 +149,5 @@ func (ns NamingStrategy) toDBName(name string) string { buf.WriteByte(value[len(value)-1]) } ret := buf.String() - smap.Store(name, ret) return ret } diff --git a/schema/naming_test.go b/schema/naming_test.go index b7a32160..08f8d498 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -72,3 +72,99 @@ func TestNamingStrategy(t *testing.T) { t.Errorf("invalid column name generated, got %v", columdName) } } + +type CustomReplacer struct { + f func(string) string +} + +func (r CustomReplacer) Replace(name string) string { + return r.f(name) +} + +func TestCustomReplacer(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: false, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_replaced_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here. + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.replaced_userlanguage" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.replaced_company" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "replaced_name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestCustomReplacerWithNoLowerCase(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: true, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_REPLACED_NAME" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.REPLACED_USER_LANGUAGES" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.REPLACED_USERLANGUAGE" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.REPLACED_COMPANY" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "REPLACED_NAME_Cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +}