From 3f92bfdfb28cd667adaf06838e0cb135b636f3ed Mon Sep 17 00:00:00 2001 From: Halo Arrow Date: Fri, 25 Aug 2023 18:54:44 -0700 Subject: [PATCH] Use crypto/rand for mask key --- conn.go | 13 ++++++++++--- prepared_test.go | 9 +++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 9353252..44b1aed 100644 --- a/conn.go +++ b/conn.go @@ -6,10 +6,10 @@ package websocket import ( "bufio" + "crypto/rand" "encoding/binary" "errors" "io" - "math/rand" "net" "strconv" "strings" @@ -180,9 +180,16 @@ var ( errInvalidControlFrame = errors.New("websocket: invalid control frame") ) +// maskRand is an io.Reader for generating mask bytes. The reader is initialized +// to crypto/rand Reader. Tests swap the reader to a math/rand reader for +// reproducible results. +var maskRand = rand.Reader + +// newMaskKey returns a new 32 bit value for masking client frames. func newMaskKey() [4]byte { - n := rand.Uint32() - return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} + var k [4]byte + _, _ = io.ReadFull(maskRand, k[:]) + return k } func hideTempErr(err error) error { diff --git a/prepared_test.go b/prepared_test.go index 2297802..536d58d 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -33,6 +33,11 @@ var preparedMessageTests = []struct { } func TestPreparedMessage(t *testing.T) { + testRand := rand.New(rand.NewSource(99)) + prevMaskRand := maskRand + maskRand = testRand + defer func() { maskRand = prevMaskRand }() + for _, tt := range preparedMessageTests { var data = []byte("this is a test") var buf bytes.Buffer @@ -43,7 +48,7 @@ func TestPreparedMessage(t *testing.T) { c.SetCompressionLevel(tt.compressionLevel) // Seed random number generator for consistent frame mask. - rand.Seed(1234) + testRand.Seed(1234) if err := c.WriteMessage(tt.messageType, data); err != nil { t.Fatal(err) @@ -59,7 +64,7 @@ func TestPreparedMessage(t *testing.T) { copy(data, "hello world") // Seed random number generator for consistent frame mask. - rand.Seed(1234) + testRand.Seed(1234) buf.Reset() if err := c.WritePreparedMessage(pm); err != nil {