diff --git a/utils.go b/utils.go index b457f321..158d8a3b 100644 --- a/utils.go +++ b/utils.go @@ -41,21 +41,49 @@ func newSafeMap() *safeMap { var smap = newSafeMap() +type Case bool + +const ( + lower Case = false + upper Case = true +) + func ToDBName(name string) string { if v := smap.Get(name); v != "" { return v } - value := commonInitialismsReplacer.Replace(name) - buf := bytes.NewBufferString("") - for i, v := range value { - if i > 0 && v >= 'A' && v <= 'Z' { - buf.WriteRune('_') + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase Case + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] >= 'A' && value[i+1] <= 'Z' + if i > 0 { + if currCase == upper { + if lastCase == upper && nextCase == upper { + buf.WriteRune(v) + } else { + buf.WriteRune('_') + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + } + } else { + currCase = upper + buf.WriteRune(v) } - buf.WriteRune(v) + lastCase = currCase + currCase = nextCase } - s := strings.ToLower(buf.String()) + buf.WriteByte(value[len(value)-1]) + + s := strings.Replace(strings.ToLower(buf.String()), "__", "_", -1) + smap.Set(name, s) return s } diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 00000000..81b8ce9d --- /dev/null +++ b/utils_test.go @@ -0,0 +1,28 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestToDBNameGenerateFriendlyName(t *testing.T) { + var maps = map[string]string{ + "ThisIsATest": "this_is_a_test", + "PFAndESI": "pf_and_esi", + "AbcAndJkl": "abc_and_jkl", + "EmployeeID": "employee_id", + "HTTPAndSMTP": "http_and_smtp", + "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", + "UUID": "uuid", + "HTTPURL": "http_url", + "HTTP_URL": "http_url", + "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", + } + + for key, value := range maps { + if gorm.ToDBName(key) != value { + t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) + } + } +}