Merge pull request #1031 from nkovacs/add-foreign-key-panic-fix

fix panic in AddForeignKey on mysql dialect
This commit is contained in:
Jinzhu 2016-05-22 19:10:33 +08:00
commit 27522453de
3 changed files with 34 additions and 12 deletions

View File

@ -842,25 +842,33 @@ func TestForeignKey(t *testing.T) {
} }
} }
func TestLongForeignKey(t *testing.T) { func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) {
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
// sqlite does not support ADD CONSTRAINT in ALTER TABLE // sqlite does not support ADD CONSTRAINT in ALTER TABLE
return return
} }
targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{}) targetScope := DB.NewScope(target)
targetTableName := targetScope.TableName() targetTableName := targetScope.TableName()
modelScope := DB.NewScope(&NotSoLongTableName{}) modelScope := DB.NewScope(source)
modelField, ok := modelScope.FieldByName("ReallyLongThingID") modelField, ok := modelScope.FieldByName(sourceFieldName)
if !ok { if !ok {
t.Fatalf("Failed to get field by name: ReallyLongThingID") t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName))
} }
targetField, ok := targetScope.FieldByName("ID") targetField, ok := targetScope.FieldByName(targetFieldName)
if !ok { if !ok {
t.Fatalf("Failed to get field by name: ID") t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName))
} }
dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName)
err := DB.Model(&NotSoLongTableName{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
if err != nil { if err != nil {
t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err))
} }
} }
func TestLongForeignKey(t *testing.T) {
testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID")
}
func TestLongForeignKeyWithShortDest(t *testing.T) {
testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID")
}

View File

@ -7,6 +7,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"time" "time"
"unicode/utf8"
) )
type mysql struct { type mysql struct {
@ -120,7 +121,7 @@ func (mysql) SelectFromDummyTable() string {
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
if len(keyName) <= 64 { if utf8.RuneCountInString(keyName) <= 64 {
return keyName return keyName
} }
h := sha1.New() h := sha1.New()
@ -128,7 +129,10 @@ func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
bs := h.Sum(nil) bs := h.Sum(nil)
// sha1 is 40 digits, keep first 24 characters of destination // sha1 is 40 digits, keep first 24 characters of destination
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_") destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
if len(destRunes) > 24 {
destRunes = destRunes[:24]
}
return fmt.Sprintf("%s%x", keyName[:24], bs) return fmt.Sprintf("%s%x", string(destRunes), bs)
} }

View File

@ -49,6 +49,16 @@ type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
Id int64 Id int64
} }
type ReallyLongThingThatReferencesShort struct {
Id int64
ShortID int64
Short Short
}
type Short struct {
Id int64
}
type CreditCard struct { type CreditCard struct {
ID int8 ID int8
Number string Number string
@ -241,7 +251,7 @@ func runMigration() {
DB.Exec(fmt.Sprintf("drop table %v;", table)) DB.Exec(fmt.Sprintf("drop table %v;", table))
} }
values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}} values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}}
for _, value := range values { for _, value := range values {
DB.DropTable(value) DB.DropTable(value)
} }