added naming strategy option for db, table and column names (#2040)

This commit is contained in:
Adem Özay 2018-09-10 00:52:20 +03:00 committed by Jinzhu
parent 32455088f2
commit 6f58f8a52c
6 changed files with 201 additions and 104 deletions

View File

@ -34,7 +34,7 @@ func (s *ModelStruct) TableName(db *DB) string {
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
s.defaultTableName = tabler.TableName() s.defaultTableName = tabler.TableName()
} else { } else {
tableName := ToDBName(s.ModelType.Name()) tableName := ToTableName(s.ModelType.Name())
if db == nil || !db.parent.singularTable { if db == nil || !db.parent.singularTable {
tableName = inflection.Plural(tableName) tableName = inflection.Plural(tableName)
} }
@ -105,7 +105,7 @@ type Relationship struct {
func getForeignField(column string, fields []*StructField) *StructField { func getForeignField(column string, fields []*StructField) *StructField {
for _, field := range fields { for _, field := range fields {
if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
return field return field
} }
} }
@ -269,7 +269,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
// if defined join table's foreign key // if defined join table's foreign key
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
} else { } else {
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
} }
} }
@ -300,7 +300,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
} else { } else {
// join table foreign keys for association // join table foreign keys for association
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
} }
} }
@ -308,7 +308,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
joinTableHandler := JoinTableHandler{} joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, reflectType, elemType) joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType)
relationship.JoinTableHandler = &joinTableHandler relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship field.Relationship = relationship
} else { } else {
@ -566,7 +566,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if value, ok := field.TagSettings["COLUMN"]; ok { if value, ok := field.TagSettings["COLUMN"]; ok {
field.DBName = value field.DBName = value
} else { } else {
field.DBName = ToDBName(fieldStruct.Name) field.DBName = ToColumnName(fieldStruct.Name)
} }
modelStruct.StructFields = append(modelStruct.StructFields, field) modelStruct.StructFields = append(modelStruct.StructFields, field)

124
naming.go Normal file
View File

@ -0,0 +1,124 @@
package gorm
import (
"bytes"
"strings"
)
// Namer is a function type which is given a string and return a string
type Namer func(string) string
// NamingStrategy represents naming strategies
type NamingStrategy struct {
DB Namer
Table Namer
Column Namer
}
// TheNamingStrategy is being initialized with defaultNamingStrategy
var TheNamingStrategy = &NamingStrategy{
DB: defaultNamer,
Table: defaultNamer,
Column: defaultNamer,
}
// AddNamingStrategy sets the naming strategy
func AddNamingStrategy(ns *NamingStrategy) {
if ns.DB == nil {
ns.DB = defaultNamer
}
if ns.Table == nil {
ns.Table = defaultNamer
}
if ns.Column == nil {
ns.Column = defaultNamer
}
TheNamingStrategy = ns
}
// DBName alters the given name by DB
func (ns *NamingStrategy) DBName(name string) string {
return ns.DB(name)
}
// TableName alters the given name by Table
func (ns *NamingStrategy) TableName(name string) string {
return ns.Table(name)
}
// ColumnName alters the given name by Column
func (ns *NamingStrategy) ColumnName(name string) string {
return ns.Column(name)
}
// ToDBName convert string to db name
func ToDBName(name string) string {
return TheNamingStrategy.DBName(name)
}
// ToTableName convert string to table name
func ToTableName(name string) string {
return TheNamingStrategy.TableName(name)
}
// ToColumnName convert string to db name
func ToColumnName(name string) string {
return TheNamingStrategy.ColumnName(name)
}
var smap = newSafeMap()
func defaultNamer(name string) string {
const (
lower = false
upper = true
)
if v := smap.Get(name); v != "" {
return v
}
if name == "" {
return ""
}
var (
value = commonInitialismsReplacer.Replace(name)
buf = bytes.NewBufferString("")
lastCase, currCase, nextCase, nextNumber bool
)
for i, v := range value[:len(value)-1] {
nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
if i > 0 {
if currCase == upper {
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
buf.WriteRune(v)
} else {
if value[i-1] != '_' && value[i+1] != '_' {
buf.WriteRune('_')
}
buf.WriteRune(v)
}
} else {
buf.WriteRune(v)
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
buf.WriteRune('_')
}
}
} else {
currCase = upper
buf.WriteRune(v)
}
lastCase = currCase
currCase = nextCase
}
buf.WriteByte(value[len(value)-1])
s := strings.ToLower(buf.String())
smap.Set(name, s)
return s
}

69
naming_test.go Normal file
View File

@ -0,0 +1,69 @@
package gorm_test
import (
"testing"
"github.com/jinzhu/gorm"
)
func TestTheNamingStrategy(t *testing.T) {
cases := []struct {
name string
namer gorm.Namer
expected string
}{
{name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB},
{name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table},
{name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
result := c.namer(c.name)
if result != c.expected {
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
}
})
}
}
func TestNamingStrategy(t *testing.T) {
dbNameNS := func(name string) string {
return "db_" + name
}
tableNameNS := func(name string) string {
return "tbl_" + name
}
columnNameNS := func(name string) string {
return "col_" + name
}
ns := &gorm.NamingStrategy{
DB: dbNameNS,
Table: tableNameNS,
Column: columnNameNS,
}
cases := []struct {
name string
namer gorm.Namer
expected string
}{
{name: "auth", expected: "db_auth", namer: ns.DB},
{name: "user", expected: "tbl_user", namer: ns.Table},
{name: "password", expected: "col_password", namer: ns.Column},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
result := c.namer(c.name)
if result != c.expected {
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
}
})
}
}

View File

@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field {
// FieldByName find `gorm.Field` with field name or db name // FieldByName find `gorm.Field` with field name or db name
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
var ( var (
dbName = ToDBName(name) dbName = ToColumnName(name)
mostMatchedField *Field mostMatchedField *Field
) )
@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Map: case reflect.Map:
for _, key := range reflectValue.MapKeys() { for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
} }
default: default:
for _, field := range (&Scope{Value: values}).Fields() { for _, field := range (&Scope{Value: values}).Fields() {

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"bytes"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
@ -58,66 +57,6 @@ func newSafeMap() *safeMap {
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
} }
var smap = newSafeMap()
type strCase bool
const (
lower strCase = false
upper strCase = true
)
// ToDBName convert string to db name
func ToDBName(name string) string {
if v := smap.Get(name); v != "" {
return v
}
if name == "" {
return ""
}
var (
value = commonInitialismsReplacer.Replace(name)
buf = bytes.NewBufferString("")
lastCase, currCase, nextCase, nextNumber strCase
)
for i, v := range value[:len(value)-1] {
nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9')
if i > 0 {
if currCase == upper {
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
buf.WriteRune(v)
} else {
if value[i-1] != '_' && value[i+1] != '_' {
buf.WriteRune('_')
}
buf.WriteRune(v)
}
} else {
buf.WriteRune(v)
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
buf.WriteRune('_')
}
}
} else {
currCase = upper
buf.WriteRune(v)
}
lastCase = currCase
currCase = nextCase
}
buf.WriteByte(value[len(value)-1])
s := strings.ToLower(buf.String())
smap.Set(name, s)
return s
}
// SQL expression // SQL expression
type expr struct { type expr struct {
expr string expr string

View File

@ -1,35 +0,0 @@
package gorm_test
import (
"testing"
"github.com/jinzhu/gorm"
)
func TestToDBNameGenerateFriendlyName(t *testing.T) {
var maps = map[string]string{
"": "",
"X": "x",
"ThisIsATest": "this_is_a_test",
"PFAndESI": "pf_and_esi",
"AbcAndJkl": "abc_and_jkl",
"EmployeeID": "employee_id",
"SKU_ID": "sku_id",
"UTF8": "utf8",
"Level1": "level1",
"SHA256Hash": "sha256_hash",
"FieldX": "field_x",
"HTTPAndSMTP": "http_and_smtp",
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
"UUID": "uuid",
"HTTPURL": "http_url",
"HTTP_URL": "http_url",
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
}
for key, value := range maps {
if gorm.ToDBName(key) != value {
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
}
}
}