diff --git a/ws/client.go b/ws/client.go index 2561085..999555b 100644 --- a/ws/client.go +++ b/ws/client.go @@ -144,8 +144,13 @@ func clientReader(c *client){ /*** Parse message ***/ msg, err := readMessage(c.io.reader) + if err == UnmaskedFrameErr { + errorCode = PROTOCOL_ERR + clientAck = false + break + } + if err != nil { - // fmt.Printf(" [reader] %s\n", err) break } diff --git a/ws/message.go b/ws/message.go index 58b6f5f..3060d5a 100644 --- a/ws/message.go +++ b/ws/message.go @@ -1,11 +1,13 @@ package ws import ( + "fmt" "io" - "net" "encoding/binary" ) +var UnmaskedFrameErr = fmt.Errorf("Received unmasked frame") + // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask const maximumHeaderSize = 1 + 1 + 8 + 4 @@ -37,10 +39,10 @@ const ( // Represents a websocket message type Message struct { - Type MessageType - Data []byte - Size uint Final bool + Type MessageType + Size uint + Data []byte } @@ -141,6 +143,11 @@ func readMessage(reader io.Reader) (*Message, error){ } + // return error if unmasked frame + if mask == nil { + return nil, UnmaskedFrameErr + } + return m, nil } diff --git a/ws/message_test.go b/ws/message_test.go new file mode 100644 index 0000000..5aba272 --- /dev/null +++ b/ws/message_test.go @@ -0,0 +1,80 @@ +package ws + +import ( + "bytes" + "testing" +) + + +func TestSimpleMessageReading(t *testing.T) { + + cases := []struct{ + ReadBuffer []byte + Expected Message + Err error + }{ + { // FIN ; TEXT ; Unmasked -> error + []byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f}, + Message{}, + UnmaskedFrameErr, + }, + { // FIN ; TEXT ; hello + []byte{0x81,0x85,0x00,0x00,0x00,0x00,0x68,0x65,0x6c,0x6c,0x6f}, + Message{ true, TEXT, 5, []byte("hello") }, + nil, + }, + { // FIN ; BINARY ; hello + []byte{0x82,0x85,0x00,0x00,0x00,0x00,0x68,0x65,0x6c,0x6c,0x6f}, + Message{ true, BINARY, 5, []byte("hello") }, + nil, + }, + { // FIN ; BINARY ; test unmasking + []byte{0x82,0x88,0x01,0x02,0x03,0x04,0x10,0x20,0x30,0x40,0x50,0x60,0x70,0x80}, + Message{ true, BINARY, 8, []byte{0x11,0x22,0x33,0x44,0x51,0x62,0x73,0x84} }, + nil, + }, + { // FIN=0 ; TEXT ; + []byte{0x01,0x82,0x00,0x00,0x00,0x00,0x01,0x02}, + Message{ false, TEXT, 2, []byte{0x01,0x02} }, + nil, + }, + } + + for _, tc := range cases{ + + reader := bytes.NewBuffer(tc.ReadBuffer) + + got, err := readMessage(reader) + + if err != tc.Err { + t.Errorf("Expected %v error, got %v", tc.Err, err) + } + + // do not check message if error expected + if tc.Err != nil { + continue + } + + // check FIN + if got.Final != tc.Expected.Final { + t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, got.Final) + } + + // check OpCode + if got.Type != tc.Expected.Type { + t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, got.Type) + } + + // check Size + if got.Size != tc.Expected.Size { + t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, got.Size) + } + + // check Data + if string(got.Data) != string(tc.Expected.Data) { + t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, got.Data) + } + + } + +} \ No newline at end of file