Refactor decoder

This commit is contained in:
Masaaki Goshima 2020-04-22 17:59:01 +09:00
parent aa0aff6388
commit 0c42c47179
2 changed files with 70 additions and 36 deletions

View File

@ -24,13 +24,13 @@ const (
type Decoder struct {
r io.Reader
state int
literal []byte
value []byte
}
type context struct {
idx int
keys [][]byte
literals [][]byte
values [][]byte
start int
stack int
}
@ -38,14 +38,14 @@ type context struct {
func newContext() *context {
return &context{
keys: make([][]byte, 64),
literals: make([][]byte, 64),
values: make([][]byte, 64),
}
}
func (c *context) pushStack() {
if len(c.keys) <= c.stack {
c.keys = append(c.keys, nil)
c.literals = append(c.literals, nil)
c.values = append(c.values, nil)
}
c.stack++
}
@ -58,26 +58,26 @@ func (c *context) setKey(key []byte) {
c.keys[c.stack] = key
}
func (c *context) setLiteral(literal []byte) {
c.literals[c.stack] = literal
func (c *context) setValue(value []byte) {
c.values[c.stack] = value
}
func (c *context) key() ([]byte, error) {
if len(c.keys) <= c.stack {
return nil, errors.New("unexpected error")
return nil, errors.New("unexpected error key")
}
key := c.keys[c.stack]
if len(key) == 0 {
return nil, errors.New("unexpected error")
return nil, errors.New("unexpected error key")
}
return key, nil
}
func (c *context) literal() ([]byte, error) {
if len(c.literals) <= c.stack {
return nil, errors.New("unexpected error")
func (c *context) value() ([]byte, error) {
if len(c.values) <= c.stack {
return nil, errors.New("unexpected error value")
}
return c.literals[c.stack], nil
return c.values[c.stack], nil
}
var (
@ -419,11 +419,9 @@ func (d *Decoder) compileStruct(v reflect.Value) (DecodeOp, error) {
func (d *Decoder) decode(ctx *context, src []byte, ptr uintptr, op DecodeOp) error {
slen := len(src)
for i := ctx.idx; i < slen; i++ {
for i := 0; i < slen; i++ {
c := src[i]
switch c {
case ' ':
ctx.start++
case '{':
ctx.pushStack()
case '}':
@ -431,27 +429,23 @@ func (d *Decoder) decode(ctx *context, src []byte, ptr uintptr, op DecodeOp) err
if err != nil {
return err
}
if err := op(ptr, key, src[ctx.start:i]); err != nil {
if err := op(ptr, key, d.value); err != nil {
return err
}
ctx.popStack()
case '[':
d.state = stateArray
case ']':
d.state = stateNone
case ':':
if len(d.literal) == 0 {
return errors.New("unexpected error")
if len(d.value) == 0 {
return errors.New("unexpected error map value")
}
ctx.setKey(d.literal)
ctx.start = i + 1
ctx.setKey(d.value)
case ',':
literal := src[ctx.start:i]
key, err := ctx.key()
if err != nil {
return err
}
if err := op(ptr, key, literal); err != nil {
if err := op(ptr, key, d.value); err != nil {
return err
}
case '"':
@ -463,9 +457,47 @@ func (d *Decoder) decode(ctx *context, src []byte, ptr uintptr, op DecodeOp) err
}
end := i
if end <= start {
return errors.New("unexpected error")
return errors.New("unexpected error value")
}
d.literal = src[start:end]
d.value = src[start:end]
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
start := i
for ; i < slen; i++ {
c := src[i]
switch c {
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', 'e', 'E':
default:
goto end
}
}
end:
end := i
if end <= start {
return errors.New("unexpected error number")
}
d.value = src[start:end]
i--
case 't':
if i+3 < slen && src[i+1] == 'r' && src[i+2] == 'u' && src[i+3] == 'e' {
d.value = []byte("true")
} else {
return errors.New("unexpected error true")
}
i += 3
case 'f':
if i+4 < slen && src[i+1] == 'a' && src[i+2] == 'l' && src[i+3] == 's' && src[i+4] == 'e' {
d.value = []byte("false")
} else {
return errors.New("unexpected error false")
}
i += 4
case 'n':
if i+3 < slen && src[i+1] == 'u' && src[i+2] == 'l' && src[i+3] == 'l' {
d.value = []byte("null")
} else {
return errors.New("unexpected error null")
}
i += 3
}
}
return nil

View File

@ -9,9 +9,11 @@ import (
func Test_Decoder(t *testing.T) {
t.Run("struct", func(t *testing.T) {
var v struct {
A int
A int `json:"abcd"`
B string `json:"str"`
}
assertErr(t, json.Unmarshal([]byte(`{"a":123}`), &v))
assertEq(t, "struct.A", v.A, 123)
assertErr(t, json.Unmarshal([]byte(`{ "abcd" : 123 , "str" : "hello" }`), &v))
assertEq(t, "struct.A", 123, v.A)
assertEq(t, "struct.B", "hello", v.B)
})
}