diff --git a/example/extension/extension.go b/example/extension/extension.go index 49eacf1..f58ea3a 100644 --- a/example/extension/extension.go +++ b/example/extension/extension.go @@ -8,29 +8,10 @@ import ( ) func main() { - const ( - use_hook = true - load_query = "SELECT load_extension('sqlite3_mod_regexp.dll')" - ) - sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{ - EnableLoadExtension: true, - ConnectHook: func(c *sqlite3.SQLiteConn) error { - if use_hook { - stmt, err := c.Prepare(load_query) - if err != nil { - return err - } - - _, err = stmt.Exec(nil) - if err != nil { - return err - } - - return stmt.Close() - } - return nil + Extensions: []string{ + "sqlite3_mod_regexp.dll", }, }) @@ -40,12 +21,6 @@ func main() { } defer db.Close() - if !use_hook { - if _, err = db.Exec(load_query); err != nil { - log.Fatal(err) - } - } - // Force db to make a new connection in pool // by putting the original in a transaction tx, err := db.Begin() diff --git a/sqlite3.go b/sqlite3.go index cc42c13..e7417ec 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{ } func init() { - sql.Register("sqlite3", &SQLiteDriver{false, nil}) + sql.Register("sqlite3", &SQLiteDriver{}) } // Driver struct. type SQLiteDriver struct { - EnableLoadExtension bool - ConnectHook func(*SQLiteConn) error + Extensions []string + ConnectHook func(*SQLiteConn) error } // Conn struct. @@ -182,17 +182,35 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) } - enableLoadExtension := 0 - if d.EnableLoadExtension { - enableLoadExtension = 1 - } - rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension)) - if rv != C.SQLITE_OK { - return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) - } - conn := &SQLiteConn{db} + if len(d.Extensions) > 0 { + rv = C.sqlite3_enable_load_extension(db, 1) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + + stmt, err := conn.Prepare("SELECT load_extension(?);") + if err != nil { + return nil, err + } + + for _, extension := range d.Extensions { + if _, err = stmt.Exec([]driver.Value{extension}); err != nil { + return nil, err + } + } + + if err = stmt.Close(); err != nil { + return nil, err + } + + rv = C.sqlite3_enable_load_extension(db, 0) + if rv != C.SQLITE_OK { + return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) + } + } + if d.ConnectHook != nil { if err := d.ConnectHook(conn); err != nil { return nil, err