diff --git a/query_test.go b/query_test.go index 15bf8b3c..2b7e0dff 100644 --- a/query_test.go +++ b/query_test.go @@ -133,6 +133,23 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } +func TestStringAgainstIncompleteParentheses(t *testing.T) { + type AddressByZipCode struct { + ZipCode string `gorm:"primary_key"` + Address string + } + + DB.AutoMigrate(&AddressByZipCode{}) + DB.Create(&AddressByZipCode{ZipCode: "00502", Address: "Holtsville"}) + + var address AddressByZipCode + var addresses []AddressByZipCode + _ = DB.First(&address, "address_by_zip_codes=00502)) UNION ALL SELECT NULL,version(),current_database(),NULL,NULL,NULL,NULL,NULL--").Find(&addresses).GetErrors() + if len(addresses) > 0 { + t.Errorf("Fetch a record from with a string that has incomplete parentheses should be fail, zip code is %v", address.ZipCode) + } + +} func TestFindAsSliceOfPointers(t *testing.T) { DB.Save(&User{Name: "user"}) diff --git a/scope.go b/scope.go index c962c165..541fe522 100644 --- a/scope.go +++ b/scope.go @@ -277,6 +277,23 @@ func (scope *Scope) AddToVars(value interface{}) string { return scope.Dialect().BindVar(len(scope.SQLVars)) } +// IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection +func (scope *Scope) IsCompleteParentheses(value string) bool { + count := 0 + for i, _ := range value { + if value[i] == 40 { // ( + count++ + } else if value[i] == 41 { // ) + count-- + } + if count < 0 { + break + } + i++ + } + return count == 0 +} + // SelectAttrs return selected attributes func (scope *Scope) SelectAttrs() []string { if scope.selectAttrs == nil { @@ -556,6 +573,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) } if value != "" { + if !scope.IsCompleteParentheses(value) { + scope.Err(fmt.Errorf("incomplete parentheses found: %v", value)) + return + } if !include { if comparisonRegexp.MatchString(value) { str = fmt.Sprintf("NOT (%v)", value)