From 9a06c80b6a6cbf3fb9e27114e09a16b8d5fb6abc Mon Sep 17 00:00:00 2001 From: Austin Green Date: Fri, 19 Jul 2019 17:25:24 -0400 Subject: [PATCH] port of the bson patch --- README.md | 5 +- enumer.go | 21 ++++++ golden_test.go | 174 ++++++++++++++++++++++++++++++++++++++++++++----- stringer.go | 12 +++- 4 files changed, 193 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index f558f22..22f7866 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,9 @@ When Enumer is applied to a type, it will generate: - When the flag `json` is provided, two additional methods will be generated, `MarshalJSON()` and `UnmarshalJSON()`. These make the enum conform to the `json.Marshaler` and `json.Unmarshaler` interfaces. Very useful to use it in JSON APIs. +- When the flag `bson` is provided, two additional methods will be generated, `MarshalBSONValue()` and `UnmarshalBSONValue()`. These make + the enum conform to the `go.mongodb.org/mongo-driver/bson.ValueMarshaler` and `go.mongodb.org/mongo-driver/bson.ValueUnmarshaler` interfaces. + This can be used when inserting and retrieving enums from MongoDB. - When the flag `text` is provided, two additional methods will be generated, `MarshalText()` and `UnmarshalText()`. These make the enum conform to the `encoding.TextMarshaler` and `encoding.TextUnmarshaler` interfaces. **Note:** If you use your enum values as keys in a map and you encode the map as _JSON_, you need this flag set to true to properly @@ -191,7 +194,7 @@ name := MyTypeValue.String() // name => "my_type_value" ## How to use -There are four boolean flags: `json`, `text`, `yaml` and `sql`. You can use any combination of them (i.e. `enumer -type=Pill -json -text`), +There are five boolean flags: `json`, `bson`, `text`, `yaml` and `sql`. You can use any combination of them (i.e. `enumer -type=Pill -json -text`), For enum string representation transformation the `transform` and `trimprefix` flags were added (i.e. `enumer -type=MyType -json -transform=snake`). diff --git a/enumer.go b/enumer.go index 8f92f58..7d74f1e 100644 --- a/enumer.go +++ b/enumer.go @@ -196,3 +196,24 @@ func (i *%[1]s) UnmarshalYAML(unmarshal func(interface{}) error) error { func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int) { g.Printf(yamlMethods, typeName) } + +// Arguments to format are: +// [1]: type name +const bsonMethods = ` +// MarshalBSONValue implements the bson.ValueMarshaler interface for %[1]s +func (i %[1]s) MarshalBSONValue() (bsontype.Type, []byte, error) { + return bsontype.String, bsoncore.AppendString(nil, i.String()), nil +} + +// UnmarshalBSONValue implements the bson.ValueUnmarshaler interface for %[1]s +func (i *%[1]s) UnmarshalBSONValue(t bsontype.Type, src []byte) error { + str, _, _ := bsoncore.ReadString(src) + var err error + *i, err = %[1]sString(str) + return err +} +` + +func (g *Generator) buildBSONMethods(runs [][]Value, typeName string, runsThreshold int) { + g.Printf(bsonMethods, typeName) +} diff --git a/golden_test.go b/golden_test.go index 35aec3f..8431582 100644 --- a/golden_test.go +++ b/golden_test.go @@ -10,6 +10,7 @@ package main import ( + "fmt" "io/ioutil" "os" "path/filepath" @@ -48,6 +49,10 @@ var goldenSQL = []Golden{ {"prime", primeSqlIn, primeSqlOut}, } +var goldenBSON = []Golden{ + {"prime with BSON", primeBsonIn, primeBsonOut}, +} + var goldenJSONAndSQL = []Golden{ {"prime", primeJsonAndSqlIn, primeJsonAndSqlOut}, } @@ -1344,6 +1349,140 @@ func (i *Prime) Scan(value interface{}) error { } ` +const primeBsonIn = `type Prime int +const ( + p2 Prime = 2 + p3 Prime = 3 + p5 Prime = 5 + p7 Prime = 7 + p77 Prime = 7 // Duplicate; note that p77 doesn't appear below. + p11 Prime = 11 + p13 Prime = 13 + p17 Prime = 17 + p19 Prime = 19 + p23 Prime = 23 + p29 Prime = 29 + p37 Prime = 31 + p41 Prime = 41 + p43 Prime = 43 +) +` + +const primeBsonOut = ` +const _PrimeName = "p2p3p5p7p11p13p17p19p23p29p37p41p43" +const _PrimeLowerName = "p2p3p5p7p11p13p17p19p23p29p37p41p43" + +var _PrimeMap = map[Prime]string{ + 2: _PrimeName[0:2], + 3: _PrimeName[2:4], + 5: _PrimeName[4:6], + 7: _PrimeName[6:8], + 11: _PrimeName[8:11], + 13: _PrimeName[11:14], + 17: _PrimeName[14:17], + 19: _PrimeName[17:20], + 23: _PrimeName[20:23], + 29: _PrimeName[23:26], + 31: _PrimeName[26:29], + 41: _PrimeName[29:32], + 43: _PrimeName[32:35], +} + +func (i Prime) String() string { + if str, ok := _PrimeMap[i]; ok { + return str + } + return fmt.Sprintf("Prime(%d)", i) +} + +var _PrimeValues = []Prime{2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 41, 43} + +var _PrimeNameToValueMap = map[string]Prime{ + _PrimeName[0:2]: 2, + _PrimeLowerName[0:2]: 2, + _PrimeName[2:4]: 3, + _PrimeLowerName[2:4]: 3, + _PrimeName[4:6]: 5, + _PrimeLowerName[4:6]: 5, + _PrimeName[6:8]: 7, + _PrimeLowerName[6:8]: 7, + _PrimeName[8:11]: 11, + _PrimeLowerName[8:11]: 11, + _PrimeName[11:14]: 13, + _PrimeLowerName[11:14]: 13, + _PrimeName[14:17]: 17, + _PrimeLowerName[14:17]: 17, + _PrimeName[17:20]: 19, + _PrimeLowerName[17:20]: 19, + _PrimeName[20:23]: 23, + _PrimeLowerName[20:23]: 23, + _PrimeName[23:26]: 29, + _PrimeLowerName[23:26]: 29, + _PrimeName[26:29]: 31, + _PrimeLowerName[26:29]: 31, + _PrimeName[29:32]: 41, + _PrimeLowerName[29:32]: 41, + _PrimeName[32:35]: 43, + _PrimeLowerName[32:35]: 43, +} + +var _PrimeNames = []string{ + _PrimeName[0:2], + _PrimeName[2:4], + _PrimeName[4:6], + _PrimeName[6:8], + _PrimeName[8:11], + _PrimeName[11:14], + _PrimeName[14:17], + _PrimeName[17:20], + _PrimeName[20:23], + _PrimeName[23:26], + _PrimeName[26:29], + _PrimeName[29:32], + _PrimeName[32:35], +} + +// PrimeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func PrimeString(s string) (Prime, error) { + if val, ok := _PrimeNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Prime values", s) +} + +// PrimeValues returns all values of the enum +func PrimeValues() []Prime { + return _PrimeValues +} + +// PrimeStrings returns a slice of all String values of the enum +func PrimeStrings() []string { + strs := make([]string, len(_PrimeNames)) + copy(strs, _PrimeNames) + return strs +} + +// IsAPrime returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Prime) IsAPrime() bool { + _, ok := _PrimeMap[i] + return ok +} + +// MarshalBSONValue implements the bson.ValueMarshaler interface for Prime +func (i Prime) MarshalBSONValue() (bsontype.Type, []byte, error) { + return bsontype.String, bsoncore.AppendString(nil, i.String()), nil +} + +// UnmarshalBSONValue implements the bson.ValueUnmarshaler interface for Prime +func (i *Prime) UnmarshalBSONValue(t bsontype.Type, src []byte) error { + str, _, _ := bsoncore.ReadString(src) + var err error + *i, err = PrimeString(str) + return err +} +` + const primeJsonAndSqlIn = `type Prime int const ( p2 Prime = 2 @@ -1524,35 +1663,38 @@ const ( func TestGolden(t *testing.T) { for _, test := range golden { - runGoldenTest(t, test, false, false, false, false, "", "") + runGoldenTest(t, test, false, false, false, false, false, "", "") } for _, test := range goldenJSON { - runGoldenTest(t, test, true, false, false, false, "", "") + runGoldenTest(t, test, true, false, false, false, false, "", "") } for _, test := range goldenText { - runGoldenTest(t, test, false, false, false, true, "", "") + runGoldenTest(t, test, false, false, false, true, false, "", "") } for _, test := range goldenYAML { - runGoldenTest(t, test, false, true, false, false, "", "") + runGoldenTest(t, test, false, true, false, false, false, "", "") } for _, test := range goldenSQL { - runGoldenTest(t, test, false, false, true, false, "", "") + runGoldenTest(t, test, false, false, true, false, false, "", "") + } + for _, test := range goldenBSON { + runGoldenTest(t, test, false, false, false, false, true, "", "") } for _, test := range goldenJSONAndSQL { - runGoldenTest(t, test, true, false, true, false, "", "") + runGoldenTest(t, test, true, false, true, false, false, "", "") } for _, test := range goldenTrimPrefix { - runGoldenTest(t, test, false, false, false, false, "Day", "") + runGoldenTest(t, test, false, false, false, false, false, "Day", "") } for _, test := range goldenWithPrefix { - runGoldenTest(t, test, false, false, false, false, "", "Day") + runGoldenTest(t, test, false, false, false, false, false, "", "Day") } for _, test := range goldenTrimAndAddPrefix { - runGoldenTest(t, test, false, false, false, false, "Day", "Night") + runGoldenTest(t, test, false, false, false, false, false, "Day", "Night") } } -func runGoldenTest(t *testing.T, test Golden, generateJSON, generateYAML, generateSQL, generateText bool, trimPrefix string, prefix string) { +func runGoldenTest(t *testing.T, test Golden, generateJSON, generateYAML, generateSQL, generateText, generateBSON bool, trimPrefix string, prefix string) { var g Generator file := test.name + ".go" input := "package test\n" + test.input @@ -1579,15 +1721,15 @@ func runGoldenTest(t *testing.T, test Golden, generateJSON, generateYAML, genera if len(tokens) != 3 { t.Fatalf("%s: need type declaration on first line", test.name) } - g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, "noop", trimPrefix, prefix, false) + g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateBSON, "noop", trimPrefix, prefix, false) got := string(g.format()) if got != test.output { // Use this to help build a golden text when changes are needed - //goldenFile := fmt.Sprintf("./goldendata/%v-%v-%v-%v-%v-%v-%v-%v-%v-%v.golden", test.name, tokens[1], generateJSON, generateYAML, generateSQL, generateText, "noop", trimPrefix, prefix, false) - //err = ioutil.WriteFile(goldenFile, []byte(got), 0644) - //if err != nil { - // t.Error(err) - //} + goldenFile := fmt.Sprintf("./goldendata/%v-%v-%v-%v-%v-%v-%v-%v-%v-%v-%v.golden", test.name, tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateBSON, "noop", trimPrefix, prefix, false) + err = ioutil.WriteFile(goldenFile, []byte(got), 0644) + if err != nil { + t.Error(err) + } t.Errorf("%s: got\n====\n%s====\nexpected\n====%s", test.name, got, test.output) } } diff --git a/stringer.go b/stringer.go index 1c23057..15751fd 100644 --- a/stringer.go +++ b/stringer.go @@ -49,6 +49,7 @@ var ( json = flag.Bool("json", false, "if true, json marshaling methods will be generated. Default: false") yaml = flag.Bool("yaml", false, "if true, yaml marshaling methods will be generated. Default: false") text = flag.Bool("text", false, "if true, text marshaling methods will be generated. Default: false") + bson = flag.Bool("bson", false, "if true, bson marshaling methods will be generated. Default: false") output = flag.String("output", "", "output file name; default srcdir/_string.go") transformMethod = flag.String("transform", "noop", "enum item name transformation method. Default: noop") trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix. Default: \"\"") @@ -124,11 +125,15 @@ func main() { if *json { g.Printf("\t\"encoding/json\"\n") } + if *bson { + g.Printf("\t\"go.mongodb.org/mongo-driver/bson/bsontype\"\n") + g.Printf("\t\"go.mongodb.org/mongo-driver/x/bsonx/bsoncore\"\n") + } g.Printf(")\n") // Run generate for each type. for _, typeName := range typs { - g.generate(typeName, *json, *yaml, *sql, *text, *transformMethod, *trimPrefix, *addPrefix, *linecomment) + g.generate(typeName, *json, *yaml, *sql, *text, *bson, *transformMethod, *trimPrefix, *addPrefix, *linecomment) } // Format the output. @@ -397,7 +402,7 @@ func (g *Generator) prefixValueNames(values []Value, prefix string) { } // generate produces the String method for the named type. -func (g *Generator) generate(typeName string, includeJSON, includeYAML, includeSQL, includeText bool, +func (g *Generator) generate(typeName string, includeJSON, includeYAML, includeSQL, includeText, includeBSON bool, transformMethod string, trimPrefix string, addPrefix string, lineComment bool) { values := make([]Value, 0, 100) for _, file := range g.pkg.files { @@ -454,6 +459,9 @@ func (g *Generator) generate(typeName string, includeJSON, includeYAML, includeS if includeYAML { g.buildYAMLMethods(runs, typeName, runsThreshold) } + if includeBSON { + g.buildBSONMethods(runs, typeName, runsThreshold) + } if includeSQL { g.addValueAndScanMethod(typeName) }