diff --git a/ws/client.go b/ws/client.go index f549fc7..b09dc2c 100644 --- a/ws/client.go +++ b/ws/client.go @@ -1,7 +1,6 @@ package ws import ( - "unicode/utf8" "time" "sync" "bufio" @@ -131,7 +130,7 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli func clientReader(c *client){ var frag *Message - errorCode := NORMAL + closeStatus := NORMAL clientAck := true c.io.reading.Add(1) @@ -147,14 +146,22 @@ func clientReader(c *client){ /* (2) Parse message */ msg, err := readMessage(c.io.reader) - if err == UnmaskedFrameErr { - errorCode = PROTOCOL_ERR + if err == ErrUnmaskedFrame { + closeStatus = PROTOCOL_ERR break } if err != nil { break } /* (3) Fail on invalid message */ - if !isMessageValid(msg, &errorCode, &clientAck) { + // s0 := time.Now().UnixNano() + msgErr := msg.check(frag != nil) + // fmt.Printf("> %.3f us\n", float64(time.Now().UnixNano()-s0)/1e3) + + if msgErr == ErrInvalidPayload { + closeStatus = INVALID_PAYLOAD + break + } else if msgErr != nil { + closeStatus = PROTOCOL_ERR break } @@ -166,7 +173,7 @@ func clientReader(c *client){ continue } - /* (5) Store first fragmented msg */ + /* (5) Store first fragment */ if frag == nil && !msg.Final { frag = &Message{ Type: msg.Type, @@ -177,31 +184,23 @@ func clientReader(c *client){ continue } - // unexpected continuation - if msg.Type == CONTINUATION && frag == nil { - errorCode = PROTOCOL_ERR - break - - } - // waiting fragment error - if frag != nil && msg.Type != CONTINUATION { - errorCode = PROTOCOL_ERR - break - } - /* (6) Store fragments */ if frag != nil { frag.Final = msg.Final frag.Size += msg.Size frag.Data = append(frag.Data, msg.Data...) - if !frag.Final { + if !frag.Final { // continue if not last fragment continue } // check message errors - if !isMessageValid(frag, &errorCode, &clientAck) { - frag = nil + fragErr := frag.check(false) + if fragErr == ErrInvalidPayload { + closeStatus = INVALID_PAYLOAD + break + } else if fragErr != nil { + closeStatus = PROTOCOL_ERR break } @@ -222,7 +221,7 @@ func clientReader(c *client){ /* (8) close channel (if not already done) */ // fmt.Printf("[reader] end\n") - c.close(errorCode, clientAck) + c.close(closeStatus, clientAck) } @@ -337,85 +336,3 @@ func (c *client) close(status MessageError, clientACK bool){ } - - - -func isMessageValid(m *Message, me *MessageError, ca *bool) bool{ - - /* (1) Too long control frame */ - if m.Type == CLOSE || m.Type == PING || m.Type == PONG { - if m.Size > 125 || !m.Final { - *me = PROTOCOL_ERR - return false - } - } - - - /* (2) Invalid close */ - if m.Type == CLOSE { - - // uncomplete code || too long - if m.Size == 1 { - *me = PROTOCOL_ERR - return false - } - - // invalid utf-8 reason - if m.Size > 2 && !utf8.Valid(m.Data[2:]) { - *me = INVALID_PAYLOAD - return false - } - - // invalid code - if m.Size >= 2 { - cCode := binary.BigEndian.Uint16(m.Data[0:2]) - if invalidCloseCode(cCode) { - *me = PROTOCOL_ERR - return false - } - } - - return false - } - - /* (3) Invalid utf8 text */ - if m.Type == TEXT && !utf8.Valid(m.Data) { - *me = INVALID_PAYLOAD - return false - } - - /* (4) Invalid first fragment (not TEXT nor BINARY) */ - if !m.Final && m.Type != CONTINUATION && m.Type != TEXT && m.Type != BINARY { - *me = PROTOCOL_ERR - return false - } - - /* (5) Invalid OpCode */ - switch m.Type { - case CONTINUATION: - case TEXT: - case BINARY: - case CLOSE: - case PING: - case PONG: - default: - *me = PROTOCOL_ERR - *ca = false - return false - - } - - return true -} - - -func invalidCloseCode(code uint16) bool{ - - tolow := code < 1000 - badrange1 := code >= 1004 && code <= 1006 - badrange2 := code >= 1012 && code <= 1016 - badspecific := code == 1100 || code == 2000 || code == 2999 - - return tolow || badrange1 || badrange2 || badspecific - -} \ No newline at end of file diff --git a/ws/message.go b/ws/message.go index 6f300b2..6cb54f9 100644 --- a/ws/message.go +++ b/ws/message.go @@ -1,12 +1,20 @@ package ws import ( + "unicode/utf8" "fmt" "io" "encoding/binary" ) -var UnmaskedFrameErr = fmt.Errorf("Received unmasked frame") +var ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame") +var ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long") +var ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation") +var ErrInvalidSize = fmt.Errorf("Received invalid payload size") +var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload") +var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status") +var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode") + // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask const maximumHeaderSize = 1 + 1 + 8 + 4 @@ -146,7 +154,7 @@ func readMessage(reader io.Reader) (*Message, error){ // we have to fully read it for read buffer to be clean err = nil if mask == nil { - err = UnmaskedFrameErr + err = ErrUnmaskedFrame } return m, err @@ -220,5 +228,81 @@ func (m Message) Send(writer io.Writer) error { } + return nil +} + + + + +// Check for message errors with: +// (m) the current message +// (fragment) whether there is a fragment in construction +// returns the message error +func (m *Message) check(fragment bool) error{ + + /* (1) Invalid first fragment (not TEXT nor BINARY) */ + if !m.Final && !fragment && m.Type != TEXT && m.Type != BINARY { + return ErrInvalidFragment + } + + /* (2) Waiting fragment but received standalone frame */ + if fragment && m.Type != CONTINUATION { + return ErrInvalidFragment + } + + /* (3) Control frame too long */ + if (m.Type == CLOSE || m.Type == PING || m.Type == PONG) && (m.Size > 125 || !m.Final) { + return ErrTooLongControlFrame + } + + switch m.Type { + case CONTINUATION: + // unexpected continuation + if !fragment { + return ErrInvalidFragment + } + return nil + + case TEXT: + if !utf8.Valid(m.Data) { + return ErrInvalidPayload + } + return nil + + case BINARY: + return nil + + case CLOSE: + // incomplete code + if m.Size == 1 { + return ErrInvalidCloseStatus + } + + // invalid utf-8 reason + if m.Size > 2 && !utf8.Valid(m.Data[2:]) { + return ErrInvalidPayload + } + + // invalid code + if m.Size >= 2 { + c := binary.BigEndian.Uint16(m.Data[0:2]) + + if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 { + return ErrInvalidCloseStatus + } + } + return nil + + case PING: + return nil + + case PONG: + return nil + + default: + return ErrInvalidOpCode + + } + return nil } \ No newline at end of file diff --git a/ws/message_test.go b/ws/message_test.go index 9e8db72..bf1c2c6 100644 --- a/ws/message_test.go +++ b/ws/message_test.go @@ -19,7 +19,7 @@ func TestSimpleMessageReading(t *testing.T) { "must fail on unmasked frame", []byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f}, Message{}, - UnmaskedFrameErr, + ErrUnmaskedFrame, }, { // FIN ; TEXT ; hello "simple hello text message", @@ -175,8 +175,8 @@ func TestReadEOF(t *testing.T) { return } - if tc.unmaskedError && err != UnmaskedFrameErr { - t.Errorf("Expected UnmaskedFrameError, got %v", err) + if tc.unmaskedError && err != ErrUnmaskedFrame { + t.Errorf("Expected UnmaskedFrameor, got %v", err) } if got.Size != 0x00 {