diff --git a/chainable_api.go b/chainable_api.go index e1b73457..c8417a6d 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,7 +93,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsChar) + fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { @@ -133,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) } else { tx.Statement.Omits = columns } @@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsChar) + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/finisher_api.go b/finisher_api.go index cf46f78a..2cde3c31 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -362,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsChar) + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/utils/utils.go b/utils/utils.go index e93f3055..71336f4b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -29,8 +29,8 @@ func FileWithLineNum() string { return "" } -func IsChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' } func CheckTruth(val interface{}) bool { diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..5737c511 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,14 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestIsValidDBNameChar(t *testing.T) { + for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { + if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { + t.Fatalf("failed to parse db name %v", db) + } + } +}