diff --git a/README.md b/README.md index 821709c..e6b2ca7 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -#Enumer +# Enumer Enumer generates Go code to get string names from enum values and viceversa. It is a fork of [Rob Pike’s Stringer tool](https://godoc.org/golang.org/x/tools/cmd/stringer) but adding a *"string to enum value"* method to the generated code. @@ -50,4 +50,7 @@ The generated code is exactly the same as the Stringer tool plus the `Stri ## How to use The usage of Enumer is the same as Stringer, no changes were introduced. -For more information please refer to the [Stringer docs](https://godoc.org/golang.org/x/tools/cmd/stringer) \ No newline at end of file +For more information please refer to the [Stringer docs](https://godoc.org/golang.org/x/tools/cmd/stringer) + +## Additional functions of this fork +This fork additionally implements the Scanner and Valuer interface to use a enum seamlessly in a database model. diff --git a/sql.go b/sql.go new file mode 100644 index 0000000..27aee38 --- /dev/null +++ b/sql.go @@ -0,0 +1,31 @@ +package main + +// Arguments to format are: +// [1]: type name +const valuer = `func (i %[1]s) Value() (driver.Value, error) { + return i.String(), nil +} +` + +const scanner = `func (i %[1]s) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + fmt.Errorf("value is not a string") + } + + val, err := %[1]sString(str) + if err != nil { + return err + } + + i = val + return nil +} +` + +func (g *Generator) addValuerAndScanner(runs [][]Value, typeName string, runsThreshold int) { + g.Printf("\n") + g.Printf(valuer, typeName) + g.Printf("\n\n") + g.Printf(scanner, typeName) +} diff --git a/stringer.go b/stringer.go index 81efa59..0944bc0 100644 --- a/stringer.go +++ b/stringer.go @@ -128,11 +128,14 @@ func main() { } // Print the header and package clause. - g.Printf("// Code generated by \"stringer %s\"; DO NOT EDIT\n", strings.Join(os.Args[1:], " ")) + g.Printf("// Code generated by \"enumer %s\"; DO NOT EDIT\n", strings.Join(os.Args[1:], " ")) g.Printf("\n") g.Printf("package %s", g.pkg.name) g.Printf("\n") - g.Printf("import \"fmt\"\n") // Used by all methods. + g.Printf("import (\n") + g.Printf("\t\"fmt\"\n") + g.Printf("\t\"database/sql/driver\"\n") + g.Printf(")\n") // Run generate for each type. for _, typeName := range types { @@ -310,6 +313,9 @@ func (g *Generator) generate(typeName string) { } // ENUMER: This is the only addition over the original stringer code. Everything else is in enumer.go g.buildValueToNameMap(runs, typeName, 10) + + // SQL + g.addValuerAndScanner(runs, typeName, 0) } // splitIntoRuns breaks the values into runs of contiguous sequences.