diff --git a/association_test.go b/association_test.go index 52d2303f..2b2d3bac 100644 --- a/association_test.go +++ b/association_test.go @@ -2,6 +2,7 @@ package gorm_test import ( "fmt" + "os" "reflect" "sort" "testing" @@ -840,3 +841,26 @@ func TestForeignKey(t *testing.T) { } } } + +func TestLongForeignKey(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { + // sqlite does not support ADD CONSTRAINT in ALTER TABLE + return + } + targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{}) + targetTableName := targetScope.TableName() + modelScope := DB.NewScope(&NotSoLongTableName{}) + modelField, ok := modelScope.FieldByName("ReallyLongThingID") + if !ok { + t.Fatalf("Failed to get field by name: ReallyLongThingID") + } + targetField, ok := targetScope.FieldByName("ID") + if !ok { + t.Fatalf("Failed to get field by name: ID") + } + dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) + err := DB.Model(&NotSoLongTableName{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error + if err != nil { + t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) + } +} diff --git a/dialect.go b/dialect.go index 6c9405da..033b9555 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,9 @@ type Dialect interface { SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + + // BuildForeignKeyName returns a foreign key name for the given table, field and reference + BuildForeignKeyName(tableName, field, dest string) string } var dialectsMap = map[string]Dialect{} diff --git a/dialect_common.go b/dialect_common.go index f009271b..6d43fb84 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -4,12 +4,18 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "time" ) +// DefaultForeignKeyNamer contains the default foreign key name generator method +type DefaultForeignKeyNamer struct { +} + type commonDialect struct { db *sql.DB + DefaultForeignKeyNamer } func init() { @@ -135,3 +141,9 @@ func (commonDialect) SelectFromDummyTable() string { func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { + keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + return keyName +} diff --git a/dialect_mysql.go b/dialect_mysql.go index f62bbe35..d848de35 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -1,8 +1,10 @@ package gorm import ( + "crypto/sha1" "fmt" "reflect" + "regexp" "strings" "time" ) @@ -115,3 +117,18 @@ func (s mysql) currentDatabase() (name string) { func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } + +func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { + keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) + if len(keyName) <= 64 { + return keyName + } + h := sha1.New() + h.Write([]byte(keyName)) + bs := h.Sum(nil) + + // sha1 is 40 digits, keep first 24 characters of destination + keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_") + + return fmt.Sprintf("%s%x", keyName[:24], bs) +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 5b994f9d..34eda717 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -24,6 +24,7 @@ func init() { type mssql struct { db *sql.DB + gorm.DefaultForeignKeyNamer } func (mssql) GetName() string { diff --git a/migration_test.go b/migration_test.go index 38e5c1c2..4d15917c 100644 --- a/migration_test.go +++ b/migration_test.go @@ -39,6 +39,16 @@ type User struct { IgnoredPointer *User `sql:"-"` } +type NotSoLongTableName struct { + Id int64 + ReallyLongThingID int64 + ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit +} + +type ReallyLongTableNameToTestMySQLNameLengthLimit struct { + Id int64 +} + type CreditCard struct { ID int8 Number string @@ -231,7 +241,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} + values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} for _, value := range values { DB.DropTable(value) } diff --git a/scope.go b/scope.go index 0116da08..1e755626 100644 --- a/scope.go +++ b/scope.go @@ -1117,8 +1117,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return