diff --git a/README.md b/README.md index 97f1b70..6f78c36 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,14 @@ When Enumer is applied to a type, it will generate: * A method `String()` that returns the string representation of the enum value. This makes the enum conform the `Stringer` interface, so whenever you print an enum value, you'll get the string name instead of a number. -* A function `String(s string)` to get the enum value from its string representation. This is useful -when you need to read enum values from the command line arguments, from a configuration file, -from a REST API request... In short, from those places where using the real enum value (an integer) would +* A function `String(s string)` to get the enum value from its string representation. This is useful +when you need to read enum values from command line arguments, from a configuration file, or +from a REST API request... In short, from those places where using the real enum value (an integer) would be almost meaningless or hard to trace or use by a human. -* When the flag `json` is provided, two more methods will be generated, `MarshalJSON()` and `UnmarshalJSON()`. Those make -the enum conform the `json.Marshaler` and `json.Unmarshaler` interfaces. Very useful to use it in JSON APIs. +* When the flag `json` is provided, two additional methods will be generated, `MarshalJSON()` and `UnmarshalJSON()`. These make +the enum conform to the `json.Marshaler` and `json.Unmarshaler` interfaces. Very useful to use it in JSON APIs. +* When the flag `yaml` is provided, two additional methods will be generated, `MarshalYAML()` and `UnmarshalYAML()`. These make +the enum conform to the `gopkg.in/yaml.v2.Marshaler` and `gopkg.in/yaml.v2.Unmarshaler` interfaces. * When the flag `sql` is provided, the methods for implementing the Scanner and Valuer interfaces will be also generated. Useful when storing the enum in a database. @@ -74,8 +76,9 @@ The generated code is exactly the same as the Stringer tool plus the mentioned a The usage of Enumer is the same as Stringer, so you can refer to the [Stringer docs](https://godoc.org/golang.org/x/tools/cmd/stringer) for more information. -There are two flags added: `json` and `sql`. If the json flag is set to true (i.e. `enumer -type=Pill -json`), -the JSON related methods will be generated. And if the sql flag is set to true, the Scanner and Valuer interface will +There are three flags added: `json`, `yaml` and `sql`. If the json flag is set to true (i.e. `enumer -type=Pill -json`), +the JSON related methods will be generated. Similarly if the yaml flag is set to true, +the YAML related methods will be generated. And if the sql flag is set to true, the Scanner and Valuer interface will be implemented to seamlessly use the enum in a database model. ## Inspiring projects diff --git a/enumer.go b/enumer.go index 1a60b74..834fc30 100644 --- a/enumer.go +++ b/enumer.go @@ -57,3 +57,26 @@ func (i *%[1]s) UnmarshalJSON(data []byte) error { func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int) { g.Printf(jsonMethods, typeName) } + +// Arguments to format are: +// [1]: type name +const yamlMethods = ` +func (i %[1]s) MarshalYAML() (interface{}, error) { + return i.String(), nil +} + +func (i *%[1]s) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + var err error + *i, err = %[1]sString(s) + return err +} +` + +func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int) { + g.Printf(yamlMethods, typeName) +} diff --git a/golden_test.go b/golden_test.go index 37454ac..c64c231 100644 --- a/golden_test.go +++ b/golden_test.go @@ -34,6 +34,10 @@ var goldenJSON = []Golden{ {"prime", prime_json_in, prime_json_out}, } +var goldenYAML = []Golden{ + {"prime", prime_yaml_in, prime_yaml_out}, +} + var goldenSQL = []Golden{ {"prime", prime_sql_in, prime_sql_out}, } @@ -434,6 +438,90 @@ func (i *Prime) UnmarshalJSON(data []byte) error { } ` +const prime_yaml_in = `type Prime int +const ( + p2 Prime = 2 + p3 Prime = 3 + p5 Prime = 5 + p7 Prime = 7 + p77 Prime = 7 // Duplicate; note that p77 doesn't appear below. + p11 Prime = 11 + p13 Prime = 13 + p17 Prime = 17 + p19 Prime = 19 + p23 Prime = 23 + p29 Prime = 29 + p37 Prime = 31 + p41 Prime = 41 + p43 Prime = 43 +) +` + +const prime_yaml_out = ` +const _Prime_name = "p2p3p5p7p11p13p17p19p23p29p37p41p43" + +var _Prime_map = map[Prime]string{ + 2: _Prime_name[0:2], + 3: _Prime_name[2:4], + 5: _Prime_name[4:6], + 7: _Prime_name[6:8], + 11: _Prime_name[8:11], + 13: _Prime_name[11:14], + 17: _Prime_name[14:17], + 19: _Prime_name[17:20], + 23: _Prime_name[20:23], + 29: _Prime_name[23:26], + 31: _Prime_name[26:29], + 41: _Prime_name[29:32], + 43: _Prime_name[32:35], +} + +func (i Prime) String() string { + if str, ok := _Prime_map[i]; ok { + return str + } + return fmt.Sprintf("Prime(%d)", i) +} + +var _PrimeNameToValue_map = map[string]Prime{ + _Prime_name[0:2]: 2, + _Prime_name[2:4]: 3, + _Prime_name[4:6]: 5, + _Prime_name[6:8]: 7, + _Prime_name[8:11]: 11, + _Prime_name[11:14]: 13, + _Prime_name[14:17]: 17, + _Prime_name[17:20]: 19, + _Prime_name[20:23]: 23, + _Prime_name[23:26]: 29, + _Prime_name[26:29]: 31, + _Prime_name[29:32]: 41, + _Prime_name[32:35]: 43, +} + +func PrimeString(s string) (Prime, error) { + if val, ok := _PrimeNameToValue_map[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Prime values", s) +} + +func (i Prime) MarshalYAML() (interface{}, error) { + return i.String(), nil +} + +func (i *Prime) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + var err error + *i, err = PrimeString(s) + return err +} +` + const prime_sql_in = `type Prime int const ( p2 Prime = 2 @@ -645,20 +733,23 @@ func (i *Prime) Scan(value interface{}) error { func TestGolden(t *testing.T) { for _, test := range golden { - runGoldenTest(t, test, false, false) + runGoldenTest(t, test, false, false, false) } for _, test := range goldenJSON { - runGoldenTest(t, test, true, false) + runGoldenTest(t, test, true, false, false) + } + for _, test := range goldenYAML { + runGoldenTest(t, test, false, true, false) } for _, test := range goldenSQL { - runGoldenTest(t, test, false, true) + runGoldenTest(t, test, false, false, true) } for _, test := range goldenJSONAndSQL { - runGoldenTest(t, test, true, true) + runGoldenTest(t, test, true, false, true) } } -func runGoldenTest(t *testing.T, test Golden, generateJSON, generateSQL bool) { +func runGoldenTest(t *testing.T, test Golden, generateJSON, generateYAML, generateSQL bool) { var g Generator input := "package test\n" + test.input file := test.name + ".go" @@ -668,7 +759,7 @@ func runGoldenTest(t *testing.T, test Golden, generateJSON, generateSQL bool) { if len(tokens) != 3 { t.Fatalf("%s: need type declaration on first line", test.name) } - g.generate(tokens[1], generateJSON, generateSQL) + g.generate(tokens[1], generateJSON, generateYAML, generateSQL) got := string(g.format()) if got != test.output { t.Errorf("%s: got\n====\n%s====\nexpected\n====%s", test.name, got, test.output) diff --git a/stringer.go b/stringer.go index 180c53b..100198d 100644 --- a/stringer.go +++ b/stringer.go @@ -84,6 +84,7 @@ var ( typeNames = flag.String("type", "", "comma-separated list of type names; must be set") sql = flag.Bool("sql", false, "if true, the Scanner and Valuer interface will be implemented.") json = flag.Bool("json", false, "if true, json marshaling methods will be generated. Default: false") + yaml = flag.Bool("yaml", false, "if true, yaml marshaling methods will be generated. Default: false") output = flag.String("output", "", "output file name; default srcdir/_string.go") ) @@ -146,7 +147,7 @@ func main() { // Run generate for each type. for _, typeName := range types { - g.generate(typeName, *json, *sql) + g.generate(typeName, *json, *yaml, *sql) } // Format the output. @@ -282,7 +283,7 @@ func (pkg *Package) check(fs *token.FileSet, astFiles []*ast.File) { } // generate produces the String method for the named type. -func (g *Generator) generate(typeName string, includeJSON, includeSQL bool) { +func (g *Generator) generate(typeName string, includeJSON, includeYAML, includeSQL bool) { values := make([]Value, 0, 100) for _, file := range g.pkg.files { // Set the state for this run of the walker. @@ -324,7 +325,9 @@ func (g *Generator) generate(typeName string, includeJSON, includeSQL bool) { if includeJSON { g.buildJSONMethods(runs, typeName, runsThreshold) } - + if includeYAML { + g.buildYAMLMethods(runs, typeName, runsThreshold) + } if includeSQL { g.addValueAndScanMethod(typeName) }