forked from mirror/gorm
added naming strategy option for db, table and column names (#2040)
This commit is contained in:
parent
32455088f2
commit
6f58f8a52c
|
@ -34,7 +34,7 @@ func (s *ModelStruct) TableName(db *DB) string {
|
||||||
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
||||||
s.defaultTableName = tabler.TableName()
|
s.defaultTableName = tabler.TableName()
|
||||||
} else {
|
} else {
|
||||||
tableName := ToDBName(s.ModelType.Name())
|
tableName := ToTableName(s.ModelType.Name())
|
||||||
if db == nil || !db.parent.singularTable {
|
if db == nil || !db.parent.singularTable {
|
||||||
tableName = inflection.Plural(tableName)
|
tableName = inflection.Plural(tableName)
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ type Relationship struct {
|
||||||
|
|
||||||
func getForeignField(column string, fields []*StructField) *StructField {
|
func getForeignField(column string, fields []*StructField) *StructField {
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
|
if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
|
||||||
return field
|
return field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -269,7 +269,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
// if defined join table's foreign key
|
// if defined join table's foreign key
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
|
||||||
} else {
|
} else {
|
||||||
defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
|
||||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -300,7 +300,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
|
||||||
} else {
|
} else {
|
||||||
// join table foreign keys for association
|
// join table foreign keys for association
|
||||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
|
||||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -308,7 +308,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
}
|
}
|
||||||
|
|
||||||
joinTableHandler := JoinTableHandler{}
|
joinTableHandler := JoinTableHandler{}
|
||||||
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType)
|
||||||
relationship.JoinTableHandler = &joinTableHandler
|
relationship.JoinTableHandler = &joinTableHandler
|
||||||
field.Relationship = relationship
|
field.Relationship = relationship
|
||||||
} else {
|
} else {
|
||||||
|
@ -566,7 +566,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
if value, ok := field.TagSettings["COLUMN"]; ok {
|
if value, ok := field.TagSettings["COLUMN"]; ok {
|
||||||
field.DBName = value
|
field.DBName = value
|
||||||
} else {
|
} else {
|
||||||
field.DBName = ToDBName(fieldStruct.Name)
|
field.DBName = ToColumnName(fieldStruct.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
|
|
|
@ -0,0 +1,124 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Namer is a function type which is given a string and return a string
|
||||||
|
type Namer func(string) string
|
||||||
|
|
||||||
|
// NamingStrategy represents naming strategies
|
||||||
|
type NamingStrategy struct {
|
||||||
|
DB Namer
|
||||||
|
Table Namer
|
||||||
|
Column Namer
|
||||||
|
}
|
||||||
|
|
||||||
|
// TheNamingStrategy is being initialized with defaultNamingStrategy
|
||||||
|
var TheNamingStrategy = &NamingStrategy{
|
||||||
|
DB: defaultNamer,
|
||||||
|
Table: defaultNamer,
|
||||||
|
Column: defaultNamer,
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNamingStrategy sets the naming strategy
|
||||||
|
func AddNamingStrategy(ns *NamingStrategy) {
|
||||||
|
if ns.DB == nil {
|
||||||
|
ns.DB = defaultNamer
|
||||||
|
}
|
||||||
|
if ns.Table == nil {
|
||||||
|
ns.Table = defaultNamer
|
||||||
|
}
|
||||||
|
if ns.Column == nil {
|
||||||
|
ns.Column = defaultNamer
|
||||||
|
}
|
||||||
|
TheNamingStrategy = ns
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBName alters the given name by DB
|
||||||
|
func (ns *NamingStrategy) DBName(name string) string {
|
||||||
|
return ns.DB(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName alters the given name by Table
|
||||||
|
func (ns *NamingStrategy) TableName(name string) string {
|
||||||
|
return ns.Table(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnName alters the given name by Column
|
||||||
|
func (ns *NamingStrategy) ColumnName(name string) string {
|
||||||
|
return ns.Column(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDBName convert string to db name
|
||||||
|
func ToDBName(name string) string {
|
||||||
|
return TheNamingStrategy.DBName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTableName convert string to table name
|
||||||
|
func ToTableName(name string) string {
|
||||||
|
return TheNamingStrategy.TableName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToColumnName convert string to db name
|
||||||
|
func ToColumnName(name string) string {
|
||||||
|
return TheNamingStrategy.ColumnName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var smap = newSafeMap()
|
||||||
|
|
||||||
|
func defaultNamer(name string) string {
|
||||||
|
const (
|
||||||
|
lower = false
|
||||||
|
upper = true
|
||||||
|
)
|
||||||
|
|
||||||
|
if v := smap.Get(name); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
value = commonInitialismsReplacer.Replace(name)
|
||||||
|
buf = bytes.NewBufferString("")
|
||||||
|
lastCase, currCase, nextCase, nextNumber bool
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, v := range value[:len(value)-1] {
|
||||||
|
nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
||||||
|
nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
|
||||||
|
|
||||||
|
if i > 0 {
|
||||||
|
if currCase == upper {
|
||||||
|
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
} else {
|
||||||
|
if value[i-1] != '_' && value[i+1] != '_' {
|
||||||
|
buf.WriteRune('_')
|
||||||
|
}
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
|
||||||
|
buf.WriteRune('_')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currCase = upper
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
lastCase = currCase
|
||||||
|
currCase = nextCase
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteByte(value[len(value)-1])
|
||||||
|
|
||||||
|
s := strings.ToLower(buf.String())
|
||||||
|
smap.Set(name, s)
|
||||||
|
return s
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTheNamingStrategy(t *testing.T) {
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
namer gorm.Namer
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB},
|
||||||
|
{name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table},
|
||||||
|
{name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
result := c.namer(c.name)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamingStrategy(t *testing.T) {
|
||||||
|
|
||||||
|
dbNameNS := func(name string) string {
|
||||||
|
return "db_" + name
|
||||||
|
}
|
||||||
|
tableNameNS := func(name string) string {
|
||||||
|
return "tbl_" + name
|
||||||
|
}
|
||||||
|
columnNameNS := func(name string) string {
|
||||||
|
return "col_" + name
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := &gorm.NamingStrategy{
|
||||||
|
DB: dbNameNS,
|
||||||
|
Table: tableNameNS,
|
||||||
|
Column: columnNameNS,
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
namer gorm.Namer
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "auth", expected: "db_auth", namer: ns.DB},
|
||||||
|
{name: "user", expected: "tbl_user", namer: ns.Table},
|
||||||
|
{name: "password", expected: "col_password", namer: ns.Column},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
result := c.namer(c.name)
|
||||||
|
if result != c.expected {
|
||||||
|
t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
4
scope.go
4
scope.go
|
@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field {
|
||||||
// FieldByName find `gorm.Field` with field name or db name
|
// FieldByName find `gorm.Field` with field name or db name
|
||||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||||
var (
|
var (
|
||||||
dbName = ToDBName(name)
|
dbName = ToColumnName(name)
|
||||||
mostMatchedField *Field
|
mostMatchedField *Field
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
for _, key := range reflectValue.MapKeys() {
|
for _, key := range reflectValue.MapKeys() {
|
||||||
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
for _, field := range (&Scope{Value: values}).Fields() {
|
for _, field := range (&Scope{Value: values}).Fields() {
|
||||||
|
|
61
utils.go
61
utils.go
|
@ -1,7 +1,6 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -58,66 +57,6 @@ func newSafeMap() *safeMap {
|
||||||
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
||||||
}
|
}
|
||||||
|
|
||||||
var smap = newSafeMap()
|
|
||||||
|
|
||||||
type strCase bool
|
|
||||||
|
|
||||||
const (
|
|
||||||
lower strCase = false
|
|
||||||
upper strCase = true
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToDBName convert string to db name
|
|
||||||
func ToDBName(name string) string {
|
|
||||||
if v := smap.Get(name); v != "" {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
if name == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
value = commonInitialismsReplacer.Replace(name)
|
|
||||||
buf = bytes.NewBufferString("")
|
|
||||||
lastCase, currCase, nextCase, nextNumber strCase
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, v := range value[:len(value)-1] {
|
|
||||||
nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
|
||||||
nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9')
|
|
||||||
|
|
||||||
if i > 0 {
|
|
||||||
if currCase == upper {
|
|
||||||
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
|
|
||||||
buf.WriteRune(v)
|
|
||||||
} else {
|
|
||||||
if value[i-1] != '_' && value[i+1] != '_' {
|
|
||||||
buf.WriteRune('_')
|
|
||||||
}
|
|
||||||
buf.WriteRune(v)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
buf.WriteRune(v)
|
|
||||||
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
|
|
||||||
buf.WriteRune('_')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
currCase = upper
|
|
||||||
buf.WriteRune(v)
|
|
||||||
}
|
|
||||||
lastCase = currCase
|
|
||||||
currCase = nextCase
|
|
||||||
}
|
|
||||||
|
|
||||||
buf.WriteByte(value[len(value)-1])
|
|
||||||
|
|
||||||
s := strings.ToLower(buf.String())
|
|
||||||
smap.Set(name, s)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQL expression
|
// SQL expression
|
||||||
type expr struct {
|
type expr struct {
|
||||||
expr string
|
expr string
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
package gorm_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestToDBNameGenerateFriendlyName(t *testing.T) {
|
|
||||||
var maps = map[string]string{
|
|
||||||
"": "",
|
|
||||||
"X": "x",
|
|
||||||
"ThisIsATest": "this_is_a_test",
|
|
||||||
"PFAndESI": "pf_and_esi",
|
|
||||||
"AbcAndJkl": "abc_and_jkl",
|
|
||||||
"EmployeeID": "employee_id",
|
|
||||||
"SKU_ID": "sku_id",
|
|
||||||
"UTF8": "utf8",
|
|
||||||
"Level1": "level1",
|
|
||||||
"SHA256Hash": "sha256_hash",
|
|
||||||
"FieldX": "field_x",
|
|
||||||
"HTTPAndSMTP": "http_and_smtp",
|
|
||||||
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
|
|
||||||
"UUID": "uuid",
|
|
||||||
"HTTPURL": "http_url",
|
|
||||||
"HTTP_URL": "http_url",
|
|
||||||
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value := range maps {
|
|
||||||
if gorm.ToDBName(key) != value {
|
|
||||||
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue