forked from mirror/go-sqlcipher
Changed extension support to load from a string list of extensions
By loading extensions this way, it's not possible to later load extensions using db.Exec, which improves security, and makes it much easier to load extensions correctly. The zero value for the slice (the empty slice) loads no extensions by default. The extension example has been updated to use this much simpler system. The ConnectHook field is still in SQLiteDriver in case it's needed for other driver-wide initialization. Updates #71 of mattn/go-sqlite3.
This commit is contained in:
parent
976f43861f
commit
0dd71564e2
|
@ -8,29 +8,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
const (
|
|
||||||
use_hook = true
|
|
||||||
load_query = "SELECT load_extension('sqlite3_mod_regexp.dll')"
|
|
||||||
)
|
|
||||||
|
|
||||||
sql.Register("sqlite3_with_extensions",
|
sql.Register("sqlite3_with_extensions",
|
||||||
&sqlite3.SQLiteDriver{
|
&sqlite3.SQLiteDriver{
|
||||||
EnableLoadExtension: true,
|
Extensions: []string{
|
||||||
ConnectHook: func(c *sqlite3.SQLiteConn) error {
|
"sqlite3_mod_regexp.dll",
|
||||||
if use_hook {
|
|
||||||
stmt, err := c.Prepare(load_query)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = stmt.Exec(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return stmt.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -40,12 +21,6 @@ func main() {
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
if !use_hook {
|
|
||||||
if _, err = db.Exec(load_query); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force db to make a new connection in pool
|
// Force db to make a new connection in pool
|
||||||
// by putting the original in a transaction
|
// by putting the original in a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
|
|
34
sqlite3.go
34
sqlite3.go
|
@ -72,12 +72,12 @@ var SQLiteTimestampFormats = []string{
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
sql.Register("sqlite3", &SQLiteDriver{false, nil})
|
sql.Register("sqlite3", &SQLiteDriver{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Driver struct.
|
// Driver struct.
|
||||||
type SQLiteDriver struct {
|
type SQLiteDriver struct {
|
||||||
EnableLoadExtension bool
|
Extensions []string
|
||||||
ConnectHook func(*SQLiteConn) error
|
ConnectHook func(*SQLiteConn) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,16 +182,34 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||||
}
|
}
|
||||||
|
|
||||||
enableLoadExtension := 0
|
conn := &SQLiteConn{db}
|
||||||
if d.EnableLoadExtension {
|
|
||||||
enableLoadExtension = 1
|
if len(d.Extensions) > 0 {
|
||||||
}
|
rv = C.sqlite3_enable_load_extension(db, 1)
|
||||||
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
|
|
||||||
if rv != C.SQLITE_OK {
|
if rv != C.SQLITE_OK {
|
||||||
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &SQLiteConn{db}
|
stmt, err := conn.Prepare("SELECT load_extension(?);")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, extension := range d.Extensions {
|
||||||
|
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = stmt.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rv = C.sqlite3_enable_load_extension(db, 0)
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if d.ConnectHook != nil {
|
if d.ConnectHook != nil {
|
||||||
if err := d.ConnectHook(conn); err != nil {
|
if err := d.ConnectHook(conn); err != nil {
|
||||||
|
|
Loading…
Reference in New Issue