mirror of https://github.com/go-gorm/gorm.git
Refactor dialect
This commit is contained in:
parent
6546ec3b5e
commit
4e8370e18b
|
@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
|
||||||
DB.AutoMigrate(&CustomizeColumn{})
|
DB.AutoMigrate(&CustomizeColumn{})
|
||||||
|
|
||||||
scope := DB.NewScope(&CustomizeColumn{})
|
scope := DB.NewScope(&CustomizeColumn{})
|
||||||
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
|
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
||||||
t.Errorf("CustomizeColumn should have column %s", col)
|
t.Errorf("CustomizeColumn should have column %s", col)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
DB.HasTable("foobarbaz")
|
if err := DB.Find(&User{}).Error; err == nil {
|
||||||
if DB.Error == nil {
|
|
||||||
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
22
dialect.go
22
dialect.go
|
@ -10,6 +10,9 @@ import (
|
||||||
|
|
||||||
// Dialect interface contains behaviors that differ across SQL database
|
// Dialect interface contains behaviors that differ across SQL database
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
|
// SetDB set db for dialect
|
||||||
|
SetDB(db *sql.DB)
|
||||||
|
|
||||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||||
BindVar(i int) string
|
BindVar(i int) string
|
||||||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||||
|
@ -18,13 +21,13 @@ type Dialect interface {
|
||||||
DataTypeOf(field *StructField) string
|
DataTypeOf(field *StructField) string
|
||||||
|
|
||||||
// HasIndex check has index or not
|
// HasIndex check has index or not
|
||||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
HasIndex(tableName string, indexName string) bool
|
||||||
// RemoveIndex remove index
|
// RemoveIndex remove index
|
||||||
RemoveIndex(scope *Scope, indexName string)
|
RemoveIndex(tableName string, indexName string) error
|
||||||
// HasTable check has table or not
|
// HasTable check has table or not
|
||||||
HasTable(scope *Scope, tableName string) bool
|
HasTable(tableName string) bool
|
||||||
// HasColumn check has column or not
|
// HasColumn check has column or not
|
||||||
HasColumn(scope *Scope, tableName string, columnName string) bool
|
HasColumn(tableName string, columnName string) bool
|
||||||
|
|
||||||
// LimitAndOffsetSQL return generate SQL with limit and offset, as mssql has special case
|
// LimitAndOffsetSQL return generate SQL with limit and offset, as mssql has special case
|
||||||
LimitAndOffsetSQL(limit, offset int) string
|
LimitAndOffsetSQL(limit, offset int) string
|
||||||
|
@ -36,12 +39,17 @@ type Dialect interface {
|
||||||
|
|
||||||
var dialectsMap = map[string]Dialect{}
|
var dialectsMap = map[string]Dialect{}
|
||||||
|
|
||||||
func newDialect(name string) Dialect {
|
func newDialect(name string, db *sql.DB) Dialect {
|
||||||
if dialect, ok := dialectsMap[name]; ok {
|
if value, ok := dialectsMap[name]; ok {
|
||||||
|
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||||
|
dialect.SetDB(db)
|
||||||
return dialect
|
return dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
|
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
|
||||||
return &commonDialect{}
|
commontDialect := &commonDialect{}
|
||||||
|
commontDialect.SetDB(db)
|
||||||
|
return commontDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterDialect register new dialect
|
// RegisterDialect register new dialect
|
||||||
|
|
|
@ -1,18 +1,25 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type commonDialect struct{}
|
type commonDialect struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RegisterDialect("common", &commonDialect{})
|
RegisterDialect("common", &commonDialect{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *commonDialect) SetDB(db *sql.DB) {
|
||||||
|
s.db = db
|
||||||
|
}
|
||||||
|
|
||||||
func (commonDialect) BindVar(i int) string {
|
func (commonDialect) BindVar(i int) string {
|
||||||
return "$$" // ?
|
return "$$" // ?
|
||||||
}
|
}
|
||||||
|
@ -73,51 +80,31 @@ func (commonDialect) DataTypeOf(field *StructField) string {
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||||
var (
|
var count int
|
||||||
count int
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count)
|
||||||
databaseName = c.currentDatabase(scope)
|
|
||||||
)
|
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
|
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
func (s commonDialect) HasTable(tableName string) bool {
|
||||||
var (
|
var count int
|
||||||
count int
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count)
|
||||||
databaseName = c.currentDatabase(scope)
|
|
||||||
)
|
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
|
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||||
var (
|
var count int
|
||||||
count int
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
|
||||||
databaseName = c.currentDatabase(scope)
|
|
||||||
)
|
|
||||||
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// RawScanInt scans the first column of the first row into the `scan' int pointer.
|
func (s commonDialect) currentDatabase() (name string) {
|
||||||
// This function captures raw query errors and propagates them to the original scope.
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
|
|
||||||
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
|
||||||
}
|
|
||||||
|
|
||||||
// RawScanString scans the first column of the first row into the `scan' string pointer.
|
|
||||||
// This function captures raw query errors and propagates them to the original scope.
|
|
||||||
func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
|
|
||||||
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (commonDialect) currentDatabase(scope *Scope) (name string) {
|
|
||||||
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -67,32 +67,31 @@ func (mssql) DataTypeOf(field *StructField) string {
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
|
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
||||||
var (
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
count int
|
return err
|
||||||
databaseName = s.currentDatabase(scope)
|
}
|
||||||
)
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||||
var (
|
var count int
|
||||||
count int
|
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
|
||||||
databaseName = s.currentDatabase(scope)
|
|
||||||
)
|
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
|
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) currentDatabase(scope *Scope) (name string) {
|
func (s mssql) currentDatabase() (name string) {
|
||||||
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
|
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,8 +88,13 @@ func (mysql) DataTypeOf(field *StructField) string {
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) currentDatabase(scope *Scope) (name string) {
|
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
||||||
s.RawScanString(scope, &name, "SELECT DATABASE()")
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) currentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,30 +77,26 @@ func (postgres) DataTypeOf(field *StructField) string {
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
|
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (postgres) RemoveIndex(scope *Scope, indexName string) {
|
func (s postgres) HasTable(tableName string) bool {
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s postgres) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
var count int
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s postgres) currentDatabase(scope *Scope) (name string) {
|
func (s postgres) currentDatabase() (name string) {
|
||||||
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
|
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,29 +65,25 @@ func (sqlite3) DataTypeOf(field *StructField) string {
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
|
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
func (s sqlite3) HasTable(tableName string) bool {
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
var count int
|
||||||
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
|
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
|
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sqlite3) currentDatabase(scope *Scope) (name string) {
|
func (s sqlite3) currentDatabase() (name string) {
|
||||||
var (
|
var (
|
||||||
ifaces = make([]interface{}, 3)
|
ifaces = make([]interface{}, 3)
|
||||||
pointers = make([]*string, 3)
|
pointers = make([]*string, 3)
|
||||||
|
@ -96,7 +92,7 @@ func (sqlite3) currentDatabase(scope *Scope) (name string) {
|
||||||
for i = 0; i < 3; i++ {
|
for i = 0; i < 3; i++ {
|
||||||
ifaces[i] = &pointers[i]
|
ifaces[i] = &pointers[i]
|
||||||
}
|
}
|
||||||
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
|
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if pointers[1] != nil {
|
if pointers[1] != nil {
|
||||||
|
|
6
main.go
6
main.go
|
@ -62,7 +62,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
db = DB{
|
db = DB{
|
||||||
dialect: newDialect(dialect),
|
dialect: newDialect(dialect, dbSql.(*sql.DB)),
|
||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
callbacks: defaultCallback,
|
callbacks: defaultCallback,
|
||||||
source: source,
|
source: source,
|
||||||
|
@ -430,7 +430,7 @@ func (s *DB) HasTable(value interface{}) bool {
|
||||||
tableName = scope.TableName()
|
tableName = scope.TableName()
|
||||||
}
|
}
|
||||||
|
|
||||||
has := scope.Dialect().HasTable(scope, tableName)
|
has := scope.Dialect().HasTable(tableName)
|
||||||
s.AddError(scope.db.Error)
|
s.AddError(scope.db.Error)
|
||||||
return has
|
return has
|
||||||
}
|
}
|
||||||
|
@ -531,7 +531,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
|
||||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
handler.Setup(field.Relationship, many2many, source, destination)
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
field.Relationship.JoinTableHandler = handler
|
field.Relationship.JoinTableHandler = handler
|
||||||
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
|
if table := handler.Table(s); scope.Dialect().HasTable(table) {
|
||||||
s.Table(table).AutoMigrate(handler)
|
s.Table(table).AutoMigrate(handler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
scope := DB.NewScope(&Email{})
|
scope := DB.NewScope(&Email{})
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email should have index idx_email_email")
|
t.Errorf("Email should have index idx_email_email")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email's index idx_email_email should be deleted")
|
t.Errorf("Email's index idx_email_email should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) {
|
||||||
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
|
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
|
||||||
|
|
||||||
scope := DB.NewScope(&BigEmail{})
|
scope := DB.NewScope(&BigEmail{})
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -515,7 +515,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTableHandler := relationship.JoinTableHandler
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
joinTable := joinTableHandler.Table(scope.db)
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
if !scope.Dialect().HasTable(joinTable) {
|
||||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||||
|
|
||||||
var sqlTypes, primaryKeys []string
|
var sqlTypes, primaryKeys []string
|
||||||
|
@ -586,7 +586,7 @@ func (scope *Scope) dropTable() *Scope {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) dropTableIfExists() *Scope {
|
func (scope *Scope) dropTableIfExists() *Scope {
|
||||||
if scope.Dialect().HasTable(scope, scope.TableName()) {
|
if scope.Dialect().HasTable(scope.TableName()) {
|
||||||
scope.dropTable()
|
scope.dropTable()
|
||||||
}
|
}
|
||||||
return scope
|
return scope
|
||||||
|
@ -601,7 +601,7 @@ func (scope *Scope) dropColumn(column string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||||
if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
|
if scope.Dialect().HasIndex(scope.TableName(), indexName) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -626,18 +626,18 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) removeIndex(indexName string) {
|
func (scope *Scope) removeIndex(indexName string) {
|
||||||
scope.Dialect().RemoveIndex(scope, indexName)
|
scope.Dialect().RemoveIndex(scope.TableName(), indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) autoMigrate() *Scope {
|
func (scope *Scope) autoMigrate() *Scope {
|
||||||
tableName := scope.TableName()
|
tableName := scope.TableName()
|
||||||
quotedTableName := scope.QuotedTableName()
|
quotedTableName := scope.QuotedTableName()
|
||||||
|
|
||||||
if !scope.Dialect().HasTable(scope, tableName) {
|
if !scope.Dialect().HasTable(tableName) {
|
||||||
scope.createTable()
|
scope.createTable()
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.GetModelStruct().StructFields {
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
if !scope.Dialect().HasColumn(tableName, field.DBName) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
sqlTag := scope.Dialect().DataTypeOf(field)
|
sqlTag := scope.Dialect().DataTypeOf(field)
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||||
|
|
Loading…
Reference in New Issue