Fix too long foreign key names in mysql.

The dialect must define its own foreign key generator method.
The previous default is available as a method on gorm.DefaultForeignKeyNamer
and can be embedded in other dialects.
The mysql dialect uses the first 24 characters plus an sha1 hash of the
full key name if the key name is more than 64 characters.
This commit is contained in:
Nikola Kovacs 2016-05-22 00:13:26 +02:00
parent 58ce6153e2
commit 6eb73ae654
7 changed files with 42 additions and 6 deletions

View File

@ -810,7 +810,7 @@ func TestRelated(t *testing.T) {
func TestForeignKey(t *testing.T) {
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID", "ReallyLongThingID"} {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
@ -849,7 +849,7 @@ func TestLongForeignKey(t *testing.T) {
}
targetScope := DB.NewScope(&ReallyLongTableNameToTestMySQLNameLengthLimit{})
targetTableName := targetScope.TableName()
modelScope := DB.NewScope(&User{})
modelScope := DB.NewScope(&NotSoLongTableName{})
modelField, ok := modelScope.FieldByName("ReallyLongThingID")
if !ok {
t.Fatalf("Failed to get field by name: ReallyLongThingID")
@ -859,7 +859,7 @@ func TestLongForeignKey(t *testing.T) {
t.Fatalf("Failed to get field by name: ID")
}
dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName)
err := DB.Model(&User{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
err := DB.Model(&NotSoLongTableName{}).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
if err != nil {
t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err))
}

View File

@ -40,6 +40,9 @@ type Dialect interface {
SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
BuildForeignKeyName(tableName, field, dest string) string
}
var dialectsMap = map[string]Dialect{}

View File

@ -4,12 +4,18 @@ import (
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)
// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
type commonDialect struct {
db *sql.DB
DefaultForeignKeyNamer
}
func init() {
@ -135,3 +141,9 @@ func (commonDialect) SelectFromDummyTable() string {
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
return keyName
}

View File

@ -1,8 +1,10 @@
package gorm
import (
"crypto/sha1"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)
@ -115,3 +117,18 @@ func (s mysql) currentDatabase() (name string) {
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
if len(keyName) <= 64 {
return keyName
}
h := sha1.New()
h.Write([]byte(keyName))
bs := h.Sum(nil)
// sha1 is 40 digits, keep first 24 characters of destination
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")
return fmt.Sprintf("%s%x", keyName[:24], bs)
}

View File

@ -24,6 +24,7 @@ func init() {
type mssql struct {
db *sql.DB
gorm.DefaultForeignKeyNamer
}
func (mssql) GetName() string {

View File

@ -37,6 +37,10 @@ type User struct {
IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"`
}
type NotSoLongTableName struct {
Id int64
ReallyLongThingID int64
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
}
@ -237,7 +241,7 @@ func runMigration() {
DB.Exec(fmt.Sprintf("drop table %v;", table))
}
values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}}
values := []interface{}{&ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}}
for _, value := range values {
DB.DropTable(value)
}

View File

@ -1117,8 +1117,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest)
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest)
if scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return