From 4e8370e18ba237a4c7bdad7eccddb297a45a906c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 15 Feb 2016 14:09:24 +0800 Subject: [PATCH] Refactor dialect --- customize_column_test.go | 2 +- ddl_errors_test.go | 3 +-- dialect.go | 22 +++++++++++----- dialect_common.go | 57 ++++++++++++++++------------------------ dialect_mssql.go | 31 +++++++++++----------- dialect_mysql.go | 9 +++++-- dialect_postgres.go | 20 ++++++-------- dialect_sqlite3.go | 20 ++++++-------- main.go | 6 ++--- migration_test.go | 16 +++++------ scope_private.go | 12 ++++----- 11 files changed, 94 insertions(+), 104 deletions(-) diff --git a/customize_column_test.go b/customize_column_test.go index 93bab2e1..177b4a5d 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) { DB.AutoMigrate(&CustomizeColumn{}) scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope, scope.TableName(), col) { + if !scope.Dialect().HasColumn(scope.TableName(), col) { t.Errorf("CustomizeColumn should have column %s", col) } diff --git a/ddl_errors_test.go b/ddl_errors_test.go index aca59553..2c31b354 100644 --- a/ddl_errors_test.go +++ b/ddl_errors_test.go @@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) { } }() - DB.HasTable("foobarbaz") - if DB.Error == nil { + if err := DB.Find(&User{}).Error; err == nil { t.Errorf("Expected operation on closed db to produce an error, but err was nil") } } diff --git a/dialect.go b/dialect.go index 1923e66e..cce68789 100644 --- a/dialect.go +++ b/dialect.go @@ -10,6 +10,9 @@ import ( // Dialect interface contains behaviors that differ across SQL database type Dialect interface { + // SetDB set db for dialect + SetDB(db *sql.DB) + // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 BindVar(i int) string // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name @@ -18,13 +21,13 @@ type Dialect interface { DataTypeOf(field *StructField) string // HasIndex check has index or not - HasIndex(scope *Scope, tableName string, indexName string) bool + HasIndex(tableName string, indexName string) bool // RemoveIndex remove index - RemoveIndex(scope *Scope, indexName string) + RemoveIndex(tableName string, indexName string) error // HasTable check has table or not - HasTable(scope *Scope, tableName string) bool + HasTable(tableName string) bool // HasColumn check has column or not - HasColumn(scope *Scope, tableName string, columnName string) bool + HasColumn(tableName string, columnName string) bool // LimitAndOffsetSQL return generate SQL with limit and offset, as mssql has special case LimitAndOffsetSQL(limit, offset int) string @@ -36,12 +39,17 @@ type Dialect interface { var dialectsMap = map[string]Dialect{} -func newDialect(name string) Dialect { - if dialect, ok := dialectsMap[name]; ok { +func newDialect(name string, db *sql.DB) Dialect { + if value, ok := dialectsMap[name]; ok { + dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) + dialect.SetDB(db) return dialect } + fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - return &commonDialect{} + commontDialect := &commonDialect{} + commontDialect.SetDB(db) + return commontDialect } // RegisterDialect register new dialect diff --git a/dialect_common.go b/dialect_common.go index d5a81ad6..333b0b45 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -1,18 +1,25 @@ package gorm import ( + "database/sql" "fmt" "reflect" "strings" "time" ) -type commonDialect struct{} +type commonDialect struct { + db *sql.DB +} func init() { RegisterDialect("common", &commonDialect{}) } +func (s *commonDialect) SetDB(db *sql.DB) { + s.db = db +} + func (commonDialect) BindVar(i int) string { return "$$" // ? } @@ -73,51 +80,31 @@ func (commonDialect) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { - var ( - count int - databaseName = c.currentDatabase(scope) - ) - c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName) +func (s commonDialect) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count) return count > 0 } -func (commonDialect) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) +func (s commonDialect) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) + return err } -func (c commonDialect) HasTable(scope *Scope, tableName string) bool { - var ( - count int - databaseName = c.currentDatabase(scope) - ) - c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName) +func (s commonDialect) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count) return count > 0 } -func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { - var ( - count int - databaseName = c.currentDatabase(scope) - ) - c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) +func (s commonDialect) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) return count > 0 } -// RawScanInt scans the first column of the first row into the `scan' int pointer. -// This function captures raw query errors and propagates them to the original scope. -func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) { - scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) -} - -// RawScanString scans the first column of the first row into the `scan' string pointer. -// This function captures raw query errors and propagates them to the original scope. -func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) { - scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) -} - -func (commonDialect) currentDatabase(scope *Scope) (name string) { - scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) +func (s commonDialect) currentDatabase() (name string) { + s.db.QueryRow("SELECT DATABASE()").Scan(&name) return } diff --git a/dialect_mssql.go b/dialect_mssql.go index a2af49ad..63b46e9e 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -67,32 +67,31 @@ func (mssql) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (s mssql) HasIndex(tableName string, indexName string) bool { var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName) + s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) return count > 0 } -func (s mssql) HasTable(scope *Scope, tableName string) bool { - var ( - count int - databaseName = s.currentDatabase(scope) - ) - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName) +func (s mssql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) + return err +} + +func (s mssql) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count) return count > 0 } -func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { - var ( - count int - databaseName = s.currentDatabase(scope) - ) - s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) +func (s mssql) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) return count > 0 } -func (s mssql) currentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]") +func (s mssql) currentDatabase() (name string) { + s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) return } diff --git a/dialect_mysql.go b/dialect_mysql.go index 10d63db2..9e530a9a 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -88,8 +88,13 @@ func (mysql) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s mysql) currentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT DATABASE()") +func (s mysql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) + return err +} + +func (s mysql) currentDatabase() (name string) { + s.db.QueryRow("SELECT DATABASE()").Scan(&name) return } diff --git a/dialect_postgres.go b/dialect_postgres.go index e726d233..3c18acc2 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -77,30 +77,26 @@ func (postgres) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (s postgres) HasIndex(tableName string, indexName string) bool { var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName) + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) return count > 0 } -func (postgres) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) -} - -func (s postgres) HasTable(scope *Scope, tableName string) bool { +func (s postgres) HasTable(tableName string) bool { var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) return count > 0 } -func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s postgres) HasColumn(tableName string, columnName string) bool { var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) return count > 0 } -func (s postgres) currentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()") +func (s postgres) currentDatabase() (name string) { + s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) return } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 3abdb92e..41e45517 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -65,29 +65,25 @@ func (sqlite3) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (s sqlite3) HasIndex(tableName string, indexName string) bool { var count int - s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) + s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) return count > 0 } -func (sqlite3) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) -} - -func (s sqlite3) HasTable(scope *Scope, tableName string) bool { +func (s sqlite3) HasTable(tableName string) bool { var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName) + s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) return count > 0 } -func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s sqlite3) HasColumn(tableName string, columnName string) bool { var count int - s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName) + s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName).Scan(&count) return count > 0 } -func (sqlite3) currentDatabase(scope *Scope) (name string) { +func (s sqlite3) currentDatabase() (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) @@ -96,7 +92,7 @@ func (sqlite3) currentDatabase(scope *Scope) (name string) { for i = 0; i < 3; i++ { ifaces[i] = &pointers[i] } - if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil { + if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { return } if pointers[1] != nil { diff --git a/main.go b/main.go index 51bd9914..9581a216 100644 --- a/main.go +++ b/main.go @@ -62,7 +62,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) { } db = DB{ - dialect: newDialect(dialect), + dialect: newDialect(dialect, dbSql.(*sql.DB)), logger: defaultLogger, callbacks: defaultCallback, source: source, @@ -430,7 +430,7 @@ func (s *DB) HasTable(value interface{}) bool { tableName = scope.TableName() } - has := scope.Dialect().HasTable(scope, tableName) + has := scope.Dialect().HasTable(tableName) s.AddError(scope.db.Error) return has } @@ -531,7 +531,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(scope, table) { + if table := handler.Table(s); scope.Dialect().HasTable(table) { s.Table(table).AutoMigrate(handler) } } diff --git a/migration_test.go b/migration_test.go index 0411872e..de35c1df 100644 --- a/migration_test.go +++ b/migration_test.go @@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) { } scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { t.Errorf("Email should have index idx_email_email") } @@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { t.Errorf("Email's index idx_email_email should be deleted") } @@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) { DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) scope := DB.NewScope(&BigEmail{}) - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") { + if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") { t.Errorf("Failed to create index") } diff --git a/scope_private.go b/scope_private.go index 8e148820..4ed2060c 100644 --- a/scope_private.go +++ b/scope_private.go @@ -515,7 +515,7 @@ func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(scope, joinTable) { + if !scope.Dialect().HasTable(joinTable) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} var sqlTypes, primaryKeys []string @@ -586,7 +586,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) dropTableIfExists() *Scope { - if scope.Dialect().HasTable(scope, scope.TableName()) { + if scope.Dialect().HasTable(scope.TableName()) { scope.dropTable() } return scope @@ -601,7 +601,7 @@ func (scope *Scope) dropColumn(column string) { } func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) { + if scope.Dialect().HasIndex(scope.TableName(), indexName) { return } @@ -626,18 +626,18 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on } func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope, indexName) + scope.Dialect().RemoveIndex(scope.TableName(), indexName) } func (scope *Scope) autoMigrate() *Scope { tableName := scope.TableName() quotedTableName := scope.QuotedTableName() - if !scope.Dialect().HasTable(scope, tableName) { + if !scope.Dialect().HasTable(tableName) { scope.createTable() } else { for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { + if !scope.Dialect().HasColumn(tableName, field.DBName) { if field.IsNormal { sqlTag := scope.Dialect().DataTypeOf(field) scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()