diff --git a/callback_create.go b/callback_create.go index 6138e9d0..ea7ddaee 100644 --- a/callback_create.go +++ b/callback_create.go @@ -27,7 +27,7 @@ func Create(scope *Scope) { for _, field := range scope.Fields() { if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored { - columns = append(columns, scope.quote(field.DBName)) + columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Value)) } } diff --git a/callback_update.go b/callback_update.go index 09c12508..1b3ceda2 100644 --- a/callback_update.go +++ b/callback_update.go @@ -49,12 +49,12 @@ func Update(scope *Scope) { updateAttrs, ok := scope.Get("gorm:update_attrs") if ok { for key, value := range updateAttrs.(map[string]interface{}) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) } } else { for _, field := range scope.Fields() { if field.DBName != scope.PrimaryKey() && len(field.SqlTag) > 0 && !field.IsIgnored { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.quote(field.DBName), scope.AddToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value))) } } } diff --git a/scope.go b/scope.go index 1b8a0e19..d73ddb10 100644 --- a/scope.go +++ b/scope.go @@ -460,7 +460,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if foreignValue, ok := scope.FieldByName(foreignKey); ok { return toScope.inlineCondition(foreignValue).callCallbacks(scope.db.parent.callback.queries) } else if toScope.HasColumn(foreignKey) { - sql := fmt.Sprintf("%v = ?", scope.quote(toSnake(foreignKey))) + sql := fmt.Sprintf("%v = ?", scope.Quote(toSnake(foreignKey))) return toScope.inlineCondition(sql, scope.PrimaryKeyValue()).callCallbacks(scope.db.parent.callback.queries) } } diff --git a/scope_condition.go b/scope_condition.go index b5e99b4e..fd820c8f 100644 --- a/scope_condition.go +++ b/scope_condition.go @@ -10,12 +10,12 @@ import ( "strings" ) -func (scope *Scope) quote(str string) string { +func (scope *Scope) Quote(str string) string { return scope.Dialect().Quote(str) } func (scope *Scope) primaryCondiation(value interface{}) string { - return fmt.Sprintf("(%v = %v)", scope.quote(scope.PrimaryKey()), value) + return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) } func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { @@ -33,19 +33,19 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case sql.NullInt64: return scope.primaryCondiation(scope.AddToVars(value.Int64)) case []int64, []int, []int32, []string: - str = fmt.Sprintf("(%v in (?))", scope.quote(scope.PrimaryKey())) + str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) } return strings.Join(sqls, " AND ") case interface{}: var sqls []string for _, field := range scope.New(value).Fields() { if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.quote(field.DBName), scope.AddToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value))) } } return strings.Join(sqls, " AND ") @@ -79,19 +79,19 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string case string: if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.quote(scope.PrimaryKey()), id) + return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id) } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) not_equal_sql = fmt.Sprintf("NOT (%v)", value) } else { - str = fmt.Sprintf("(%v NOT IN (?))", scope.quote(value)) - not_equal_sql = fmt.Sprintf("(%v <> ?)", scope.quote(value)) + str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) + not_equal_sql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) } case int, int64, int32: - return fmt.Sprintf("(%v <> %v)", scope.quote(scope.PrimaryKey()), value) + return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value) case []int64, []int, []int32, []string: if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v not in (?))", scope.quote(scope.PrimaryKey())) + str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey())) clause["args"] = []interface{}{value} } else { return "" @@ -99,14 +99,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) } return strings.Join(sqls, " AND ") case interface{}: var sqls []string for _, field := range scope.New(value).Fields() { if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.quote(field.DBName), scope.AddToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value))) } } return strings.Join(sqls, " AND ") diff --git a/scope_database.go b/scope_database.go index 99bc197e..3f17c0ff 100644 --- a/scope_database.go +++ b/scope_database.go @@ -9,7 +9,7 @@ func (scope *Scope) createTable() *Scope { var sqls []string for _, field := range scope.Fields() { if !field.IsIgnored && len(field.SqlTag) > 0 { - sqls = append(sqls, scope.quote(field.DBName)+" "+field.SqlTag) + sqls = append(sqls, scope.Quote(field.DBName)+" "+field.SqlTag) } } scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.TableName(), strings.Join(sqls, ","))).Exec() @@ -22,11 +22,11 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.quote(column), typ)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.TableName(), scope.Quote(column), typ)).Exec() } func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.quote(column))).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.TableName(), scope.Quote(column))).Exec() } func (scope *Scope) addIndex(column string, names ...string) { @@ -37,7 +37,7 @@ func (scope *Scope) addIndex(column string, names ...string) { indexName = fmt.Sprintf("index_%v_on_%v", scope.TableName(), column) } - scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.quote(column))).Exec() + scope.Raw(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", indexName, scope.TableName(), scope.Quote(column))).Exec() } func (scope *Scope) removeIndex(indexName string) {