Merge pull request #73 from cookieo9/extlist

Change extension loading mechanism to use a string list of extensions
This commit is contained in:
mattn 2013-08-25 08:53:29 -07:00
commit a3e3a8e981
2 changed files with 41 additions and 16 deletions

View File

@ -10,8 +10,9 @@ import (
func main() {
sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{
EnableLoadExtension: true,
ConnectHook: nil,
Extensions: []string{
"sqlite3_mod_regexp.dll",
},
})
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
@ -20,11 +21,15 @@ func main() {
}
defer db.Close()
_, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')")
// Force db to make a new connection in pool
// by putting the original in a transaction
tx, err := db.Begin()
if err != nil {
log.Fatal(err)
}
defer tx.Commit()
// New connection works (hopefully!)
rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'")
if err != nil {
log.Fatal(err)

View File

@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
}
func init() {
sql.Register("sqlite3", &SQLiteDriver{false, nil})
sql.Register("sqlite3", &SQLiteDriver{})
}
// Driver struct.
type SQLiteDriver struct {
EnableLoadExtension bool
ConnectHook func(*SQLiteConn)
Extensions []string
ConnectHook func(*SQLiteConn) error
}
// Conn struct.
@ -182,19 +182,39 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
enableLoadExtension := 0
if d.EnableLoadExtension {
enableLoadExtension = 1
}
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
conn := &SQLiteConn{db}
if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err
}
for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
return nil, err
}
}
if err = stmt.Close(); err != nil {
return nil, err
}
rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
}
if d.ConnectHook != nil {
d.ConnectHook(conn)
if err := d.ConnectHook(conn); err != nil {
return nil, err
}
}
return conn, nil