Fixes and adds tests after merge

This commit is contained in:
Álvaro 2016-10-26 16:59:55 +01:00
parent ad5bc9c3e5
commit d3da35e3d0
3 changed files with 231 additions and 10 deletions

View File

@ -33,7 +33,7 @@ func TestEndToEnd(t *testing.T) {
defer os.RemoveAll(dir)
// Create stringer in temporary directory.
stringer := filepath.Join(dir, "stringer.exe")
err = run("go", "build", "-o", stringer, "enumer.go", "stringer.go")
err = run("go", "build", "-o", stringer, "enumer.go", "sql.go", "stringer.go")
if err != nil {
t.Fatalf("building stringer: %s", err)
}

View File

@ -34,6 +34,14 @@ var goldenJSON = []Golden{
{"prime", prime_json_in, prime_json_out},
}
var goldenSQL = []Golden{
{"prime", prime_sql_in, prime_sql_out},
}
var goldenJSONAndSQL = []Golden{
{"prime", prime_json_and_sql_in, prime_json_and_sql_out},
}
// Each example starts with "type XXX [u]int", with a single space separating them.
// Simple test: enumeration of type int starting at 0.
@ -426,16 +434,231 @@ func (i *Prime) UnmarshalJSON(data []byte) error {
}
`
const prime_sql_in = `type Prime int
const (
p2 Prime = 2
p3 Prime = 3
p5 Prime = 5
p7 Prime = 7
p77 Prime = 7 // Duplicate; note that p77 doesn't appear below.
p11 Prime = 11
p13 Prime = 13
p17 Prime = 17
p19 Prime = 19
p23 Prime = 23
p29 Prime = 29
p37 Prime = 31
p41 Prime = 41
p43 Prime = 43
)
`
const prime_sql_out = `
const _Prime_name = "p2p3p5p7p11p13p17p19p23p29p37p41p43"
var _Prime_map = map[Prime]string{
2: _Prime_name[0:2],
3: _Prime_name[2:4],
5: _Prime_name[4:6],
7: _Prime_name[6:8],
11: _Prime_name[8:11],
13: _Prime_name[11:14],
17: _Prime_name[14:17],
19: _Prime_name[17:20],
23: _Prime_name[20:23],
29: _Prime_name[23:26],
31: _Prime_name[26:29],
41: _Prime_name[29:32],
43: _Prime_name[32:35],
}
func (i Prime) String() string {
if str, ok := _Prime_map[i]; ok {
return str
}
return fmt.Sprintf("Prime(%d)", i)
}
var _PrimeNameToValue_map = map[string]Prime{
_Prime_name[0:2]: 2,
_Prime_name[2:4]: 3,
_Prime_name[4:6]: 5,
_Prime_name[6:8]: 7,
_Prime_name[8:11]: 11,
_Prime_name[11:14]: 13,
_Prime_name[14:17]: 17,
_Prime_name[17:20]: 19,
_Prime_name[20:23]: 23,
_Prime_name[23:26]: 29,
_Prime_name[26:29]: 31,
_Prime_name[29:32]: 41,
_Prime_name[32:35]: 43,
}
func PrimeString(s string) (Prime, error) {
if val, ok := _PrimeNameToValue_map[s]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Prime values", s)
}
func (i Prime) Value() (driver.Value, error) {
return i.String(), nil
}
func (i *Prime) Scan(value interface{}) error {
if value == nil {
return nil
}
str, ok := value.(string)
if !ok {
bytes, ok := value.([]byte)
if !ok {
return fmt.Errorf("value is not a byte slice")
}
str = string(bytes[:])
}
val, err := PrimeString(str)
if err != nil {
return err
}
*i = val
return nil
}
`
const prime_json_and_sql_in = `type Prime int
const (
p2 Prime = 2
p3 Prime = 3
p5 Prime = 5
p7 Prime = 7
p77 Prime = 7 // Duplicate; note that p77 doesn't appear below.
p11 Prime = 11
p13 Prime = 13
p17 Prime = 17
p19 Prime = 19
p23 Prime = 23
p29 Prime = 29
p37 Prime = 31
p41 Prime = 41
p43 Prime = 43
)
`
const prime_json_and_sql_out = `
const _Prime_name = "p2p3p5p7p11p13p17p19p23p29p37p41p43"
var _Prime_map = map[Prime]string{
2: _Prime_name[0:2],
3: _Prime_name[2:4],
5: _Prime_name[4:6],
7: _Prime_name[6:8],
11: _Prime_name[8:11],
13: _Prime_name[11:14],
17: _Prime_name[14:17],
19: _Prime_name[17:20],
23: _Prime_name[20:23],
29: _Prime_name[23:26],
31: _Prime_name[26:29],
41: _Prime_name[29:32],
43: _Prime_name[32:35],
}
func (i Prime) String() string {
if str, ok := _Prime_map[i]; ok {
return str
}
return fmt.Sprintf("Prime(%d)", i)
}
var _PrimeNameToValue_map = map[string]Prime{
_Prime_name[0:2]: 2,
_Prime_name[2:4]: 3,
_Prime_name[4:6]: 5,
_Prime_name[6:8]: 7,
_Prime_name[8:11]: 11,
_Prime_name[11:14]: 13,
_Prime_name[14:17]: 17,
_Prime_name[17:20]: 19,
_Prime_name[20:23]: 23,
_Prime_name[23:26]: 29,
_Prime_name[26:29]: 31,
_Prime_name[29:32]: 41,
_Prime_name[32:35]: 43,
}
func PrimeString(s string) (Prime, error) {
if val, ok := _PrimeNameToValue_map[s]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Prime values", s)
}
func (i Prime) MarshalJSON() ([]byte, error) {
return json.Marshal(i.String())
}
func (i *Prime) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("Prime should be a string, got %s", data)
}
var err error
*i, err = PrimeString(s)
return err
}
func (i Prime) Value() (driver.Value, error) {
return i.String(), nil
}
func (i *Prime) Scan(value interface{}) error {
if value == nil {
return nil
}
str, ok := value.(string)
if !ok {
bytes, ok := value.([]byte)
if !ok {
return fmt.Errorf("value is not a byte slice")
}
str = string(bytes[:])
}
val, err := PrimeString(str)
if err != nil {
return err
}
*i = val
return nil
}
`
func TestGolden(t *testing.T) {
for _, test := range golden {
runGoldenTest(t, test, false)
runGoldenTest(t, test, false, false)
}
for _, test := range goldenJSON {
runGoldenTest(t, test, true)
runGoldenTest(t, test, true, false)
}
for _, test := range goldenSQL {
runGoldenTest(t, test, false, true)
}
for _, test := range goldenJSONAndSQL {
runGoldenTest(t, test, true, true)
}
}
func runGoldenTest(t *testing.T, test Golden, generateJSON bool) {
func runGoldenTest(t *testing.T, test Golden, generateJSON, generateSQL bool) {
var g Generator
input := "package test\n" + test.input
file := test.name + ".go"
@ -445,7 +668,7 @@ func runGoldenTest(t *testing.T, test Golden, generateJSON bool) {
if len(tokens) != 3 {
t.Fatalf("%s: need type declaration on first line", test.name)
}
g.generate(tokens[1], generateJSON)
g.generate(tokens[1], generateJSON, generateSQL)
got := string(g.format())
if got != test.output {
t.Errorf("%s: got\n====\n%s====\nexpected\n====%s", test.name, got, test.output)

View File

@ -146,7 +146,7 @@ func main() {
// Run generate for each type.
for _, typeName := range types {
g.generate(typeName, *json)
g.generate(typeName, *json, *sql)
}
// Format the output.
@ -282,7 +282,7 @@ func (pkg *Package) check(fs *token.FileSet, astFiles []*ast.File) {
}
// generate produces the String method for the named type.
func (g *Generator) generate(typeName string, includeJSON bool) {
func (g *Generator) generate(typeName string, includeJSON, includeSQL bool) {
values := make([]Value, 0, 100)
for _, file := range g.pkg.files {
// Set the state for this run of the walker.
@ -320,14 +320,12 @@ func (g *Generator) generate(typeName string, includeJSON bool) {
g.buildMap(runs, typeName)
}
// ENUMER part
g.buildValueToNameMap(runs, typeName, runsThreshold)
if includeJSON {
g.buildJSONMethods(runs, typeName, runsThreshold)
}
// SQL
if *sql {
if includeSQL {
g.addValueAndScanMethod(typeName)
}
}