add RSV bytes error if set + add function 'readBytes(io.Reader, []bytes)' to read until the expected size is met while there is no error (for octet-wise chops)
This commit is contained in:
parent
fff5d068d1
commit
fc9f174434
|
@ -146,9 +146,8 @@ func clientReader(c *client){
|
||||||
/* (2) Parse message */
|
/* (2) Parse message */
|
||||||
msg, err := readMessage(c.io.reader)
|
msg, err := readMessage(c.io.reader)
|
||||||
|
|
||||||
if err == ErrUnmaskedFrame {
|
if err == ErrUnmaskedFrame || err == ErrReservedBits {
|
||||||
closeStatus = PROTOCOL_ERR
|
closeStatus = PROTOCOL_ERR
|
||||||
break
|
|
||||||
}
|
}
|
||||||
if err != nil { break }
|
if err != nil { break }
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ var ErrInvalidSize = fmt.Errorf("Received invalid payload size")
|
||||||
var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload")
|
var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload")
|
||||||
var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status")
|
var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status")
|
||||||
var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode")
|
var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode")
|
||||||
|
var ErrReservedBits = fmt.Errorf("Received reserved bits")
|
||||||
var CloseFrame = fmt.Errorf("Received close Frame")
|
var CloseFrame = fmt.Errorf("Received close Frame")
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,16 +69,21 @@ func readMessage(reader io.Reader) (*Message, error){
|
||||||
|
|
||||||
/* (2) Byte 1: FIN and OpCode */
|
/* (2) Byte 1: FIN and OpCode */
|
||||||
tmpBuf = make([]byte, 1)
|
tmpBuf = make([]byte, 1)
|
||||||
_, err = reader.Read(tmpBuf)
|
err = readBytes(reader, tmpBuf)
|
||||||
if err != nil { return m, err }
|
if err != nil { return m, err }
|
||||||
|
|
||||||
|
|
||||||
|
// check reserved bits
|
||||||
|
if tmpBuf[0] & 0x70 != 0 {
|
||||||
|
return m, ErrReservedBits
|
||||||
|
}
|
||||||
|
|
||||||
m.Final = bool( tmpBuf[0] & 0x80 == 0x80 )
|
m.Final = bool( tmpBuf[0] & 0x80 == 0x80 )
|
||||||
m.Type = MessageType( tmpBuf[0] & 0x0f )
|
m.Type = MessageType( tmpBuf[0] & 0x0f )
|
||||||
|
|
||||||
/* (3) Byte 2: Mask and Length[0] */
|
/* (3) Byte 2: Mask and Length[0] */
|
||||||
tmpBuf = make([]byte, 1)
|
tmpBuf = make([]byte, 1)
|
||||||
_, err = reader.Read(tmpBuf)
|
err = readBytes(reader, tmpBuf)
|
||||||
if err != nil { return m, err }
|
if err != nil { return m, err }
|
||||||
|
|
||||||
// if mask, byte array not nil
|
// if mask, byte array not nil
|
||||||
|
@ -92,18 +98,16 @@ func readMessage(reader io.Reader) (*Message, error){
|
||||||
if m.Size == 127 {
|
if m.Size == 127 {
|
||||||
|
|
||||||
tmpBuf = make([]byte, 8)
|
tmpBuf = make([]byte, 8)
|
||||||
nbr, err := reader.Read(tmpBuf)
|
err := readBytes(reader, tmpBuf)
|
||||||
if err != nil { return m, err }
|
if err != nil { return m, err }
|
||||||
if nbr < 8 { return m, io.EOF }
|
|
||||||
|
|
||||||
m.Size = uint( binary.BigEndian.Uint64(tmpBuf) )
|
m.Size = uint( binary.BigEndian.Uint64(tmpBuf) )
|
||||||
|
|
||||||
} else if m.Size == 126 {
|
} else if m.Size == 126 {
|
||||||
|
|
||||||
tmpBuf = make([]byte, 2)
|
tmpBuf = make([]byte, 2)
|
||||||
nbr, err := reader.Read(tmpBuf)
|
err := readBytes(reader, tmpBuf)
|
||||||
if err != nil { return m, err }
|
if err != nil { return m, err }
|
||||||
if nbr < 2 { return m, io.EOF }
|
|
||||||
|
|
||||||
m.Size = uint( binary.BigEndian.Uint16(tmpBuf) )
|
m.Size = uint( binary.BigEndian.Uint16(tmpBuf) )
|
||||||
|
|
||||||
|
@ -113,9 +117,8 @@ func readMessage(reader io.Reader) (*Message, error){
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
|
|
||||||
tmpBuf = make([]byte, 4)
|
tmpBuf = make([]byte, 4)
|
||||||
nbr, err := reader.Read(tmpBuf)
|
err := readBytes(reader, tmpBuf)
|
||||||
if err != nil { return m, err }
|
if err != nil { return m, err }
|
||||||
if nbr < 4 { return m, io.EOF }
|
|
||||||
|
|
||||||
mask = make([]byte, 4)
|
mask = make([]byte, 4)
|
||||||
copy(mask, tmpBuf)
|
copy(mask, tmpBuf)
|
||||||
|
@ -247,7 +250,7 @@ func (m *Message) check(fragment bool) error{
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Waiting fragment but received standalone frame */
|
/* (2) Waiting fragment but received standalone frame */
|
||||||
if fragment && m.Type != CONTINUATION {
|
if fragment && m.Type != CONTINUATION && m.Type != CLOSE && m.Type != PING {
|
||||||
return ErrInvalidFragment
|
return ErrInvalidFragment
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -307,3 +310,24 @@ func (m *Message) check(fragment bool) error{
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
func readBytes(reader io.Reader, buffer []byte) error {
|
||||||
|
|
||||||
|
var cur, len int = 0, len(buffer)
|
||||||
|
|
||||||
|
// try to read until the full size is read
|
||||||
|
for cur < len {
|
||||||
|
|
||||||
|
nbread, err := reader.Read(buffer[cur:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cur += nbread
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue