add message check as 'function -> method'

This commit is contained in:
xdrm-brackets 2018-05-08 11:15:01 +02:00
parent 1891bdff1d
commit ca3b83abee
3 changed files with 110 additions and 109 deletions

View File

@ -1,7 +1,6 @@
package ws package ws
import ( import (
"unicode/utf8"
"time" "time"
"sync" "sync"
"bufio" "bufio"
@ -131,7 +130,7 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
func clientReader(c *client){ func clientReader(c *client){
var frag *Message var frag *Message
errorCode := NORMAL closeStatus := NORMAL
clientAck := true clientAck := true
c.io.reading.Add(1) c.io.reading.Add(1)
@ -147,14 +146,22 @@ func clientReader(c *client){
/* (2) Parse message */ /* (2) Parse message */
msg, err := readMessage(c.io.reader) msg, err := readMessage(c.io.reader)
if err == UnmaskedFrameErr { if err == ErrUnmaskedFrame {
errorCode = PROTOCOL_ERR closeStatus = PROTOCOL_ERR
break break
} }
if err != nil { break } if err != nil { break }
/* (3) Fail on invalid message */ /* (3) Fail on invalid message */
if !isMessageValid(msg, &errorCode, &clientAck) { // s0 := time.Now().UnixNano()
msgErr := msg.check(frag != nil)
// fmt.Printf("> %.3f us\n", float64(time.Now().UnixNano()-s0)/1e3)
if msgErr == ErrInvalidPayload {
closeStatus = INVALID_PAYLOAD
break
} else if msgErr != nil {
closeStatus = PROTOCOL_ERR
break break
} }
@ -166,7 +173,7 @@ func clientReader(c *client){
continue continue
} }
/* (5) Store first fragmented msg */ /* (5) Store first fragment */
if frag == nil && !msg.Final { if frag == nil && !msg.Final {
frag = &Message{ frag = &Message{
Type: msg.Type, Type: msg.Type,
@ -177,31 +184,23 @@ func clientReader(c *client){
continue continue
} }
// unexpected continuation
if msg.Type == CONTINUATION && frag == nil {
errorCode = PROTOCOL_ERR
break
}
// waiting fragment error
if frag != nil && msg.Type != CONTINUATION {
errorCode = PROTOCOL_ERR
break
}
/* (6) Store fragments */ /* (6) Store fragments */
if frag != nil { if frag != nil {
frag.Final = msg.Final frag.Final = msg.Final
frag.Size += msg.Size frag.Size += msg.Size
frag.Data = append(frag.Data, msg.Data...) frag.Data = append(frag.Data, msg.Data...)
if !frag.Final { if !frag.Final { // continue if not last fragment
continue continue
} }
// check message errors // check message errors
if !isMessageValid(frag, &errorCode, &clientAck) { fragErr := frag.check(false)
frag = nil if fragErr == ErrInvalidPayload {
closeStatus = INVALID_PAYLOAD
break
} else if fragErr != nil {
closeStatus = PROTOCOL_ERR
break break
} }
@ -222,7 +221,7 @@ func clientReader(c *client){
/* (8) close channel (if not already done) */ /* (8) close channel (if not already done) */
// fmt.Printf("[reader] end\n") // fmt.Printf("[reader] end\n")
c.close(errorCode, clientAck) c.close(closeStatus, clientAck)
} }
@ -337,85 +336,3 @@ func (c *client) close(status MessageError, clientACK bool){
} }
func isMessageValid(m *Message, me *MessageError, ca *bool) bool{
/* (1) Too long control frame */
if m.Type == CLOSE || m.Type == PING || m.Type == PONG {
if m.Size > 125 || !m.Final {
*me = PROTOCOL_ERR
return false
}
}
/* (2) Invalid close */
if m.Type == CLOSE {
// uncomplete code || too long
if m.Size == 1 {
*me = PROTOCOL_ERR
return false
}
// invalid utf-8 reason
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
*me = INVALID_PAYLOAD
return false
}
// invalid code
if m.Size >= 2 {
cCode := binary.BigEndian.Uint16(m.Data[0:2])
if invalidCloseCode(cCode) {
*me = PROTOCOL_ERR
return false
}
}
return false
}
/* (3) Invalid utf8 text */
if m.Type == TEXT && !utf8.Valid(m.Data) {
*me = INVALID_PAYLOAD
return false
}
/* (4) Invalid first fragment (not TEXT nor BINARY) */
if !m.Final && m.Type != CONTINUATION && m.Type != TEXT && m.Type != BINARY {
*me = PROTOCOL_ERR
return false
}
/* (5) Invalid OpCode */
switch m.Type {
case CONTINUATION:
case TEXT:
case BINARY:
case CLOSE:
case PING:
case PONG:
default:
*me = PROTOCOL_ERR
*ca = false
return false
}
return true
}
func invalidCloseCode(code uint16) bool{
tolow := code < 1000
badrange1 := code >= 1004 && code <= 1006
badrange2 := code >= 1012 && code <= 1016
badspecific := code == 1100 || code == 2000 || code == 2999
return tolow || badrange1 || badrange2 || badspecific
}

View File

@ -1,12 +1,20 @@
package ws package ws
import ( import (
"unicode/utf8"
"fmt" "fmt"
"io" "io"
"encoding/binary" "encoding/binary"
) )
var UnmaskedFrameErr = fmt.Errorf("Received unmasked frame") var ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame")
var ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long")
var ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation")
var ErrInvalidSize = fmt.Errorf("Received invalid payload size")
var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload")
var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status")
var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode")
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
const maximumHeaderSize = 1 + 1 + 8 + 4 const maximumHeaderSize = 1 + 1 + 8 + 4
@ -146,7 +154,7 @@ func readMessage(reader io.Reader) (*Message, error){
// we have to fully read it for read buffer to be clean // we have to fully read it for read buffer to be clean
err = nil err = nil
if mask == nil { if mask == nil {
err = UnmaskedFrameErr err = ErrUnmaskedFrame
} }
return m, err return m, err
@ -220,5 +228,81 @@ func (m Message) Send(writer io.Writer) error {
} }
return nil
}
// Check for message errors with:
// (m) the current message
// (fragment) whether there is a fragment in construction
// returns the message error
func (m *Message) check(fragment bool) error{
/* (1) Invalid first fragment (not TEXT nor BINARY) */
if !m.Final && !fragment && m.Type != TEXT && m.Type != BINARY {
return ErrInvalidFragment
}
/* (2) Waiting fragment but received standalone frame */
if fragment && m.Type != CONTINUATION {
return ErrInvalidFragment
}
/* (3) Control frame too long */
if (m.Type == CLOSE || m.Type == PING || m.Type == PONG) && (m.Size > 125 || !m.Final) {
return ErrTooLongControlFrame
}
switch m.Type {
case CONTINUATION:
// unexpected continuation
if !fragment {
return ErrInvalidFragment
}
return nil
case TEXT:
if !utf8.Valid(m.Data) {
return ErrInvalidPayload
}
return nil
case BINARY:
return nil
case CLOSE:
// incomplete code
if m.Size == 1 {
return ErrInvalidCloseStatus
}
// invalid utf-8 reason
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
return ErrInvalidPayload
}
// invalid code
if m.Size >= 2 {
c := binary.BigEndian.Uint16(m.Data[0:2])
if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 {
return ErrInvalidCloseStatus
}
}
return nil
case PING:
return nil
case PONG:
return nil
default:
return ErrInvalidOpCode
}
return nil return nil
} }

View File

@ -19,7 +19,7 @@ func TestSimpleMessageReading(t *testing.T) {
"must fail on unmasked frame", "must fail on unmasked frame",
[]byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f}, []byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f},
Message{}, Message{},
UnmaskedFrameErr, ErrUnmaskedFrame,
}, },
{ // FIN ; TEXT ; hello { // FIN ; TEXT ; hello
"simple hello text message", "simple hello text message",
@ -175,8 +175,8 @@ func TestReadEOF(t *testing.T) {
return return
} }
if tc.unmaskedError && err != UnmaskedFrameErr { if tc.unmaskedError && err != ErrUnmaskedFrame {
t.Errorf("Expected UnmaskedFrameError, got %v", err) t.Errorf("Expected UnmaskedFrameor, got %v", err)
} }
if got.Size != 0x00 { if got.Size != 0x00 {