forked from mirror/go-sqlcipher
Add connection option for recursive triggers
Similar to foreign keys, the recursive triggers PRAGMA affects the interpretation of all statements on a connection.
This commit is contained in:
parent
9cbb097044
commit
848386d7a2
26
sqlite3.go
26
sqlite3.go
|
@ -599,6 +599,8 @@ func errorString(err Error) string {
|
|||
// "deferred", "exclusive".
|
||||
// _foreign_keys=X
|
||||
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
|
||||
// _recursive_triggers=X
|
||||
// Enable or disable recursive triggers. X can be 1 or 0.
|
||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||
if C.sqlite3_threadsafe() == 0 {
|
||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||
|
@ -608,6 +610,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
txlock := "BEGIN"
|
||||
busyTimeout := 5000
|
||||
foreignKeys := -1
|
||||
recursiveTriggers := -1
|
||||
pos := strings.IndexRune(dsn, '?')
|
||||
if pos >= 1 {
|
||||
params, err := url.ParseQuery(dsn[pos+1:])
|
||||
|
@ -662,6 +665,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// _recursive_triggers
|
||||
if val := params.Get("_recursive_triggers"); val != "" {
|
||||
switch val {
|
||||
case "1":
|
||||
recursiveTriggers = 1
|
||||
case "0":
|
||||
recursiveTriggers = 0
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(dsn, "file:") {
|
||||
dsn = dsn[:pos]
|
||||
}
|
||||
|
@ -708,6 +723,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if recursiveTriggers == 0 {
|
||||
if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
} else if recursiveTriggers == 1 {
|
||||
if err := exec("PRAGMA recursive_triggers = ON;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
||||
|
||||
|
|
|
@ -136,6 +136,35 @@ func TestForeignKeys(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRecursiveTriggers(t *testing.T) {
|
||||
cases := map[string]bool{
|
||||
"?_recursive_triggers=1": true,
|
||||
"?_recursive_triggers=0": false,
|
||||
}
|
||||
for option, want := range cases {
|
||||
fname := TempFilename(t)
|
||||
uri := "file:" + fname + option
|
||||
db, err := sql.Open("sqlite3", uri)
|
||||
if err != nil {
|
||||
os.Remove(fname)
|
||||
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
|
||||
continue
|
||||
}
|
||||
var enabled bool
|
||||
err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
|
||||
db.Close()
|
||||
os.Remove(fname)
|
||||
if err != nil {
|
||||
t.Errorf("query recursive_triggers for %s: %v", uri, err)
|
||||
continue
|
||||
}
|
||||
if enabled != want {
|
||||
t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
|
|
Loading…
Reference in New Issue