From cf831bd67e2f7f15abe856f8fae8067e9c805a9a Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Fri, 8 Nov 2024 21:10:22 -0500 Subject: [PATCH] test: add Exec tests and benchmarks --- sqlite3_test.go | 133 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d..ee5fd58 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -1090,6 +1091,67 @@ func TestExecer(t *testing.T) { } } +func TestExecDriverResult(t *testing.T) { + setup := func(t *testing.T) *sql.DB { + db, err := sql.Open("sqlite3", t.TempDir()+"/test.sqlite3") + if err != nil { + t.Fatal("Failed to open database:", err) + } + if _, err := db.Exec(`CREATE TABLE foo (id INTEGER PRIMARY KEY);`); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + return db + } + + test := func(t *testing.T, execStmt string, args ...any) { + db := setup(t) + res, err := db.Exec(execStmt, args...) + if err != nil { + t.Fatal(err) + } + rows, err := res.RowsAffected() + if err != nil { + t.Fatal(err) + } + // We only return the changes from the last statement. + if rows != 1 { + t.Errorf("RowsAffected got: %d want: %d", rows, 1) + } + id, err := res.LastInsertId() + if err != nil { + t.Fatal(err) + } + if id != 3 { + t.Errorf("LastInsertId got: %d want: %d", id, 3) + } + var count int64 + err = db.QueryRow(`SELECT COUNT(*) FROM foo WHERE id IN (1, 2, 3);`).Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 3 { + t.Errorf("Expected count to be %d got: %d", 3, count) + } + } + + t.Run("NoArgs", func(t *testing.T) { + const stmt = ` + INSERT INTO foo(id) VALUES(1); + INSERT INTO foo(id) VALUES(2); + INSERT INTO foo(id) VALUES(3);` + test(t, stmt) + }) + + t.Run("WithArgs", func(t *testing.T) { + const stmt = ` + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?);` + test(t, stmt, 1, 2, 3) + }) +} + func TestQueryer(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2106,6 +2168,10 @@ var tests = []testing.InternalTest{ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExec", F: benchmarkExec}, + {Name: "BenchmarkExecContext", F: benchmarkExecContext}, + {Name: "BenchmarkExecStep", F: benchmarkExecStep}, + {Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep}, + {Name: "BenchmarkExecTx", F: benchmarkExecTx}, {Name: "BenchmarkQuery", F: benchmarkQuery}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, @@ -2458,10 +2524,75 @@ func testExecEmptyQuery(t *testing.T) { // benchmarkExec is benchmark for exec func benchmarkExec(b *testing.B) { + b.Run("Params", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select ?;", int64(1)); err != nil { + panic(err) + } + } + }) + b.Run("NoParams", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select 1;"); err != nil { + panic(err) + } + } + }) +} + +func benchmarkExecContext(b *testing.B) { + b.Run("Params", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select ?;", int64(1)); err != nil { + panic(err) + } + } + }) + b.Run("NoParams", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select 1;"); err != nil { + panic(err) + } + } + }) +} + +func benchmarkExecTx(b *testing.B) { for i := 0; i < b.N; i++ { - if _, err := db.Exec("select 1"); err != nil { + tx, err := db.Begin() + if err != nil { panic(err) } + if _, err := tx.Exec("select 1;"); err != nil { + panic(err) + } + if err := tx.Commit(); err != nil { + panic(err) + } + } +} + +var largeSelectStmt = strings.Repeat("select 1;\n", 1_000) + +func benchmarkExecStep(b *testing.B) { + for n := 0; n < b.N; n++ { + if _, err := db.Exec(largeSelectStmt); err != nil { + b.Fatal(err) + } + } +} + +func benchmarkExecContextStep(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for n := 0; n < b.N; n++ { + if _, err := db.ExecContext(ctx, largeSelectStmt); err != nil { + b.Fatal(err) + } } }