From 6eb73ae65412e4e6fafaaf870048c738bfcd4bf3 Mon Sep 17 00:00:00 2001 From: Nikola Kovacs Date: Sun, 22 May 2016 00:13:26 +0200 Subject: [PATCH] Fix too long foreign key names in mysql. The dialect must define its own foreign key generator method. The previous default is available as a method on gorm.DefaultForeignKeyNamer and can be embedded in other dialects. The mysql dialect uses the first 24 characters plus an sha1 hash of the full key name if the key name is more than 64 characters. --- association_test.go | 6 +++--- dialect.go | 3 +++ dialect_common.go | 12 ++++++++++++ dialect_mysql.go | 17 +++++++++++++++++ dialects/mssql/mssql.go | 1 + migration_test.go | 6 +++++- scope.go | 3 +-- 7 files changed, 42 insertions(+), 6 deletions(-) diff --git a/association_test.go b/association_test.go index 1e2e5179..2b2d3bac 100644 --- a/association_test.go +++ b/association_test.go @@ -810,7 +810,7 @@ func TestRelated(t *testing.T) { func TestForeignKey(t *testing.T) { for _, structField := range DB.NewScope(&User{}).GetStructFields() { - for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID", "ReallyLongThingID"} { + for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { if structField.Name == foreignKey && !structField.IsForeignKey { t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) } @@ -849,7 +849,7 @@ func TestLongForeignKey(t *testing.T) { } targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{}) targetTableName := targetScope.TableName() - modelScope := DB.NewScope(&User{}) + modelScope := DB.NewScope(&NotSoLongTableName{}) modelField, ok := modelScope.FieldByName("ReallyLongThingID") if !ok { t.Fatalf("Failed to get field by name: ReallyLongThingID") @@ -859,7 +859,7 @@ func TestLongForeignKey(t *testing.T) { t.Fatalf("Failed to get field by name: ID") } dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) - err := DB.Model(&User{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error + err := DB.Model(&NotSoLongTableName{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error if err != nil { t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) } diff --git a/dialect.go b/dialect.go index 6c9405da..033b9555 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,9 @@ type Dialect interface { SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + + // BuildForeignKeyName returns a foreign key name for the given table, field and reference + BuildForeignKeyName(tableName, field, dest string) string } var dialectsMap = map[string]Dialect{} diff --git a/dialect_common.go b/dialect_common.go index f009271b..6d43fb84 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -4,12 +4,18 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "time" ) +// DefaultForeignKeyNamer contains the default foreign key name generator method +type DefaultForeignKeyNamer struct { +} + type commonDialect struct { db *sql.DB + DefaultForeignKeyNamer } func init() { @@ -135,3 +141,9 @@ func (commonDialect) SelectFromDummyTable() string { func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { + keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + return keyName +} diff --git a/dialect_mysql.go b/dialect_mysql.go index f62bbe35..d848de35 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -1,8 +1,10 @@ package gorm import ( + "crypto/sha1" "fmt" "reflect" + "regexp" "strings" "time" ) @@ -115,3 +117,18 @@ func (s mysql) currentDatabase() (name string) { func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } + +func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { + keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) + if len(keyName) <= 64 { + return keyName + } + h := sha1.New() + h.Write([]byte(keyName)) + bs := h.Sum(nil) + + // sha1 is 40 digits, keep first 24 characters of destination + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_") + + return fmt.Sprintf("%s%x", keyName[:24], bs) +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 5b994f9d..34eda717 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -24,6 +24,7 @@ func init() { type mssql struct { db *sql.DB + gorm.DefaultForeignKeyNamer } func (mssql) GetName() string { diff --git a/migration_test.go b/migration_test.go index fc692bcc..4d15917c 100644 --- a/migration_test.go +++ b/migration_test.go @@ -37,6 +37,10 @@ type User struct { IgnoreStringSlice []string `sql:"-"` Ignored struct{ Name string } `sql:"-"` IgnoredPointer *User `sql:"-"` +} + +type NotSoLongTableName struct { + Id int64 ReallyLongThingID int64 ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit } @@ -237,7 +241,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} + values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} for _, value := range values { DB.DropTable(value) } diff --git a/scope.go b/scope.go index 0116da08..1e755626 100644 --- a/scope.go +++ b/scope.go @@ -1117,8 +1117,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return