diff --git a/schema/naming.go b/schema/naming.go index 47e313a7..8407bffa 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/hex" "fmt" + "regexp" "strings" "unicode/utf8" @@ -13,6 +14,7 @@ import ( // Namer namer interface type Namer interface { TableName(table string) string + SchemaName(table string) string ColumnName(table, column string) string JoinTableName(joinTable string) string RelationshipFKName(Relationship) string @@ -41,6 +43,16 @@ func (ns NamingStrategy) TableName(str string) string { return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) } +// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName +func (ns NamingStrategy) SchemaName(table string) string { + table = strings.TrimPrefix(table, ns.TablePrefix) + + if ns.SingularTable { + return ns.toSchemaName(table) + } + return ns.toSchemaName(inflection.Singular(table)) +} + // ColumnName convert string to column name func (ns NamingStrategy) ColumnName(table, column string) string { return ns.toDBName(column) @@ -154,3 +166,11 @@ func (ns NamingStrategy) toDBName(name string) string { ret := buf.String() return ret } + +func (ns NamingStrategy) toSchemaName(name string) string { + result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1) + for _, initialism := range commonInitialisms { + result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") + } + return result +} diff --git a/schema/naming_test.go b/schema/naming_test.go index face9364..6add338e 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -33,6 +33,26 @@ func TestToDBName(t *testing.T) { t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key)) } } + + maps = map[string]string{ + "x": "X", + "user_restrictions": "UserRestriction", + "this_is_a_test": "ThisIsATest", + "abc_and_jkl": "AbcAndJkl", + "employee_id": "EmployeeID", + "field_x": "FieldX", + "http_and_smtp": "HTTPAndSMTP", + "http_server_handler_for_url_id": "HTTPServerHandlerForURLID", + "uuid": "UUID", + "http_url": "HTTPURL", + "sha256_hash": "Sha256Hash", + "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID", + } + for key, value := range maps { + if ns.SchemaName(key) != value { + t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key)) + } + } } func TestNamingStrategy(t *testing.T) {