* fix code format
* add tests for Golang:1.10
* Fix error on extension loading
* Fix test on userauth
This commit is contained in:
Gert-Jan Timmer 2018-06-25 17:29:07 +02:00
parent f5c3c6f922
commit 8334dc42e6
10 changed files with 280 additions and 34 deletions

View File

@ -165,6 +165,7 @@ func (tx TxLock) String() string {
} }
} }
// Value returns the Transaction Lock Value
func (tx TxLock) Value() string { func (tx TxLock) Value() string {
return string(tx) return string(tx)
} }
@ -717,7 +718,7 @@ func (cfg *Config) createConnection() (driver.Conn, error) {
return nil, errors.New("sqlite library was not compiled for thread-safe operation") return nil, errors.New("sqlite library was not compiled for thread-safe operation")
} }
if len(cfg.Database) == 0 { if len(cfg.Database) == 0 || cfg.Database == "file:" {
return nil, fmt.Errorf("No database configured") return nil, fmt.Errorf("No database configured")
} }
@ -750,7 +751,6 @@ func (cfg *Config) createConnection() (driver.Conn, error) {
// Check if the database was opened succesful. // Check if the database was opened succesful.
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
fmt.Println(Error{Code: ErrNo(rv)})
return nil, Error{Code: ErrNo(rv)} return nil, Error{Code: ErrNo(rv)}
} }
@ -810,7 +810,7 @@ func (cfg *Config) createConnection() (driver.Conn, error) {
// Register sqlite_crypt function with the CryptEncoder provided // Register sqlite_crypt function with the CryptEncoder provided
// within *Config.Authentication // within *Config.Authentication
if err := conn.RegisterFunc("sqlite_crypt", cfg.Authentication.Encoder.Encode, true); err != nil { if err := conn.RegisterFunc("sqlite_crypt", cfg.Authentication.Encoder.Encode, true); err != nil {
return nil, fmt.Errorf("CryptEncoderSHA1: %s", err) return nil, fmt.Errorf("CryptEncoder: %s", err)
} }
// Register: authenticate // Register: authenticate

View File

@ -13,10 +13,6 @@ import (
"time" "time"
) )
func TestConfig(t *testing.T) {
}
func TestParseDSN(t *testing.T) { func TestParseDSN(t *testing.T) {
// URI // URI
uriCases := map[string]*Config{ uriCases := map[string]*Config{

View File

@ -45,6 +45,7 @@ type SQLiteConn struct {
aggregators []*aggInfo aggregators []*aggInfo
} }
// PRAGMA executes a PRAGMA statement
func (c *SQLiteConn) PRAGMA(name, value string) error { func (c *SQLiteConn) PRAGMA(name, value string) error {
stmt := fmt.Sprintf("PRAGMA %s = %s;", name, value) stmt := fmt.Sprintf("PRAGMA %s = %s;", name, value)

View File

@ -3,6 +3,8 @@
// Use of this source code is governed by an MIT-style // Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package sqlite3
// The crypt functions provides several different implementations for the // The crypt functions provides several different implementations for the
// default embedded sqlite_crypt function. // default embedded sqlite_crypt function.
// This function is uses a ceasar-cypher by default // This function is uses a ceasar-cypher by default
@ -40,7 +42,6 @@
// sqlite3_create_function() interface to provide an alternative // sqlite3_create_function() interface to provide an alternative
// implementation of sqlite_crypt() that computes a stronger password hash, // implementation of sqlite_crypt() that computes a stronger password hash,
// perhaps using a cryptographic hash function like SHA1. // perhaps using a cryptographic hash function like SHA1.
package sqlite3
import ( import (
"crypto/sha1" "crypto/sha1"

View File

@ -59,5 +59,9 @@ func TestEncoders(t *testing.T) {
if strings.Compare(fmt.Sprintf("%x", h), e.expected) != 0 { if strings.Compare(fmt.Sprintf("%x", h), e.expected) != 0 {
t.Fatalf("Invalid %s hash: expected: %s; got: %x", strings.ToUpper(e.enc), e.expected, h) t.Fatalf("Invalid %s hash: expected: %s; got: %x", strings.ToUpper(e.enc), e.expected, h)
} }
if e.enc != enc.String() {
t.Fatalf("Invalid encoder; expected: %s, got: %s", e.enc, enc.String())
}
} }
} }

View File

@ -0,0 +1,72 @@
// Copyright (C) 2018 The Go-SQLite3 Authors.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build cgo
// +build go1.10
package sqlite3
import (
"context"
"database/sql"
"database/sql/driver"
"os"
"testing"
)
func TestOpenConnector(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
drv := &SQLiteDriver{}
connector, err := drv.OpenConnector(tempFilename)
if err != nil {
t.Fatal(err)
}
conn, err := connector.Connect(context.Background())
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("Failed to create connection to database")
}
defer conn.Close()
stmt, err := conn.Prepare("create table if not exists foo (id integer)")
if err != nil {
t.Fatalf("Failed to create statement: %s", err)
}
defer stmt.Close()
if _, err := stmt.Exec([]driver.Value{}); err != nil {
t.Fatalf("Failed to exec statement: %s", err)
}
// Verify database has been created
if _, err := os.Stat(tempFilename); os.IsNotExist(err) {
t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err)
}
}
func TestOpenDB(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
cfg := NewConfig()
cfg.Database = tempFilename
// OpenDB
db := sql.OpenDB(cfg)
defer db.Close()
_, err := db.Exec("create table if not exists foo (id integer)")
if err != nil {
t.Fatalf("Failed to create table: %s", err)
}
if _, err := os.Stat(tempFilename); os.IsNotExist(err) {
t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err)
}
}

View File

@ -8,7 +8,11 @@
package sqlite3 package sqlite3
import ( import (
"database/sql"
"database/sql/driver"
"fmt"
"io/ioutil" "io/ioutil"
"os"
"testing" "testing"
) )
@ -20,3 +24,174 @@ func TempFilename(t *testing.T) string {
f.Close() f.Close()
return f.Name() return f.Name()
} }
func TestOpen(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatalf("Failed to open database: %s", err)
}
defer db.Close()
_, err = db.Exec("create table if not exists foo (id integer)")
if err != nil {
t.Fatalf("Failed to create table: %s", err)
}
if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err)
}
tempFilename = TempFilename(t)
defer os.Remove(tempFilename)
// Open Driver Directly
drv := &SQLiteDriver{}
conn, err := drv.Open(tempFilename)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("Failed to create connection to database")
}
defer conn.Close()
stmt, err := conn.Prepare("create table if not exists foo (id integer)")
if err != nil {
t.Fatalf("Failed to create statement: %s", err)
}
defer stmt.Close()
if _, err := stmt.Exec([]driver.Value{}); err != nil {
t.Fatalf("Failed to exec statement: %s", err)
}
// Verify database has been created
if _, err := os.Stat(tempFilename); os.IsNotExist(err) {
t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err)
}
}
func TestOpenInvalidDSN(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
// Open Invalid DSN
drv := &SQLiteDriver{}
conn, err := drv.Open(fmt.Sprintf("%s?%35%2%%43?test=false", tempFilename))
if err == nil {
t.Fatal("Connection created while error was expected")
}
if conn != nil {
t.Fatal("Conection created while error was expected")
}
}
func TestOpenConfigDSN(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
cfg := NewConfig()
cfg.Database = tempFilename
db, err := sql.Open("sqlite3", cfg.FormatDSN())
if err != nil {
t.Fatalf("Failed to open database: %s", err)
}
defer db.Close()
_, err = db.Exec("create table if not exists foo (id integer)")
if err != nil {
t.Fatalf("Failed to create table: %s", err)
}
if _, err := os.Stat(tempFilename); os.IsNotExist(err) {
t.Fatalf("Failed to create database: '%s'; %s", tempFilename, err)
}
// Test Open Empry Database location
cfg.Database = ""
db, err = sql.Open("sqlite3", cfg.FormatDSN())
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec("create table if not exists foo (id integer)")
if err == nil {
t.Fatalf("Table created while error was expected")
}
}
func TestInvalidConnectHook(t *testing.T) {
driverName := "sqlite3_invalid_connecthook"
sql.Register(driverName, &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
return fmt.Errorf("ConnectHook Error")
},
})
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open(driverName, tempFilename)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec("create table if not exists foo (id integer)")
if err == nil {
t.Fatalf("Table created while error was expected")
}
}
func TestInvalidExtension(t *testing.T) {
driverName := "sqlite3_invalid_extension"
sql.Register(driverName, &SQLiteDriver{
Extensions: []string{
"invalid.extension",
},
})
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open(driverName, tempFilename)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec("create table if not exists foo (id integer)")
if err == nil {
t.Fatalf("Table created while error was expected")
}
tempFilename = TempFilename(t)
defer os.Remove(tempFilename)
driverName = "sqlite3_conn_invalid_extension"
var driverConn *SQLiteConn
sql.Register(driverName, &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
driverConn = conn
return nil
},
})
db, err = sql.Open(driverName, tempFilename)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec("SELECT 1;")
if err != nil {
t.Fatalf("Failed to exec ping statement")
}
if err := driverConn.LoadExtension("invalid.extension", ""); err == nil {
t.Fatal("Extension loaded while error was expected")
}
}

View File

@ -31,29 +31,29 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
for _, extension := range extensions { for _, extension := range extensions {
cext := C.CString(extension) cext := C.CString(extension)
cerr := C.CString("")
defer C.free(unsafe.Pointer(cext)) defer C.free(unsafe.Pointer(cext))
defer C.free(unsafe.Pointer(cerr))
rv = C.sqlite3_load_extension(c.db, cext, nil, &cerr) rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
// Disable Extension Loading // Disable extension loading on return
rv = C.sqlite3_enable_load_extension(c.db, 0) rv = C.sqlite3_enable_load_extension(c.db, 0)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return errors.New("Failed to disable extension loading") return fmt.Errorf("Failed to disable extension loading: %s", C.GoString(C.sqlite3_errmsg(c.db)))
} }
fmt.Println(">>>")
fmt.Println(*cerr) // Store last error
fmt.Printf("%v\n", cerr) // Required because other wise next statement to disable extension loading
fmt.Printf("%v\n", *cerr) // will override the error
return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) return fmt.Errorf("Failed to load extension: '%s'", extension)
} }
} }
// Disable extension loading // Disable extension loading on return
rv = C.sqlite3_enable_load_extension(c.db, 0) rv = C.sqlite3_enable_load_extension(c.db, 0)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) return fmt.Errorf("Failed to disable extension loading: %s", C.GoString(C.sqlite3_errmsg(c.db)))
} }
return nil return nil
} }

View File

@ -85,12 +85,12 @@ const (
func (c *SQLiteConn) Authenticate(username, password string) error { func (c *SQLiteConn) Authenticate(username, password string) error {
rv := c.authenticate(username, password) rv := c.authenticate(username, password)
switch rv { switch rv {
case C.SQLITE_ERROR, C.SQLITE_AUTH: case C.SQLITE_AUTH:
return Error{Code: ErrNo(rv)} return Error{Code: ErrNo(rv)}
case C.SQLITE_OK: case C.SQLITE_OK:
return nil return nil
default: default:
return c.lastError() return Error{Code: ErrNo(rv)}
} }
} }
@ -132,12 +132,12 @@ func (c *SQLiteConn) AuthUserAdd(username, password string, admin bool) error {
rv := c.authUserAdd(username, password, isAdmin) rv := c.authUserAdd(username, password, isAdmin)
switch rv { switch rv {
case C.SQLITE_ERROR, C.SQLITE_AUTH: case C.SQLITE_AUTH:
return Error{Code: ErrNo(rv)} return Error{Code: ErrNo(rv)}
case C.SQLITE_OK: case C.SQLITE_OK:
return nil return nil
default: default:
return c.lastError() return Error{Code: ErrNo(rv)}
} }
} }
@ -181,12 +181,12 @@ func (c *SQLiteConn) AuthUserChange(username, password string, admin bool) error
rv := c.authUserChange(username, password, isAdmin) rv := c.authUserChange(username, password, isAdmin)
switch rv { switch rv {
case C.SQLITE_ERROR, C.SQLITE_AUTH: case C.SQLITE_AUTH:
return Error{Code: ErrNo(rv)} return Error{Code: ErrNo(rv)}
case C.SQLITE_OK: case C.SQLITE_OK:
return nil return nil
default: default:
return c.lastError() return Error{Code: ErrNo(rv)}
} }
} }
@ -228,12 +228,12 @@ func (c *SQLiteConn) authUserChange(username, password string, admin int) int {
func (c *SQLiteConn) AuthUserDelete(username string) error { func (c *SQLiteConn) AuthUserDelete(username string) error {
rv := c.authUserDelete(username) rv := c.authUserDelete(username)
switch rv { switch rv {
case C.SQLITE_ERROR, C.SQLITE_AUTH: case C.SQLITE_AUTH:
return Error{Code: ErrNo(rv)} return Error{Code: ErrNo(rv)}
case C.SQLITE_OK: case C.SQLITE_OK:
return nil return nil
default: default:
return c.lastError() return Error{Code: ErrNo(rv)}
} }
} }

View File

@ -463,8 +463,8 @@ func TestUserAuthModifyUser(t *testing.T) {
} }
func TestUserAuthDeleteUser(t *testing.T) { func TestUserAuthDeleteUser(t *testing.T) {
f1, db1, c, err := connect(t, "", "admin", "admin") f1, db1, c1, err := connect(t, "", "admin", "admin")
if err != nil && c == nil && db1 == nil { if err != nil && c1 == nil && db1 == nil {
t.Fatal(err) t.Fatal(err)
} }
defer os.Remove(f1) defer os.Remove(f1)
@ -522,13 +522,10 @@ func TestUserAuthDeleteUser(t *testing.T) {
} }
// Delete user through *SQLiteConn // Delete user through *SQLiteConn
rv, err = deleteUser(db1, "admin3") err = c1.AuthUserDelete("admin3")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if rv != 0 {
t.Fatal("Failed to delete admin3")
}
// Verify user admin3 deleted // Verify user admin3 deleted
exists, err = userExists(db1, "admin3") exists, err = userExists(db1, "admin3")