mirror of https://github.com/go-gorm/gorm.git
Merge branch 'master' into pluck
This commit is contained in:
commit
483c2ef521
|
@ -1,6 +1,9 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// Define callbacks for row query
|
// Define callbacks for row query
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -20,6 +23,9 @@ type RowsQueryResult struct {
|
||||||
func rowQueryCallback(scope *Scope) {
|
func rowQueryCallback(scope *Scope) {
|
||||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||||
scope.prepareQuerySQL()
|
scope.prepareQuerySQL()
|
||||||
|
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||||
|
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||||
|
}
|
||||||
|
|
||||||
if rowResult, ok := result.(*RowQueryResult); ok {
|
if rowResult, ok := result.(*RowQueryResult); ok {
|
||||||
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||||
|
|
|
@ -48,6 +48,9 @@ type Dialect interface {
|
||||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||||
BuildKeyName(kind, tableName string, fields ...string) string
|
BuildKeyName(kind, tableName string, fields ...string) string
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
|
||||||
|
NormalizeIndexAndColumn(indexName, columnName string) (string, string)
|
||||||
|
|
||||||
// CurrentDatabase return current database name
|
// CurrentDatabase return current database name
|
||||||
CurrentDatabase() string
|
CurrentDatabase() string
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
|
||||||
|
|
||||||
// DefaultForeignKeyNamer contains the default foreign key name generator method
|
// DefaultForeignKeyNamer contains the default foreign key name generator method
|
||||||
type DefaultForeignKeyNamer struct {
|
type DefaultForeignKeyNamer struct {
|
||||||
}
|
}
|
||||||
|
@ -166,10 +168,15 @@ func (commonDialect) DefaultValueStr() string {
|
||||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||||
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
|
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||||
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
|
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
|
||||||
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
|
keyName = keyNameRegex.ReplaceAllString(keyName, "_")
|
||||||
return keyName
|
return keyName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
|
||||||
|
func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
|
||||||
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
||||||
func IsByteArrayOrSlice(value reflect.Value) bool {
|
func IsByteArrayOrSlice(value reflect.Value) bool {
|
||||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||||
|
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`)
|
||||||
|
|
||||||
type mysql struct {
|
type mysql struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
|
@ -178,7 +180,7 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||||
bs := h.Sum(nil)
|
bs := h.Sum(nil)
|
||||||
|
|
||||||
// sha1 is 40 characters, keep first 24 characters of destination
|
// sha1 is 40 characters, keep first 24 characters of destination
|
||||||
destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
|
destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_"))
|
||||||
if len(destRunes) > 24 {
|
if len(destRunes) > 24 {
|
||||||
destRunes = destRunes[:24]
|
destRunes = destRunes[:24]
|
||||||
}
|
}
|
||||||
|
@ -186,6 +188,17 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||||
return fmt.Sprintf("%s%x", string(destRunes), bs)
|
return fmt.Sprintf("%s%x", string(destRunes), bs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed
|
||||||
|
func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||||
|
submatch := mysqlIndexRegex.FindStringSubmatch(indexName)
|
||||||
|
if len(submatch) != 3 {
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
indexName = submatch[1]
|
||||||
|
columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2])
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
|
||||||
func (mysql) DefaultValueStr() string {
|
func (mysql) DefaultValueStr() string {
|
||||||
return "VALUES()"
|
return "VALUES()"
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,6 +198,11 @@ func (mssql) DefaultValueStr() string {
|
||||||
return "DEFAULT VALUES"
|
return "DEFAULT VALUES"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
|
||||||
|
func (mssql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
|
||||||
|
return indexName, columnName
|
||||||
|
}
|
||||||
|
|
||||||
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
|
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||||
|
|
23
main_test.go
23
main_test.go
|
@ -1141,6 +1141,29 @@ func TestPluck(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCountWithQueryOption(t *testing.T) {
|
||||||
|
db := DB.New()
|
||||||
|
db.Delete(User{})
|
||||||
|
defer db.Delete(User{})
|
||||||
|
|
||||||
|
DB.Create(&User{Name: "user1"})
|
||||||
|
DB.Create(&User{Name: "user2"})
|
||||||
|
DB.Create(&User{Name: "user3"})
|
||||||
|
|
||||||
|
var count int
|
||||||
|
err := db.Model(User{}).Select("users.id").
|
||||||
|
Set("gorm:query_option", "WHERE users.name='user2'").
|
||||||
|
Count(&count).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Unexpected error on query count with query_option")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != 1 {
|
||||||
|
t.Error("Unexpected result on query count with query_option")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGorm(b *testing.B) {
|
func BenchmarkGorm(b *testing.B) {
|
||||||
b.N = 2000
|
b.N = 2000
|
||||||
for x := 0; x < b.N; x++ {
|
for x := 0; x < b.N; x++ {
|
||||||
|
|
|
@ -538,3 +538,42 @@ func TestModifyColumnType(t *testing.T) {
|
||||||
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
|
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIndexWithPrefixLength(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
||||||
|
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
||||||
|
}
|
||||||
|
|
||||||
|
type IndexWithPrefix struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
}
|
||||||
|
type IndexesWithPrefix struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
}
|
||||||
|
type IndexesWithPrefixAndWithoutPrefix struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"index:idx_index_with_prefixes_length"`
|
||||||
|
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
|
}
|
||||||
|
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
|
||||||
|
for _, table := range tables {
|
||||||
|
scope := DB.NewScope(table)
|
||||||
|
tableName := scope.TableName()
|
||||||
|
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
|
||||||
|
if err := DB.DropTableIfExists(table).Error; err != nil {
|
||||||
|
t.Errorf("Failed to drop %s table: %v", tableName, err)
|
||||||
|
}
|
||||||
|
if err := DB.CreateTable(table).Error; err != nil {
|
||||||
|
t.Errorf("Failed to create %s table: %v", tableName, err)
|
||||||
|
}
|
||||||
|
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
|
||||||
|
t.Errorf("Failed to create %s table index:", tableName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
6
scope.go
6
scope.go
|
@ -1288,7 +1288,8 @@ func (scope *Scope) autoIndex() *Scope {
|
||||||
if name == "INDEX" || name == "" {
|
if name == "INDEX" || name == "" {
|
||||||
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
||||||
}
|
}
|
||||||
indexes[name] = append(indexes[name], field.DBName)
|
name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
|
||||||
|
indexes[name] = append(indexes[name], column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1299,7 +1300,8 @@ func (scope *Scope) autoIndex() *Scope {
|
||||||
if name == "UNIQUE_INDEX" || name == "" {
|
if name == "UNIQUE_INDEX" || name == "" {
|
||||||
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
||||||
}
|
}
|
||||||
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName)
|
||||||
|
uniqueIndexes[name] = append(uniqueIndexes[name], column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue