diff --git a/driver/config.go b/driver/config.go index 714a632..5305b0a 100644 --- a/driver/config.go +++ b/driver/config.go @@ -165,6 +165,7 @@ func (tx TxLock) String() string { } } +// Value returns the Transaction Lock Value func (tx TxLock) Value() string { return string(tx) } @@ -717,7 +718,7 @@ func (cfg *Config) createConnection() (driver.Conn, error) { return nil, errors.New("sqlite library was not compiled for thread-safe operation") } - if len(cfg.Database) == 0 { + if len(cfg.Database) == 0 || cfg.Database == "file:" { return nil, fmt.Errorf("No database configured") } @@ -750,7 +751,6 @@ func (cfg *Config) createConnection() (driver.Conn, error) { // Check if the database was opened succesful. if rv != C.SQLITE_OK { - fmt.Println(Error{Code: ErrNo(rv)}) return nil, Error{Code: ErrNo(rv)} } @@ -810,7 +810,7 @@ func (cfg *Config) createConnection() (driver.Conn, error) { // Register sqlite_crypt function with the CryptEncoder provided // within *Config.Authentication if err := conn.RegisterFunc("sqlite_crypt", cfg.Authentication.Encoder.Encode, true); err != nil { - return nil, fmt.Errorf("CryptEncoderSHA1: %s", err) + return nil, fmt.Errorf("CryptEncoder: %s", err) } // Register: authenticate diff --git a/driver/config_test.go b/driver/config_test.go index 6bbc4db..17d1039 100644 --- a/driver/config_test.go +++ b/driver/config_test.go @@ -13,10 +13,6 @@ import ( "time" ) -func TestConfig(t *testing.T) { - -} - func TestParseDSN(t *testing.T) { // URI uriCases := map[string]*Config{ diff --git a/driver/connection.go b/driver/connection.go index 86a8fc6..aecb293 100644 --- a/driver/connection.go +++ b/driver/connection.go @@ -45,6 +45,7 @@ type SQLiteConn struct { aggregators []*aggInfo } +// PRAGMA executes a PRAGMA statement func (c *SQLiteConn) PRAGMA(name, value string) error { stmt := fmt.Sprintf("PRAGMA %s = %s;", name, value) diff --git a/driver/crypt.go b/driver/crypt.go index 8f2be38..76af313 100644 --- a/driver/crypt.go +++ b/driver/crypt.go @@ -3,6 +3,8 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +package sqlite3 + // The crypt functions provides several different implementations for the // default embedded sqlite_crypt function. // This function is uses a ceasar-cypher by default @@ -40,7 +42,6 @@ // sqlite3_create_function() interface to provide an alternative // implementation of sqlite_crypt() that computes a stronger password hash, // perhaps using a cryptographic hash function like SHA1. -package sqlite3 import ( "crypto/sha1" diff --git a/driver/crypt_test.go b/driver/crypt_test.go index 5d85f73..72919a4 100644 --- a/driver/crypt_test.go +++ b/driver/crypt_test.go @@ -59,5 +59,9 @@ func TestEncoders(t *testing.T) { if strings.Compare(fmt.Sprintf("%x", h), e.expected) != 0 { t.Fatalf("Invalid %s hash: expected: %s; got: %x", strings.ToUpper(e.enc), e.expected, h) } + + if e.enc != enc.String() { + t.Fatalf("Invalid encoder; expected: %s, got: %s", e.enc, enc.String()) + } } } diff --git a/driver/driver_go110_test.go b/driver/driver_go110_test.go new file mode 100644 index 0000000..31bf3fa --- /dev/null +++ b/driver/driver_go110_test.go @@ -0,0 +1,72 @@ +// Copyright (C) 2018 The Go-SQLite3 Authors. +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build cgo +// +build go1.10 + +package sqlite3 + +import ( + "context" + "database/sql" + "database/sql/driver" + "os" + "testing" +) + +func TestOpenConnector(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + drv := &SQLiteDriver{} + + connector, err := drv.OpenConnector(tempFilename) + if err != nil { + t.Fatal(err) + } + conn, err := connector.Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("Failed to create connection to database") + } + defer conn.Close() + + stmt, err := conn.Prepare("create table if not exists foo (id integer)") + if err != nil { + t.Fatalf("Failed to create statement: %s", err) + } + defer stmt.Close() + if _, err := stmt.Exec([]driver.Value{}); err != nil { + t.Fatalf("Failed to exec statement: %s", err) + } + + // Verify database has been created + if _, err := os.Stat(tempFilename); os.IsNotExist(err) { + t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err) + } +} + +func TestOpenDB(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + cfg := NewConfig() + cfg.Database = tempFilename + + // OpenDB + db := sql.OpenDB(cfg) + defer db.Close() + + _, err := db.Exec("create table if not exists foo (id integer)") + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + + if _, err := os.Stat(tempFilename); os.IsNotExist(err) { + t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err) + } +} diff --git a/driver/driver_test.go b/driver/driver_test.go index 725bb71..244dba5 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -8,7 +8,11 @@ package sqlite3 import ( + "database/sql" + "database/sql/driver" + "fmt" "io/ioutil" + "os" "testing" ) @@ -20,3 +24,174 @@ func TempFilename(t *testing.T) string { f.Close() return f.Name() } + +func TestOpen(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatalf("Failed to open database: %s", err) + } + defer db.Close() + + _, err = db.Exec("create table if not exists foo (id integer)") + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + + if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() { + t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err) + } + + tempFilename = TempFilename(t) + defer os.Remove(tempFilename) + + // Open Driver Directly + drv := &SQLiteDriver{} + conn, err := drv.Open(tempFilename) + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("Failed to create connection to database") + } + defer conn.Close() + + stmt, err := conn.Prepare("create table if not exists foo (id integer)") + if err != nil { + t.Fatalf("Failed to create statement: %s", err) + } + defer stmt.Close() + if _, err := stmt.Exec([]driver.Value{}); err != nil { + t.Fatalf("Failed to exec statement: %s", err) + } + + // Verify database has been created + if _, err := os.Stat(tempFilename); os.IsNotExist(err) { + t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err) + } +} + +func TestOpenInvalidDSN(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + // Open Invalid DSN + drv := &SQLiteDriver{} + conn, err := drv.Open(fmt.Sprintf("%s?%35%2%%43?test=false", tempFilename)) + if err == nil { + t.Fatal("Connection created while error was expected") + } + if conn != nil { + t.Fatal("Conection created while error was expected") + } +} + +func TestOpenConfigDSN(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + cfg := NewConfig() + cfg.Database = tempFilename + + db, err := sql.Open("sqlite3", cfg.FormatDSN()) + if err != nil { + t.Fatalf("Failed to open database: %s", err) + } + defer db.Close() + + _, err = db.Exec("create table if not exists foo (id integer)") + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + + if _, err := os.Stat(tempFilename); os.IsNotExist(err) { + t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err) + } + + // Test Open Empry Database location + cfg.Database = "" + + db, err = sql.Open("sqlite3", cfg.FormatDSN()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("create table if not exists foo (id integer)") + if err == nil { + t.Fatalf("Table created while error was expected") + } +} + +func TestInvalidConnectHook(t *testing.T) { + driverName := "sqlite3_invalid_connecthook" + sql.Register(driverName, &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return fmt.Errorf("ConnectHook Error") + }, + }) + + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open(driverName, tempFilename) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("create table if not exists foo (id integer)") + if err == nil { + t.Fatalf("Table created while error was expected") + } +} + +func TestInvalidExtension(t *testing.T) { + driverName := "sqlite3_invalid_extension" + sql.Register(driverName, &SQLiteDriver{ + Extensions: []string{ + "invalid.extension", + }, + }) + + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open(driverName, tempFilename) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("create table if not exists foo (id integer)") + if err == nil { + t.Fatalf("Table created while error was expected") + } + + tempFilename = TempFilename(t) + defer os.Remove(tempFilename) + + driverName = "sqlite3_conn_invalid_extension" + var driverConn *SQLiteConn + sql.Register(driverName, &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + driverConn = conn + return nil + }, + }) + + db, err = sql.Open(driverName, tempFilename) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("SELECT 1;") + if err != nil { + t.Fatalf("Failed to exec ping statement") + } + + if err := driverConn.LoadExtension("invalid.extension", ""); err == nil { + t.Fatal("Extension loaded while error was expected") + } +} diff --git a/driver/extensions.go b/driver/extensions.go index 823a807..140b0d2 100644 --- a/driver/extensions.go +++ b/driver/extensions.go @@ -31,29 +31,29 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error { for _, extension := range extensions { cext := C.CString(extension) - cerr := C.CString("") defer C.free(unsafe.Pointer(cext)) - defer C.free(unsafe.Pointer(cerr)) - rv = C.sqlite3_load_extension(c.db, cext, nil, &cerr) + + rv = C.sqlite3_load_extension(c.db, cext, nil, nil) if rv != C.SQLITE_OK { - // Disable Extension Loading + // Disable extension loading on return rv = C.sqlite3_enable_load_extension(c.db, 0) if rv != C.SQLITE_OK { - return errors.New("Failed to disable extension loading") + return fmt.Errorf("Failed to disable extension loading: %s", C.GoString(C.sqlite3_errmsg(c.db))) } - fmt.Println(">>>") - fmt.Println(*cerr) - fmt.Printf("%v\n", cerr) - fmt.Printf("%v\n", *cerr) - return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) + + // Store last error + // Required because other wise next statement to disable extension loading + // will override the error + return fmt.Errorf("Failed to load extension: '%s'", extension) } } - // Disable extension loading + // Disable extension loading on return rv = C.sqlite3_enable_load_extension(c.db, 0) if rv != C.SQLITE_OK { - return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) + return fmt.Errorf("Failed to disable extension loading: %s", C.GoString(C.sqlite3_errmsg(c.db))) } + return nil } diff --git a/driver/opt_userauth.go b/driver/opt_userauth.go index 6c386a8..fa70c8f 100644 --- a/driver/opt_userauth.go +++ b/driver/opt_userauth.go @@ -85,12 +85,12 @@ const ( func (c *SQLiteConn) Authenticate(username, password string) error { rv := c.authenticate(username, password) switch rv { - case C.SQLITE_ERROR, C.SQLITE_AUTH: + case C.SQLITE_AUTH: return Error{Code: ErrNo(rv)} case C.SQLITE_OK: return nil default: - return c.lastError() + return Error{Code: ErrNo(rv)} } } @@ -132,12 +132,12 @@ func (c *SQLiteConn) AuthUserAdd(username, password string, admin bool) error { rv := c.authUserAdd(username, password, isAdmin) switch rv { - case C.SQLITE_ERROR, C.SQLITE_AUTH: + case C.SQLITE_AUTH: return Error{Code: ErrNo(rv)} case C.SQLITE_OK: return nil default: - return c.lastError() + return Error{Code: ErrNo(rv)} } } @@ -181,12 +181,12 @@ func (c *SQLiteConn) AuthUserChange(username, password string, admin bool) error rv := c.authUserChange(username, password, isAdmin) switch rv { - case C.SQLITE_ERROR, C.SQLITE_AUTH: + case C.SQLITE_AUTH: return Error{Code: ErrNo(rv)} case C.SQLITE_OK: return nil default: - return c.lastError() + return Error{Code: ErrNo(rv)} } } @@ -228,12 +228,12 @@ func (c *SQLiteConn) authUserChange(username, password string, admin int) int { func (c *SQLiteConn) AuthUserDelete(username string) error { rv := c.authUserDelete(username) switch rv { - case C.SQLITE_ERROR, C.SQLITE_AUTH: + case C.SQLITE_AUTH: return Error{Code: ErrNo(rv)} case C.SQLITE_OK: return nil default: - return c.lastError() + return Error{Code: ErrNo(rv)} } } diff --git a/driver/opt_userauth_test.go b/driver/opt_userauth_test.go index 4b03107..8593f42 100644 --- a/driver/opt_userauth_test.go +++ b/driver/opt_userauth_test.go @@ -463,8 +463,8 @@ func TestUserAuthModifyUser(t *testing.T) { } func TestUserAuthDeleteUser(t *testing.T) { - f1, db1, c, err := connect(t, "", "admin", "admin") - if err != nil && c == nil && db1 == nil { + f1, db1, c1, err := connect(t, "", "admin", "admin") + if err != nil && c1 == nil && db1 == nil { t.Fatal(err) } defer os.Remove(f1) @@ -522,13 +522,10 @@ func TestUserAuthDeleteUser(t *testing.T) { } // Delete user through *SQLiteConn - rv, err = deleteUser(db1, "admin3") + err = c1.AuthUserDelete("admin3") if err != nil { t.Fatal(err) } - if rv != 0 { - t.Fatal("Failed to delete admin3") - } // Verify user admin3 deleted exists, err = userExists(db1, "admin3")