Refactor decoder

This commit is contained in:
Masaaki Goshima 2020-04-22 17:59:01 +09:00
parent aa0aff6388
commit 0c42c47179
2 changed files with 70 additions and 36 deletions

View File

@ -22,30 +22,30 @@ const (
) )
type Decoder struct { type Decoder struct {
r io.Reader r io.Reader
state int state int
literal []byte value []byte
} }
type context struct { type context struct {
idx int idx int
keys [][]byte keys [][]byte
literals [][]byte values [][]byte
start int start int
stack int stack int
} }
func newContext() *context { func newContext() *context {
return &context{ return &context{
keys: make([][]byte, 64), keys: make([][]byte, 64),
literals: make([][]byte, 64), values: make([][]byte, 64),
} }
} }
func (c *context) pushStack() { func (c *context) pushStack() {
if len(c.keys) <= c.stack { if len(c.keys) <= c.stack {
c.keys = append(c.keys, nil) c.keys = append(c.keys, nil)
c.literals = append(c.literals, nil) c.values = append(c.values, nil)
} }
c.stack++ c.stack++
} }
@ -58,26 +58,26 @@ func (c *context) setKey(key []byte) {
c.keys[c.stack] = key c.keys[c.stack] = key
} }
func (c *context) setLiteral(literal []byte) { func (c *context) setValue(value []byte) {
c.literals[c.stack] = literal c.values[c.stack] = value
} }
func (c *context) key() ([]byte, error) { func (c *context) key() ([]byte, error) {
if len(c.keys) <= c.stack { if len(c.keys) <= c.stack {
return nil, errors.New("unexpected error") return nil, errors.New("unexpected error key")
} }
key := c.keys[c.stack] key := c.keys[c.stack]
if len(key) == 0 { if len(key) == 0 {
return nil, errors.New("unexpected error") return nil, errors.New("unexpected error key")
} }
return key, nil return key, nil
} }
func (c *context) literal() ([]byte, error) { func (c *context) value() ([]byte, error) {
if len(c.literals) <= c.stack { if len(c.values) <= c.stack {
return nil, errors.New("unexpected error") return nil, errors.New("unexpected error value")
} }
return c.literals[c.stack], nil return c.values[c.stack], nil
} }
var ( var (
@ -419,11 +419,9 @@ func (d *Decoder) compileStruct(v reflect.Value) (DecodeOp, error) {
func (d *Decoder) decode(ctx *context, src []byte, ptr uintptr, op DecodeOp) error { func (d *Decoder) decode(ctx *context, src []byte, ptr uintptr, op DecodeOp) error {
slen := len(src) slen := len(src)
for i := ctx.idx; i < slen; i++ { for i := 0; i < slen; i++ {
c := src[i] c := src[i]
switch c { switch c {
case ' ':
ctx.start++
case '{': case '{':
ctx.pushStack() ctx.pushStack()
case '}': case '}':
@ -431,27 +429,23 @@ func (d *Decoder) decode(ctx *context, src []byte, ptr uintptr, op DecodeOp) err
if err != nil { if err != nil {
return err return err
} }
if err := op(ptr, key, src[ctx.start:i]); err != nil { if err := op(ptr, key, d.value); err != nil {
return err return err
} }
ctx.popStack() ctx.popStack()
case '[': case '[':
d.state = stateArray
case ']': case ']':
d.state = stateNone
case ':': case ':':
if len(d.literal) == 0 { if len(d.value) == 0 {
return errors.New("unexpected error") return errors.New("unexpected error map value")
} }
ctx.setKey(d.literal) ctx.setKey(d.value)
ctx.start = i + 1
case ',': case ',':
literal := src[ctx.start:i]
key, err := ctx.key() key, err := ctx.key()
if err != nil { if err != nil {
return err return err
} }
if err := op(ptr, key, literal); err != nil { if err := op(ptr, key, d.value); err != nil {
return err return err
} }
case '"': case '"':
@ -463,9 +457,47 @@ func (d *Decoder) decode(ctx *context, src []byte, ptr uintptr, op DecodeOp) err
} }
end := i end := i
if end <= start { if end <= start {
return errors.New("unexpected error") return errors.New("unexpected error value")
} }
d.literal = src[start:end] d.value = src[start:end]
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
start := i
for ; i < slen; i++ {
c := src[i]
switch c {
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', 'e', 'E':
default:
goto end
}
}
end:
end := i
if end <= start {
return errors.New("unexpected error number")
}
d.value = src[start:end]
i--
case 't':
if i+3 < slen && src[i+1] == 'r' && src[i+2] == 'u' && src[i+3] == 'e' {
d.value = []byte("true")
} else {
return errors.New("unexpected error true")
}
i += 3
case 'f':
if i+4 < slen && src[i+1] == 'a' && src[i+2] == 'l' && src[i+3] == 's' && src[i+4] == 'e' {
d.value = []byte("false")
} else {
return errors.New("unexpected error false")
}
i += 4
case 'n':
if i+3 < slen && src[i+1] == 'u' && src[i+2] == 'l' && src[i+3] == 'l' {
d.value = []byte("null")
} else {
return errors.New("unexpected error null")
}
i += 3
} }
} }
return nil return nil

View File

@ -9,9 +9,11 @@ import (
func Test_Decoder(t *testing.T) { func Test_Decoder(t *testing.T) {
t.Run("struct", func(t *testing.T) { t.Run("struct", func(t *testing.T) {
var v struct { var v struct {
A int A int `json:"abcd"`
B string `json:"str"`
} }
assertErr(t, json.Unmarshal([]byte(`{"a":123}`), &v)) assertErr(t, json.Unmarshal([]byte(`{ "abcd" : 123 , "str" : "hello" }`), &v))
assertEq(t, "struct.A", v.A, 123) assertEq(t, "struct.A", 123, v.A)
assertEq(t, "struct.B", "hello", v.B)
}) })
} }