Fix check valid db name, close #3315

This commit is contained in:
Jinzhu 2020-08-27 19:15:40 +08:00
parent cd54dddd94
commit d50dbb0896
4 changed files with 20 additions and 6 deletions

View File

@ -93,7 +93,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
}
delete(tx.Statement.Clauses, "SELECT")
case string:
fields := strings.FieldsFunc(v, utils.IsChar)
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar)
// normal field names
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
@ -133,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
} else {
tx.Statement.Omits = columns
}
@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance()
fields := strings.FieldsFunc(name, utils.IsChar)
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
})

View File

@ -362,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx.AddError(ErrModelValueRequired)
}
fields := strings.FieldsFunc(column, utils.IsChar)
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},

View File

@ -29,8 +29,8 @@ func FileWithLineNum() string {
return ""
}
func IsChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*'
func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
}
func CheckTruth(val interface{}) bool {

14
utils/utils_test.go Normal file
View File

@ -0,0 +1,14 @@
package utils
import (
"strings"
"testing"
)
func TestIsValidDBNameChar(t *testing.T) {
for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} {
if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 {
t.Fatalf("failed to parse db name %v", db)
}
}
}