From 976f43861ffde82e2f0939793124b172d2ebd2ec Mon Sep 17 00:00:00 2001 From: Carlos Castillo Date: Sat, 24 Aug 2013 20:04:51 -0700 Subject: [PATCH 1/2] Added error return to ConnectHook and fixed extension example The ConnectHook field of an SQLiteDriver should return an error in case something bad happened during the hook. The extension example needs to load the extension in a ConnectHook, otherwise the extension is only loaded in a single connection in the pool. By putting the extension loading in the ConnectHook, its called for every connection that is opened by the sql.DB. --- example/extension/extension.go | 34 ++++++++++++++++++++++++++++++++-- sqlite3.go | 6 ++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/example/extension/extension.go b/example/extension/extension.go index d4b8fdb..49eacf1 100644 --- a/example/extension/extension.go +++ b/example/extension/extension.go @@ -8,10 +8,30 @@ import ( ) func main() { + const ( + use_hook = true + load_query = "SELECT load_extension('sqlite3_mod_regexp.dll')" + ) + sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{ EnableLoadExtension: true, - ConnectHook: nil, + ConnectHook: func(c *sqlite3.SQLiteConn) error { + if use_hook { + stmt, err := c.Prepare(load_query) + if err != nil { + return err + } + + _, err = stmt.Exec(nil) + if err != nil { + return err + } + + return stmt.Close() + } + return nil + }, }) db, err := sql.Open("sqlite3_with_extensions", ":memory:") @@ -20,11 +40,21 @@ func main() { } defer db.Close() - _, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')") + if !use_hook { + if _, err = db.Exec(load_query); err != nil { + log.Fatal(err) + } + } + + // Force db to make a new connection in pool + // by putting the original in a transaction + tx, err := db.Begin() if err != nil { log.Fatal(err) } + defer tx.Commit() + // New connection works (hopefully!) rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'") if err != nil { log.Fatal(err) diff --git a/sqlite3.go b/sqlite3.go index 692306d..cc42c13 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -78,7 +78,7 @@ func init() { // Driver struct. type SQLiteDriver struct { EnableLoadExtension bool - ConnectHook func(*SQLiteConn) + ConnectHook func(*SQLiteConn) error } // Conn struct. @@ -194,7 +194,9 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { conn := &SQLiteConn{db} if d.ConnectHook != nil { - d.ConnectHook(conn) + if err := d.ConnectHook(conn); err != nil { + return nil, err + } } return conn, nil From 0dd71564e26e5fbe2177d420e2bf82a889568d64 Mon Sep 17 00:00:00 2001 From: Carlos Castillo Date: Sat, 24 Aug 2013 20:36:35 -0700 Subject: [PATCH 2/2] Changed extension support to load from a string list of extensions By loading extensions this way, it's not possible to later load extensions using db.Exec, which improves security, and makes it much easier to load extensions correctly. The zero value for the slice (the empty slice) loads no extensions by default. The extension example has been updated to use this much simpler system. The ConnectHook field is still in SQLiteDriver in case it's needed for other driver-wide initialization. Updates #71 of mattn/go-sqlite3. --- example/extension/extension.go | 29 ++--------------------- sqlite3.go | 42 ++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/example/extension/extension.go b/example/extension/extension.go index 49eacf1..f58ea3a 100644 --- a/example/extension/extension.go +++ b/example/extension/extension.go @@ -8,29 +8,10 @@ import ( ) func main() { - const ( - use_hook = true - load_query = "SELECT load_extension('sqlite3_mod_regexp.dll')" - ) - sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{ - EnableLoadExtension: true, - ConnectHook: func(c *sqlite3.SQLiteConn) error { - if use_hook { - stmt, err := c.Prepare(load_query) - if err != nil { - return err - } - - _, err = stmt.Exec(nil) - if err != nil { - return err - } - - return stmt.Close() - } - return nil + Extensions: []string{ + "sqlite3_mod_regexp.dll", }, }) @@ -40,12 +21,6 @@ func main() { } defer db.Close() - if !use_hook { - if _, err = db.Exec(load_query); err != nil { - log.Fatal(err) - } - } - // Force db to make a new connection in pool // by putting the original in a transaction tx, err := db.Begin() diff --git a/sqlite3.go b/sqlite3.go index cc42c13..e7417ec 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{ } func init() { - sql.Register("sqlite3", &SQLiteDriver{false, nil}) + sql.Register("sqlite3", &SQLiteDriver{}) } // Driver struct. type SQLiteDriver struct { - EnableLoadExtension bool - ConnectHook func(*SQLiteConn) error + Extensions []string + ConnectHook func(*SQLiteConn) error } // Conn struct. @@ -182,17 +182,35 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) } - enableLoadExtension := 0 - if d.EnableLoadExtension { - enableLoadExtension = 1 - } - rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension)) - if rv != C.SQLITE_OK { - return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) - } - conn := &SQLiteConn{db} + if len(d.Extensions) > 0 { + rv = C.sqlite3_enable_load_extension(db, 1) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + + stmt, err := conn.Prepare("SELECT load_extension(?);") + if err != nil { + return nil, err + } + + for _, extension := range d.Extensions { + if _, err = stmt.Exec([]driver.Value{extension}); err != nil { + return nil, err + } + } + + if err = stmt.Close(); err != nil { + return nil, err + } + + rv = C.sqlite3_enable_load_extension(db, 0) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + } + if d.ConnectHook != nil { if err := d.ConnectHook(conn); err != nil { return nil, err