diff --git a/sqlite3_go18.go b/sqlite3_go18.go index 43e6418..82b7fea 100644 --- a/sqlite3_go18.go +++ b/sqlite3_go18.go @@ -10,7 +10,6 @@ package sqlite3 import ( "database/sql/driver" - "errors" "context" ) @@ -18,7 +17,8 @@ import ( // Ping implement Pinger. func (c *SQLiteConn) Ping(ctx context.Context) error { if c.db == nil { - return errors.New("Connection was closed") + // must be ErrBadConn for sql to close the database + return driver.ErrBadConn } return nil } diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index 741ed90..c9e79e7 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -11,6 +11,7 @@ import ( "context" "database/sql" "fmt" + "io/ioutil" "math/rand" "os" "testing" @@ -154,3 +155,217 @@ func TestExecCancel(t *testing.T) { } } } + +func doTestOpenContext(t *testing.T, option string) (string, error) { + tempFilename := TempFilename(t) + url := tempFilename + option + + defer func() { + err := os.Remove(tempFilename) + if err != nil { + t.Error("temp file remove error:", err) + } + }() + + db, err := sql.Open("sqlite3", url) + if err != nil { + return "Failed to open database:", err + } + + defer func() { + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 55*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + return "ping error:", err + } + + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + return "Failed to create table:", err + } + + if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() { + return "Failed to create ./foo.db", nil + } + + return "", nil +} + +func TestOpenContext(t *testing.T) { + cases := map[string]bool{ + "": true, + "?_txlock=immediate": true, + "?_txlock=deferred": true, + "?_txlock=exclusive": true, + "?_txlock=bogus": false, + } + for option, expectedPass := range cases { + result, err := doTestOpenContext(t, option) + if result == "" { + if !expectedPass { + errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option) + t.Fatal(errmsg) + } + } else if expectedPass { + if err == nil { + t.Fatal(result) + } else { + t.Fatal(result, err) + } + } + } +} + +func TestFileCopyTruncate(t *testing.T) { + var err error + tempFilename := TempFilename(t) + + defer func() { + err = os.Remove(tempFilename) + if err != nil { + t.Error("temp file remove error:", err) + } + }() + + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("open error:", err) + } + + defer func() { + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 55*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + t.Fatal("ping error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + t.Fatal("create table error:", err) + } + + // copy db to new file + var data []byte + data, err = ioutil.ReadFile(tempFilename) + if err != nil { + t.Fatal("read file error:", err) + } + + var f *os.File + f, err = os.Create(tempFilename + "-db-copy") + if err != nil { + t.Fatal("create file error:", err) + } + + defer func() { + err = os.Remove(tempFilename + "-db-copy") + if err != nil { + t.Error("temp file moved remove error:", err) + } + }() + + _, err = f.Write(data) + if err != nil { + f.Close() + t.Fatal("write file error:", err) + } + err = f.Close() + if err != nil { + t.Fatal("close file error:", err) + } + + // truncate current db file + f, err = os.OpenFile(tempFilename, os.O_WRONLY|os.O_TRUNC, 0666) + if err != nil { + t.Fatal("open file error:", err) + } + err = f.Close() + if err != nil { + t.Fatal("close file error:", err) + } + + // test db after file truncate + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + t.Fatal("ping error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + if err == nil { + t.Fatal("drop table no error") + } + + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + t.Fatal("create table error:", err) + } + + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + + // test copied file + db, err = sql.Open("sqlite3", tempFilename+"-db-copy") + if err != nil { + t.Fatal("open error:", err) + } + + defer func() { + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + }() + + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + err = db.PingContext(ctx) + cancel() + if err != nil { + t.Fatal("ping error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "drop table foo") + cancel() + if err != nil { + t.Fatal("drop table error:", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second) + _, err = db.ExecContext(ctx, "create table foo (id integer)") + cancel() + if err != nil { + t.Fatal("create table error:", err) + } +} diff --git a/sqlite3_test.go b/sqlite3_test.go index 3ef8533..806ab8d 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -34,20 +34,32 @@ func TempFilename(t *testing.T) string { } func doTestOpen(t *testing.T, option string) (string, error) { - var url string tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - if option != "" { - url = tempFilename + option - } else { - url = tempFilename - } + url := tempFilename + option + + defer func() { + err := os.Remove(tempFilename) + if err != nil { + t.Error("temp file remove error:", err) + } + }() + db, err := sql.Open("sqlite3", url) if err != nil { return "Failed to open database:", err } - defer os.Remove(tempFilename) - defer db.Close() + + defer func() { + err = db.Close() + if err != nil { + t.Error("db close error:", err) + } + }() + + err = db.Ping() + if err != nil { + return "ping error:", err + } _, err = db.Exec("drop table foo") _, err = db.Exec("create table foo (id integer)")