Merge pull request #3 from webguerilla/master

Scanner and Valuer interface
This commit is contained in:
Álvaro López Espinosa 2016-10-26 16:45:19 +01:00 committed by GitHub
commit ad5bc9c3e5
3 changed files with 58 additions and 5 deletions

View File

@ -72,9 +72,10 @@ The generated code is exactly the same as the Stringer tool plus the mentioned a
The usage of Enumer is the same as Stringer, so you can refer to the [Stringer docs](https://godoc.org/golang.org/x/tools/cmd/stringer)
for more information.
As mentioned before, there is only one flag added: `json`.
There are two flags added: `json` and `sql`. If the json flag is set to true (i.e. `enumer -type=Pill -json`),
the JSON related methods will be generated. And if the sql flag is set to true, the Scanner and Valuer interface will
be implemented to seamlessly use the enum in a database model.
## Inspiring projects
* [Stringer](https://godoc.org/golang.org/x/tools/cmd/stringer)
* [jsonenums](https://github.com/campoy/jsonenums)

40
sql.go Normal file
View File

@ -0,0 +1,40 @@
package main
// Arguments to format are:
// [1]: type name
const valueMethod = `func (i %[1]s) Value() (driver.Value, error) {
return i.String(), nil
}
`
const scanMethod = `func (i *%[1]s) Scan(value interface{}) error {
if value == nil {
return nil
}
str, ok := value.(string)
if !ok {
bytes, ok := value.([]byte)
if !ok {
return fmt.Errorf("value is not a byte slice")
}
str = string(bytes[:])
}
val, err := %[1]sString(str)
if err != nil {
return err
}
*i = val
return nil
}
`
func (g *Generator) addValueAndScanMethod(typeName string) {
g.Printf("\n")
g.Printf(valueMethod, typeName)
g.Printf("\n\n")
g.Printf(scanMethod, typeName)
}

View File

@ -82,6 +82,7 @@ import (
var (
typeNames = flag.String("type", "", "comma-separated list of type names; must be set")
sql = flag.Bool("sql", false, "if true, the Scanner and Valuer interface will be implemented.")
json = flag.Bool("json", false, "if true, json marshaling methods will be generated. Default: false")
output = flag.String("output", "", "output file name; default srcdir/<type>_string.go")
)
@ -133,10 +134,15 @@ func main() {
g.Printf("\n")
g.Printf("package %s", g.pkg.name)
g.Printf("\n")
g.Printf("import \"fmt\"\n") // Used by all methods.
if *json {
g.Printf("import \"encoding/json\"\n")
g.Printf("import (\n")
g.Printf("\t\"fmt\"\n")
if *sql {
g.Printf("\t\"database/sql/driver\"\n")
}
if *json {
g.Printf("\t\"encoding/json\"\n")
}
g.Printf(")\n")
// Run generate for each type.
for _, typeName := range types {
@ -313,11 +319,17 @@ func (g *Generator) generate(typeName string, includeJSON bool) {
default:
g.buildMap(runs, typeName)
}
// ENUMER part
g.buildValueToNameMap(runs, typeName, runsThreshold)
if includeJSON {
g.buildJSONMethods(runs, typeName, runsThreshold)
}
// SQL
if *sql {
g.addValueAndScanMethod(typeName)
}
}
// splitIntoRuns breaks the values into runs of contiguous sequences.