Fix encoding of marshaler

This commit is contained in:
Masaaki Goshima 2021-03-18 23:56:56 +09:00
parent 141e3992af
commit a02cea2c89
6 changed files with 163 additions and 7 deletions

View File

@ -2,6 +2,7 @@ package json_test
import ( import (
"bytes" "bytes"
"fmt"
"testing" "testing"
"github.com/goccy/go-json" "github.com/goccy/go-json"
@ -12,7 +13,7 @@ type coverMarshalJSON struct {
} }
func (c coverMarshalJSON) MarshalJSON() ([]byte, error) { func (c coverMarshalJSON) MarshalJSON() ([]byte, error) {
return []byte(`"hello"`), nil return []byte(fmt.Sprint(c.A)), nil
} }
type coverPtrMarshalJSON struct { type coverPtrMarshalJSON struct {
@ -20,7 +21,22 @@ type coverPtrMarshalJSON struct {
} }
func (c *coverPtrMarshalJSON) MarshalJSON() ([]byte, error) { func (c *coverPtrMarshalJSON) MarshalJSON() ([]byte, error) {
return []byte(`"hello"`), nil if c == nil {
return []byte(`"NULL"`), nil
}
return []byte(fmt.Sprint(c.B)), nil
}
type coverPtrMarshalJSONString struct {
dummy int
C string
}
func (c *coverPtrMarshalJSONString) MarshalJSON() ([]byte, error) {
if c == nil {
return []byte(`"NULL"`), nil
}
return []byte(c.C), nil
} }
func TestCoverMarshalJSON(t *testing.T) { func TestCoverMarshalJSON(t *testing.T) {
@ -66,6 +82,117 @@ func TestCoverMarshalJSON(t *testing.T) {
name string name string
data interface{} data interface{}
}{ }{
{
name: "MarshalJSON",
data: coverMarshalJSON{A: 1},
},
{
name: "PtrMarshalJSON",
data: &coverMarshalJSON{A: 1},
},
{
name: "PtrMarshalJSON",
data: coverPtrMarshalJSON{B: 1},
},
{
name: "PtrPtrMarshalJSON",
data: &coverPtrMarshalJSON{B: 1},
},
{
name: "SliceMarshalJSON",
data: []coverMarshalJSON{{A: 1}, {A: 2}},
},
{
name: "SliceAddrMarshalJSON",
data: []*coverMarshalJSON{{A: 1}, {A: 2}},
},
{
name: "SlicePtrMarshalJSON",
data: []coverPtrMarshalJSON{{B: 1}, {B: 2}},
},
{
name: "SliceAddrPtrMarshalJSON",
data: []*coverPtrMarshalJSON{{B: 1}, {B: 2}},
},
{
name: "StructSliceMarshalJSON",
data: struct {
A []coverMarshalJSON
}{A: []coverMarshalJSON{{A: 1}, {A: 2}}},
},
{
name: "StructSliceAddrMarshalJSON",
data: struct {
A []*coverMarshalJSON
}{A: []*coverMarshalJSON{{A: 1}, {A: 2}}},
},
{
name: "StructSlicePtrMarshalJSON",
data: struct {
A []coverPtrMarshalJSON
}{A: []coverPtrMarshalJSON{{B: 1}, {B: 2}}},
},
{
name: "StructSliceAddrPtrMarshalJSON",
data: struct {
A []*coverPtrMarshalJSON
}{A: []*coverPtrMarshalJSON{{B: 1}, {B: 2}}},
},
{
name: "PtrStructSliceMarshalJSON",
data: &struct {
A []coverMarshalJSON
}{A: []coverMarshalJSON{{A: 1}, {A: 2}}},
},
{
name: "PtrStructSliceAddrMarshalJSON",
data: &struct {
A []*coverMarshalJSON
}{A: []*coverMarshalJSON{{A: 1}, {A: 2}}},
},
{
name: "PtrStructSlicePtrMarshalJSON",
data: &struct {
A []coverPtrMarshalJSON
}{A: []coverPtrMarshalJSON{{B: 1}, {B: 2}}},
},
{
name: "PtrStructSlicePtrMarshalJSONString",
data: &struct {
A []coverPtrMarshalJSONString
}{A: []coverPtrMarshalJSONString{{C: "1"}, {C: "2"}}},
},
{
name: "PtrStructSliceAddrPtrMarshalJSONString",
data: &struct {
A []*coverPtrMarshalJSONString
}{A: []*coverPtrMarshalJSONString{{C: "1"}, {C: "2"}}},
},
{
name: "PtrStructArrayPtrMarshalJSONString",
data: &struct {
A [2]coverPtrMarshalJSONString
}{A: [2]coverPtrMarshalJSONString{{C: "1"}, {C: "2"}}},
},
{
name: "PtrStructArrayAddrPtrMarshalJSONString",
data: &struct {
A [2]*coverPtrMarshalJSONString
}{A: [2]*coverPtrMarshalJSONString{{C: "1"}, {C: "2"}}},
},
{
name: "PtrStructMapPtrMarshalJSONString",
data: &struct {
A map[string]coverPtrMarshalJSONString
}{A: map[string]coverPtrMarshalJSONString{"a": {C: "1"}, "b": {C: "2"}}},
},
{
name: "PtrStructMapAddrPtrMarshalJSONString",
data: &struct {
A map[string]*coverPtrMarshalJSONString
}{A: map[string]*coverPtrMarshalJSONString{"a": {C: "1"}, "b": {C: "2"}}},
},
// HeadMarshalJSONZero // HeadMarshalJSONZero
{ {
name: "HeadMarshalJSONZero", name: "HeadMarshalJSONZero",

View File

@ -600,10 +600,11 @@ func compileSlice(ctx *compileContext) (*Opcode, error) {
header := newSliceHeaderCode(ctx) header := newSliceHeaderCode(ctx)
ctx.incIndex() ctx.incIndex()
code, err := compileSliceElem(ctx.withType(elem).incIndent()) code, err := compileListElem(ctx.withType(elem).incIndent())
if err != nil { if err != nil {
return nil, err return nil, err
} }
code.Indirect = true
// header => opcode => elem => end // header => opcode => elem => end
// ^ | // ^ |
@ -624,7 +625,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) {
return (*Opcode)(unsafe.Pointer(header)), nil return (*Opcode)(unsafe.Pointer(header)), nil
} }
func compileSliceElem(ctx *compileContext) (*Opcode, error) { func compileListElem(ctx *compileContext) (*Opcode, error) {
typ := ctx.typ typ := ctx.typ
switch { switch {
case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType): case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType):
@ -645,10 +646,11 @@ func compileArray(ctx *compileContext) (*Opcode, error) {
header := newArrayHeaderCode(ctx, alen) header := newArrayHeaderCode(ctx, alen)
ctx.incIndex() ctx.incIndex()
code, err := compile(ctx.withType(elem).incIndent(), false) code, err := compileListElem(ctx.withType(elem).incIndent())
if err != nil { if err != nil {
return nil, err return nil, err
} }
code.Indirect = true
// header => opcode => elem => end // header => opcode => elem => end
// ^ | // ^ |
// |________| // |________|
@ -690,6 +692,7 @@ func compileMap(ctx *compileContext) (*Opcode, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
valueCode.Indirect = true
key := newMapKeyCode(ctx, header) key := newMapKeyCode(ctx, header)
ctx.incIndex() ctx.incIndex()
@ -1057,10 +1060,11 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
// *struct{ field T } => struct { field *T } // *struct{ field T } => struct { field *T }
// func (*T) MarshalJSON() ([]byte, error) // func (*T) MarshalJSON() ([]byte, error)
// move pointer position from head to first field // move pointer position from head to first field
code, err := compileMarshalJSON(ctx.withType(runtime.PtrTo(fieldType))) code, err := compileMarshalJSON(ctx.withType(fieldType))
if err != nil { if err != nil {
return nil, err return nil, err
} }
addrForMarshaler = true
valueCode = code valueCode = code
nilcheck = false nilcheck = false
indirect = false indirect = false
@ -1069,10 +1073,11 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
// *struct{ field T } => struct { field *T } // *struct{ field T } => struct { field *T }
// func (*T) MarshalText() ([]byte, error) // func (*T) MarshalText() ([]byte, error)
// move pointer position from head to first field // move pointer position from head to first field
code, err := compileMarshalText(ctx.withType(runtime.PtrTo(fieldType))) code, err := compileMarshalText(ctx.withType(fieldType))
if err != nil { if err != nil {
return nil, err return nil, err
} }
addrForMarshaler = true
valueCode = code valueCode = code
nilcheck = false nilcheck = false
indirect = false indirect = false

View File

@ -272,6 +272,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalJSON(code, b, ptrToInterface(code, p), false) bb, err := appendMarshalJSON(code, b, ptrToInterface(code, p), false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -296,6 +299,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalText(code, b, ptrToInterface(code, p), false) bb, err := appendMarshalText(code, b, ptrToInterface(code, p), false)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -272,6 +272,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalJSON(code, b, ptrToInterface(code, p), true) bb, err := appendMarshalJSON(code, b, ptrToInterface(code, p), true)
if err != nil { if err != nil {
return nil, err return nil, err
@ -296,6 +299,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalText(code, b, ptrToInterface(code, p), true) bb, err := appendMarshalText(code, b, ptrToInterface(code, p), true)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -270,6 +270,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalJSON(ctx, code, b, ptrToInterface(code, p), code.Indent, true) bb, err := appendMarshalJSON(ctx, code, b, ptrToInterface(code, p), code.Indent, true)
if err != nil { if err != nil {
return nil, err return nil, err
@ -294,6 +297,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalText(code, b, ptrToInterface(code, p), true) bb, err := appendMarshalText(code, b, ptrToInterface(code, p), true)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -276,6 +276,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalJSON(ctx, code, b, ptrToInterface(code, p), code.Indent, false) bb, err := appendMarshalJSON(ctx, code, b, ptrToInterface(code, p), code.Indent, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -300,6 +303,9 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
code = code.Next code = code.Next
break break
} }
if code.Type.Kind() == reflect.Ptr && code.Indirect {
p = ptrToPtr(p)
}
bb, err := appendMarshalText(code, b, ptrToInterface(code, p), false) bb, err := appendMarshalText(code, b, ptrToInterface(code, p), false)
if err != nil { if err != nil {
return nil, err return nil, err