From 1ff3c79c75a1dd4bf7a1734a7f9fe10e1bd24169 Mon Sep 17 00:00:00 2001 From: Felix Sun Date: Fri, 25 Apr 2014 00:38:40 +0800 Subject: [PATCH] When multiple databases have the same table name and column names, automigrate doesn't work --- main.go | 5 +++-- main_private.go | 2 +- scope_private.go | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index 8fca8612..e985c5c6 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package gorm import ( "database/sql" - "github.com/jinzhu/gorm/dialect" ) @@ -18,11 +17,13 @@ type DB struct { dialect dialect.Dialect tagIdentifier string singularTable bool + source string } func Open(driver, source string) (DB, error) { var err error - db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback} + + db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source} db.db, err = sql.Open(driver, source) db.parent = &db return db, err diff --git a/main_private.go b/main_private.go index fe198c1d..9c096572 100644 --- a/main_private.go +++ b/main_private.go @@ -6,7 +6,7 @@ import ( ) func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, source: s.source} if s.search == nil { db.search = &search{} diff --git a/scope_private.go b/scope_private.go index d9f16438..de739ffa 100644 --- a/scope_private.go +++ b/scope_private.go @@ -447,8 +447,18 @@ func (scope *Scope) removeIndex(indexName string) { } func (scope *Scope) autoMigrate() *Scope { + // scope.db.source sample: root:@/testdatabase?parseTime=true + from := strings.Index(scope.db.source, "/") + to := strings.Index(scope.db.source, "?") + if to == -1 { + to = len(scope.db.source) + } + databaseName := scope.db.source[from:to] + var tableName string - scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_name = %v", scope.AddToVars(scope.TableName()))) + scope.Raw(fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.tables where table_schema = %v AND table_name = %v", + scope.AddToVars(databaseName), + scope.AddToVars(scope.TableName()))) scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&tableName) scope.SqlVars = []interface{}{} @@ -458,7 +468,8 @@ func (scope *Scope) autoMigrate() *Scope { } else { for _, field := range scope.Fields() { var column, data string - scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v and column_name = %v", + scope.Raw(fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %v AND table_name = %v AND column_name = %v", + scope.AddToVars(databaseName), scope.AddToVars(scope.TableName()), scope.AddToVars(field.DBName), ))