diff --git a/ws/client.go b/ws/client.go index ff24278..86c40ae 100644 --- a/ws/client.go +++ b/ws/client.go @@ -146,9 +146,8 @@ func clientReader(c *client){ /* (2) Parse message */ msg, err := readMessage(c.io.reader) - if err == ErrUnmaskedFrame { + if err == ErrUnmaskedFrame || err == ErrReservedBits { closeStatus = PROTOCOL_ERR - break } if err != nil { break } diff --git a/ws/message.go b/ws/message.go index 26f5404..d3b19ea 100644 --- a/ws/message.go +++ b/ws/message.go @@ -14,6 +14,7 @@ var ErrInvalidSize = fmt.Errorf("Received invalid payload size") var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload") var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status") var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode") +var ErrReservedBits = fmt.Errorf("Received reserved bits") var CloseFrame = fmt.Errorf("Received close Frame") @@ -68,16 +69,21 @@ func readMessage(reader io.Reader) (*Message, error){ /* (2) Byte 1: FIN and OpCode */ tmpBuf = make([]byte, 1) - _, err = reader.Read(tmpBuf) + err = readBytes(reader, tmpBuf) if err != nil { return m, err } + // check reserved bits + if tmpBuf[0] & 0x70 != 0 { + return m, ErrReservedBits + } + m.Final = bool( tmpBuf[0] & 0x80 == 0x80 ) m.Type = MessageType( tmpBuf[0] & 0x0f ) /* (3) Byte 2: Mask and Length[0] */ tmpBuf = make([]byte, 1) - _, err = reader.Read(tmpBuf) + err = readBytes(reader, tmpBuf) if err != nil { return m, err } // if mask, byte array not nil @@ -92,18 +98,16 @@ func readMessage(reader io.Reader) (*Message, error){ if m.Size == 127 { tmpBuf = make([]byte, 8) - nbr, err := reader.Read(tmpBuf) + err := readBytes(reader, tmpBuf) if err != nil { return m, err } - if nbr < 8 { return m, io.EOF } m.Size = uint( binary.BigEndian.Uint64(tmpBuf) ) } else if m.Size == 126 { tmpBuf = make([]byte, 2) - nbr, err := reader.Read(tmpBuf) + err := readBytes(reader, tmpBuf) if err != nil { return m, err } - if nbr < 2 { return m, io.EOF } m.Size = uint( binary.BigEndian.Uint16(tmpBuf) ) @@ -113,9 +117,8 @@ func readMessage(reader io.Reader) (*Message, error){ if mask != nil { tmpBuf = make([]byte, 4) - nbr, err := reader.Read(tmpBuf) + err := readBytes(reader, tmpBuf) if err != nil { return m, err } - if nbr < 4 { return m, io.EOF } mask = make([]byte, 4) copy(mask, tmpBuf) @@ -247,7 +250,7 @@ func (m *Message) check(fragment bool) error{ } /* (2) Waiting fragment but received standalone frame */ - if fragment && m.Type != CONTINUATION { + if fragment && m.Type != CONTINUATION && m.Type != CLOSE && m.Type != PING { return ErrInvalidFragment } @@ -306,4 +309,25 @@ func (m *Message) check(fragment bool) error{ } return nil +} + + + +func readBytes(reader io.Reader, buffer []byte) error { + + var cur, len int = 0, len(buffer) + + // try to read until the full size is read + for cur < len { + + nbread, err := reader.Read(buffer[cur:]) + if err != nil { + return err + } + + cur += nbread + } + + return nil + } \ No newline at end of file