From 7983fc626923b7f0e80755499b9d2aed212d550e Mon Sep 17 00:00:00 2001 From: Nikola Kovacs Date: Sun, 22 May 2016 09:58:37 +0200 Subject: [PATCH] fix panic in AddForeignKey on mysql dialect --- association_test.go | 24 ++++++++++++++++-------- dialect_mysql.go | 10 +++++++--- migration_test.go | 12 +++++++++++- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/association_test.go b/association_test.go index 2b2d3bac..ad56d84e 100644 --- a/association_test.go +++ b/association_test.go @@ -842,25 +842,33 @@ func TestForeignKey(t *testing.T) { } } -func TestLongForeignKey(t *testing.T) { +func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) { if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { // sqlite does not support ADD CONSTRAINT in ALTER TABLE return } - targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{}) + targetScope := DB.NewScope(target) targetTableName := targetScope.TableName() - modelScope := DB.NewScope(&NotSoLongTableName{}) - modelField, ok := modelScope.FieldByName("ReallyLongThingID") + modelScope := DB.NewScope(source) + modelField, ok := modelScope.FieldByName(sourceFieldName) if !ok { - t.Fatalf("Failed to get field by name: ReallyLongThingID") + t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName)) } - targetField, ok := targetScope.FieldByName("ID") + targetField, ok := targetScope.FieldByName(targetFieldName) if !ok { - t.Fatalf("Failed to get field by name: ID") + t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName)) } dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) - err := DB.Model(&NotSoLongTableName{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error + err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error if err != nil { t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) } } + +func TestLongForeignKey(t *testing.T) { + testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID") +} + +func TestLongForeignKeyWithShortDest(t *testing.T) { + testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID") +} diff --git a/dialect_mysql.go b/dialect_mysql.go index d848de35..bc4828de 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" "time" + "unicode/utf8" ) type mysql struct { @@ -120,7 +121,7 @@ func (mysql) SelectFromDummyTable() string { func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) - if len(keyName) <= 64 { + if utf8.RuneCountInString(keyName) <= 64 { return keyName } h := sha1.New() @@ -128,7 +129,10 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { bs := h.Sum(nil) // sha1 is 40 digits, keep first 24 characters of destination - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_") + destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) + if len(destRunes) > 24 { + destRunes = destRunes[:24] + } - return fmt.Sprintf("%s%x", keyName[:24], bs) + return fmt.Sprintf("%s%x", string(destRunes), bs) } diff --git a/migration_test.go b/migration_test.go index 4d15917c..3e385466 100644 --- a/migration_test.go +++ b/migration_test.go @@ -49,6 +49,16 @@ type ReallyLongTableNameToTestMySQLNameLengthLimit struct { Id int64 } +type ReallyLongThingThatReferencesShort struct { + Id int64 + ShortID int64 + Short Short +} + +type Short struct { + Id int64 +} + type CreditCard struct { ID int8 Number string @@ -241,7 +251,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} for _, value := range values { DB.DropTable(value) }