diff --git a/tool/upgrade.go b/tool/upgrade.go index 4e99635..baadd0f 100644 --- a/tool/upgrade.go +++ b/tool/upgrade.go @@ -37,18 +37,21 @@ func main() { if err != nil { log.Fatal(err) } - defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { + resp.Body.Close() log.Fatal(err) } fmt.Printf("extracting %v\n", path.Base(url)) r, err := zip.NewReader(bytes.NewReader(b), resp.ContentLength) if err != nil { + resp.Body.Close() log.Fatal(err) } + resp.Body.Close() + for _, zf := range r.File { var f *os.File switch path.Base(zf.Name) { @@ -68,11 +71,27 @@ func main() { if err != nil { log.Fatal(err) } - _, err = io.Copy(f, zr) - f.Close() + + _, err = io.WriteString(f, "#ifndef USE_LIBSQLITE3\n") if err != nil { + zr.Close() + f.Close() log.Fatal(err) } + _, err = io.Copy(f, zr) + if err != nil { + zr.Close() + f.Close() + log.Fatal(err) + } + _, err = io.WriteString(f, "#else // USE_LIBSQLITE3\n // If users really want to link against the system sqlite3 we\n// need to make this file a noop.\n #endif") + if err != nil { + zr.Close() + f.Close() + log.Fatal(err) + } + zr.Close() + f.Close() fmt.Printf("extracted %v\n", filepath.Base(f.Name())) } }