diff --git a/ws/client.go b/ws/client.go index 7a544ff..f549fc7 100644 --- a/ws/client.go +++ b/ws/client.go @@ -129,88 +129,91 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli // reader reads and parses messages from the buffer func clientReader(c *client){ + var frag *Message errorCode := NORMAL clientAck := true + c.io.reading.Add(1) for { - /* if currently closing -> exit */ + /* (1) if currently closing -> exit */ if c.io.closing { fmt.Printf("[reader] killed because closing") break } - /*** Parse message ***/ + /* (2) Parse message */ msg, err := readMessage(c.io.reader) + if err == UnmaskedFrameErr { errorCode = PROTOCOL_ERR - clientAck = false + break + } + if err != nil { break } + + /* (3) Fail on invalid message */ + if !isMessageValid(msg, &errorCode, &clientAck) { break } - if err != nil { - break + /* (4) Ping <-> Pong */ + if msg.Type == PING && c.io.writing { + msg.Final = true + msg.Type = PONG + c.ch.send <- *msg + continue } - /* (4) CLOSE */ - if msg.Type == CLOSE { - // uncomplete code || too long - if msg.Size == 1 || msg.Size > 125 { - errorCode = PROTOCOL_ERR - } - // invalid utf-8 reason - if msg.Size > 2 && !utf8.Valid(msg.Data[2:]) { - errorCode = INVALID_PAYLOAD - } - // invalid code - if msg.Size >= 2 { - cCode := binary.BigEndian.Uint16(msg.Data[0:2]) - if invalidCloseCode(cCode) { - errorCode = PROTOCOL_ERR - } - } - clientAck = false - break - - } - - /* (5) PING size error */ - if msg.Type == PING && msg.Size > 125 { - // fmt.Printf(" [reader] PING payload too big\n") - // fmt.Printf("[reader] PING err\n") - errorCode = PROTOCOL_ERR - break - } - - /* (6) Send PONG */ - if msg.Type == PING { - // fmt.Printf("[reader] PING -> PONG\n") - if c.io.writing { - msg.Final = true - msg.Type = PONG - c.ch.send <- *msg + /* (5) Store first fragmented msg */ + if frag == nil && !msg.Final { + frag = &Message{ + Type: msg.Type, + Final: msg.Final, + Data: msg.Data, + Size: msg.Size, } continue } - /* (7) Invalid UTF8 */ - if msg.Type == TEXT && !utf8.Valid(msg.Data) { - // fmt.Printf(" [reader] invalid utf-8\n") - errorCode = INVALID_PAYLOAD + // unexpected continuation + if msg.Type == CONTINUATION && frag == nil { + errorCode = PROTOCOL_ERR break - } - /* (8) Unknown opcode */ - if msg.Type != TEXT && msg.Type != BINARY { - // fmt.Printf(" [reader] unknown OpCode %d\n", msg.Type) + } + // waiting fragment error + if frag != nil && msg.Type != CONTINUATION { errorCode = PROTOCOL_ERR break } - /* (9) Dispatch to receiver */ - c.ch.receive <- *msg + /* (6) Store fragments */ + if frag != nil { + frag.Final = msg.Final + frag.Size += msg.Size + frag.Data = append(frag.Data, msg.Data...) + + if !frag.Final { + continue + } + + // check message errors + if !isMessageValid(frag, &errorCode, &clientAck) { + frag = nil + break + } + + msg = frag + frag = nil + + } + + /* (7) Dispatch to receiver */ + if msg.Type == TEXT || msg.Type == BINARY { + c.ch.receive <- *msg + } } @@ -296,10 +299,10 @@ func (c *client) close(status MessageError, clientACK bool){ binary.BigEndian.PutUint16(msg.Data, uint16(status)) /* (4) Send message */ - err := msg.Send(c.io.sock) - if err != nil { - fmt.Printf("[close] send error (%s0\n", err) - } + msg.Send(c.io.sock) + // if err != nil { + // fmt.Printf("[close] send error (%s0\n", err) + // } } @@ -310,14 +313,14 @@ func (c *client) close(status MessageError, clientACK bool){ c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond)) /* Wait for message */ - msg, err := readMessage(c.io.reader) - if err != nil || msg.Type != CLOSE { - if err == nil { - fmt.Printf("[close] received OpCode = %d\n", msg.Type) - } else { - fmt.Printf("[close] read error (%v)\n", err) - } - } + readMessage(c.io.reader) + // if err != nil || msg.Type != CLOSE { + // if err == nil { + // fmt.Printf("[close] received OpCode = %d\n", msg.Type) + // } else { + // fmt.Printf("[close] read error (%v)\n", err) + // } + // } // fmt.Printf("[close] received ACK\n") @@ -337,6 +340,74 @@ func (c *client) close(status MessageError, clientACK bool){ +func isMessageValid(m *Message, me *MessageError, ca *bool) bool{ + + /* (1) Too long control frame */ + if m.Type == CLOSE || m.Type == PING || m.Type == PONG { + if m.Size > 125 || !m.Final { + *me = PROTOCOL_ERR + return false + } + } + + + /* (2) Invalid close */ + if m.Type == CLOSE { + + // uncomplete code || too long + if m.Size == 1 { + *me = PROTOCOL_ERR + return false + } + + // invalid utf-8 reason + if m.Size > 2 && !utf8.Valid(m.Data[2:]) { + *me = INVALID_PAYLOAD + return false + } + + // invalid code + if m.Size >= 2 { + cCode := binary.BigEndian.Uint16(m.Data[0:2]) + if invalidCloseCode(cCode) { + *me = PROTOCOL_ERR + return false + } + } + + return false + } + + /* (3) Invalid utf8 text */ + if m.Type == TEXT && !utf8.Valid(m.Data) { + *me = INVALID_PAYLOAD + return false + } + + /* (4) Invalid first fragment (not TEXT nor BINARY) */ + if !m.Final && m.Type != CONTINUATION && m.Type != TEXT && m.Type != BINARY { + *me = PROTOCOL_ERR + return false + } + + /* (5) Invalid OpCode */ + switch m.Type { + case CONTINUATION: + case TEXT: + case BINARY: + case CLOSE: + case PING: + case PONG: + default: + *me = PROTOCOL_ERR + *ca = false + return false + + } + + return true +} + func invalidCloseCode(code uint16) bool{ diff --git a/ws/message.go b/ws/message.go index 1f5c202..ee80dda 100644 --- a/ws/message.go +++ b/ws/message.go @@ -59,9 +59,8 @@ func readMessage(reader io.Reader) (*Message, error){ /* (2) Byte 1: FIN and OpCode */ tmpBuf = make([]byte, 1) - nbr, err := reader.Read(tmpBuf) + _, err = reader.Read(tmpBuf) if err != nil { return m, err } - if nbr < 1 { return m, io.EOF } m.Final = bool( tmpBuf[0] & 0x80 == 0x80 ) @@ -69,9 +68,8 @@ func readMessage(reader io.Reader) (*Message, error){ /* (3) Byte 2: Mask and Length[0] */ tmpBuf = make([]byte, 1) - nbr, err = reader.Read(tmpBuf) + _, err = reader.Read(tmpBuf) if err != nil { return m, err } - if nbr < 1 { return m, io.EOF } // if mask, byte array not nil if tmpBuf[0] & 0x80 == 0x80 {