mirror of https://github.com/mattn/go-sqlite3.git
Merge pull request #407 from zombiezen/foreignkeys
Add _foreign_keys connection parameter
This commit is contained in:
commit
46e826d22a
46
sqlite3.go
46
sqlite3.go
|
@ -400,14 +400,18 @@ func (c *SQLiteConn) AutoCommit() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SQLiteConn) lastError() error {
|
func (c *SQLiteConn) lastError() error {
|
||||||
rv := C.sqlite3_errcode(c.db)
|
return lastError(c.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lastError(db *C.sqlite3) error {
|
||||||
|
rv := C.sqlite3_errcode(db)
|
||||||
if rv == C.SQLITE_OK {
|
if rv == C.SQLITE_OK {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return Error{
|
return Error{
|
||||||
Code: ErrNo(rv),
|
Code: ErrNo(rv),
|
||||||
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
|
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)),
|
||||||
err: C.GoString(C.sqlite3_errmsg(c.db)),
|
err: C.GoString(C.sqlite3_errmsg(db)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -537,6 +541,8 @@ func errorString(err Error) string {
|
||||||
// _txlock=XXX
|
// _txlock=XXX
|
||||||
// Specify locking behavior for transactions. XXX can be "immediate",
|
// Specify locking behavior for transactions. XXX can be "immediate",
|
||||||
// "deferred", "exclusive".
|
// "deferred", "exclusive".
|
||||||
|
// _foreign_keys=X
|
||||||
|
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
|
||||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
if C.sqlite3_threadsafe() == 0 {
|
if C.sqlite3_threadsafe() == 0 {
|
||||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||||
|
@ -545,6 +551,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
var loc *time.Location
|
var loc *time.Location
|
||||||
txlock := "BEGIN"
|
txlock := "BEGIN"
|
||||||
busyTimeout := 5000
|
busyTimeout := 5000
|
||||||
|
foreignKeys := -1
|
||||||
pos := strings.IndexRune(dsn, '?')
|
pos := strings.IndexRune(dsn, '?')
|
||||||
if pos >= 1 {
|
if pos >= 1 {
|
||||||
params, err := url.ParseQuery(dsn[pos+1:])
|
params, err := url.ParseQuery(dsn[pos+1:])
|
||||||
|
@ -587,6 +594,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// _foreign_keys
|
||||||
|
if val := params.Get("_foreign_keys"); val != "" {
|
||||||
|
switch val {
|
||||||
|
case "1":
|
||||||
|
foreignKeys = 1
|
||||||
|
case "0":
|
||||||
|
foreignKeys = 0
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Invalid _foreign_keys: %v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(dsn, "file:") {
|
if !strings.HasPrefix(dsn, "file:") {
|
||||||
dsn = dsn[:pos]
|
dsn = dsn[:pos]
|
||||||
}
|
}
|
||||||
|
@ -613,6 +632,27 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
return nil, Error{Code: ErrNo(rv)}
|
return nil, Error{Code: ErrNo(rv)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
exec := func(s string) error {
|
||||||
|
cs := C.CString(s)
|
||||||
|
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
|
||||||
|
C.free(unsafe.Pointer(cs))
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return lastError(db)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if foreignKeys == 0 {
|
||||||
|
if err := exec("PRAGMA foreign_keys = OFF;"); err != nil {
|
||||||
|
C.sqlite3_close_v2(db)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if foreignKeys == 1 {
|
||||||
|
if err := exec("PRAGMA foreign_keys = ON;"); err != nil {
|
||||||
|
C.sqlite3_close_v2(db)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
||||||
|
|
||||||
if len(d.Extensions) > 0 {
|
if len(d.Extensions) > 0 {
|
||||||
|
|
|
@ -107,6 +107,35 @@ func TestReadonly(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForeignKeys(t *testing.T) {
|
||||||
|
cases := map[string]bool{
|
||||||
|
"?_foreign_keys=1": true,
|
||||||
|
"?_foreign_keys=0": false,
|
||||||
|
}
|
||||||
|
for option, want := range cases {
|
||||||
|
fname := TempFilename(t)
|
||||||
|
uri := "file:" + fname + option
|
||||||
|
db, err := sql.Open("sqlite3", uri)
|
||||||
|
if err != nil {
|
||||||
|
os.Remove(fname)
|
||||||
|
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var enabled bool
|
||||||
|
err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
|
||||||
|
db.Close()
|
||||||
|
os.Remove(fname)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("query foreign_keys for %s: %v", uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if enabled != want {
|
||||||
|
t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClose(t *testing.T) {
|
func TestClose(t *testing.T) {
|
||||||
tempFilename := TempFilename(t)
|
tempFilename := TempFilename(t)
|
||||||
defer os.Remove(tempFilename)
|
defer os.Remove(tempFilename)
|
||||||
|
|
Loading…
Reference in New Issue