mirror of https://github.com/go-gorm/gorm.git
Fix check valid db name, close #3315
This commit is contained in:
parent
cd54dddd94
commit
d50dbb0896
|
@ -93,7 +93,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||
}
|
||||
delete(tx.Statement.Clauses, "SELECT")
|
||||
case string:
|
||||
fields := strings.FieldsFunc(v, utils.IsChar)
|
||||
fields := strings.FieldsFunc(v, utils.IsValidDBNameChar)
|
||||
|
||||
// normal field names
|
||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||
|
@ -133,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
|||
tx = db.getInstance()
|
||||
|
||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
|
||||
} else {
|
||||
tx.Statement.Omits = columns
|
||||
}
|
||||
|
@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
|||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
fields := strings.FieldsFunc(name, utils.IsChar)
|
||||
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
|
||||
})
|
||||
|
|
|
@ -362,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
|||
tx.AddError(ErrModelValueRequired)
|
||||
}
|
||||
|
||||
fields := strings.FieldsFunc(column, utils.IsChar)
|
||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||
|
|
|
@ -29,8 +29,8 @@ func FileWithLineNum() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func IsChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*'
|
||||
func IsValidDBNameChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$'
|
||||
}
|
||||
|
||||
func CheckTruth(val interface{}) bool {
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsValidDBNameChar(t *testing.T) {
|
||||
for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} {
|
||||
if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 {
|
||||
t.Fatalf("failed to parse db name %v", db)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue