Fix too long foreign key names in mysql.

The dialect must define its own foreign key generator method.
The previous default is available as a method on gorm.DefaultForeignKeyNamer
and can be embedded in other dialects.
The mysql dialect uses the first 24 characters plus an sha1 hash of the
full key name if the key name is more than 64 characters.
This commit is contained in:
Nikola Kovacs 2016-05-22 00:13:26 +02:00
parent 58ce6153e2
commit 6eb73ae654
7 changed files with 42 additions and 6 deletions

View File

@ -810,7 +810,7 @@ func TestRelated(t *testing.T) {
func TestForeignKey(t *testing.T) { func TestForeignKey(t *testing.T) {
for _, structField := range DB.NewScope(&User{}).GetStructFields() { for _, structField := range DB.NewScope(&User{}).GetStructFields() {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID", "ReallyLongThingID"} { for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
if structField.Name == foreignKey && !structField.IsForeignKey { if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
} }
@ -849,7 +849,7 @@ func TestLongForeignKey(t *testing.T) {
} }
targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{}) targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{})
targetTableName := targetScope.TableName() targetTableName := targetScope.TableName()
modelScope := DB.NewScope(&User{}) modelScope := DB.NewScope(&NotSoLongTableName{})
modelField, ok := modelScope.FieldByName("ReallyLongThingID") modelField, ok := modelScope.FieldByName("ReallyLongThingID")
if !ok { if !ok {
t.Fatalf("Failed to get field by name: ReallyLongThingID") t.Fatalf("Failed to get field by name: ReallyLongThingID")
@ -859,7 +859,7 @@ func TestLongForeignKey(t *testing.T) {
t.Fatalf("Failed to get field by name: ID") t.Fatalf("Failed to get field by name: ID")
} }
dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName)
err := DB.Model(&User{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error err := DB.Model(&NotSoLongTableName{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
if err != nil { if err != nil {
t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err))
} }

View File

@ -40,6 +40,9 @@ type Dialect interface {
SelectFromDummyTable() string SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string LastInsertIDReturningSuffix(tableName, columnName string) string
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
BuildForeignKeyName(tableName, field, dest string) string
} }
var dialectsMap = map[string]Dialect{} var dialectsMap = map[string]Dialect{}

View File

@ -4,12 +4,18 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
"time" "time"
) )
// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
type commonDialect struct { type commonDialect struct {
db *sql.DB db *sql.DB
DefaultForeignKeyNamer
} }
func init() { func init() {
@ -135,3 +141,9 @@ func (commonDialect) SelectFromDummyTable() string {
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return "" return ""
} }
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
return keyName
}

View File

@ -1,8 +1,10 @@
package gorm package gorm
import ( import (
"crypto/sha1"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
"time" "time"
) )
@ -115,3 +117,18 @@ func (s mysql) currentDatabase() (name string) {
func (mysql) SelectFromDummyTable() string { func (mysql) SelectFromDummyTable() string {
return "FROM DUAL" return "FROM DUAL"
} }
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
if len(keyName) <= 64 {
return keyName
}
h := sha1.New()
h.Write([]byte(keyName))
bs := h.Sum(nil)
// sha1 is 40 digits, keep first 24 characters of destination
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")
return fmt.Sprintf("%s%x", keyName[:24], bs)
}

View File

@ -24,6 +24,7 @@ func init() {
type mssql struct { type mssql struct {
db *sql.DB db *sql.DB
gorm.DefaultForeignKeyNamer
} }
func (mssql) GetName() string { func (mssql) GetName() string {

View File

@ -37,6 +37,10 @@ type User struct {
IgnoreStringSlice []string `sql:"-"` IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"` Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"` IgnoredPointer *User `sql:"-"`
}
type NotSoLongTableName struct {
Id int64
ReallyLongThingID int64 ReallyLongThingID int64
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
} }
@ -237,7 +241,7 @@ func runMigration() {
DB.Exec(fmt.Sprintf("drop table %v;", table)) DB.Exec(fmt.Sprintf("drop table %v;", table))
} }
values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}}
for _, value := range values { for _, value := range values {
DB.DropTable(value) DB.DropTable(value)
} }

View File

@ -1117,8 +1117,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
} }
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return return