diff --git a/README.md b/README.md index 322aa6b..6e4d84e 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,10 @@ The generated code is exactly the same as the Stringer tool plus the mentioned a The usage of Enumer is the same as Stringer, so you can refer to the [Stringer docs](https://godoc.org/golang.org/x/tools/cmd/stringer) for more information. -As mentioned before, there is only one flag added: `json`. +There are two flags added: `json` and `sql`. If the json flag is set to true (i.e. `enumer -type=Pill -json`), +the JSON related methods will be generated. And if the sql flag is set to true, the Scanner and Valuer interface will +be implemented to seamlessly use the enum in a database model. ## Inspiring projects * [Stringer](https://godoc.org/golang.org/x/tools/cmd/stringer) * [jsonenums](https://github.com/campoy/jsonenums) - diff --git a/sql.go b/sql.go new file mode 100644 index 0000000..67d5354 --- /dev/null +++ b/sql.go @@ -0,0 +1,40 @@ +package main + +// Arguments to format are: +// [1]: type name +const valueMethod = `func (i %[1]s) Value() (driver.Value, error) { + return i.String(), nil +} +` + +const scanMethod = `func (i *%[1]s) Scan(value interface{}) error { + if value == nil { + return nil + } + + str, ok := value.(string) + if !ok { + bytes, ok := value.([]byte) + if !ok { + return fmt.Errorf("value is not a byte slice") + } + + str = string(bytes[:]) + } + + val, err := %[1]sString(str) + if err != nil { + return err + } + + *i = val + return nil +} +` + +func (g *Generator) addValueAndScanMethod(typeName string) { + g.Printf("\n") + g.Printf(valueMethod, typeName) + g.Printf("\n\n") + g.Printf(scanMethod, typeName) +} diff --git a/stringer.go b/stringer.go index efd2fbd..bdfcb78 100644 --- a/stringer.go +++ b/stringer.go @@ -82,6 +82,7 @@ import ( var ( typeNames = flag.String("type", "", "comma-separated list of type names; must be set") + sql = flag.Bool("sql", false, "if true, the Scanner and Valuer interface will be implemented.") json = flag.Bool("json", false, "if true, json marshaling methods will be generated. Default: false") output = flag.String("output", "", "output file name; default srcdir/_string.go") ) @@ -133,10 +134,15 @@ func main() { g.Printf("\n") g.Printf("package %s", g.pkg.name) g.Printf("\n") - g.Printf("import \"fmt\"\n") // Used by all methods. - if *json { - g.Printf("import \"encoding/json\"\n") + g.Printf("import (\n") + g.Printf("\t\"fmt\"\n") + if *sql { + g.Printf("\t\"database/sql/driver\"\n") } + if *json { + g.Printf("\t\"encoding/json\"\n") + } + g.Printf(")\n") // Run generate for each type. for _, typeName := range types { @@ -313,11 +319,17 @@ func (g *Generator) generate(typeName string, includeJSON bool) { default: g.buildMap(runs, typeName) } + // ENUMER part g.buildValueToNameMap(runs, typeName, runsThreshold) if includeJSON { g.buildJSONMethods(runs, typeName, runsThreshold) } + + // SQL + if *sql { + g.addValueAndScanMethod(typeName) + } } // splitIntoRuns breaks the values into runs of contiguous sequences.