diff --git a/example/extension/extension.go b/example/extension/extension.go index 3ec145a..95f2f70 100644 --- a/example/extension/extension.go +++ b/example/extension/extension.go @@ -3,11 +3,13 @@ package main import ( "database/sql" "fmt" - _ "github.com/mattn/go-sqlite3" + "github.com/mattn/go-sqlite3" "log" ) func main() { + sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{true, nil}) + db, err := sql.Open("sqlite3_with_extensions", ":memory:") if err != nil { log.Fatal(err) diff --git a/sqlite3.go b/sqlite3.go index 0aba2db..43e255b 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -72,13 +72,13 @@ var SQLiteTimestampFormats = []string{ } func init() { - sql.Register("sqlite3", &SQLiteDriver{false}) - sql.Register("sqlite3_with_extensions", &SQLiteDriver{true}) + sql.Register("sqlite3", &SQLiteDriver{false, nil}) } // Driver struct. type SQLiteDriver struct { - enableLoadExtentions bool + EnableLoadExtentions bool + ConnectHook func(*SQLiteConn) } // Conn struct. @@ -179,7 +179,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { } enableLoadExtentions := 0 - if d.enableLoadExtentions { + if d.EnableLoadExtentions { enableLoadExtentions = 1 } rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtentions)) @@ -187,7 +187,13 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, errors.New(C.GoString(C.sqlite3_errmsg(db))) } - return &SQLiteConn{db}, nil + conn := &SQLiteConn{db} + + if d.ConnectHook != nil { + d.ConnectHook(conn) + } + + return conn, nil } // Close the connection.