From 9cdda4e8ee9d7e00c682a9bf6f6677868870e2e5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Jul 2016 21:37:44 +0800 Subject: [PATCH] Expose current database name API --- dialect.go | 3 +++ dialect_common.go | 8 ++++---- dialect_mysql.go | 4 ++-- dialect_postgres.go | 2 +- dialect_sqlite3.go | 2 +- dialects/mssql/mssql.go | 6 +++--- 6 files changed, 14 insertions(+), 11 deletions(-) diff --git a/dialect.go b/dialect.go index 96bf4a2c..facde0d0 100644 --- a/dialect.go +++ b/dialect.go @@ -43,6 +43,9 @@ type Dialect interface { // BuildForeignKeyName returns a foreign key name for the given table, field and reference BuildForeignKeyName(tableName, field, dest string) string + + // CurrentDatabase return current database name + CurrentDatabase() string } var dialectsMap = map[string]Dialect{} diff --git a/dialect_common.go b/dialect_common.go index 8e66110f..5b5682c5 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -93,7 +93,7 @@ func (commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) return count > 0 } @@ -108,17 +108,17 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) return count > 0 } func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) return count > 0 } -func (s commonDialect) currentDatabase() (name string) { +func (s commonDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return } diff --git a/dialect_mysql.go b/dialect_mysql.go index 0ddcea4d..11b894b3 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -114,11 +114,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error { func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), tableName, foreignKeyName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) return count > 0 } -func (s mysql) currentDatabase() (name string) { +func (s mysql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return } diff --git a/dialect_postgres.go b/dialect_postgres.go index 0c17d28e..5a6114c0 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -107,7 +107,7 @@ func (s postgres) HasColumn(tableName string, columnName string) bool { return count > 0 } -func (s postgres) currentDatabase() (name string) { +func (s postgres) CurrentDatabase() (name string) { s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) return } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 79adf6d2..2abcefa5 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -89,7 +89,7 @@ func (s sqlite3) HasColumn(tableName string, columnName string) bool { return count > 0 } -func (s sqlite3) currentDatabase() (name string) { +func (s sqlite3) CurrentDatabase() (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index fcdd0ed8..a7bca6b8 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -113,17 +113,17 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) return count > 0 } func (s mssql) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) return count > 0 } -func (s mssql) currentDatabase() (name string) { +func (s mssql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) return }