From 0cba662be09665c353f09f368d75acfd947d23aa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 3 Jun 2014 17:15:05 +0800 Subject: [PATCH] Add method QuotedTableName for Scope --- callback_create.go | 4 ++-- callback_delete.go | 4 ++-- callback_update.go | 2 +- scope.go | 8 ++++++++ scope_private.go | 14 +++++++------- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/callback_create.go b/callback_create.go index 2f641b3d..38706cf3 100644 --- a/callback_create.go +++ b/callback_create.go @@ -35,13 +35,13 @@ func Create(scope *Scope) { if len(columns) == 0 { scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", - scope.TableName(), + scope.QuotedTableName(), scope.Dialect().ReturningStr(scope.PrimaryKey()), )) } else { scope.Raw(fmt.Sprintf( "INSERT INTO %v (%v) VALUES (%v) %v", - scope.TableName(), + scope.QuotedTableName(), strings.Join(columns, ","), strings.Join(sqls, ","), scope.Dialect().ReturningStr(scope.PrimaryKey()), diff --git a/callback_delete.go b/callback_delete.go index 3bf5b7b0..32d88630 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -14,12 +14,12 @@ func Delete(scope *Scope) { if !scope.Search.Unscope && scope.HasColumn("DeletedAt") { scope.Raw( fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", - scope.TableName(), + scope.QuotedTableName(), scope.AddToVars(time.Now()), scope.CombinedConditionSql(), )) } else { - scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql())) } scope.Exec() diff --git a/callback_update.go b/callback_update.go index 095bbf48..b8178ab5 100644 --- a/callback_update.go +++ b/callback_update.go @@ -59,7 +59,7 @@ func Update(scope *Scope) { scope.Raw(fmt.Sprintf( "UPDATE %v SET %v %v", - scope.TableName(), + scope.QuotedTableName(), strings.Join(sqls, ", "), scope.CombinedConditionSql(), )) diff --git a/scope.go b/scope.go index 2a567058..fbc2ad5a 100644 --- a/scope.go +++ b/scope.go @@ -214,6 +214,14 @@ func (scope *Scope) TableName() string { } } +func (scope *Scope) QuotedTableName() string { + if scope.Search != nil && len(scope.Search.TableName) > 0 { + return scope.Search.TableName + } else { + return scope.Quote(scope.TableName()) + } +} + // CombinedConditionSql get combined condition sql func (scope *Scope) CombinedConditionSql() string { return scope.joinsSql() + scope.whereSql() + scope.groupSql() + diff --git a/scope_private.go b/scope_private.go index 232e6150..4db4ebbe 100644 --- a/scope_private.go +++ b/scope_private.go @@ -237,7 +237,7 @@ func (scope *Scope) prepareQuerySql() { if scope.Search.Raw { scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE ")) } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.TableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) } return } @@ -414,21 +414,21 @@ func (scope *Scope) createTable() *Scope { sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) } } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec() return scope } func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.TableName())).Exec() + scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() return scope } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() } func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() } func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { @@ -446,7 +446,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) removeIndex(indexName string) { - scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.TableName())).Exec() + scope.Raw(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Exec() } func (scope *Scope) autoMigrate() *Scope { @@ -456,7 +456,7 @@ func (scope *Scope) autoMigrate() *Scope { for _, field := range scope.Fields() { if !scope.Dialect().HasColumn(scope, scope.TableName(), field.DBName) { if len(field.SqlTag) > 0 && !field.IsIgnored { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.TableName(), field.DBName, field.SqlTag)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", scope.QuotedTableName(), field.DBName, field.SqlTag)).Exec() } } }