Merge pull request #73 from cookieo9/extlist

Change extension loading mechanism to use a string list of extensions
This commit is contained in:
mattn 2013-08-25 08:53:29 -07:00
commit a3e3a8e981
2 changed files with 41 additions and 16 deletions

View File

@ -10,8 +10,9 @@ import (
func main() { func main() {
sql.Register("sqlite3_with_extensions", sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{ &sqlite3.SQLiteDriver{
EnableLoadExtension: true, Extensions: []string{
ConnectHook: nil, "sqlite3_mod_regexp.dll",
},
}) })
db, err := sql.Open("sqlite3_with_extensions", ":memory:") db, err := sql.Open("sqlite3_with_extensions", ":memory:")
@ -20,11 +21,15 @@ func main() {
} }
defer db.Close() defer db.Close()
_, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')") // Force db to make a new connection in pool
// by putting the original in a transaction
tx, err := db.Begin()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer tx.Commit()
// New connection works (hopefully!)
rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'") rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
} }
func init() { func init() {
sql.Register("sqlite3", &SQLiteDriver{false, nil}) sql.Register("sqlite3", &SQLiteDriver{})
} }
// Driver struct. // Driver struct.
type SQLiteDriver struct { type SQLiteDriver struct {
EnableLoadExtension bool Extensions []string
ConnectHook func(*SQLiteConn) ConnectHook func(*SQLiteConn) error
} }
// Conn struct. // Conn struct.
@ -182,19 +182,39 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
} }
enableLoadExtension := 0
if d.EnableLoadExtension {
enableLoadExtension = 1
}
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
conn := &SQLiteConn{db} conn := &SQLiteConn{db}
if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err
}
for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
return nil, err
}
}
if err = stmt.Close(); err != nil {
return nil, err
}
rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
}
if d.ConnectHook != nil { if d.ConnectHook != nil {
d.ConnectHook(conn) if err := d.ConnectHook(conn); err != nil {
return nil, err
}
} }
return conn, nil return conn, nil