diff --git a/upgrade/package.go b/upgrade/package.go index 895bc73..557ded3 100644 --- a/upgrade/package.go +++ b/upgrade/package.go @@ -1,3 +1,5 @@ +// +build !upgrade + // Package upgrade is a dummy package to ensure package can be loaded // // This file is to avoid the following error: diff --git a/upgrade/upgrade.go b/upgrade/upgrade.go index 9652c67..2c4a92d 100644 --- a/upgrade/upgrade.go +++ b/upgrade/upgrade.go @@ -7,13 +7,15 @@ import ( "archive/zip" "bufio" "bytes" + "crypto/sha1" + "encoding/hex" + "errors" "fmt" "io" "io/ioutil" "log" "net/http" "os" - "path" "path/filepath" "strings" "time" @@ -21,211 +23,235 @@ import ( "github.com/PuerkitoBio/goquery" ) -func download(prefix string) (url string, content []byte, err error) { +const buildFlags = "-DSQLITE_ENABLE_UPDATE_DELETE_LIMIT=1" + +func main() { + err := func() error { + fmt.Println("Go-SQLite3 Upgrade Tool") + + wd, err := os.Getwd() + if err != nil { + return err + } + if filepath.Base(wd) != "upgrade" { + return fmt.Errorf("Current directory is %q but should run in upgrade directory", wd) + } + + // Download Source + source, hash, err := download("sqlite-src-") + if err != nil { + return fmt.Errorf("failed to download: sqlite-src; %v", err) + } + fmt.Printf("Download successful and verified hash %x\n", hash) + + // Extract Source + baseDir, err := extractZip(source) + if baseDir != "" && !filepath.IsAbs(baseDir) { + defer func() { + fmt.Println("Cleaning up source: deleting", baseDir) + os.RemoveAll(baseDir) + }() + } + if err != nil { + return fmt.Errorf("failed to extract source: %v", err) + } + fmt.Println("Extracted sqlite source to", baseDir) + + // Build amalgamation files (OS-specific) + fmt.Printf("Starting to generate amalgamation with build flags: %s\n", buildFlags) + if err := buildAmalgamation(baseDir, buildFlags); err != nil { + return fmt.Errorf("failed to build amalgamation: %v", err) + } + fmt.Println("SQLite3 amalgamation built") + + // Patch bindings + patchSource(baseDir, "sqlite3.c", "../sqlite3-binding.c", "ext/userauth/userauth.c") + patchSource(baseDir, "sqlite3.h", "../sqlite3-binding.h", "ext/userauth/sqlite3userauth.h") + patchSource(baseDir, "sqlite3ext.h", "../sqlite3ext.h") + + fmt.Println("Done patching amalgamation") + return nil + }() + if err != nil { + log.Fatal("Returned with error:", err) + } +} + +func download(prefix string) (content, hash []byte, err error) { year := time.Now().Year() site := "https://www.sqlite.org/download.html" //fmt.Printf("scraping %v\n", site) doc, err := goquery.NewDocument(site) if err != nil { - log.Fatal(err) + return nil, nil, err } - doc.Find("a").Each(func(_ int, s *goquery.Selection) { - if strings.HasPrefix(s.Text(), prefix) { - url = fmt.Sprintf("https://www.sqlite.org/%d/", year) + s.Text() + url, hashString := "", "" + doc.Find("tr").EachWithBreak(func(_ int, s *goquery.Selection) bool { + found := false + s.Find("a").Each(func(_ int, s *goquery.Selection) { + if strings.HasPrefix(s.Text(), prefix) { + found = true + url = fmt.Sprintf("https://www.sqlite.org/%d/", year) + s.Text() + } + }) + if found { + s.Find("td").Each(func(_ int, s *goquery.Selection) { + text := s.Text() + split := strings.Split(text, "(sha1: ") + if len(split) < 2 { + return + } + text = split[1] + hashString = strings.Split(text, ")")[0] + }) } + return !found }) + targetHash, err := hex.DecodeString(hashString) + if err != nil || len(targetHash) != sha1.Size { + return nil, nil, fmt.Errorf("unable to find valid sha1 hash on sqlite.org: %q", hashString) + } + if url == "" { - return "", nil, fmt.Errorf("Unable to find prefix '%s' on sqlite.org", prefix) + return nil, nil, fmt.Errorf("unable to find prefix '%s' on sqlite.org", prefix) } fmt.Printf("Downloading %v\n", url) resp, err := http.Get(url) if err != nil { - log.Fatal(err) + return nil, nil, err } // Ready Body Content - content, err = ioutil.ReadAll(resp.Body) + shasum := sha1.New() + content, err = ioutil.ReadAll(io.TeeReader(resp.Body, shasum)) defer resp.Body.Close() if err != nil { - return "", nil, err + return nil, nil, err } - return url, content, nil + computedHash := shasum.Sum(nil) + if !bytes.Equal(targetHash, computedHash) { + return nil, nil, fmt.Errorf("invalid hash of file downloaded from %q: got %x instead of %x", url, computedHash, targetHash) + } + + return content, computedHash, nil } -func mergeFile(src string, dst string) error { - defer func() error { - fmt.Printf("Removing: %s\n", src) - err := os.Remove(src) +func extractZip(data []byte) (string, error) { + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return "", err + } - if err != nil { - return err + if len(zr.File) == 0 { + return "", errors.New("no files in zip archive") + } + if !zr.File[0].Mode().IsDir() { + return "", errors.New("expecting base directory at the top of zip archive") + } + baseDir := zr.File[0].Name + + for _, zf := range zr.File { + if !strings.HasPrefix(zf.Name, baseDir) { + return baseDir, fmt.Errorf("file %q in zip archive not in base directory %q", zf.Name, baseDir) } - return nil - }() - - // Open destination - fdst, err := os.OpenFile(dst, os.O_APPEND|os.O_WRONLY, 0666) - if err != nil { - return err - } - defer fdst.Close() - - // Read source content - content, err := ioutil.ReadFile(src) - if err != nil { - return err - } - - // Add Additional newline - if _, err := fdst.WriteString("\n"); err != nil { - return err - } - - fmt.Printf("Merging: %s into %s\n", src, dst) - if _, err = fdst.Write(content); err != nil { - return err - } - - return nil -} - -func main() { - fmt.Println("Go-SQLite3 Upgrade Tool") - - wd, err := os.Getwd() - if err != nil { - log.Fatal(err) - } - if filepath.Base(wd) != "upgrade" { - log.Printf("Current directory is %q but should run in upgrade directory", wd) - os.Exit(1) - } - - // Download Amalgamation - _, amalgamation, err := download("sqlite-amalgamation-") - if err != nil { - log.Fatalf("Failed to download: sqlite-amalgamation; %s", err) - } - - // Download Source - _, source, err := download("sqlite-src-") - if err != nil { - log.Fatalf("Failed to download: sqlite-src; %s", err) - } - - // Create Amalgamation Zip Reader - rAmalgamation, err := zip.NewReader(bytes.NewReader(amalgamation), int64(len(amalgamation))) - if err != nil { - log.Fatal(err) - } - - // Create Source Zip Reader - rSource, err := zip.NewReader(bytes.NewReader(source), int64(len(source))) - if err != nil { - log.Fatal(err) - } - - // Extract Amalgamation - for _, zf := range rAmalgamation.File { - var f *os.File - switch path.Base(zf.Name) { - case "sqlite3.c": - f, err = os.Create("../sqlite3-binding.c") - case "sqlite3.h": - f, err = os.Create("../sqlite3-binding.h") - case "sqlite3ext.h": - f, err = os.Create("../sqlite3ext.h") - default: + if zf.Mode().IsDir() { + if err := os.Mkdir(zf.Name, zf.Mode()); err != nil { + return baseDir, err + } continue } + f, err := os.OpenFile(zf.Name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, zf.Mode()) if err != nil { - log.Fatal(err) + return baseDir, err } - zr, err := zf.Open() - if err != nil { - log.Fatal(err) - } - - _, err = io.WriteString(f, "#ifndef USE_LIBSQLITE3\n") - if err != nil { - zr.Close() - f.Close() - log.Fatal(err) - } - scanner := bufio.NewScanner(zr) - for scanner.Scan() { - text := scanner.Text() - if text == `#include "sqlite3.h"` { - text = `#include "sqlite3-binding.h" -#ifdef __clang__ -#define assert(condition) ((void)0) -#endif -` - } - _, err = fmt.Fprintln(f, text) - if err != nil { - break - } - } - err = scanner.Err() - if err != nil { - zr.Close() - f.Close() - log.Fatal(err) - } - _, err = io.WriteString(f, "#else // USE_LIBSQLITE3\n // If users really want to link against the system sqlite3 we\n// need to make this file a noop.\n #endif") - if err != nil { - zr.Close() - f.Close() - log.Fatal(err) - } - zr.Close() - f.Close() - fmt.Printf("Extracted: %v\n", filepath.Base(f.Name())) - } - - //Extract Source - for _, zf := range rSource.File { - var f *os.File - switch path.Base(zf.Name) { - case "userauth.c": - f, err = os.Create("../userauth.c") - case "sqlite3userauth.h": - f, err = os.Create("../userauth.h") - default: + if zf.UncompressedSize == 0 { continue } - if err != nil { - log.Fatal(err) - } + zr, err := zf.Open() if err != nil { - log.Fatal(err) + return baseDir, err } _, err = io.Copy(f, zr) if err != nil { - log.Fatal(err) + return baseDir, err } - zr.Close() - f.Close() - fmt.Printf("extracted %v\n", filepath.Base(f.Name())) + if err := zr.Close(); err != nil { + return baseDir, err + } + if err := f.Close(); err != nil { + return baseDir, err + } } - // Merge SQLite User Authentication into amalgamation - if err := mergeFile("../userauth.c", "../sqlite3-binding.c"); err != nil { - log.Fatal(err) - } - if err := mergeFile("../userauth.h", "../sqlite3-binding.h"); err != nil { - log.Fatal(err) - } - - os.Exit(0) + return baseDir, nil +} + +func patchSource(baseDir, src, dst string, extensions ...string) error { + srcFile, err := os.Open(filepath.Join(baseDir, src)) + if err != nil { + return err + } + defer srcFile.Close() + + dstFile, err := os.Create(dst) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.WriteString(dstFile, "#ifndef USE_LIBSQLITE3\n") + if err != nil { + return err + } + scanner := bufio.NewScanner(srcFile) + for scanner.Scan() { + text := scanner.Text() + if text == `#include "sqlite3.h"` { + text = `#include "sqlite3-binding.h"` + } + _, err = fmt.Fprintln(dstFile, text) + if err != nil { + break + } + } + err = scanner.Err() + if err != nil { + return err + } + _, err = io.WriteString(dstFile, "#else // USE_LIBSQLITE3\n// If users really want to link against the system sqlite3 we\n// need to make this file a noop.\n#endif\n") + if err != nil { + return err + } + + for _, ext := range extensions { + ext = filepath.FromSlash(ext) + fmt.Printf("Merging: %s into %s\n", ext, dst) + + extFile, err := os.Open(filepath.Join(baseDir, ext)) + if err != nil { + return err + } + _, err = io.Copy(dstFile, extFile) + extFile.Close() + if err != nil { + return err + } + } + + if err := dstFile.Close(); err != nil { + return err + } + + fmt.Printf("Patched: %s -> %s\n", src, dst) + + return nil } diff --git a/upgrade/upgrade_unix.go b/upgrade/upgrade_unix.go new file mode 100644 index 0000000..61ad3bc --- /dev/null +++ b/upgrade/upgrade_unix.go @@ -0,0 +1,34 @@ +// +build !cgo +// +build upgrade,ignore +// +build !windows + +package main + +import ( + "fmt" + "os/exec" +) + +func buildAmalgamation(baseDir, buildFlags string) error { + args := []string{"configure"} + if buildFlags != "" { + args = append(args, "CFLAGS="+buildFlags) + } + cmd := exec.Command("sh", args...) + cmd.Dir = baseDir + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("configure failed: %v\n\n%s", err, out) + } + fmt.Println("Ran configure successfully") + + cmd = exec.Command("make", "sqlite3.c") + cmd.Dir = baseDir + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("make failed: %v\n\n%s", err, out) + } + fmt.Println("Ran make successfully") + + return nil +} diff --git a/upgrade/upgrade_windows.go b/upgrade/upgrade_windows.go new file mode 100644 index 0000000..7f926f8 --- /dev/null +++ b/upgrade/upgrade_windows.go @@ -0,0 +1,25 @@ +// +build !cgo +// +build upgrade,ignore + +package main + +import ( + "fmt" + "os/exec" +) + +func buildAmalgamation(baseDir, buildFlags string) error { + args := []string{"/f", "Makefile.msc", "sqlite3.c"} + if buildFlags != "" { + args = append(args, "OPTS="+buildFlags) + } + cmd := exec.Command("nmake", args...) + cmd.Dir = baseDir + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("nmake failed: %v\n\n%s", err, out) + } + fmt.Println("Ran nmake successfully") + + return nil +}