mirror of https://github.com/go-gorm/gorm.git
Add limit and offset parse error
This commit is contained in:
parent
59408390c2
commit
23f6840776
|
@ -37,7 +37,7 @@ type Dialect interface {
|
|||
ModifyColumn(tableName string, columnName string, typ string) error
|
||||
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||
SelectFromDummyTable() string
|
||||
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||
|
|
|
@ -139,14 +139,23 @@ func (s commonDialect) CurrentDatabase() (name string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset
|
||||
func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
parsedLimit, err := s.parseInt(limit)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
}
|
||||
}
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
parsedOffset, err := s.parseInt(offset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
|
@ -181,6 +190,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri
|
|||
return indexName, columnName
|
||||
}
|
||||
|
||||
func (commonDialect) parseInt(value interface{}) (int64, error) {
|
||||
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
|
||||
}
|
||||
|
||||
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
||||
func IsByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err
|
|||
return err
|
||||
}
|
||||
|
||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
parsedLimit, err := s.parseInt(limit)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
parsedOffset, err := s.parseInt(offset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -168,14 +168,25 @@ func (s mssql) CurrentDatabase() (name string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
parseInt := func(value interface{}) (int64, error) {
|
||||
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
|
||||
}
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
parsedOffset, err := parseInt(offset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
|
||||
}
|
||||
}
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
parsedLimit, err := parseInt(limit)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedLimit >= 0 {
|
||||
if sql == "" {
|
||||
// add default zero offset
|
||||
sql += " OFFSET 0 ROWS"
|
||||
|
|
|
@ -457,6 +457,74 @@ func TestOffset(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestLimitAndOffsetSQL(t *testing.T) {
|
||||
user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
|
||||
user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
|
||||
user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
|
||||
user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
|
||||
user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
|
||||
if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
limit, offset interface{}
|
||||
users []*User
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "OK",
|
||||
limit: float64(2),
|
||||
offset: float64(2),
|
||||
users: []*User{
|
||||
&User{Name: "TestLimitAndOffsetSQL3", Age: 30},
|
||||
&User{Name: "TestLimitAndOffsetSQL2", Age: 20},
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "Limit parse error",
|
||||
limit: float64(1000000), // 1e+06
|
||||
offset: float64(2),
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "Offset parse error",
|
||||
limit: float64(2),
|
||||
offset: float64(1000000), // 1e+06
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var users []*User
|
||||
err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
|
||||
if tt.ok {
|
||||
if err != nil {
|
||||
t.Errorf("error expected nil, but got %v", err)
|
||||
}
|
||||
if len(users) != len(tt.users) {
|
||||
t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
|
||||
}
|
||||
for i := range tt.users {
|
||||
if users[i].Name != tt.users[i].Name {
|
||||
t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
|
||||
}
|
||||
if users[i].Age != tt.users[i].Age {
|
||||
t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Error("error expected not nil, but got nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOr(t *testing.T) {
|
||||
user1 := User{Name: "OrUser1", Age: 1}
|
||||
user2 := User{Name: "OrUser2", Age: 10}
|
||||
|
|
4
scope.go
4
scope.go
|
@ -797,7 +797,9 @@ func (scope *Scope) orderSQL() string {
|
|||
}
|
||||
|
||||
func (scope *Scope) limitAndOffsetSQL() string {
|
||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
scope.Err(err)
|
||||
return sql
|
||||
}
|
||||
|
||||
func (scope *Scope) groupSQL() string {
|
||||
|
|
Loading…
Reference in New Issue