forked from mirror/enumer
Merge pull request #33 from dterei/master
Add comment support + temporary file before writing
This commit is contained in:
commit
9875c3a8c3
43
stringer.go
43
stringer.go
|
@ -31,6 +31,17 @@ import (
|
||||||
"github.com/pascaldekloe/name"
|
"github.com/pascaldekloe/name"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type arrayFlags []string
|
||||||
|
|
||||||
|
func (af arrayFlags) String() string {
|
||||||
|
return strings.Join(af, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (af *arrayFlags) Set(value string) error {
|
||||||
|
*af = append(*af, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
typeNames = flag.String("type", "", "comma-separated list of type names; must be set")
|
typeNames = flag.String("type", "", "comma-separated list of type names; must be set")
|
||||||
sql = flag.Bool("sql", false, "if true, the Scanner and Valuer interface will be implemented.")
|
sql = flag.Bool("sql", false, "if true, the Scanner and Valuer interface will be implemented.")
|
||||||
|
@ -42,13 +53,19 @@ var (
|
||||||
trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix. Default: \"\"")
|
trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix. Default: \"\"")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var comments arrayFlags
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.Var(&comments, "comment", "comments to include in generated code, can repeat. Default: \"\"")
|
||||||
|
}
|
||||||
|
|
||||||
// Usage is a replacement usage function for the flags package.
|
// Usage is a replacement usage function for the flags package.
|
||||||
func Usage() {
|
func Usage() {
|
||||||
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
|
||||||
fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T [directory]\n")
|
fmt.Fprintf(os.Stderr, "\tenumer [flags] -type T [directory]\n")
|
||||||
fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T files... # Must be a single package\n")
|
fmt.Fprintf(os.Stderr, "\tenumer [flags] -type T files... # Must be a single package\n")
|
||||||
fmt.Fprintf(os.Stderr, "For more information, see:\n")
|
fmt.Fprintf(os.Stderr, "For more information, see:\n")
|
||||||
fmt.Fprintf(os.Stderr, "\thttp://godoc.org/golang.org/x/tools/cmd/stringer\n")
|
fmt.Fprintf(os.Stderr, "\thttps://github.com/alvaroloes/enumer\n")
|
||||||
fmt.Fprintf(os.Stderr, "Flags:\n")
|
fmt.Fprintf(os.Stderr, "Flags:\n")
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
}
|
}
|
||||||
|
@ -88,6 +105,7 @@ func main() {
|
||||||
// Print the header and package clause.
|
// Print the header and package clause.
|
||||||
g.Printf("// Code generated by \"enumer %s\"; DO NOT EDIT.\n", strings.Join(os.Args[1:], " "))
|
g.Printf("// Code generated by \"enumer %s\"; DO NOT EDIT.\n", strings.Join(os.Args[1:], " "))
|
||||||
g.Printf("\n")
|
g.Printf("\n")
|
||||||
|
g.Printf("// %s\n", comments.String())
|
||||||
g.Printf("package %s", g.pkg.name)
|
g.Printf("package %s", g.pkg.name)
|
||||||
g.Printf("\n")
|
g.Printf("\n")
|
||||||
g.Printf("import (\n")
|
g.Printf("import (\n")
|
||||||
|
@ -108,16 +126,31 @@ func main() {
|
||||||
// Format the output.
|
// Format the output.
|
||||||
src := g.format()
|
src := g.format()
|
||||||
|
|
||||||
// Write to file.
|
// Figure out filename to write to
|
||||||
outputName := *output
|
outputName := *output
|
||||||
if outputName == "" {
|
if outputName == "" {
|
||||||
baseName := fmt.Sprintf("%s_enumer.go", types[0])
|
baseName := fmt.Sprintf("%s_enumer.go", types[0])
|
||||||
outputName = filepath.Join(dir, strings.ToLower(baseName))
|
outputName = filepath.Join(dir, strings.ToLower(baseName))
|
||||||
}
|
}
|
||||||
err := ioutil.WriteFile(outputName, src, 0644)
|
|
||||||
|
// Write to tmpfile first
|
||||||
|
tmpFile, err := ioutil.TempFile("", fmt.Sprintf("%s_enumer_", types[0]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Fatalf("creating temporary file for output: %s", err)
|
||||||
|
}
|
||||||
|
_, err = tmpFile.Write(src)
|
||||||
|
if err != nil {
|
||||||
|
tmpFile.Close()
|
||||||
|
os.Remove(tmpFile.Name())
|
||||||
log.Fatalf("writing output: %s", err)
|
log.Fatalf("writing output: %s", err)
|
||||||
}
|
}
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
// Rename tmpfile to output file
|
||||||
|
err = os.Rename(tmpFile.Name(), outputName)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("moving tempfile to output file: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isDirectory reports whether the named file is a directory.
|
// isDirectory reports whether the named file is a directory.
|
||||||
|
|
Loading…
Reference in New Issue