diff --git a/dialect.go b/dialect.go index acec284a..d1176039 100644 --- a/dialect.go +++ b/dialect.go @@ -14,6 +14,8 @@ type Dialect interface { PrimaryKeyTag(value reflect.Value, size int) string ReturningStr(key string) string Quote(key string) string + HasTable(scope *Scope, tableName string) bool + HasColumn(scope *Scope, tableName string, columnName string) bool } func NewDialect(driver string) Dialect { diff --git a/mysql.go b/mysql.go index 6a5ed6ef..3fc0487a 100644 --- a/mysql.go +++ b/mysql.go @@ -68,10 +68,21 @@ func (s *mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } -func (s *mysql) HasTable(tableName string) bool { - return true +func (s *mysql) HasTable(scope *Scope, tableName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v", newScope.AddToVars(tableName))) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 } -func (s *mysql) HasColumn(tableName string, columnName string) bool { - return true +func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_name = %v AND column_name = %v", + newScope.AddToVars(tableName), + newScope.AddToVars(columnName), + )) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 } diff --git a/postgres.go b/postgres.go index 7c5c61e4..e54a8680 100644 --- a/postgres.go +++ b/postgres.go @@ -61,3 +61,22 @@ func (s *postgres) ReturningStr(key string) string { func (s *postgres) Quote(key string) string { return fmt.Sprintf("\"%s\"", key) } + +func (s *postgres) HasTable(scope *Scope, tableName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v", newScope.AddToVars(tableName))) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} + +func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_name = %v AND column_name = %v", + newScope.AddToVars(tableName), + newScope.AddToVars(columnName), + )) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} diff --git a/scope_private.go b/scope_private.go index ee57316f..67f08bda 100644 --- a/scope_private.go +++ b/scope_private.go @@ -447,28 +447,14 @@ func (scope *Scope) removeIndex(indexName string) { } func (scope *Scope) autoMigrate() *Scope { - var tableName string - scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", - scope.AddToVars(scope.TableName()))) - scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName) - scope.SqlVars = []interface{}{} - - // If table doesn't exist - if len(tableName) == 0 { + if !scope.Dialect().HasTable(scope, scope.TableName()) { scope.createTable() } else { for _, field := range scope.Fields() { - var column, data string - scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v AND column_name = %v", - scope.AddToVars(scope.TableName()), - scope.AddToVars(field.DBName), - )) - scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&column, &data) - scope.SqlVars = []interface{}{} - - // If column doesn't exist - if len(column) == 0 && len(field.SqlTag) > 0 && !field.IsIgnored { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() + if !scope.Dialect().HasColumn(scope, scope.TableName(), field.DBName) { + if len(field.SqlTag) > 0 && !field.IsIgnored { + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() + } } } } diff --git a/sqlite3.go b/sqlite3.go index cbeac9a3..1b2fadc6 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -52,10 +52,29 @@ func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *sqlite3) ReturningStr(key string) (str string) { - return +func (s *sqlite3) ReturningStr(key string) string { + return "" } -func (s *sqlite3) Quote(key string) (str string) { +func (s *sqlite3) Quote(key string) string { return fmt.Sprintf("\"%s\"", key) } + +func (s *sqlite3) HasTable(scope *Scope, tableName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v", newScope.AddToVars(tableName))) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +} + +func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + newScope := scope.New(nil) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_name = %v AND column_name = %v", + newScope.AddToVars(tableName), + newScope.AddToVars(columnName), + )) + newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) + return count > 0 +}