diff --git a/mysql.go b/mysql.go index 3fc0487a..366a6775 100644 --- a/mysql.go +++ b/mysql.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "strings" "reflect" ) @@ -68,10 +69,21 @@ func (s *mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } +func (s *mysql) databaseName(scope *Scope) string { + from := strings.Index(scope.db.parent.source, "/") + 1 + to := strings.Index(scope.db.parent.source, "?") + if to == -1 { + to = len(scope.db.parent.source) + } + return scope.db.parent.source[from:to] +} + func (s *mysql) HasTable(scope *Scope, tableName string) bool { var count int newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v", newScope.AddToVars(tableName))) + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v", + newScope.AddToVars(tableName), + newScope.AddToVars(s.databaseName(scope)))) newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count) return count > 0 } @@ -79,7 +91,8 @@ func (s *mysql) HasTable(scope *Scope, tableName string) bool { func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int newScope := scope.New(nil) - newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_name = %v AND column_name = %v", + newScope.Raw(fmt.Sprintf("SELECT count(*) FROM information_schema.columns WHERE table_schema = %v AND table_name = %v AND column_name = %v", + newScope.AddToVars(s.databaseName(scope)), newScope.AddToVars(tableName), newScope.AddToVars(columnName), ))