diff --git a/dialect.go b/dialect.go index b6f95df7..749587f4 100644 --- a/dialect.go +++ b/dialect.go @@ -37,7 +37,7 @@ type Dialect interface { ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) string + LimitAndOffsetSQL(limit, offset interface{}) (string, error) // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` diff --git a/dialect_common.go b/dialect_common.go index 16da76dc..950c1986 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -139,14 +139,23 @@ func (s commonDialect) CurrentDatabase() (name string) { return } -func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +// LimitAndOffsetSQL return generated SQL with Limit and Offset +func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + parsedLimit, err := s.parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) } } if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + parsedOffset, err := s.parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } @@ -181,6 +190,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri return indexName, columnName } +func (commonDialect) parseInt(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) +} + // IsByteArrayOrSlice returns true of the reflected value is an array or slice func IsByteArrayOrSlice(value reflect.Value) bool { return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) diff --git a/dialect_mysql.go b/dialect_mysql.go index ab6a8a91..b4467ffa 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" "regexp" - "strconv" "strings" "time" "unicode/utf8" @@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err return err } -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + parsedLimit, err := s.parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + parsedOffset, err := s.parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index eb79f7e7..43acb379 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -168,14 +168,25 @@ func (s mssql) CurrentDatabase() (name string) { return } -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { + parseInt := func(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) + } if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { + parsedOffset, err := parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) } } if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { + parsedLimit, err := parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { if sql == "" { // add default zero offset sql += " OFFSET 0 ROWS" diff --git a/query_test.go b/query_test.go index 15bf8b3c..a23a9e24 100644 --- a/query_test.go +++ b/query_test.go @@ -457,6 +457,74 @@ func TestOffset(t *testing.T) { } } +func TestLimitAndOffsetSQL(t *testing.T) { + user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10} + user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20} + user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30} + user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40} + user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50} + if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + limit, offset interface{} + users []*User + ok bool + }{ + { + name: "OK", + limit: float64(2), + offset: float64(2), + users: []*User{ + &User{Name: "TestLimitAndOffsetSQL3", Age: 30}, + &User{Name: "TestLimitAndOffsetSQL2", Age: 20}, + }, + ok: true, + }, + { + name: "Limit parse error", + limit: float64(1000000), // 1e+06 + offset: float64(2), + ok: false, + }, + { + name: "Offset parse error", + limit: float64(2), + offset: float64(1000000), // 1e+06 + ok: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var users []*User + err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error + if tt.ok { + if err != nil { + t.Errorf("error expected nil, but got %v", err) + } + if len(users) != len(tt.users) { + t.Errorf("users length expected %d, but got %d", len(tt.users), len(users)) + } + for i := range tt.users { + if users[i].Name != tt.users[i].Name { + t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name) + } + if users[i].Age != tt.users[i].Age { + t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age) + } + } + } else { + if err == nil { + t.Error("error expected not nil, but got nil") + } + } + }) + } +} + func TestOr(t *testing.T) { user1 := User{Name: "OrUser1", Age: 1} user2 := User{Name: "OrUser2", Age: 10} diff --git a/scope.go b/scope.go index eb7525b8..0e9dfd1c 100644 --- a/scope.go +++ b/scope.go @@ -797,7 +797,9 @@ func (scope *Scope) orderSQL() string { } func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + scope.Err(err) + return sql } func (scope *Scope) groupSQL() string {