add FRAGMENTATION + refactored error management (not perfect but works)

This commit is contained in:
xdrm-brackets 2018-05-07 16:35:40 +02:00
parent b7cd292577
commit 6759c489cc
2 changed files with 137 additions and 68 deletions

View File

@ -129,88 +129,91 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
// reader reads and parses messages from the buffer // reader reads and parses messages from the buffer
func clientReader(c *client){ func clientReader(c *client){
var frag *Message
errorCode := NORMAL errorCode := NORMAL
clientAck := true clientAck := true
c.io.reading.Add(1) c.io.reading.Add(1)
for { for {
/* if currently closing -> exit */ /* (1) if currently closing -> exit */
if c.io.closing { if c.io.closing {
fmt.Printf("[reader] killed because closing") fmt.Printf("[reader] killed because closing")
break break
} }
/*** Parse message ***/ /* (2) Parse message */
msg, err := readMessage(c.io.reader) msg, err := readMessage(c.io.reader)
if err == UnmaskedFrameErr { if err == UnmaskedFrameErr {
errorCode = PROTOCOL_ERR errorCode = PROTOCOL_ERR
clientAck = false break
}
if err != nil { break }
/* (3) Fail on invalid message */
if !isMessageValid(msg, &errorCode, &clientAck) {
break break
} }
if err != nil { /* (4) Ping <-> Pong */
break if msg.Type == PING && c.io.writing {
msg.Final = true
msg.Type = PONG
c.ch.send <- *msg
continue
} }
/* (4) CLOSE */ /* (5) Store first fragmented msg */
if msg.Type == CLOSE { if frag == nil && !msg.Final {
// uncomplete code || too long frag = &Message{
if msg.Size == 1 || msg.Size > 125 { Type: msg.Type,
errorCode = PROTOCOL_ERR Final: msg.Final,
} Data: msg.Data,
// invalid utf-8 reason Size: msg.Size,
if msg.Size > 2 && !utf8.Valid(msg.Data[2:]) {
errorCode = INVALID_PAYLOAD
}
// invalid code
if msg.Size >= 2 {
cCode := binary.BigEndian.Uint16(msg.Data[0:2])
if invalidCloseCode(cCode) {
errorCode = PROTOCOL_ERR
}
}
clientAck = false
break
}
/* (5) PING size error */
if msg.Type == PING && msg.Size > 125 {
// fmt.Printf(" [reader] PING payload too big\n")
// fmt.Printf("[reader] PING err\n")
errorCode = PROTOCOL_ERR
break
}
/* (6) Send PONG */
if msg.Type == PING {
// fmt.Printf("[reader] PING -> PONG\n")
if c.io.writing {
msg.Final = true
msg.Type = PONG
c.ch.send <- *msg
} }
continue continue
} }
/* (7) Invalid UTF8 */ // unexpected continuation
if msg.Type == TEXT && !utf8.Valid(msg.Data) { if msg.Type == CONTINUATION && frag == nil {
// fmt.Printf(" [reader] invalid utf-8\n") errorCode = PROTOCOL_ERR
errorCode = INVALID_PAYLOAD
break break
}
/* (8) Unknown opcode */ }
if msg.Type != TEXT && msg.Type != BINARY { // waiting fragment error
// fmt.Printf(" [reader] unknown OpCode %d\n", msg.Type) if frag != nil && msg.Type != CONTINUATION {
errorCode = PROTOCOL_ERR errorCode = PROTOCOL_ERR
break break
} }
/* (9) Dispatch to receiver */ /* (6) Store fragments */
c.ch.receive <- *msg if frag != nil {
frag.Final = msg.Final
frag.Size += msg.Size
frag.Data = append(frag.Data, msg.Data...)
if !frag.Final {
continue
}
// check message errors
if !isMessageValid(frag, &errorCode, &clientAck) {
frag = nil
break
}
msg = frag
frag = nil
}
/* (7) Dispatch to receiver */
if msg.Type == TEXT || msg.Type == BINARY {
c.ch.receive <- *msg
}
} }
@ -296,10 +299,10 @@ func (c *client) close(status MessageError, clientACK bool){
binary.BigEndian.PutUint16(msg.Data, uint16(status)) binary.BigEndian.PutUint16(msg.Data, uint16(status))
/* (4) Send message */ /* (4) Send message */
err := msg.Send(c.io.sock) msg.Send(c.io.sock)
if err != nil { // if err != nil {
fmt.Printf("[close] send error (%s0\n", err) // fmt.Printf("[close] send error (%s0\n", err)
} // }
} }
@ -310,14 +313,14 @@ func (c *client) close(status MessageError, clientACK bool){
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond)) c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
/* Wait for message */ /* Wait for message */
msg, err := readMessage(c.io.reader) readMessage(c.io.reader)
if err != nil || msg.Type != CLOSE { // if err != nil || msg.Type != CLOSE {
if err == nil { // if err == nil {
fmt.Printf("[close] received OpCode = %d\n", msg.Type) // fmt.Printf("[close] received OpCode = %d\n", msg.Type)
} else { // } else {
fmt.Printf("[close] read error (%v)\n", err) // fmt.Printf("[close] read error (%v)\n", err)
} // }
} // }
// fmt.Printf("[close] received ACK\n") // fmt.Printf("[close] received ACK\n")
@ -337,6 +340,74 @@ func (c *client) close(status MessageError, clientACK bool){
func isMessageValid(m *Message, me *MessageError, ca *bool) bool{
/* (1) Too long control frame */
if m.Type == CLOSE || m.Type == PING || m.Type == PONG {
if m.Size > 125 || !m.Final {
*me = PROTOCOL_ERR
return false
}
}
/* (2) Invalid close */
if m.Type == CLOSE {
// uncomplete code || too long
if m.Size == 1 {
*me = PROTOCOL_ERR
return false
}
// invalid utf-8 reason
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
*me = INVALID_PAYLOAD
return false
}
// invalid code
if m.Size >= 2 {
cCode := binary.BigEndian.Uint16(m.Data[0:2])
if invalidCloseCode(cCode) {
*me = PROTOCOL_ERR
return false
}
}
return false
}
/* (3) Invalid utf8 text */
if m.Type == TEXT && !utf8.Valid(m.Data) {
*me = INVALID_PAYLOAD
return false
}
/* (4) Invalid first fragment (not TEXT nor BINARY) */
if !m.Final && m.Type != CONTINUATION && m.Type != TEXT && m.Type != BINARY {
*me = PROTOCOL_ERR
return false
}
/* (5) Invalid OpCode */
switch m.Type {
case CONTINUATION:
case TEXT:
case BINARY:
case CLOSE:
case PING:
case PONG:
default:
*me = PROTOCOL_ERR
*ca = false
return false
}
return true
}
func invalidCloseCode(code uint16) bool{ func invalidCloseCode(code uint16) bool{

View File

@ -59,9 +59,8 @@ func readMessage(reader io.Reader) (*Message, error){
/* (2) Byte 1: FIN and OpCode */ /* (2) Byte 1: FIN and OpCode */
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
nbr, err := reader.Read(tmpBuf) _, err = reader.Read(tmpBuf)
if err != nil { return m, err } if err != nil { return m, err }
if nbr < 1 { return m, io.EOF }
m.Final = bool( tmpBuf[0] & 0x80 == 0x80 ) m.Final = bool( tmpBuf[0] & 0x80 == 0x80 )
@ -69,9 +68,8 @@ func readMessage(reader io.Reader) (*Message, error){
/* (3) Byte 2: Mask and Length[0] */ /* (3) Byte 2: Mask and Length[0] */
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
nbr, err = reader.Read(tmpBuf) _, err = reader.Read(tmpBuf)
if err != nil { return m, err } if err != nil { return m, err }
if nbr < 1 { return m, io.EOF }
// if mask, byte array not nil // if mask, byte array not nil
if tmpBuf[0] & 0x80 == 0x80 { if tmpBuf[0] & 0x80 == 0x80 {