diff --git a/stringer.go b/stringer.go index f32e9e9..5788650 100644 --- a/stringer.go +++ b/stringer.go @@ -126,16 +126,31 @@ func main() { // Format the output. src := g.format() - // Write to file. + // Figure out filename to write to outputName := *output if outputName == "" { baseName := fmt.Sprintf("%s_enumer.go", types[0]) outputName = filepath.Join(dir, strings.ToLower(baseName)) } - err := ioutil.WriteFile(outputName, src, 0644) + + // Write to tmpfile first + tmpFile, err := ioutil.TempFile("", fmt.Sprintf("%s_enumer_", types[0])) if err != nil { + log.Fatalf("creating temporary file for output: %s", err) + } + _, err = tmpFile.Write(src) + if err != nil { + tmpFile.Close() + os.Remove(tmpFile.Name()) log.Fatalf("writing output: %s", err) } + tmpFile.Close() + + // Rename tmpfile to output file + err = os.Rename(tmpFile.Name(), outputName) + if err != nil { + log.Fatalf("moving tempfile to output file: %s", err) + } } // isDirectory reports whether the named file is a directory.