forked from mirror/go-sqlcipher
Merge pull request #73 from cookieo9/extlist
Change extension loading mechanism to use a string list of extensions
This commit is contained in:
commit
a3e3a8e981
|
@ -10,8 +10,9 @@ import (
|
||||||
func main() {
|
func main() {
|
||||||
sql.Register("sqlite3_with_extensions",
|
sql.Register("sqlite3_with_extensions",
|
||||||
&sqlite3.SQLiteDriver{
|
&sqlite3.SQLiteDriver{
|
||||||
EnableLoadExtension: true,
|
Extensions: []string{
|
||||||
ConnectHook: nil,
|
"sqlite3_mod_regexp.dll",
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
|
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
|
||||||
|
@ -20,11 +21,15 @@ func main() {
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
_, err = db.Exec("select load_extension('sqlite3_mod_regexp.dll')")
|
// Force db to make a new connection in pool
|
||||||
|
// by putting the original in a transaction
|
||||||
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer tx.Commit()
|
||||||
|
|
||||||
|
// New connection works (hopefully!)
|
||||||
rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'")
|
rows, err := db.Query("select 'hello world' where 'hello world' regexp '^hello.*d$'")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
|
46
sqlite3.go
46
sqlite3.go
|
@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
sql.Register("sqlite3", &SQLiteDriver{false, nil})
|
sql.Register("sqlite3", &SQLiteDriver{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Driver struct.
|
// Driver struct.
|
||||||
type SQLiteDriver struct {
|
type SQLiteDriver struct {
|
||||||
EnableLoadExtension bool
|
Extensions []string
|
||||||
ConnectHook func(*SQLiteConn)
|
ConnectHook func(*SQLiteConn) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn struct.
|
// Conn struct.
|
||||||
|
@ -182,19 +182,39 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||||
}
|
}
|
||||||
|
|
||||||
enableLoadExtension := 0
|
|
||||||
if d.EnableLoadExtension {
|
|
||||||
enableLoadExtension = 1
|
|
||||||
}
|
|
||||||
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
|
|
||||||
if rv != C.SQLITE_OK {
|
|
||||||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := &SQLiteConn{db}
|
conn := &SQLiteConn{db}
|
||||||
|
|
||||||
|
if len(d.Extensions) > 0 {
|
||||||
|
rv = C.sqlite3_enable_load_extension(db, 1)
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, err := conn.Prepare("SELECT load_extension(?);")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, extension := range d.Extensions {
|
||||||
|
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = stmt.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rv = C.sqlite3_enable_load_extension(db, 0)
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if d.ConnectHook != nil {
|
if d.ConnectHook != nil {
|
||||||
d.ConnectHook(conn)
|
if err := d.ConnectHook(conn); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|
Loading…
Reference in New Issue