From 4e8d43dd4f7b69cdb5b62b5d5c70cf152e649f08 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 28 Feb 2015 17:01:27 +0800 Subject: [PATCH] Refactor check HasTable, HasColumn --- common_dialect.go | 16 +++------------- main_test.go | 2 +- migration_test.go | 2 +- mssql.go | 16 +++------------- mysql.go | 16 +++------------- postgres.go | 13 +++---------- scope_private.go | 4 +++- sqlite3.go | 6 +++--- 8 files changed, 20 insertions(+), 55 deletions(-) diff --git a/common_dialect.go b/common_dialect.go index e1a54c29..d8910dcd 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -86,26 +86,16 @@ func (s *commonDialect) databaseName(scope *Scope) string { func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", - newScope.AddToVars(tableName), - newScope.AddToVars(s.databaseName(scope)))) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) return count > 0 } func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_schema = %v AND table_name = %v AND column_name = %v", - newScope.AddToVars(s.databaseName(scope)), - newScope.AddToVars(tableName), - newScope.AddToVars(columnName), - )) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec() + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/main_test.go b/main_test.go index 1fa3b248..6ffbbf8d 100644 --- a/main_test.go +++ b/main_test.go @@ -35,7 +35,7 @@ func init() { DB, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True") case "postgres": fmt.Println("testing postgres...") - DB, err = gorm.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") + DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") case "mssql": fmt.Println("testing mssql...") DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") diff --git a/migration_test.go b/migration_test.go index 1a7ae6f2..0a2f0852 100644 --- a/migration_test.go +++ b/migration_test.go @@ -7,7 +7,7 @@ import ( ) func runMigration() { - if err := DB.DropTable(&User{}).Error; err != nil { + if err := DB.DropTableIfExists(&User{}).Error; err != nil { fmt.Printf("Got error when try to delete table users, %+v\n", err) } diff --git a/mssql.go b/mssql.go index 30f9b824..720dc615 100644 --- a/mssql.go +++ b/mssql.go @@ -88,26 +88,16 @@ func (s *mssql) databaseName(scope *Scope) string { func (s *mssql) HasTable(scope *Scope, tableName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v", - newScope.AddToVars(tableName), - newScope.AddToVars(s.databaseName(scope)))) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count) return count > 0 } func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE TABLE_CATALOG = %v AND table_name = %v AND column_name = %v", - newScope.AddToVars(s.databaseName(scope)), - newScope.AddToVars(tableName), - newScope.AddToVars(columnName), - )) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } func (s *mssql) RemoveIndex(scope *Scope, indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec() + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/mysql.go b/mysql.go index ec575336..3c077b4b 100644 --- a/mysql.go +++ b/mysql.go @@ -86,26 +86,16 @@ func (s *mysql) databaseName(scope *Scope) string { func (s *mysql) HasTable(scope *Scope, tableName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", - newScope.AddToVars(tableName), - newScope.AddToVars(s.databaseName(scope)))) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) return count > 0 } func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_schema = %v AND table_name = %v AND column_name = %v", - newScope.AddToVars(s.databaseName(scope)), - newScope.AddToVars(tableName), - newScope.AddToVars(columnName), - )) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } func (s *mysql) RemoveIndex(scope *Scope, indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec() + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/postgres.go b/postgres.go index 1ffbe4aa..994ed9b8 100644 --- a/postgres.go +++ b/postgres.go @@ -81,25 +81,18 @@ func (s *postgres) Quote(key string) string { func (s *postgres) HasTable(scope *Scope, tableName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v and table_type = 'BASE TABLE'", newScope.AddToVars(tableName))) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count) return count > 0 } func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_name = %v AND column_name = %v", - newScope.AddToVars(tableName), - newScope.AddToVars(columnName), - )) - newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) return count > 0 } func (s *postgres) RemoveIndex(scope *Scope, indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName))).Exec() + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } var hstoreType = reflect.TypeOf(Hstore{}) diff --git a/scope_private.go b/scope_private.go index cb5c0f5e..433e364a 100644 --- a/scope_private.go +++ b/scope_private.go @@ -471,7 +471,9 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) dropTableIfExists() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE IF EXISTS %v", scope.QuotedTableName())).Exec() + if scope.Dialect().HasTable(scope, scope.TableName()) { + scope.dropTable() + } return scope } diff --git a/sqlite3.go b/sqlite3.go index 7eb071d0..2ff10790 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -70,16 +70,16 @@ func (s *sqlite3) Quote(key string) string { func (s *sqlite3) HasTable(scope *Scope, tableName string) bool { var count int - scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count) return count > 0 } func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count) + scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count) return count > 0 } func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v", indexName)).Exec() + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) }