mirror of https://github.com/go-gorm/gorm.git
Replacer interface for more flexible NamingStrategy (#4042)
* Change NameReplacer to an interface, allowing custom Replacers. * Add NoLowerCase option to skip the snake_casing of names. * Move sync.Map from global variable into member of NamingStrategy. This maintains backward compatibility by making the smap optional - the NamingStrategy still works if it is nil. gorm.Open activates it by calling Init() if the given Namer is a schema.NamingStrategy. Also, this changes the key stored in the smap to be the original name, instead of the replaced name. * Refactor NamingStrategy tests to add more assertions about how and when Replacers get called. * Remove the name cache from NamingStrategy.
This commit is contained in:
parent
a13b7a6acb
commit
5744e29fbd
|
@ -4,7 +4,6 @@ import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/jinzhu/inflection"
|
"github.com/jinzhu/inflection"
|
||||||
|
@ -20,11 +19,17 @@ type Namer interface {
|
||||||
IndexName(table, column string) string
|
IndexName(table, column string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replacer replacer interface like strings.Replacer
|
||||||
|
type Replacer interface {
|
||||||
|
Replace(name string) string
|
||||||
|
}
|
||||||
|
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
type NamingStrategy struct {
|
type NamingStrategy struct {
|
||||||
TablePrefix string
|
TablePrefix string
|
||||||
SingularTable bool
|
SingularTable bool
|
||||||
NameReplacer *strings.Replacer
|
NameReplacer Replacer
|
||||||
|
NoLowerCase bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName convert string to table name
|
// TableName convert string to table name
|
||||||
|
@ -42,7 +47,7 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
|
||||||
|
|
||||||
// JoinTableName convert string to join table name
|
// JoinTableName convert string to join table name
|
||||||
func (ns NamingStrategy) JoinTableName(str string) string {
|
func (ns NamingStrategy) JoinTableName(str string) string {
|
||||||
if strings.ToLower(str) == str {
|
if !ns.NoLowerCase && strings.ToLower(str) == str {
|
||||||
return ns.TablePrefix + str
|
return ns.TablePrefix + str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +86,6 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
smap sync.Map
|
|
||||||
// https://github.com/golang/lint/blob/master/lint.go#L770
|
// https://github.com/golang/lint/blob/master/lint.go#L770
|
||||||
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
commonInitialismsReplacer *strings.Replacer
|
commonInitialismsReplacer *strings.Replacer
|
||||||
|
@ -98,14 +102,16 @@ func init() {
|
||||||
func (ns NamingStrategy) toDBName(name string) string {
|
func (ns NamingStrategy) toDBName(name string) string {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return ""
|
return ""
|
||||||
} else if v, ok := smap.Load(name); ok {
|
|
||||||
return v.(string)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ns.NameReplacer != nil {
|
if ns.NameReplacer != nil {
|
||||||
name = ns.NameReplacer.Replace(name)
|
name = ns.NameReplacer.Replace(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ns.NoLowerCase {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
value = commonInitialismsReplacer.Replace(name)
|
value = commonInitialismsReplacer.Replace(name)
|
||||||
buf strings.Builder
|
buf strings.Builder
|
||||||
|
@ -143,6 +149,5 @@ func (ns NamingStrategy) toDBName(name string) string {
|
||||||
buf.WriteByte(value[len(value)-1])
|
buf.WriteByte(value[len(value)-1])
|
||||||
}
|
}
|
||||||
ret := buf.String()
|
ret := buf.String()
|
||||||
smap.Store(name, ret)
|
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,3 +72,99 @@ func TestNamingStrategy(t *testing.T) {
|
||||||
t.Errorf("invalid column name generated, got %v", columdName)
|
t.Errorf("invalid column name generated, got %v", columdName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CustomReplacer struct {
|
||||||
|
f func(string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r CustomReplacer) Replace(name string) string {
|
||||||
|
return r.f(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomReplacer(t *testing.T) {
|
||||||
|
var ns = NamingStrategy{
|
||||||
|
TablePrefix: "public.",
|
||||||
|
SingularTable: true,
|
||||||
|
NameReplacer: CustomReplacer{
|
||||||
|
func(name string) string {
|
||||||
|
replaced := "REPLACED_" + strings.ToUpper(name)
|
||||||
|
return strings.NewReplacer("CID", "_Cid").Replace(replaced)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NoLowerCase: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
idxName := ns.IndexName("public.table", "name")
|
||||||
|
if idxName != "idx_public_table_replaced_name" {
|
||||||
|
t.Errorf("invalid index name generated, got %v", idxName)
|
||||||
|
}
|
||||||
|
|
||||||
|
chkName := ns.CheckerName("public.table", "name")
|
||||||
|
if chkName != "chk_public_table_name" {
|
||||||
|
t.Errorf("invalid checker name generated, got %v", chkName)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTable := ns.JoinTableName("user_languages")
|
||||||
|
if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here.
|
||||||
|
t.Errorf("invalid join table generated, got %v", joinTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTable2 := ns.JoinTableName("UserLanguage")
|
||||||
|
if joinTable2 != "public.replaced_userlanguage" {
|
||||||
|
t.Errorf("invalid join table generated, got %v", joinTable2)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := ns.TableName("Company")
|
||||||
|
if tableName != "public.replaced_company" {
|
||||||
|
t.Errorf("invalid table name generated, got %v", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
columdName := ns.ColumnName("", "NameCID")
|
||||||
|
if columdName != "replaced_name_cid" {
|
||||||
|
t.Errorf("invalid column name generated, got %v", columdName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
||||||
|
var ns = NamingStrategy{
|
||||||
|
TablePrefix: "public.",
|
||||||
|
SingularTable: true,
|
||||||
|
NameReplacer: CustomReplacer{
|
||||||
|
func(name string) string {
|
||||||
|
replaced := "REPLACED_" + strings.ToUpper(name)
|
||||||
|
return strings.NewReplacer("CID", "_Cid").Replace(replaced)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NoLowerCase: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
idxName := ns.IndexName("public.table", "name")
|
||||||
|
if idxName != "idx_public_table_REPLACED_NAME" {
|
||||||
|
t.Errorf("invalid index name generated, got %v", idxName)
|
||||||
|
}
|
||||||
|
|
||||||
|
chkName := ns.CheckerName("public.table", "name")
|
||||||
|
if chkName != "chk_public_table_name" {
|
||||||
|
t.Errorf("invalid checker name generated, got %v", chkName)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTable := ns.JoinTableName("user_languages")
|
||||||
|
if joinTable != "public.REPLACED_USER_LANGUAGES" {
|
||||||
|
t.Errorf("invalid join table generated, got %v", joinTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTable2 := ns.JoinTableName("UserLanguage")
|
||||||
|
if joinTable2 != "public.REPLACED_USERLANGUAGE" {
|
||||||
|
t.Errorf("invalid join table generated, got %v", joinTable2)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := ns.TableName("Company")
|
||||||
|
if tableName != "public.REPLACED_COMPANY" {
|
||||||
|
t.Errorf("invalid table name generated, got %v", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
columdName := ns.ColumnName("", "NameCID")
|
||||||
|
if columdName != "REPLACED_NAME_Cid" {
|
||||||
|
t.Errorf("invalid column name generated, got %v", columdName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue