// Copyright 2014 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Enumer is a tool to generate Go code that adds useful methods to Go enums (constants with a specific type). // It started as a fork of Rob Pike’s Stringer tool // // Please visit http://github.com/dmarkham/enumer for a comprehensive documentation package main import ( "bytes" "flag" "fmt" "go/ast" exact "go/constant" "go/format" "go/importer" "go/token" "go/types" "io/ioutil" "log" "os" "path/filepath" "sort" "strings" "unicode" "unicode/utf8" "golang.org/x/tools/go/packages" "github.com/pascaldekloe/name" ) type arrayFlags []string func (af arrayFlags) String() string { return strings.Join(af, "") } func (af *arrayFlags) Set(value string) error { *af = append(*af, value) return nil } var ( typeNames = flag.String("type", "", "comma-separated list of type names; must be set") sql = flag.Bool("sql", false, "if true, the Scanner and Valuer interface will be implemented.") json = flag.Bool("json", false, "if true, json marshaling methods will be generated. Default: false") yaml = flag.Bool("yaml", false, "if true, yaml marshaling methods will be generated. Default: false") text = flag.Bool("text", false, "if true, text marshaling methods will be generated. Default: false") output = flag.String("output", "", "output file name; default srcdir/_string.go") transformMethod = flag.String("transform", "noop", "enum item name transformation method. Default: noop") trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix. Default: \"\"") addPrefix = flag.String("addprefix", "", "transform each item name by adding a prefix. Default: \"\"") linecomment = flag.Bool("linecomment", false, "use line comment text as printed text when present") ) var comments arrayFlags func init() { flag.Var(&comments, "comment", "comments to include in generated code, can repeat. Default: \"\"") } // Usage is a replacement usage function for the flags package. func Usage() { _, _ = fmt.Fprintf(os.Stderr, "Enumer is a tool to generate Go code that adds useful methods to Go enums (constants with a specific type).") _, _ = fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) _, _ = fmt.Fprintf(os.Stderr, "\tEnumer [flags] -type T [directory]\n") _, _ = fmt.Fprintf(os.Stderr, "\tEnumer [flags] -type T files... # Must be a single package\n") _, _ = fmt.Fprintf(os.Stderr, "For more information, see:\n") _, _ = fmt.Fprintf(os.Stderr, "\thttp://godoc.org/github.com/dmarkham/enumer\n") _, _ = fmt.Fprintf(os.Stderr, "Flags:\n") flag.PrintDefaults() } func main() { log.SetFlags(0) log.SetPrefix("enumer: ") flag.Usage = Usage flag.Parse() if len(*typeNames) == 0 { flag.Usage() os.Exit(2) } typs := strings.Split(*typeNames, ",") // We accept either one directory or a list of files. Which do we have? args := flag.Args() if len(args) == 0 { // Default: process whole package in current directory. args = []string{"."} } // Parse the package once. var ( dir string g Generator ) if len(args) == 1 && isDirectory(args[0]) { dir = args[0] // g.parsePackageDir(args[0]) } else { dir = filepath.Dir(args[0]) // g.parsePackageFiles(args) } g.parsePackage(args, []string{}) // Print the header and package clause. g.Printf("// Code generated by \"enumer %s\"; DO NOT EDIT.\n", strings.Join(os.Args[1:], " ")) g.Printf("\n") if comments.String() != "" { g.Printf("// %s\n", comments.String()) } g.Printf("package %s", g.pkg.name) g.Printf("\n") g.Printf("import (\n") g.Printf("\t\"fmt\"\n") if *sql { g.Printf("\t\"database/sql/driver\"\n") } if *json { g.Printf("\t\"encoding/json\"\n") } g.Printf(")\n") // Run generate for each type. for _, typeName := range typs { g.generate(typeName, *json, *yaml, *sql, *text, *transformMethod, *trimPrefix, *addPrefix, *linecomment) } // Format the output. src := g.format() // Figure out filename to write to outputName := *output if outputName == "" { baseName := fmt.Sprintf("%s_enumer.go", typs[0]) outputName = filepath.Join(dir, strings.ToLower(baseName)) } // Write to tmpfile first tmpFile, err := ioutil.TempFile(dir, fmt.Sprintf("%s_enumer_", typs[0])) if err != nil { log.Fatalf("creating temporary file for output: %s", err) } _, err = tmpFile.Write(src) if err != nil { tmpFile.Close() os.Remove(tmpFile.Name()) log.Fatalf("writing output: %s", err) } tmpFile.Close() // Rename tmpfile to output file err = os.Rename(tmpFile.Name(), outputName) if err != nil { log.Fatalf("moving tempfile to output file: %s", err) } } // isDirectory reports whether the named file is a directory. func isDirectory(name string) bool { info, err := os.Stat(name) if err != nil { log.Fatal(err) } return info.IsDir() } // Generator holds the state of the analysis. Primarily used to buffer // the output for format.Source. type Generator struct { buf bytes.Buffer // Accumulated output. pkg *Package // Package we are scanning. } // Printf prints the string to the output func (g *Generator) Printf(format string, args ...interface{}) { _, _ = fmt.Fprintf(&g.buf, format, args...) } // File holds a single parsed file and associated data. type File struct { pkg *Package // Package to which this file belongs. file *ast.File // Parsed AST. // These fields are reset for each type being generated. typeName string // Name of the constant type. values []Value // Accumulator for constant values of that type. trimPrefix string lineComment bool } // Package holds information about a Go package type Package struct { dir string name string defs map[*ast.Ident]types.Object files []*File typesPkg *types.Package } // // parsePackageDir parses the package residing in the directory. // func (g *Generator) parsePackageDir(directory string) { // pkg, err := build.Default.ImportDir(directory, 0) // if err != nil { // log.Fatalf("cannot process directory %s: %s", directory, err) // } // var names []string // names = append(names, pkg.GoFiles...) // names = append(names, pkg.CgoFiles...) // // TODO: Need to think about constants in test files. Maybe write type_string_test.go // // in a separate pass? For later. // // names = append(names, pkg.TestGoFiles...) // These are also in the "foo" package. // names = append(names, pkg.SFiles...) // names = prefixDirectory(directory, names) // g.parsePackage(directory, names, nil) // } // // // parsePackageFiles parses the package occupying the named files. // func (g *Generator) parsePackageFiles(names []string) { // g.parsePackage(".", names, nil) // } // // prefixDirectory places the directory name on the beginning of each name in the list. // func prefixDirectory(directory string, names []string) []string { // if directory == "." { // return names // } // ret := make([]string, len(names)) // for i, n := range names { // ret[i] = filepath.Join(directory, n) // } // return ret // } // parsePackage analyzes the single package constructed from the patterns and tags. // parsePackage exits if there is an error. func (g *Generator) parsePackage(patterns []string, tags []string) { cfg := &packages.Config{ Mode: packages.LoadSyntax, // TODO: Need to think about constants in test files. Maybe write type_string_test.go // in a separate pass? For later. Tests: false, } pkgs, err := packages.Load(cfg, patterns...) if err != nil { log.Fatal(err) } if len(pkgs) != 1 { log.Fatalf("error: %d packages found", len(pkgs)) } g.addPackage(pkgs[0]) } // addPackage adds a type checked Package and its syntax files to the generator. func (g *Generator) addPackage(pkg *packages.Package) { g.pkg = &Package{ name: pkg.Name, defs: pkg.TypesInfo.Defs, files: make([]*File, len(pkg.Syntax)), } for i, file := range pkg.Syntax { g.pkg.files[i] = &File{ file: file, pkg: g.pkg, } } } // parsePackage analyzes the single package constructed from the named files. // If text is non-nil, it is a string to be used instead of the content of the file, // to be used for testing. parsePackage exits if there is an error. // func (g *Generator) parsePackagee(directory string, names []string, text interface{}) { // var files []*File // var astFiles []*ast.File // g.pkg = new(Package) // fs := token.NewFileSet() // for _, n := range names { // if !strings.HasSuffix(n, ".go") { // continue // } // parsedFile, err := parser.ParseFile(fs, n, text, 0) // if err != nil { // log.Fatalf("parsing package: %s: %s", n, err) // } // astFiles = append(astFiles, parsedFile) // files = append(files, &File{ // file: parsedFile, // pkg: g.pkg, // }) // } // if len(astFiles) == 0 { // log.Fatalf("%s: no buildable Go files", directory) // } // g.pkg.name = astFiles[0].Name.Name // g.pkg.files = files // g.pkg.dir = directory // // Type check the package. // g.pkg.check(fs, astFiles) // } // check type-checks the package. The package must be OK to proceed. func (pkg *Package) check(fs *token.FileSet, astFiles []*ast.File) { pkg.defs = make(map[*ast.Ident]types.Object) config := types.Config{Importer: importer.Default(), FakeImportC: true} info := &types.Info{ Defs: pkg.defs, } typesPkg, err := config.Check(pkg.dir, fs, astFiles, info) if err != nil { log.Fatalf("checking package: %s", err) } pkg.typesPkg = typesPkg } func (g *Generator) transformValueNames(values []Value, transformMethod string) { var fn func(src string) string switch transformMethod { case "snake": fn = func(s string) string { return strings.ToLower(name.Delimit(s, '_')) } case "snake_upper", "snake-upper": fn = func(s string) string { return strings.ToUpper(name.Delimit(s, '_')) } case "kebab": fn = func(s string) string { return strings.ToLower(name.Delimit(s, '-')) } case "kebab_upper", "kebab-upper": fn = func(s string) string { return strings.ToUpper(name.Delimit(s, '-')) } case "upper": fn = func(s string) string { return strings.ToUpper(s) } case "lower": fn = func(s string) string { return strings.ToLower(s) } case "title": fn = func(s string) string { return strings.Title(s) } case "title-lower": fn = func(s string) string { title := []rune(strings.Title(s)) title[0] = unicode.ToLower(title[0]) return string(title) } case "first": fn = func(s string) string { r, _ := utf8.DecodeRuneInString(s) return string(r) } case "first_upper", "first-upper": fn = func(s string) string { r, _ := utf8.DecodeRuneInString(s) return strings.ToUpper(string(r)) } case "first_lower", "first-lower": fn = func(s string) string { r, _ := utf8.DecodeRuneInString(s) return strings.ToLower(string(r)) } case "whitespace": fn = func(s string) string { return strings.ToLower(name.Delimit(s, ' ')) } default: return } for i := range values { values[i].name = fn(values[i].name) } } // trimValueNames removes a prefix from each name func (g *Generator) trimValueNames(values []Value, prefix string) { for i := range values { values[i].name = strings.TrimPrefix(values[i].name, prefix) } } // prefixValueNames adds a prefix to each name func (g *Generator) prefixValueNames(values []Value, prefix string) { for i := range values { values[i].name = prefix + values[i].name } } // generate produces the String method for the named type. func (g *Generator) generate(typeName string, includeJSON, includeYAML, includeSQL, includeText bool, transformMethod string, trimPrefix string, addPrefix string, lineComment bool) { values := make([]Value, 0, 100) for _, file := range g.pkg.files { file.lineComment = lineComment // Set the state for this run of the walker. file.typeName = typeName file.values = nil if file.file != nil { ast.Inspect(file.file, file.genDecl) values = append(values, file.values...) } } if len(values) == 0 { log.Fatalf("no values defined for type %s", typeName) } g.trimValueNames(values, trimPrefix) g.transformValueNames(values, transformMethod) g.prefixValueNames(values, addPrefix) runs := splitIntoRuns(values) // The decision of which pattern to use depends on the number of // runs in the numbers. If there's only one, it's easy. For more than // one, there's a tradeoff between complexity and size of the data // and code vs. the simplicity of a map. A map takes more space, // but so does the code. The decision here (crossover at 10) is // arbitrary, but considers that for large numbers of runs the cost // of the linear scan in the switch might become important, and // rather than use yet another algorithm such as binary search, // we punt and use a map. In any case, the likelihood of a map // being necessary for any realistic example other than bitmasks // is very low. And bitmasks probably deserve their own analysis, // to be done some other day. const runsThreshold = 10 switch { case len(runs) == 1: g.buildOneRun(runs, typeName) case len(runs) <= runsThreshold: g.buildMultipleRuns(runs, typeName) default: g.buildMap(runs, typeName) } g.buildBasicExtras(runs, typeName, runsThreshold) if includeJSON { g.buildJSONMethods(runs, typeName, runsThreshold) } if includeText { g.buildTextMethods(runs, typeName, runsThreshold) } if includeYAML { g.buildYAMLMethods(runs, typeName, runsThreshold) } if includeSQL { g.addValueAndScanMethod(typeName) } } // splitIntoRuns breaks the values into runs of contiguous sequences. // For example, given 1,2,3,5,6,7 it returns {1,2,3},{5,6,7}. // The input slice is known to be non-empty. func splitIntoRuns(values []Value) [][]Value { // We use stable sort so the lexically first name is chosen for equal elements. sort.Stable(byValue(values)) // Remove duplicates. Stable sort has put the one we want to print first, // so use that one. The String method won't care about which named constant // was the argument, so the first name for the given value is the only one to keep. // We need to do this because identical values would cause the switch or map // to fail to compile. j := 1 for i := 1; i < len(values); i++ { if values[i].value != values[i-1].value { values[j] = values[i] j++ } } values = values[:j] runs := make([][]Value, 0, 10) for len(values) > 0 { // One contiguous sequence per outer loop. i := 1 for i < len(values) && values[i].value == values[i-1].value+1 { i++ } runs = append(runs, values[:i]) values = values[i:] } return runs } // format returns the gofmt-ed contents of the Generator's buffer. func (g *Generator) format() []byte { src, err := format.Source(g.buf.Bytes()) if err != nil { // Should never happen, but can arise when developing this code. // The user can compile the output to see the error. log.Printf("warning: internal error: invalid Go generated: %s", err) log.Printf("warning: compile the package to analyze the error") return g.buf.Bytes() } return src } // Value represents a declared constant. type Value struct { name string // The name of the constant after transformation (i.e. camel case => snake case) // The value is stored as a bit pattern alone. The boolean tells us // whether to interpret it as an int64 or a uint64; the only place // this matters is when sorting. // Much of the time the str field is all we need; it is printed // by Value.String. value uint64 // Will be converted to int64 when needed. signed bool // Whether the constant is a signed type. str string // The string representation given by the "go/exact" package. } func (v *Value) String() string { return v.str } // byValue lets us sort the constants into increasing order. // We take care in the Less method to sort in signed or unsigned order, // as appropriate. type byValue []Value func (b byValue) Len() int { return len(b) } func (b byValue) Swap(i, j int) { b[i], b[j] = b[j], b[i] } func (b byValue) Less(i, j int) bool { if b[i].signed { return int64(b[i].value) < int64(b[j].value) } return b[i].value < b[j].value } // genDecl processes one declaration clause. func (f *File) genDecl(node ast.Node) bool { decl, ok := node.(*ast.GenDecl) if !ok || decl.Tok != token.CONST { // We only care about const declarations. return true } // The name of the type of the constants we are declaring. // Can change if this is a multi-element declaration. typ := "" // Loop over the elements of the declaration. Each element is a ValueSpec: // a list of names possibly followed by a type, possibly followed by values. // If the type and value are both missing, we carry down the type (and value, // but the "go/types" package takes care of that). for _, spec := range decl.Specs { vspec := spec.(*ast.ValueSpec) // Guaranteed to succeed as this is CONST. if vspec.Type == nil && len(vspec.Values) > 0 { // "X = 1". With no type but a value, the constant is untyped. // Skip this vspec and reset the remembered type. typ = "" continue } if vspec.Type != nil { // "X T". We have a type. Remember it. ident, ok := vspec.Type.(*ast.Ident) if !ok { continue } typ = ident.Name } if typ != f.typeName { // This is not the type we're looking for. continue } // We now have a list of names (from one line of source code) all being // declared with the desired type. // Grab their names and actual values and store them in f.values. for _, n := range vspec.Names { if n.Name == "_" { continue } // This dance lets the type checker find the values for us. It's a // bit tricky: look up the object declared by the n, find its // types.Const, and extract its value. obj, ok := f.pkg.defs[n] if !ok { log.Fatalf("no value for constant %s", n) } info := obj.Type().Underlying().(*types.Basic).Info() if info&types.IsInteger == 0 { log.Fatalf("can't handle non-integer constant type %s", typ) } value := obj.(*types.Const).Val() // Guaranteed to succeed as this is CONST. if value.Kind() != exact.Int { log.Fatalf("can't happen: constant is not an integer %s", n) } i64, isInt := exact.Int64Val(value) u64, isUint := exact.Uint64Val(value) if !isInt && !isUint { log.Fatalf("internal error: value of %s is not an integer: %s", n, value.String()) } if !isInt { u64 = uint64(i64) } v := Value{ name: n.Name, value: u64, signed: info&types.IsUnsigned == 0, str: value.String(), } if c := vspec.Comment; f.lineComment && c != nil && len(c.List) == 1 { v.name = strings.TrimSpace(c.Text()) } f.values = append(f.values, v) } } return false } // Helpers // usize returns the number of bits of the smallest unsigned integer // type that will hold n. Used to create the smallest possible slice of // integers to use as indexes into the concatenated strings. func usize(n int) int { switch { case n < 1<<8: return 8 case n < 1<<16: return 16 default: // 2^32 is enough constants for anyone. return 32 } } // declareIndexAndNameVars declares the index slices and concatenated names // strings representing the runs of values. func (g *Generator) declareIndexAndNameVars(runs [][]Value, typeName string) { var indexes, names []string for i, run := range runs { index, n := g.createIndexAndNameDecl(run, typeName, fmt.Sprintf("_%d", i)) indexes = append(indexes, index) names = append(names, n) } g.Printf("const (\n") for _, n := range names { g.Printf("\t%s\n", n) } g.Printf(")\n\n") g.Printf("var (") for _, index := range indexes { g.Printf("\t%s\n", index) } g.Printf(")\n\n") } // declareIndexAndNameVar is the single-run version of declareIndexAndNameVars func (g *Generator) declareIndexAndNameVar(run []Value, typeName string) { index, n := g.createIndexAndNameDecl(run, typeName, "") g.Printf("const %s\n", n) g.Printf("var %s\n", index) } // createIndexAndNameDecl returns the pair of declarations for the run. The caller will add "const" and "var". func (g *Generator) createIndexAndNameDecl(run []Value, typeName string, suffix string) (string, string) { b := new(bytes.Buffer) indexes := make([]int, len(run)) for i := range run { b.WriteString(run[i].name) indexes[i] = b.Len() } nameConst := fmt.Sprintf("_%sName%s = %q", typeName, suffix, b.String()) nameLen := b.Len() b.Reset() _, _ = fmt.Fprintf(b, "_%sIndex%s = [...]uint%d{0, ", typeName, suffix, usize(nameLen)) for i, v := range indexes { if i > 0 { _, _ = fmt.Fprintf(b, ", ") } _, _ = fmt.Fprintf(b, "%d", v) } _, _ = fmt.Fprintf(b, "}") return b.String(), nameConst } // declareNameVars declares the concatenated names string representing all the values in the runs. func (g *Generator) declareNameVars(runs [][]Value, typeName string, suffix string) { g.Printf("const _%sName%s = \"", typeName, suffix) for _, run := range runs { for i := range run { g.Printf("%s", run[i].name) } } g.Printf("\"\n") } // buildOneRun generates the variables and String method for a single run of contiguous values. func (g *Generator) buildOneRun(runs [][]Value, typeName string) { values := runs[0] g.Printf("\n") g.declareIndexAndNameVar(values, typeName) // The generated code is simple enough to write as a Printf format. lessThanZero := "" if values[0].signed { lessThanZero = "i < 0 || " } if values[0].value == 0 { // Signed or unsigned, 0 is still 0. g.Printf(stringOneRun, typeName, usize(len(values)), lessThanZero) } else { g.Printf(stringOneRunWithOffset, typeName, values[0].String(), usize(len(values)), lessThanZero) } } // Arguments to format are: // [1]: type name // [2]: size of index element (8 for uint8 etc.) // [3]: less than zero check (for signed types) const stringOneRun = `func (i %[1]s) String() string { if %[3]si >= %[1]s(len(_%[1]sIndex)-1) { return fmt.Sprintf("%[1]s(%%d)", i) } return _%[1]sName[_%[1]sIndex[i]:_%[1]sIndex[i+1]] } ` // Arguments to format are: // [1]: type name // [2]: lowest defined value for type, as a string // [3]: size of index element (8 for uint8 etc.) // [4]: less than zero check (for signed types) /* */ const stringOneRunWithOffset = `func (i %[1]s) String() string { i -= %[2]s if %[4]si >= %[1]s(len(_%[1]sIndex)-1) { return fmt.Sprintf("%[1]s(%%d)", i + %[2]s) } return _%[1]sName[_%[1]sIndex[i] : _%[1]sIndex[i+1]] } ` // buildMultipleRuns generates the variables and String method for multiple runs of contiguous values. // For this pattern, a single Printf format won't do. func (g *Generator) buildMultipleRuns(runs [][]Value, typeName string) { g.Printf("\n") g.declareIndexAndNameVars(runs, typeName) g.Printf("func (i %s) String() string {\n", typeName) g.Printf("\tswitch {\n") for i, values := range runs { if len(values) == 1 { g.Printf("\tcase i == %s:\n", &values[0]) g.Printf("\t\treturn _%sName_%d\n", typeName, i) continue } g.Printf("\tcase %s <= i && i <= %s:\n", &values[0], &values[len(values)-1]) if values[0].value != 0 { g.Printf("\t\ti -= %s\n", &values[0]) } g.Printf("\t\treturn _%sName_%d[_%sIndex_%d[i]:_%sIndex_%d[i+1]]\n", typeName, i, typeName, i, typeName, i) } g.Printf("\tdefault:\n") g.Printf("\t\treturn fmt.Sprintf(\"%s(%%d)\", i)\n", typeName) g.Printf("\t}\n") g.Printf("}\n") } // buildMap handles the case where the space is so sparse a map is a reasonable fallback. // It's a rare situation but has simple code. func (g *Generator) buildMap(runs [][]Value, typeName string) { g.Printf("\n") g.declareNameVars(runs, typeName, "") g.Printf("\nvar _%sMap = map[%s]string{\n", typeName, typeName) n := 0 for _, values := range runs { for _, value := range values { g.Printf("\t%s: _%sName[%d:%d],\n", &value, typeName, n, n+len(value.name)) n += len(value.name) } } g.Printf("}\n\n") g.Printf(stringMap, typeName) } // Argument to format is the type name. const stringMap = `func (i %[1]s) String() string { if str, ok := _%[1]sMap[i]; ok { return str } return fmt.Sprintf("%[1]s(%%d)", i) } `