From 6f58f8a52cc3ad21950402d1adaa09682e07ec2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adem=20=C3=96zay?= Date: Mon, 10 Sep 2018 00:52:20 +0300 Subject: [PATCH] added naming strategy option for db, table and column names (#2040) --- model_struct.go | 12 ++--- naming.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++++ naming_test.go | 69 +++++++++++++++++++++++++++ scope.go | 4 +- utils.go | 61 ------------------------ utils_test.go | 35 -------------- 6 files changed, 201 insertions(+), 104 deletions(-) create mode 100644 naming.go create mode 100644 naming_test.go delete mode 100644 utils_test.go diff --git a/model_struct.go b/model_struct.go index 8506fe87..5b5be618 100644 --- a/model_struct.go +++ b/model_struct.go @@ -34,7 +34,7 @@ func (s *ModelStruct) TableName(db *DB) string { if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { s.defaultTableName = tabler.TableName() } else { - tableName := ToDBName(s.ModelType.Name()) + tableName := ToTableName(s.ModelType.Name()) if db == nil || !db.parent.singularTable { tableName = inflection.Plural(tableName) } @@ -105,7 +105,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -269,7 +269,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { // if defined join table's foreign key relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) } else { - defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) } } @@ -300,7 +300,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) } else { // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } } @@ -308,7 +308,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) + joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { @@ -566,7 +566,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if value, ok := field.TagSettings["COLUMN"]; ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) diff --git a/naming.go b/naming.go new file mode 100644 index 00000000..6b0a4fdd --- /dev/null +++ b/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/naming_test.go b/naming_test.go new file mode 100644 index 00000000..0c6f7713 --- /dev/null +++ b/naming_test.go @@ -0,0 +1,69 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestTheNamingStrategy(t *testing.T) { + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "auth", namer: gorm.TheNamingStrategy.DB}, + {name: "userRestrictions", expected: "user_restrictions", namer: gorm.TheNamingStrategy.Table}, + {name: "clientID", expected: "client_id", namer: gorm.TheNamingStrategy.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} + +func TestNamingStrategy(t *testing.T) { + + dbNameNS := func(name string) string { + return "db_" + name + } + tableNameNS := func(name string) string { + return "tbl_" + name + } + columnNameNS := func(name string) string { + return "col_" + name + } + + ns := &gorm.NamingStrategy{ + DB: dbNameNS, + Table: tableNameNS, + Column: columnNameNS, + } + + cases := []struct { + name string + namer gorm.Namer + expected string + }{ + {name: "auth", expected: "db_auth", namer: ns.DB}, + {name: "user", expected: "tbl_user", namer: ns.Table}, + {name: "password", expected: "col_password", namer: ns.Column}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := c.namer(c.name) + if result != c.expected { + t.Errorf("error in naming strategy. expected: %v got :%v\n", c.expected, result) + } + }) + } + +} diff --git a/scope.go b/scope.go index ca861d8a..fbf7634e 100644 --- a/scope.go +++ b/scope.go @@ -134,7 +134,7 @@ func (scope *Scope) Fields() []*Field { // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -880,7 +880,7 @@ func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: for _, field := range (&Scope{Value: values}).Fields() { diff --git a/utils.go b/utils.go index 99b532c5..ad700b98 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -58,66 +57,6 @@ func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = strCase(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression type expr struct { expr string diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 086c4450..00000000 --- a/utils_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "UTF8": "utf8", - "Level1": "level1", - "SHA256Hash": "sha256_hash", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -}