Merge branch 'master' into count

This commit is contained in:
Emir Beganović 2019-04-30 12:24:53 +04:00 committed by GitHub
commit e617218f79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 73 additions and 4 deletions

View File

@ -48,6 +48,9 @@ type Dialect interface {
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
BuildKeyName(kind, tableName string, fields ...string) string BuildKeyName(kind, tableName string, fields ...string) string
// NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
// CurrentDatabase return current database name // CurrentDatabase return current database name
CurrentDatabase() string CurrentDatabase() string
} }

View File

@ -9,6 +9,8 @@ import (
"time" "time"
) )
var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
// DefaultForeignKeyNamer contains the default foreign key name generator method // DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct { type DefaultForeignKeyNamer struct {
} }
@ -166,10 +168,15 @@ func (commonDialect) DefaultValueStr() string {
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") keyName = keyNameRegex.ReplaceAllString(keyName, "_")
return keyName return keyName
} }
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
return indexName, columnName
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice // IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool { func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))

View File

@ -11,6 +11,8 @@ import (
"unicode/utf8" "unicode/utf8"
) )
var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`)
type mysql struct { type mysql struct {
commonDialect commonDialect
} }
@ -178,7 +180,7 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
bs := h.Sum(nil) bs := h.Sum(nil)
// sha1 is 40 characters, keep first 24 characters of destination // sha1 is 40 characters, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_"))
if len(destRunes) > 24 { if len(destRunes) > 24 {
destRunes = destRunes[:24] destRunes = destRunes[:24]
} }
@ -186,6 +188,17 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
return fmt.Sprintf("%s%x", string(destRunes), bs) return fmt.Sprintf("%s%x", string(destRunes), bs)
} }
// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed
func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
submatch := mysqlIndexRegex.FindStringSubmatch(indexName)
if len(submatch) != 3 {
return indexName, columnName
}
indexName = submatch[1]
columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2])
return indexName, columnName
}
func (mysql) DefaultValueStr() string { func (mysql) DefaultValueStr() string {
return "VALUES()" return "VALUES()"
} }

View File

@ -198,6 +198,11 @@ func (mssql) DefaultValueStr() string {
return "DEFAULT VALUES" return "DEFAULT VALUES"
} }
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
return indexName, columnName
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") { if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2) splitStrings := strings.SplitN(tableName, ".", 2)

View File

@ -538,3 +538,42 @@ func TestModifyColumnType(t *testing.T) {
t.Errorf("No error should happen when ModifyColumn, but got %v", err) t.Errorf("No error should happen when ModifyColumn, but got %v", err)
} }
} }
func TestIndexWithPrefixLength(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
t.Skip("Skipping this because only mysql support setting an index prefix length")
}
type IndexWithPrefix struct {
gorm.Model
Name string
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
type IndexesWithPrefix struct {
gorm.Model
Name string
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
type IndexesWithPrefixAndWithoutPrefix struct {
gorm.Model
Name string `gorm:"index:idx_index_with_prefixes_length"`
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
for _, table := range tables {
scope := DB.NewScope(table)
tableName := scope.TableName()
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
if err := DB.DropTableIfExists(table).Error; err != nil {
t.Errorf("Failed to drop %s table: %v", tableName, err)
}
if err := DB.CreateTable(table).Error; err != nil {
t.Errorf("Failed to create %s table: %v", tableName, err)
}
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
t.Errorf("Failed to create %s table index:", tableName)
}
})
}
}

View File

@ -1284,7 +1284,8 @@ func (scope *Scope) autoIndex() *Scope {
if name == "INDEX" || name == "" { if name == "INDEX" || name == "" {
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
} }
indexes[name] = append(indexes[name], field.DBName) name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
indexes[name] = append(indexes[name], column)
} }
} }
@ -1295,7 +1296,8 @@ func (scope *Scope) autoIndex() *Scope {
if name == "UNIQUE_INDEX" || name == "" { if name == "UNIQUE_INDEX" || name == "" {
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
} }
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
uniqueIndexes[name] = append(uniqueIndexes[name], column)
} }
} }
} }