From 98a44bcf5949f178c8116fa30e62c9ac2ef65927 Mon Sep 17 00:00:00 2001 From: rittneje Date: Thu, 16 Apr 2020 01:45:59 -0400 Subject: [PATCH] report actual error message if sqlite3_load_extension fails (#800) * report actual error message if sqlite3_load_extension fails * more fixes and test cases Co-authored-by: Jesse Rittner --- _example/mod_regexp/Makefile | 17 +++--- _example/mod_vtable/Makefile | 17 +++--- _example/mod_vtable/sqlite3_mod_vtable.cc | 2 +- sqlite3_load_extension.go | 40 +++++++++----- sqlite3_load_extension_test.go | 63 +++++++++++++++++++++++ 5 files changed, 113 insertions(+), 26 deletions(-) create mode 100644 sqlite3_load_extension_test.go diff --git a/_example/mod_regexp/Makefile b/_example/mod_regexp/Makefile index 97b1e0f..1ef69a6 100644 --- a/_example/mod_regexp/Makefile +++ b/_example/mod_regexp/Makefile @@ -1,22 +1,27 @@ ifeq ($(OS),Windows_NT) EXE=extension.exe -EXT=sqlite3_mod_regexp.dll +LIB_EXT=dll RM=cmd /c del LDFLAG= else EXE=extension -EXT=sqlite3_mod_regexp.so -RM=rm +ifeq ($(shell uname -s),Darwin) +LIB_EXT=dylib +else +LIB_EXT=so +endif +RM=rm -f LDFLAG=-fPIC endif +LIB=sqlite3_mod_regexp.$(LIB_EXT) -all : $(EXE) $(EXT) +all : $(EXE) $(LIB) $(EXE) : extension.go go build $< -$(EXT) : sqlite3_mod_regexp.c +$(LIB) : sqlite3_mod_regexp.c gcc $(LDFLAG) -shared -o $@ $< -lsqlite3 -lpcre clean : - @-$(RM) $(EXE) $(EXT) + @-$(RM) $(EXE) $(LIB) diff --git a/_example/mod_vtable/Makefile b/_example/mod_vtable/Makefile index cdd4853..f65a004 100644 --- a/_example/mod_vtable/Makefile +++ b/_example/mod_vtable/Makefile @@ -1,24 +1,29 @@ ifeq ($(OS),Windows_NT) EXE=extension.exe -EXT=sqlite3_mod_vtable.dll +LIB_EXT=dll RM=cmd /c del LIBCURL=-lcurldll LDFLAG= else EXE=extension -EXT=sqlite3_mod_vtable.so -RM=rm +ifeq ($(shell uname -s),Darwin) +LIB_EXT=dylib +else +LIB_EXT=so +endif +RM=rm -f LDFLAG=-fPIC LIBCURL=-lcurl endif +LIB=sqlite3_mod_vtable.$(LIB_EXT) -all : $(EXE) $(EXT) +all : $(EXE) $(LIB) $(EXE) : extension.go go build $< -$(EXT) : sqlite3_mod_vtable.cc +$(LIB) : sqlite3_mod_vtable.cc g++ $(LDFLAG) -shared -o $@ $< -lsqlite3 $(LIBCURL) clean : - @-$(RM) $(EXE) $(EXT) + @-$(RM) $(EXE) $(LIB) diff --git a/_example/mod_vtable/sqlite3_mod_vtable.cc b/_example/mod_vtable/sqlite3_mod_vtable.cc index 5bd4e66..4caf484 100644 --- a/_example/mod_vtable/sqlite3_mod_vtable.cc +++ b/_example/mod_vtable/sqlite3_mod_vtable.cc @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include "picojson.h" diff --git a/sqlite3_load_extension.go b/sqlite3_load_extension.go index 23c5d31..e6c50f2 100644 --- a/sqlite3_load_extension.go +++ b/sqlite3_load_extension.go @@ -28,12 +28,9 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error { } for _, extension := range extensions { - cext := C.CString(extension) - defer C.free(unsafe.Pointer(cext)) - rv = C.sqlite3_load_extension(c.db, cext, nil, nil) - if rv != C.SQLITE_OK { + if err := c.loadExtension(extension, nil); err != nil { C.sqlite3_enable_load_extension(c.db, 0) - return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) + return err } } @@ -41,6 +38,7 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error { if rv != C.SQLITE_OK { return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) } + return nil } @@ -51,14 +49,9 @@ func (c *SQLiteConn) LoadExtension(lib string, entry string) error { return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) } - clib := C.CString(lib) - defer C.free(unsafe.Pointer(clib)) - centry := C.CString(entry) - defer C.free(unsafe.Pointer(centry)) - - rv = C.sqlite3_load_extension(c.db, clib, centry, nil) - if rv != C.SQLITE_OK { - return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) + if err := c.loadExtension(lib, &entry); err != nil { + C.sqlite3_enable_load_extension(c.db, 0) + return err } rv = C.sqlite3_enable_load_extension(c.db, 0) @@ -68,3 +61,24 @@ func (c *SQLiteConn) LoadExtension(lib string, entry string) error { return nil } + +func (c *SQLiteConn) loadExtension(lib string, entry *string) error { + clib := C.CString(lib) + defer C.free(unsafe.Pointer(clib)) + + var centry *C.char + if entry != nil { + centry := C.CString(*entry) + defer C.free(unsafe.Pointer(centry)) + } + + var errMsg *C.char + defer C.sqlite3_free(unsafe.Pointer(errMsg)) + + rv := C.sqlite3_load_extension(c.db, clib, centry, &errMsg) + if rv != C.SQLITE_OK { + return errors.New(C.GoString(errMsg)) + } + + return nil +} diff --git a/sqlite3_load_extension_test.go b/sqlite3_load_extension_test.go new file mode 100644 index 0000000..97b1123 --- /dev/null +++ b/sqlite3_load_extension_test.go @@ -0,0 +1,63 @@ +// Copyright (C) 2019 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build !sqlite_omit_load_extension + +package sqlite3 + +import ( + "database/sql" + "testing" +) + +func TestExtensionsError(t *testing.T) { + sql.Register("sqlite3_TestExtensionsError", + &SQLiteDriver{ + Extensions: []string{ + "foobar", + }, + }, + ) + + db, err := sql.Open("sqlite3_TestExtensionsError", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Ping() + if err == nil { + t.Fatal("expected error loading non-existent extension") + } + + if err.Error() == "not an error" { + t.Fatal("expected error from sqlite3_enable_load_extension to be returned") + } +} + +func TestLoadExtensionError(t *testing.T) { + sql.Register("sqlite3_TestLoadExtensionError", + &SQLiteDriver{ + ConnectHook: func(c *SQLiteConn) error { + return c.LoadExtension("foobar", "") + }, + }, + ) + + db, err := sql.Open("sqlite3_TestLoadExtensionError", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Ping() + if err == nil { + t.Fatal("expected error loading non-existent extension") + } + + if err.Error() == "not an error" { + t.Fatal("expected error from sqlite3_enable_load_extension to be returned") + } +}