diff --git a/association_test.go b/association_test.go index 1e2e5179..2b2d3bac 100644 --- a/association_test.go +++ b/association_test.go @@ -810,7 +810,7 @@ func TestRelated(t *testing.T) { func TestForeignKey(t *testing.T) { for _, structField := range DB.NewScope(&User{}).GetStructFields() { - for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID", "ReallyLongThingID"} { + for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { if structField.Name == foreignKey && !structField.IsForeignKey { t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) } @@ -849,7 +849,7 @@ func TestLongForeignKey(t *testing.T) { } targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{}) targetTableName := targetScope.TableName() - modelScope := DB.NewScope(&User{}) + modelScope := DB.NewScope(&NotSoLongTableName{}) modelField, ok := modelScope.FieldByName("ReallyLongThingID") if !ok { t.Fatalf("Failed to get field by name: ReallyLongThingID") @@ -859,7 +859,7 @@ func TestLongForeignKey(t *testing.T) { t.Fatalf("Failed to get field by name: ID") } dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) - err := DB.Model(&User{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error + err := DB.Model(&NotSoLongTableName{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error if err != nil { t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) } diff --git a/dialect.go b/dialect.go index 6c9405da..033b9555 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,9 @@ type Dialect interface { SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + + // BuildForeignKeyName returns a foreign key name for the given table, field and reference + BuildForeignKeyName(tableName, field, dest string) string } var dialectsMap = map[string]Dialect{} diff --git a/dialect_common.go b/dialect_common.go index f009271b..6d43fb84 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -4,12 +4,18 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "time" ) +// DefaultForeignKeyNamer contains the default foreign key name generator method +type DefaultForeignKeyNamer struct { +} + type commonDialect struct { db *sql.DB + DefaultForeignKeyNamer } func init() { @@ -135,3 +141,9 @@ func (commonDialect) SelectFromDummyTable() string { func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { + keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + return keyName +} diff --git a/dialect_mysql.go b/dialect_mysql.go index f62bbe35..d848de35 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -1,8 +1,10 @@ package gorm import ( + "crypto/sha1" "fmt" "reflect" + "regexp" "strings" "time" ) @@ -115,3 +117,18 @@ func (s mysql) currentDatabase() (name string) { func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } + +func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { + keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) + if len(keyName) <= 64 { + return keyName + } + h := sha1.New() + h.Write([]byte(keyName)) + bs := h.Sum(nil) + + // sha1 is 40 digits, keep first 24 characters of destination + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_") + + return fmt.Sprintf("%s%x", keyName[:24], bs) +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 5b994f9d..34eda717 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -24,6 +24,7 @@ func init() { type mssql struct { db *sql.DB + gorm.DefaultForeignKeyNamer } func (mssql) GetName() string { diff --git a/migration_test.go b/migration_test.go index fc692bcc..4d15917c 100644 --- a/migration_test.go +++ b/migration_test.go @@ -37,6 +37,10 @@ type User struct { IgnoreStringSlice []string `sql:"-"` Ignored struct{ Name string } `sql:"-"` IgnoredPointer *User `sql:"-"` +} + +type NotSoLongTableName struct { + Id int64 ReallyLongThingID int64 ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit } @@ -237,7 +241,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} + values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} for _, value := range values { DB.DropTable(value) } diff --git a/scope.go b/scope.go index 0116da08..1e755626 100644 --- a/scope.go +++ b/scope.go @@ -1117,8 +1117,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return