diff --git a/callback_row_query.go b/callback_row_query.go index c2ff4a08..687b0039 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "database/sql" + "fmt" +) // Define callbacks for row query func init() { @@ -20,6 +23,9 @@ type RowsQueryResult struct { func rowQueryCallback(scope *Scope) { if result, ok := scope.InstanceGet("row_query_result"); ok { scope.prepareQuerySQL() + if str, ok := scope.Get("gorm:query_option"); ok { + scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) + } if rowResult, ok := result.(*RowQueryResult); ok { rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) diff --git a/main_test.go b/main_test.go index 1dc30093..a0d95369 100644 --- a/main_test.go +++ b/main_test.go @@ -1110,6 +1110,29 @@ func TestCountWithHaving(t *testing.T) { } } +func TestCountWithQueryOption(t *testing.T) { + db := DB.New() + db.Delete(User{}) + defer db.Delete(User{}) + + DB.Create(&User{Name: "user1"}) + DB.Create(&User{Name: "user2"}) + DB.Create(&User{Name: "user3"}) + + var count int + err := db.Model(User{}).Select("users.id"). + Set("gorm:query_option", "WHERE users.name='user2'"). + Count(&count).Error + + if err != nil { + t.Error("Unexpected error on query count with query_option") + } + + if count != 1 { + t.Error("Unexpected result on query count with query_option") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ {