add message check as 'function -> method'
This commit is contained in:
parent
1891bdff1d
commit
ca3b83abee
125
ws/client.go
125
ws/client.go
|
@ -1,7 +1,6 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"unicode/utf8"
|
||||
"time"
|
||||
"sync"
|
||||
"bufio"
|
||||
|
@ -131,7 +130,7 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
|||
func clientReader(c *client){
|
||||
var frag *Message
|
||||
|
||||
errorCode := NORMAL
|
||||
closeStatus := NORMAL
|
||||
clientAck := true
|
||||
|
||||
c.io.reading.Add(1)
|
||||
|
@ -147,14 +146,22 @@ func clientReader(c *client){
|
|||
/* (2) Parse message */
|
||||
msg, err := readMessage(c.io.reader)
|
||||
|
||||
if err == UnmaskedFrameErr {
|
||||
errorCode = PROTOCOL_ERR
|
||||
if err == ErrUnmaskedFrame {
|
||||
closeStatus = PROTOCOL_ERR
|
||||
break
|
||||
}
|
||||
if err != nil { break }
|
||||
|
||||
/* (3) Fail on invalid message */
|
||||
if !isMessageValid(msg, &errorCode, &clientAck) {
|
||||
// s0 := time.Now().UnixNano()
|
||||
msgErr := msg.check(frag != nil)
|
||||
// fmt.Printf("> %.3f us\n", float64(time.Now().UnixNano()-s0)/1e3)
|
||||
|
||||
if msgErr == ErrInvalidPayload {
|
||||
closeStatus = INVALID_PAYLOAD
|
||||
break
|
||||
} else if msgErr != nil {
|
||||
closeStatus = PROTOCOL_ERR
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -166,7 +173,7 @@ func clientReader(c *client){
|
|||
continue
|
||||
}
|
||||
|
||||
/* (5) Store first fragmented msg */
|
||||
/* (5) Store first fragment */
|
||||
if frag == nil && !msg.Final {
|
||||
frag = &Message{
|
||||
Type: msg.Type,
|
||||
|
@ -177,31 +184,23 @@ func clientReader(c *client){
|
|||
continue
|
||||
}
|
||||
|
||||
// unexpected continuation
|
||||
if msg.Type == CONTINUATION && frag == nil {
|
||||
errorCode = PROTOCOL_ERR
|
||||
break
|
||||
|
||||
}
|
||||
// waiting fragment error
|
||||
if frag != nil && msg.Type != CONTINUATION {
|
||||
errorCode = PROTOCOL_ERR
|
||||
break
|
||||
}
|
||||
|
||||
/* (6) Store fragments */
|
||||
if frag != nil {
|
||||
frag.Final = msg.Final
|
||||
frag.Size += msg.Size
|
||||
frag.Data = append(frag.Data, msg.Data...)
|
||||
|
||||
if !frag.Final {
|
||||
if !frag.Final { // continue if not last fragment
|
||||
continue
|
||||
}
|
||||
|
||||
// check message errors
|
||||
if !isMessageValid(frag, &errorCode, &clientAck) {
|
||||
frag = nil
|
||||
fragErr := frag.check(false)
|
||||
if fragErr == ErrInvalidPayload {
|
||||
closeStatus = INVALID_PAYLOAD
|
||||
break
|
||||
} else if fragErr != nil {
|
||||
closeStatus = PROTOCOL_ERR
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -222,7 +221,7 @@ func clientReader(c *client){
|
|||
|
||||
/* (8) close channel (if not already done) */
|
||||
// fmt.Printf("[reader] end\n")
|
||||
c.close(errorCode, clientAck)
|
||||
c.close(closeStatus, clientAck)
|
||||
|
||||
}
|
||||
|
||||
|
@ -337,85 +336,3 @@ func (c *client) close(status MessageError, clientACK bool){
|
|||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
func isMessageValid(m *Message, me *MessageError, ca *bool) bool{
|
||||
|
||||
/* (1) Too long control frame */
|
||||
if m.Type == CLOSE || m.Type == PING || m.Type == PONG {
|
||||
if m.Size > 125 || !m.Final {
|
||||
*me = PROTOCOL_ERR
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* (2) Invalid close */
|
||||
if m.Type == CLOSE {
|
||||
|
||||
// uncomplete code || too long
|
||||
if m.Size == 1 {
|
||||
*me = PROTOCOL_ERR
|
||||
return false
|
||||
}
|
||||
|
||||
// invalid utf-8 reason
|
||||
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
|
||||
*me = INVALID_PAYLOAD
|
||||
return false
|
||||
}
|
||||
|
||||
// invalid code
|
||||
if m.Size >= 2 {
|
||||
cCode := binary.BigEndian.Uint16(m.Data[0:2])
|
||||
if invalidCloseCode(cCode) {
|
||||
*me = PROTOCOL_ERR
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/* (3) Invalid utf8 text */
|
||||
if m.Type == TEXT && !utf8.Valid(m.Data) {
|
||||
*me = INVALID_PAYLOAD
|
||||
return false
|
||||
}
|
||||
|
||||
/* (4) Invalid first fragment (not TEXT nor BINARY) */
|
||||
if !m.Final && m.Type != CONTINUATION && m.Type != TEXT && m.Type != BINARY {
|
||||
*me = PROTOCOL_ERR
|
||||
return false
|
||||
}
|
||||
|
||||
/* (5) Invalid OpCode */
|
||||
switch m.Type {
|
||||
case CONTINUATION:
|
||||
case TEXT:
|
||||
case BINARY:
|
||||
case CLOSE:
|
||||
case PING:
|
||||
case PONG:
|
||||
default:
|
||||
*me = PROTOCOL_ERR
|
||||
*ca = false
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
func invalidCloseCode(code uint16) bool{
|
||||
|
||||
tolow := code < 1000
|
||||
badrange1 := code >= 1004 && code <= 1006
|
||||
badrange2 := code >= 1012 && code <= 1016
|
||||
badspecific := code == 1100 || code == 2000 || code == 2999
|
||||
|
||||
return tolow || badrange1 || badrange2 || badspecific
|
||||
|
||||
}
|
|
@ -1,12 +1,20 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"unicode/utf8"
|
||||
"fmt"
|
||||
"io"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
var UnmaskedFrameErr = fmt.Errorf("Received unmasked frame")
|
||||
var ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame")
|
||||
var ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long")
|
||||
var ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation")
|
||||
var ErrInvalidSize = fmt.Errorf("Received invalid payload size")
|
||||
var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload")
|
||||
var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status")
|
||||
var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode")
|
||||
|
||||
|
||||
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
|
||||
const maximumHeaderSize = 1 + 1 + 8 + 4
|
||||
|
@ -146,7 +154,7 @@ func readMessage(reader io.Reader) (*Message, error){
|
|||
// we have to fully read it for read buffer to be clean
|
||||
err = nil
|
||||
if mask == nil {
|
||||
err = UnmaskedFrameErr
|
||||
err = ErrUnmaskedFrame
|
||||
}
|
||||
|
||||
return m, err
|
||||
|
@ -222,3 +230,79 @@ func (m Message) Send(writer io.Writer) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// Check for message errors with:
|
||||
// (m) the current message
|
||||
// (fragment) whether there is a fragment in construction
|
||||
// returns the message error
|
||||
func (m *Message) check(fragment bool) error{
|
||||
|
||||
/* (1) Invalid first fragment (not TEXT nor BINARY) */
|
||||
if !m.Final && !fragment && m.Type != TEXT && m.Type != BINARY {
|
||||
return ErrInvalidFragment
|
||||
}
|
||||
|
||||
/* (2) Waiting fragment but received standalone frame */
|
||||
if fragment && m.Type != CONTINUATION {
|
||||
return ErrInvalidFragment
|
||||
}
|
||||
|
||||
/* (3) Control frame too long */
|
||||
if (m.Type == CLOSE || m.Type == PING || m.Type == PONG) && (m.Size > 125 || !m.Final) {
|
||||
return ErrTooLongControlFrame
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case CONTINUATION:
|
||||
// unexpected continuation
|
||||
if !fragment {
|
||||
return ErrInvalidFragment
|
||||
}
|
||||
return nil
|
||||
|
||||
case TEXT:
|
||||
if !utf8.Valid(m.Data) {
|
||||
return ErrInvalidPayload
|
||||
}
|
||||
return nil
|
||||
|
||||
case BINARY:
|
||||
return nil
|
||||
|
||||
case CLOSE:
|
||||
// incomplete code
|
||||
if m.Size == 1 {
|
||||
return ErrInvalidCloseStatus
|
||||
}
|
||||
|
||||
// invalid utf-8 reason
|
||||
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
|
||||
return ErrInvalidPayload
|
||||
}
|
||||
|
||||
// invalid code
|
||||
if m.Size >= 2 {
|
||||
c := binary.BigEndian.Uint16(m.Data[0:2])
|
||||
|
||||
if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 {
|
||||
return ErrInvalidCloseStatus
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case PING:
|
||||
return nil
|
||||
|
||||
case PONG:
|
||||
return nil
|
||||
|
||||
default:
|
||||
return ErrInvalidOpCode
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -19,7 +19,7 @@ func TestSimpleMessageReading(t *testing.T) {
|
|||
"must fail on unmasked frame",
|
||||
[]byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f},
|
||||
Message{},
|
||||
UnmaskedFrameErr,
|
||||
ErrUnmaskedFrame,
|
||||
},
|
||||
{ // FIN ; TEXT ; hello
|
||||
"simple hello text message",
|
||||
|
@ -175,8 +175,8 @@ func TestReadEOF(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
if tc.unmaskedError && err != UnmaskedFrameErr {
|
||||
t.Errorf("Expected UnmaskedFrameError, got %v", err)
|
||||
if tc.unmaskedError && err != ErrUnmaskedFrame {
|
||||
t.Errorf("Expected UnmaskedFrameor, got %v", err)
|
||||
}
|
||||
|
||||
if got.Size != 0x00 {
|
||||
|
|
Loading…
Reference in New Issue