diff --git a/scope.go b/scope.go index 484164ad..097c2243 100644 --- a/scope.go +++ b/scope.go @@ -447,7 +447,12 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { } } -var columnRegexp = regexp.MustCompile("^[a-zA-Z]+(\\.[a-zA-Z]+)*$") // only match string like `name`, `users.name` +var ( + columnRegexp = regexp.MustCompile("^[a-zA-Z]+(\\.[a-zA-Z]+)*$") // only match string like `name`, `users.name` + isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number + comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") + countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") +) func (scope *Scope) quoteIfPossible(str string) string { if columnRegexp.MatchString(str) { @@ -509,8 +514,7 @@ func (scope *Scope) primaryCondition(value interface{}) string { func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { switch value := clause["query"].(type) { case string: - // if string is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + if isNumberRegexp.MatchString(value) { return scope.primaryCondition(scope.AddToVars(value)) } else if value != "" { str = fmt.Sprintf("(%v)", value) @@ -573,11 +577,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string switch value := clause["query"].(type) { case string: - // is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { + if isNumberRegexp.MatchString(value) { id, _ := strconv.Atoi(value) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { + } else if comparisonRegexp.MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) notEqualSQL = fmt.Sprintf("NOT (%v)", value) } else { @@ -924,7 +927,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { } func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !regexp.MustCompile("(?i)^count(.+)$").MatchString(fmt.Sprint(query)) { + if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { scope.Search.Select("count(*)") } scope.Search.ignoreOrderQuery = true diff --git a/utils.go b/utils.go index 8f3d0f38..bf1e5666 100644 --- a/utils.go +++ b/utils.go @@ -26,6 +26,9 @@ var NowFunc = func() time.Time { var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) + func init() { var commonInitialismsForReplacer []string for _, initialism := range commonInitialisms { @@ -171,7 +174,7 @@ func toQueryValues(values [][]interface{}) (results []interface{}) { func fileWithLineNum() string { for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { + if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { return fmt.Sprintf("%v:%v", file, line) } }