diff --git a/dialect_mysql.go b/dialect_mysql.go index 1addaf36..ac9b3b2e 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -2,6 +2,7 @@ package gorm import ( "crypto/sha1" + "database/sql" "fmt" "reflect" "regexp" @@ -161,6 +162,39 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { return count > 0 } +func (s mysql) HasTable(tableName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + var name string + if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM %s WHERE Tables_in_%s = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err == sql.ErrNoRows { + return false + } + panic(err) + } else { + return true + } +} + +func (s mysql) HasIndex(tableName string, indexName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + +func (s mysql) HasColumn(tableName string, columnName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + func (s mysql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return