Add _foreign_keys connection parameter

Fixes #377
Updates #255
This commit is contained in:
Ross Light 2017-04-01 09:12:21 -07:00
parent b2e464529e
commit c6d43c40e6
2 changed files with 72 additions and 3 deletions

View File

@ -400,14 +400,18 @@ func (c *SQLiteConn) AutoCommit() bool {
} }
func (c *SQLiteConn) lastError() error { func (c *SQLiteConn) lastError() error {
rv := C.sqlite3_errcode(c.db) return lastError(c.db)
}
func lastError(db *C.sqlite3) error {
rv := C.sqlite3_errcode(db)
if rv == C.SQLITE_OK { if rv == C.SQLITE_OK {
return nil return nil
} }
return Error{ return Error{
Code: ErrNo(rv), Code: ErrNo(rv),
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)), ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)),
err: C.GoString(C.sqlite3_errmsg(c.db)), err: C.GoString(C.sqlite3_errmsg(db)),
} }
} }
@ -537,6 +541,8 @@ func errorString(err Error) string {
// _txlock=XXX // _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate", // Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive". // "deferred", "exclusive".
// _foreign_keys=X
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 { if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation") return nil, errors.New("sqlite library was not compiled for thread-safe operation")
@ -545,6 +551,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
var loc *time.Location var loc *time.Location
txlock := "BEGIN" txlock := "BEGIN"
busyTimeout := 5000 busyTimeout := 5000
foreignKeys := -1
pos := strings.IndexRune(dsn, '?') pos := strings.IndexRune(dsn, '?')
if pos >= 1 { if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:]) params, err := url.ParseQuery(dsn[pos+1:])
@ -587,6 +594,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
} }
} }
// _foreign_keys
if val := params.Get("_foreign_keys"); val != "" {
switch val {
case "1":
foreignKeys = 1
case "0":
foreignKeys = 0
default:
return nil, fmt.Errorf("Invalid _foreign_keys: %v", val)
}
}
if !strings.HasPrefix(dsn, "file:") { if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos] dsn = dsn[:pos]
} }
@ -612,6 +631,27 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, Error{Code: ErrNo(rv)} return nil, Error{Code: ErrNo(rv)}
} }
exec := func(s string) error {
cs := C.CString(s)
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
C.free(unsafe.Pointer(cs))
if rv != C.SQLITE_OK {
return lastError(db)
}
return nil
}
if foreignKeys == 0 {
if err := exec("PRAGMA foreign_keys = OFF;"); err != nil {
C.sqlite3_close_v2(db)
return nil, err
}
} else if foreignKeys == 1 {
if err := exec("PRAGMA foreign_keys = ON;"); err != nil {
C.sqlite3_close_v2(db)
return nil, err
}
}
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 { if len(d.Extensions) > 0 {

View File

@ -107,6 +107,35 @@ func TestReadonly(t *testing.T) {
} }
} }
func TestForeignKeys(t *testing.T) {
cases := map[string]bool{
"?_foreign_keys=1": true,
"?_foreign_keys=0": false,
}
for option, want := range cases {
fname := TempFilename(t)
uri := "file:" + fname + option
db, err := sql.Open("sqlite3", uri)
if err != nil {
os.Remove(fname)
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
continue
}
var enabled bool
err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
db.Close()
os.Remove(fname)
if err != nil {
t.Errorf("query foreign_keys for %s: %v", uri, err)
continue
}
if enabled != want {
t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
continue
}
}
}
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
tempFilename := TempFilename(t) tempFilename := TempFilename(t)
defer os.Remove(tempFilename) defer os.Remove(tempFilename)