forked from mirror/go-sqlite3
Merge pull request #73 from cookieo9/extlist
Change extension loading mechanism to use a string list of extensions
This commit is contained in:
commit
a3e3a8e981
|
@ -10,8 +10,9 @@ import (
|
|||
func main() {
|
||||
sql.Register("sqlite3_with_extensions",
|
||||
&sqlite3.SQLiteDriver{
|
||||
EnableLoadExtension: true,
|
||||
ConnectHook: nil,
|
||||
Extensions: []string{
|
||||
"sqlite3_mod_regexp.dll",
|
||||
},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
|
||||
|
@ -20,11 +21,15 @@ func main() {
|
|||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')")
|
||||
// Force db to make a new connection in pool
|
||||
// by putting the original in a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer tx.Commit()
|
||||
|
||||
// New connection works (hopefully!)
|
||||
rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
|
46
sqlite3.go
46
sqlite3.go
|
@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
|
|||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", &SQLiteDriver{false, nil})
|
||||
sql.Register("sqlite3", &SQLiteDriver{})
|
||||
}
|
||||
|
||||
// Driver struct.
|
||||
type SQLiteDriver struct {
|
||||
EnableLoadExtension bool
|
||||
ConnectHook func(*SQLiteConn)
|
||||
Extensions []string
|
||||
ConnectHook func(*SQLiteConn) error
|
||||
}
|
||||
|
||||
// Conn struct.
|
||||
|
@ -182,19 +182,39 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||
}
|
||||
|
||||
enableLoadExtension := 0
|
||||
if d.EnableLoadExtension {
|
||||
enableLoadExtension = 1
|
||||
}
|
||||
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||
}
|
||||
|
||||
conn := &SQLiteConn{db}
|
||||
|
||||
if len(d.Extensions) > 0 {
|
||||
rv = C.sqlite3_enable_load_extension(db, 1)
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||
}
|
||||
|
||||
stmt, err := conn.Prepare("SELECT load_extension(?);")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, extension := range d.Extensions {
|
||||
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = stmt.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rv = C.sqlite3_enable_load_extension(db, 0)
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||
}
|
||||
}
|
||||
|
||||
if d.ConnectHook != nil {
|
||||
d.ConnectHook(conn)
|
||||
if err := d.ConnectHook(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
|
Loading…
Reference in New Issue