mirror of https://github.com/go-gorm/gorm.git
Refactor PR #1751
This commit is contained in:
parent
57f031e083
commit
87fc1b2473
|
@ -114,3 +114,11 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
|
|||
|
||||
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
||||
}
|
||||
|
||||
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
|
||||
if strings.Contains(tableName, ".") {
|
||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||
return splitStrings[0], splitStrings[1]
|
||||
}
|
||||
return dialect.CurrentDatabase(), tableName
|
||||
}
|
||||
|
|
|
@ -92,7 +92,7 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
|
|||
|
||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
currentDatabase, tableName := s.currentDatabaseAndTable(tableName)
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
@ -108,24 +108,14 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo
|
|||
|
||||
func (s commonDialect) HasTable(tableName string) bool {
|
||||
var count int
|
||||
currentDatabase, tableName := s.currentDatabaseAndTable(tableName)
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) currentDatabaseAndTable(tableName string) (string, string) {
|
||||
currentDatabase := s.CurrentDatabase()
|
||||
if currentDatabase == "" && strings.Contains(tableName, ".") {
|
||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||
currentDatabase = splitStrings[0]
|
||||
tableName = splitStrings[1]
|
||||
}
|
||||
return currentDatabase, tableName
|
||||
}
|
||||
|
||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
currentDatabase, tableName := s.currentDatabaseAndTable(tableName)
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
@ -157,6 +147,7 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s
|
|||
return ""
|
||||
}
|
||||
|
||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
|
||||
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
|
||||
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
|
||||
|
|
|
@ -144,7 +144,8 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
|||
|
||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
|
|
@ -128,13 +128,15 @@ func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
|||
|
||||
func (s mssql) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
@ -168,3 +170,11 @@ func (mssql) SelectFromDummyTable() string {
|
|||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
|
||||
if strings.Contains(tableName, ".") {
|
||||
splitStrings := strings.SplitN(tableName, ".", 2)
|
||||
return splitStrings[0], splitStrings[1]
|
||||
}
|
||||
return dialect.CurrentDatabase(), tableName
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue