diff --git a/sqlite3.go b/sqlite3.go index bff6b7c..076c9bd 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -811,6 +811,7 @@ func (s *SQLiteStmt) Close() error { return errors.New("sqlite statement with already closed database connection") } rv := C.sqlite3_finalize(s.s) + s.s = nil if rv != C.SQLITE_OK { return s.c.lastError() } diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index f076b81..a5f4aae 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -8,9 +8,13 @@ package sqlite3 import ( + "context" "database/sql" + "fmt" + "math/rand" "os" "testing" + "time" ) func TestNamedParams(t *testing.T) { @@ -48,3 +52,91 @@ func TestNamedParams(t *testing.T) { t.Error("Failed to db.QueryRow: not matched results") } } + +var ( + testTableStatements = []string{ + `DROP TABLE IF EXISTS test_table`, + ` +CREATE TABLE IF NOT EXISTS test_table ( + key1 VARCHAR(64) PRIMARY KEY, + key_id VARCHAR(64) NOT NULL, + key2 VARCHAR(64) NOT NULL, + key3 VARCHAR(64) NOT NULL, + key4 VARCHAR(64) NOT NULL, + key5 VARCHAR(64) NOT NULL, + key6 VARCHAR(64) NOT NULL, + data BLOB NOT NULL +);`, + } + letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +) + +func randStringBytes(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} + +func initDatabase(t *testing.T, db *sql.DB, rowCount int64) { + t.Logf("Executing db initializing statements") + for _, query := range testTableStatements { + _, err := db.Exec(query) + if err != nil { + t.Fatal(err) + } + } + for i := int64(0); i < rowCount; i++ { + query := `INSERT INTO test_table + (key1, key_id, key2, key3, key4, key5, key6, data) + VALUES + (?, ?, ?, ?, ?, ?, ?, ?);` + args := []interface{}{ + randStringBytes(50), + fmt.Sprint(i), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(2048), + } + _, err := db.Exec(query, args...) + if err != nil { + t.Fatal(err) + } + } +} + +func TestShortTimeout(t *testing.T) { + db, err := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared") + if err != nil { + t.Fatal(err) + } + defer db.Close() + initDatabase(t, db, 10000) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond) + defer cancel() + query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data + FROM test_table + ORDER BY key2 ASC` + rows, err := db.QueryContext(ctx, query) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var key1, keyid, key2, key3, key4, key5, key6 string + var data []byte + err = rows.Scan(&key1, &keyid, &key2, &key3, &key4, &key5, &key6, &data) + if err != nil { + break + } + } + if context.DeadlineExceeded != ctx.Err() { + t.Fatal(ctx.Err()) + } +}