diff --git a/ws/message.go b/ws/message.go index a2f6e2d..1f5c202 100644 --- a/ws/message.go +++ b/ws/message.go @@ -59,8 +59,9 @@ func readMessage(reader io.Reader) (*Message, error){ /* (2) Byte 1: FIN and OpCode */ tmpBuf = make([]byte, 1) - _, err = reader.Read(tmpBuf) - if err != nil { return nil, err } + nbr, err := reader.Read(tmpBuf) + if err != nil { return m, err } + if nbr < 1 { return m, io.EOF } m.Final = bool( tmpBuf[0] & 0x80 == 0x80 ) @@ -68,8 +69,9 @@ func readMessage(reader io.Reader) (*Message, error){ /* (3) Byte 2: Mask and Length[0] */ tmpBuf = make([]byte, 1) - _, err = reader.Read(tmpBuf) - if err != nil { return nil, err } + nbr, err = reader.Read(tmpBuf) + if err != nil { return m, err } + if nbr < 1 { return m, io.EOF } // if mask, byte array not nil if tmpBuf[0] & 0x80 == 0x80 { @@ -83,16 +85,18 @@ func readMessage(reader io.Reader) (*Message, error){ if m.Size == 127 { tmpBuf = make([]byte, 8) - _, err := reader.Read(tmpBuf) - if err != nil { return nil, err } + nbr, err := reader.Read(tmpBuf) + if err != nil { return m, err } + if nbr < 8 { return m, io.EOF } m.Size = uint( binary.BigEndian.Uint64(tmpBuf) ) } else if m.Size == 126 { tmpBuf = make([]byte, 2) - _, err := reader.Read(tmpBuf) - if err != nil { return nil, err } + nbr, err := reader.Read(tmpBuf) + if err != nil { return m, err } + if nbr < 2 { return m, io.EOF } m.Size = uint( binary.BigEndian.Uint16(tmpBuf) ) @@ -102,8 +106,9 @@ func readMessage(reader io.Reader) (*Message, error){ if mask != nil { tmpBuf = make([]byte, 4) - _, err := reader.Read(tmpBuf) - if err != nil { return nil, err } + nbr, err := reader.Read(tmpBuf) + if err != nil { return m, err } + if nbr < 4 { return m, io.EOF } mask = make([]byte, 4) copy(mask, tmpBuf) @@ -113,11 +118,6 @@ func readMessage(reader io.Reader) (*Message, error){ /* (6) Read payload by chunks */ m.Data = make([]byte, int(m.Size)) - // If empty payload - if m.Size <= 0 { - return m, nil - } - cursor = 0 // {1} While we have data to read // @@ -126,7 +126,7 @@ func readMessage(reader io.Reader) (*Message, error){ // {2} Try to read (at least 1 byte) // nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1) if err != nil { - return nil, err + return m, err } // {3} Unmask data // @@ -145,11 +145,13 @@ func readMessage(reader io.Reader) (*Message, error){ } // return error if unmasked frame + // we have to fully read it for read buffer to be clean + err = nil if mask == nil { - return nil, UnmaskedFrameErr + err = UnmaskedFrameErr } - return m, nil + return m, err }