Use crypto/rand for mask key

This commit is contained in:
Halo Arrow 2023-08-25 18:54:44 -07:00 committed by Canelo Hill
parent 1d5465562b
commit 3f92bfdfb2
2 changed files with 17 additions and 5 deletions

13
conn.go
View File

@ -6,10 +6,10 @@ package websocket
import ( import (
"bufio" "bufio"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"math/rand"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -180,9 +180,16 @@ var (
errInvalidControlFrame = errors.New("websocket: invalid control frame") errInvalidControlFrame = errors.New("websocket: invalid control frame")
) )
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
// reproducible results.
var maskRand = rand.Reader
// newMaskKey returns a new 32 bit value for masking client frames.
func newMaskKey() [4]byte { func newMaskKey() [4]byte {
n := rand.Uint32() var k [4]byte
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} _, _ = io.ReadFull(maskRand, k[:])
return k
} }
func hideTempErr(err error) error { func hideTempErr(err error) error {

View File

@ -33,6 +33,11 @@ var preparedMessageTests = []struct {
} }
func TestPreparedMessage(t *testing.T) { func TestPreparedMessage(t *testing.T) {
testRand := rand.New(rand.NewSource(99))
prevMaskRand := maskRand
maskRand = testRand
defer func() { maskRand = prevMaskRand }()
for _, tt := range preparedMessageTests { for _, tt := range preparedMessageTests {
var data = []byte("this is a test") var data = []byte("this is a test")
var buf bytes.Buffer var buf bytes.Buffer
@ -43,7 +48,7 @@ func TestPreparedMessage(t *testing.T) {
c.SetCompressionLevel(tt.compressionLevel) c.SetCompressionLevel(tt.compressionLevel)
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
if err := c.WriteMessage(tt.messageType, data); err != nil { if err := c.WriteMessage(tt.messageType, data); err != nil {
t.Fatal(err) t.Fatal(err)
@ -59,7 +64,7 @@ func TestPreparedMessage(t *testing.T) {
copy(data, "hello world") copy(data, "hello world")
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
buf.Reset() buf.Reset()
if err := c.WritePreparedMessage(pm); err != nil { if err := c.WritePreparedMessage(pm); err != nil {