Refactor NamingStrategy, close #3540

This commit is contained in:
Jinzhu 2020-09-28 10:55:27 +08:00
parent 9eec6ae066
commit a2faa41cbe
2 changed files with 30 additions and 23 deletions

View File

@ -42,7 +42,7 @@ func (ns NamingStrategy) ColumnName(table, column string) string {
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
if strings.ToLower(str) == str {
return str
return ns.TablePrefix + str
}
if ns.SingularTable {
@ -53,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string {
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name))
return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1)
}
// CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string {
return fmt.Sprintf("chk_%s_%s", table, column)
return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1)
}
// IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string {
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
idxName = strings.Replace(idxName, ".", "_", -1)
if utf8.RuneCountInString(idxName) > 64 {
h := sha1.New()

View File

@ -1,7 +1,6 @@
package schema
import (
"strings"
"testing"
)
@ -34,27 +33,34 @@ func TestToDBName(t *testing.T) {
}
}
type NewNamingStrategy struct {
NamingStrategy
func TestNamingStrategy(t *testing.T) {
var ns = NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_name" {
t.Errorf("invalid index name generated, got %v", idxName)
}
func (ns NewNamingStrategy) ColumnName(table, column string) string {
baseColumnName := ns.NamingStrategy.ColumnName(table, column)
if table == "" {
return baseColumnName
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
s := strings.Split(table, "_")
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.user_languages" {
t.Errorf("invalid join table generated, got %v", joinTable)
}
var prefix string
switch len(s) {
case 1:
prefix = s[0][:3]
case 2:
prefix = s[0][:1] + s[1][:2]
default:
prefix = s[0][:1] + s[1][:1] + s[2][:1]
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.user_language" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.company" {
t.Errorf("invalid table name generated, got %v", tableName)
}
return prefix + "_" + baseColumnName
}