From 87fc1b24737a885147240041293603eceb844356 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 3 Feb 2018 20:27:19 +0800 Subject: [PATCH] Refactor PR #1751 --- dialect.go | 8 ++++++++ dialect_common.go | 17 ++++------------- dialect_mysql.go | 3 ++- dialects/mssql/mssql.go | 14 ++++++++++++-- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/dialect.go b/dialect.go index 9d3be249..90b1723f 100644 --- a/dialect.go +++ b/dialect.go @@ -114,3 +114,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel return fieldValue, dataType, size, strings.TrimSpace(additionalType) } + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +} diff --git a/dialect_common.go b/dialect_common.go index 9ccff6e9..30f035a5 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -92,7 +92,7 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -108,24 +108,14 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } -func (s commonDialect) currentDatabaseAndTable(tableName string) (string, string) { - currentDatabase := s.CurrentDatabase() - if currentDatabase == "" && strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - currentDatabase = splitStrings[0] - tableName = splitStrings[1] - } - return currentDatabase, tableName -} - func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - currentDatabase, tableName := s.currentDatabaseAndTable(tableName) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } @@ -157,6 +147,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } +// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") diff --git a/dialect_mysql.go b/dialect_mysql.go index d2fd53ca..f4858e10 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -144,7 +144,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..a4f8e87c 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -128,13 +128,15 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count) return count > 0 } func (s mssql) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } @@ -168,3 +170,11 @@ func (mssql) SelectFromDummyTable() string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName +}