Add limit and offset parse error

This commit is contained in:
zaneli 2019-10-31 02:51:26 +09:00
parent 59408390c2
commit 23f6840776
6 changed files with 113 additions and 12 deletions

View File

@ -37,7 +37,7 @@ type Dialect interface {
ModifyColumn(tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`

View File

@ -139,14 +139,23 @@ func (s commonDialect) CurrentDatabase() (name string) {
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
// LimitAndOffsetSQL return generated SQL with Limit and Offset
func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
parsedLimit, err := s.parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
parsedOffset, err := s.parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
@ -181,6 +190,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri
return indexName, columnName
}
func (commonDialect) parseInt(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))

View File

@ -6,7 +6,6 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
parsedLimit, err := s.parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
parsedOffset, err := s.parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}

View File

@ -168,14 +168,25 @@ func (s mssql) CurrentDatabase() (name string) {
return
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
parseInt := func(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
parsedOffset, err := parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
}
}
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
parsedLimit, err := parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"

View File

@ -457,6 +457,74 @@ func TestOffset(t *testing.T) {
}
}
func TestLimitAndOffsetSQL(t *testing.T) {
user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
t.Fatal(err)
}
tests := []struct {
name string
limit, offset interface{}
users []*User
ok bool
}{
{
name: "OK",
limit: float64(2),
offset: float64(2),
users: []*User{
&User{Name: "TestLimitAndOffsetSQL3", Age: 30},
&User{Name: "TestLimitAndOffsetSQL2", Age: 20},
},
ok: true,
},
{
name: "Limit parse error",
limit: float64(1000000), // 1e+06
offset: float64(2),
ok: false,
},
{
name: "Offset parse error",
limit: float64(2),
offset: float64(1000000), // 1e+06
ok: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var users []*User
err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
if tt.ok {
if err != nil {
t.Errorf("error expected nil, but got %v", err)
}
if len(users) != len(tt.users) {
t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
}
for i := range tt.users {
if users[i].Name != tt.users[i].Name {
t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
}
if users[i].Age != tt.users[i].Age {
t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
}
}
} else {
if err == nil {
t.Error("error expected not nil, but got nil")
}
}
})
}
}
func TestOr(t *testing.T) {
user1 := User{Name: "OrUser1", Age: 1}
user2 := User{Name: "OrUser2", Age: 10}

View File

@ -797,7 +797,9 @@ func (scope *Scope) orderSQL() string {
}
func (scope *Scope) limitAndOffsetSQL() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
scope.Err(err)
return sql
}
func (scope *Scope) groupSQL() string {