diff --git a/stringer.go b/stringer.go index 6c267e5..5788650 100644 --- a/stringer.go +++ b/stringer.go @@ -31,6 +31,17 @@ import ( "github.com/pascaldekloe/name" ) +type arrayFlags []string + +func (af arrayFlags) String() string { + return strings.Join(af, "") +} + +func (af *arrayFlags) Set(value string) error { + *af = append(*af, value) + return nil +} + var ( typeNames = flag.String("type", "", "comma-separated list of type names; must be set") sql = flag.Bool("sql", false, "if true, the Scanner and Valuer interface will be implemented.") @@ -42,13 +53,19 @@ var ( trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix. Default: \"\"") ) +var comments arrayFlags + +func init() { + flag.Var(&comments, "comment", "comments to include in generated code, can repeat. Default: \"\"") +} + // Usage is a replacement usage function for the flags package. func Usage() { fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) - fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T [directory]\n") - fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T files... # Must be a single package\n") + fmt.Fprintf(os.Stderr, "\tenumer [flags] -type T [directory]\n") + fmt.Fprintf(os.Stderr, "\tenumer [flags] -type T files... # Must be a single package\n") fmt.Fprintf(os.Stderr, "For more information, see:\n") - fmt.Fprintf(os.Stderr, "\thttp://godoc.org/golang.org/x/tools/cmd/stringer\n") + fmt.Fprintf(os.Stderr, "\thttps://github.com/alvaroloes/enumer\n") fmt.Fprintf(os.Stderr, "Flags:\n") flag.PrintDefaults() } @@ -88,6 +105,7 @@ func main() { // Print the header and package clause. g.Printf("// Code generated by \"enumer %s\"; DO NOT EDIT.\n", strings.Join(os.Args[1:], " ")) g.Printf("\n") + g.Printf("// %s\n", comments.String()) g.Printf("package %s", g.pkg.name) g.Printf("\n") g.Printf("import (\n") @@ -108,16 +126,31 @@ func main() { // Format the output. src := g.format() - // Write to file. + // Figure out filename to write to outputName := *output if outputName == "" { baseName := fmt.Sprintf("%s_enumer.go", types[0]) outputName = filepath.Join(dir, strings.ToLower(baseName)) } - err := ioutil.WriteFile(outputName, src, 0644) + + // Write to tmpfile first + tmpFile, err := ioutil.TempFile("", fmt.Sprintf("%s_enumer_", types[0])) if err != nil { + log.Fatalf("creating temporary file for output: %s", err) + } + _, err = tmpFile.Write(src) + if err != nil { + tmpFile.Close() + os.Remove(tmpFile.Name()) log.Fatalf("writing output: %s", err) } + tmpFile.Close() + + // Rename tmpfile to output file + err = os.Rename(tmpFile.Name(), outputName) + if err != nil { + log.Fatalf("moving tempfile to output file: %s", err) + } } // isDirectory reports whether the named file is a directory.