From c05e1e23ee00a3871b6aa8a8c1c1ec098b165096 Mon Sep 17 00:00:00 2001 From: Nao Yonashiro Date: Wed, 26 Jan 2022 01:02:27 +0900 Subject: [PATCH] fix: panic when decoding time.Time with context close #327 --- decode_test.go | 14 ++++++++++++++ internal/decoder/unmarshal_json.go | 16 ++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/decode_test.go b/decode_test.go index 5df2c35..b416f7b 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3828,3 +3828,17 @@ func TestIssue303(t *testing.T) { t.Fatalf("failed to decode. count = %d type = %s value = %v", v.Count, v.Type, v.Value) } } + +func TestIssue327(t *testing.T) { + var v struct { + Date time.Time `json:"date"` + } + dec := json.NewDecoder(strings.NewReader(`{"date": "2021-11-23T13:47:30+01:00"})`)) + if err := dec.DecodeContext(context.Background(), &v); err != nil { + t.Fatal(err) + } + expected := "2021-11-23T13:47:30+01:00" + if got := v.Date.Format(time.RFC3339); got != expected { + t.Fatalf("failed to decode. expected %q but got %q", expected, got) + } +} diff --git a/internal/decoder/unmarshal_json.go b/internal/decoder/unmarshal_json.go index d90f39c..e9b25c6 100644 --- a/internal/decoder/unmarshal_json.go +++ b/internal/decoder/unmarshal_json.go @@ -1,6 +1,7 @@ package decoder import ( + "context" "encoding/json" "unsafe" @@ -46,13 +47,20 @@ func (d *unmarshalJSONDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Poi typ: d.typ, ptr: p, })) - if (s.Option.Flags & ContextOption) != 0 { - if err := v.(unmarshalerContext).UnmarshalJSON(s.Option.Context, dst); err != nil { + switch v := v.(type) { + case unmarshalerContext: + var ctx context.Context + if (s.Option.Flags & ContextOption) != 0 { + ctx = s.Option.Context + } else { + ctx = context.Background() + } + if err := v.UnmarshalJSON(ctx, dst); err != nil { d.annotateError(s.cursor, err) return err } - } else { - if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil { + case json.Unmarshaler: + if err := v.UnmarshalJSON(dst); err != nil { d.annotateError(s.cursor, err) return err }