Implement naming strategy

This commit is contained in:
Jinzhu 2020-01-31 14:17:02 +08:00
parent 1079e17caf
commit bc68fde6aa
5 changed files with 142 additions and 4 deletions

2
go.mod
View File

@ -1,3 +1,5 @@
module github.com/jinzhu/gorm
go 1.13
require github.com/jinzhu/inflection v1.0.0

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=

12
gorm.go
View File

@ -6,18 +6,18 @@ import (
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/schema"
)
// Config GORM config
type Config struct {
// Set true to use singular table name, by default, GORM will pluralize your struct's name as table name
// Refer https://github.com/jinzhu/inflection for inflection rules
SingularTable bool
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can cancel it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// Logger
Logger logger.Interface
@ -48,6 +48,10 @@ type Session struct {
// Open initialize db session based on dialector
func Open(dialector Dialector, config *Config) (db *DB, err error) {
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{}
}
return &DB{
Config: config,
Dialector: dialector,

96
schema/naming.go Normal file
View File

@ -0,0 +1,96 @@
package schema
import (
"fmt"
"strings"
"sync"
"github.com/jinzhu/inflection"
)
// Namer namer interface
type Namer interface {
TableName(string) string
ColumnName(string) string
}
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
SingularTable bool
}
// TableName convert string to table name
func (ns NamingStrategy) TableName(str string) string {
if ns.SingularTable {
return ns.TablePrefix + toDBName(str)
}
return ns.TablePrefix + inflection.Plural(toDBName(str))
}
// ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(str string) string {
return toDBName(str)
}
var (
smap sync.Map
// https://github.com/golang/lint/blob/master/lint.go#L770
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
commonInitialismsReplacer *strings.Replacer
)
func init() {
var commonInitialismsForReplacer []string
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
func toDBName(name string) string {
if name == "" {
return ""
} else if v, ok := smap.Load(name); ok {
return fmt.Sprint(v)
}
var (
value = commonInitialismsReplacer.Replace(name)
buf strings.Builder
lastCase, nextCase, nextNumber bool // upper case == true
curCase = value[0] <= 'Z' && value[0] >= 'A'
)
for i, v := range value[:len(value)-1] {
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
if curCase {
if lastCase && (nextCase || nextNumber) {
buf.WriteRune(v + 32)
} else {
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
buf.WriteByte('_')
}
buf.WriteRune(v + 32)
}
} else {
buf.WriteRune(v)
}
lastCase = curCase
curCase = nextCase
}
if curCase {
if !lastCase && len(value) > 1 {
buf.WriteByte('_')
}
buf.WriteByte(value[len(value)-1] + 32)
} else {
buf.WriteByte(value[len(value)-1])
}
return buf.String()
}

34
schema/naming_test.go Normal file
View File

@ -0,0 +1,34 @@
package schema
import (
"testing"
)
func TestToDBName(t *testing.T) {
var maps = map[string]string{
"": "",
"x": "x",
"X": "x",
"userRestrictions": "user_restrictions",
"ThisIsATest": "this_is_a_test",
"PFAndESI": "pf_and_esi",
"AbcAndJkl": "abc_and_jkl",
"EmployeeID": "employee_id",
"SKU_ID": "sku_id",
"FieldX": "field_x",
"HTTPAndSMTP": "http_and_smtp",
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
"UUID": "uuid",
"HTTPURL": "http_url",
"HTTP_URL": "http_url",
"SHA256Hash": "sha256_hash",
"SHA256HASH": "sha256_hash",
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
}
for key, value := range maps {
if toDBName(key) != value {
t.Errorf("%v toName should equal %v, but got %v", key, value, toDBName(key))
}
}
}