diff --git a/schema/naming.go b/schema/naming.go index af753ce5..dbc71e04 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -42,7 +42,7 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { if strings.ToLower(str) == str { - return str + return ns.TablePrefix + str } if ns.SingularTable { @@ -53,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)) + return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1) } // CheckerName generate checker name func (ns NamingStrategy) CheckerName(table, column string) string { - return fmt.Sprintf("chk_%s_%s", table, column) + return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1) } // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) + idxName = strings.Replace(idxName, ".", "_", -1) if utf8.RuneCountInString(idxName) > 64 { h := sha1.New() diff --git a/schema/naming_test.go b/schema/naming_test.go index a4600ceb..26b0dcf6 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,7 +1,6 @@ package schema import ( - "strings" "testing" ) @@ -34,27 +33,34 @@ func TestToDBName(t *testing.T) { } } -type NewNamingStrategy struct { - NamingStrategy -} +func TestNamingStrategy(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + } + idxName := ns.IndexName("public.table", "name") -func (ns NewNamingStrategy) ColumnName(table, column string) string { - baseColumnName := ns.NamingStrategy.ColumnName(table, column) - - if table == "" { - return baseColumnName + if idxName != "idx_public_table_name" { + t.Errorf("invalid index name generated, got %v", idxName) } - s := strings.Split(table, "_") - - var prefix string - switch len(s) { - case 1: - prefix = s[0][:3] - case 2: - prefix = s[0][:1] + s[1][:2] - default: - prefix = s[0][:1] + s[1][:1] + s[2][:1] + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.user_language" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.company" { + t.Errorf("invalid table name generated, got %v", tableName) } - return prefix + "_" + baseColumnName }