diff --git a/sqlite3.go b/sqlite3.go index 25b646f..f879718 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -198,11 +198,11 @@ type SQLiteResult struct { } func (r *SQLiteResult) LastInsertId() (int64, error) { - return int64(C.sqlite3_last_insert_rowid(r.s.s)), nil + return int64(C.sqlite3_last_insert_rowid(r.s.c.db)), nil } func (r *SQLiteResult) RowsAffected() (int64, error) { - return int64(C.sqlite3_changes(r.s.s)), nil + return int64(C.sqlite3_changes(r.s.c.db)), nil } func (s *SQLiteStmt) Exec(args []interface{}) (driver.Result, error) { diff --git a/sqlite3_test.go b/sqlite3_test.go index c46c2e3..ede867c 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1,7 +1,6 @@ package sqlite import ( - "fmt" "testing" "exp/sql" "os" @@ -10,7 +9,7 @@ import ( func TestOpen(t *testing.T) { db, err := sql.Open("sqlite3", "./foo.db") if err != nil { - fmt.Println(err) + t.Errorf("Failed to open database:", err) return } defer os.Remove("./foo.db") @@ -30,7 +29,7 @@ func TestOpen(t *testing.T) { func TestInsert(t *testing.T) { db, err := sql.Open("sqlite3", "./foo.db") if err != nil { - fmt.Println(err) + t.Errorf("Failed to open database:", err) return } defer os.Remove("./foo.db") @@ -42,15 +41,20 @@ func TestInsert(t *testing.T) { return } - _, err = db.Exec("insert into foo(id) values(123)") + res, err := db.Exec("insert into foo(id) values(123)") if err != nil { t.Errorf("Failed to insert record:", err) return } + affected, _ := res.RowsAffected() + if affected != 1 { + t.Errorf("Expected %d for affected rows, but %d:", 1, affected) + return + } rows, err := db.Query("select id from foo") if err != nil { - fmt.Println(err) + t.Errorf("Failed to select records:", err) return } defer rows.Close() @@ -67,7 +71,7 @@ func TestInsert(t *testing.T) { func TestUpdate(t *testing.T) { db, err := sql.Open("sqlite3", "./foo.db") if err != nil { - fmt.Println(err) + t.Errorf("Failed to open database:", err) return } defer os.Remove("./foo.db") @@ -79,21 +83,52 @@ func TestUpdate(t *testing.T) { return } - _, err = db.Exec("insert into foo(id) values(123)") + res, err := db.Exec("insert into foo(id) values(123)") if err != nil { t.Errorf("Failed to insert record:", err) return } + expected, err := res.LastInsertId() + if err != nil { + t.Errorf("Failed to get LastInsertId:", err) + return + } + affected, _ := res.RowsAffected() + if err != nil { + t.Errorf("Failed to get RowsAffected:", err) + return + } + if affected != 1 { + t.Errorf("Expected %d for affected rows, but %d:", 1, affected) + return + } - _, err = db.Exec("update foo set id = 234") + res, err = db.Exec("update foo set id = 234") if err != nil { t.Errorf("Failed to update record:", err) return } + lastId, err := res.LastInsertId() + if err != nil { + t.Errorf("Failed to get LastInsertId:", err) + return + } + if expected != lastId { + t.Errorf("Expected %q for last Id, but %q:", expected, lastId) + } + affected, _ = res.RowsAffected() + if err != nil { + t.Errorf("Failed to get RowsAffected:", err) + return + } + if affected != 1 { + t.Errorf("Expected %d for affected rows, but %d:", 1, affected) + return + } rows, err := db.Query("select id from foo") if err != nil { - fmt.Println(err) + t.Errorf("Failed to select records:", err) return } defer rows.Close() @@ -110,7 +145,7 @@ func TestUpdate(t *testing.T) { func TestDelete(t *testing.T) { db, err := sql.Open("sqlite3", "./foo.db") if err != nil { - fmt.Println(err) + t.Errorf("Failed to select records:", err) return } defer os.Remove("./foo.db") @@ -122,21 +157,50 @@ func TestDelete(t *testing.T) { return } - _, err = db.Exec("insert into foo(id) values(123)") + res, err := db.Exec("insert into foo(id) values(123)") if err != nil { t.Errorf("Failed to insert record:", err) return } + expected, err := res.LastInsertId() + if err != nil { + t.Errorf("Failed to get LastInsertId:", err) + return + } + affected, err := res.RowsAffected() + if err != nil { + t.Errorf("Failed to get RowsAffected:", err) + return + } + if affected != 1 { + t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) + } - _, err = db.Exec("delete from foo where id = 123") + res, err = db.Exec("delete from foo where id = 123") if err != nil { t.Errorf("Failed to delete record:", err) return } + lastId, err := res.LastInsertId() + if err != nil { + t.Errorf("Failed to get LastInsertId:", err) + return + } + if expected != lastId { + t.Errorf("Expected %q for last Id, but %q:", expected, lastId) + } + affected, err = res.RowsAffected() + if err != nil { + t.Errorf("Failed to get RowsAffected:", err) + return + } + if affected != 1 { + t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) + } rows, err := db.Query("select id from foo") if err != nil { - fmt.Println(err) + t.Errorf("Failed to select records:", err) return } defer rows.Close()