add message check as 'function -> method'
This commit is contained in:
parent
1891bdff1d
commit
ca3b83abee
125
ws/client.go
125
ws/client.go
|
@ -1,7 +1,6 @@
|
||||||
package ws
|
package ws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"unicode/utf8"
|
|
||||||
"time"
|
"time"
|
||||||
"sync"
|
"sync"
|
||||||
"bufio"
|
"bufio"
|
||||||
|
@ -131,7 +130,7 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
||||||
func clientReader(c *client){
|
func clientReader(c *client){
|
||||||
var frag *Message
|
var frag *Message
|
||||||
|
|
||||||
errorCode := NORMAL
|
closeStatus := NORMAL
|
||||||
clientAck := true
|
clientAck := true
|
||||||
|
|
||||||
c.io.reading.Add(1)
|
c.io.reading.Add(1)
|
||||||
|
@ -147,14 +146,22 @@ func clientReader(c *client){
|
||||||
/* (2) Parse message */
|
/* (2) Parse message */
|
||||||
msg, err := readMessage(c.io.reader)
|
msg, err := readMessage(c.io.reader)
|
||||||
|
|
||||||
if err == UnmaskedFrameErr {
|
if err == ErrUnmaskedFrame {
|
||||||
errorCode = PROTOCOL_ERR
|
closeStatus = PROTOCOL_ERR
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil { break }
|
if err != nil { break }
|
||||||
|
|
||||||
/* (3) Fail on invalid message */
|
/* (3) Fail on invalid message */
|
||||||
if !isMessageValid(msg, &errorCode, &clientAck) {
|
// s0 := time.Now().UnixNano()
|
||||||
|
msgErr := msg.check(frag != nil)
|
||||||
|
// fmt.Printf("> %.3f us\n", float64(time.Now().UnixNano()-s0)/1e3)
|
||||||
|
|
||||||
|
if msgErr == ErrInvalidPayload {
|
||||||
|
closeStatus = INVALID_PAYLOAD
|
||||||
|
break
|
||||||
|
} else if msgErr != nil {
|
||||||
|
closeStatus = PROTOCOL_ERR
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,7 +173,7 @@ func clientReader(c *client){
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (5) Store first fragmented msg */
|
/* (5) Store first fragment */
|
||||||
if frag == nil && !msg.Final {
|
if frag == nil && !msg.Final {
|
||||||
frag = &Message{
|
frag = &Message{
|
||||||
Type: msg.Type,
|
Type: msg.Type,
|
||||||
|
@ -177,31 +184,23 @@ func clientReader(c *client){
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// unexpected continuation
|
|
||||||
if msg.Type == CONTINUATION && frag == nil {
|
|
||||||
errorCode = PROTOCOL_ERR
|
|
||||||
break
|
|
||||||
|
|
||||||
}
|
|
||||||
// waiting fragment error
|
|
||||||
if frag != nil && msg.Type != CONTINUATION {
|
|
||||||
errorCode = PROTOCOL_ERR
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (6) Store fragments */
|
/* (6) Store fragments */
|
||||||
if frag != nil {
|
if frag != nil {
|
||||||
frag.Final = msg.Final
|
frag.Final = msg.Final
|
||||||
frag.Size += msg.Size
|
frag.Size += msg.Size
|
||||||
frag.Data = append(frag.Data, msg.Data...)
|
frag.Data = append(frag.Data, msg.Data...)
|
||||||
|
|
||||||
if !frag.Final {
|
if !frag.Final { // continue if not last fragment
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// check message errors
|
// check message errors
|
||||||
if !isMessageValid(frag, &errorCode, &clientAck) {
|
fragErr := frag.check(false)
|
||||||
frag = nil
|
if fragErr == ErrInvalidPayload {
|
||||||
|
closeStatus = INVALID_PAYLOAD
|
||||||
|
break
|
||||||
|
} else if fragErr != nil {
|
||||||
|
closeStatus = PROTOCOL_ERR
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,7 +221,7 @@ func clientReader(c *client){
|
||||||
|
|
||||||
/* (8) close channel (if not already done) */
|
/* (8) close channel (if not already done) */
|
||||||
// fmt.Printf("[reader] end\n")
|
// fmt.Printf("[reader] end\n")
|
||||||
c.close(errorCode, clientAck)
|
c.close(closeStatus, clientAck)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -337,85 +336,3 @@ func (c *client) close(status MessageError, clientACK bool){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func isMessageValid(m *Message, me *MessageError, ca *bool) bool{
|
|
||||||
|
|
||||||
/* (1) Too long control frame */
|
|
||||||
if m.Type == CLOSE || m.Type == PING || m.Type == PONG {
|
|
||||||
if m.Size > 125 || !m.Final {
|
|
||||||
*me = PROTOCOL_ERR
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* (2) Invalid close */
|
|
||||||
if m.Type == CLOSE {
|
|
||||||
|
|
||||||
// uncomplete code || too long
|
|
||||||
if m.Size == 1 {
|
|
||||||
*me = PROTOCOL_ERR
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// invalid utf-8 reason
|
|
||||||
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
|
|
||||||
*me = INVALID_PAYLOAD
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// invalid code
|
|
||||||
if m.Size >= 2 {
|
|
||||||
cCode := binary.BigEndian.Uint16(m.Data[0:2])
|
|
||||||
if invalidCloseCode(cCode) {
|
|
||||||
*me = PROTOCOL_ERR
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (3) Invalid utf8 text */
|
|
||||||
if m.Type == TEXT && !utf8.Valid(m.Data) {
|
|
||||||
*me = INVALID_PAYLOAD
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (4) Invalid first fragment (not TEXT nor BINARY) */
|
|
||||||
if !m.Final && m.Type != CONTINUATION && m.Type != TEXT && m.Type != BINARY {
|
|
||||||
*me = PROTOCOL_ERR
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (5) Invalid OpCode */
|
|
||||||
switch m.Type {
|
|
||||||
case CONTINUATION:
|
|
||||||
case TEXT:
|
|
||||||
case BINARY:
|
|
||||||
case CLOSE:
|
|
||||||
case PING:
|
|
||||||
case PONG:
|
|
||||||
default:
|
|
||||||
*me = PROTOCOL_ERR
|
|
||||||
*ca = false
|
|
||||||
return false
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
func invalidCloseCode(code uint16) bool{
|
|
||||||
|
|
||||||
tolow := code < 1000
|
|
||||||
badrange1 := code >= 1004 && code <= 1006
|
|
||||||
badrange2 := code >= 1012 && code <= 1016
|
|
||||||
badspecific := code == 1100 || code == 2000 || code == 2999
|
|
||||||
|
|
||||||
return tolow || badrange1 || badrange2 || badspecific
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,12 +1,20 @@
|
||||||
package ws
|
package ws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"unicode/utf8"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
)
|
)
|
||||||
|
|
||||||
var UnmaskedFrameErr = fmt.Errorf("Received unmasked frame")
|
var ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame")
|
||||||
|
var ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long")
|
||||||
|
var ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation")
|
||||||
|
var ErrInvalidSize = fmt.Errorf("Received invalid payload size")
|
||||||
|
var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload")
|
||||||
|
var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status")
|
||||||
|
var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode")
|
||||||
|
|
||||||
|
|
||||||
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
|
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
|
||||||
const maximumHeaderSize = 1 + 1 + 8 + 4
|
const maximumHeaderSize = 1 + 1 + 8 + 4
|
||||||
|
@ -146,7 +154,7 @@ func readMessage(reader io.Reader) (*Message, error){
|
||||||
// we have to fully read it for read buffer to be clean
|
// we have to fully read it for read buffer to be clean
|
||||||
err = nil
|
err = nil
|
||||||
if mask == nil {
|
if mask == nil {
|
||||||
err = UnmaskedFrameErr
|
err = ErrUnmaskedFrame
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, err
|
return m, err
|
||||||
|
@ -222,3 +230,79 @@ func (m Message) Send(writer io.Writer) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Check for message errors with:
|
||||||
|
// (m) the current message
|
||||||
|
// (fragment) whether there is a fragment in construction
|
||||||
|
// returns the message error
|
||||||
|
func (m *Message) check(fragment bool) error{
|
||||||
|
|
||||||
|
/* (1) Invalid first fragment (not TEXT nor BINARY) */
|
||||||
|
if !m.Final && !fragment && m.Type != TEXT && m.Type != BINARY {
|
||||||
|
return ErrInvalidFragment
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Waiting fragment but received standalone frame */
|
||||||
|
if fragment && m.Type != CONTINUATION {
|
||||||
|
return ErrInvalidFragment
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Control frame too long */
|
||||||
|
if (m.Type == CLOSE || m.Type == PING || m.Type == PONG) && (m.Size > 125 || !m.Final) {
|
||||||
|
return ErrTooLongControlFrame
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.Type {
|
||||||
|
case CONTINUATION:
|
||||||
|
// unexpected continuation
|
||||||
|
if !fragment {
|
||||||
|
return ErrInvalidFragment
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case TEXT:
|
||||||
|
if !utf8.Valid(m.Data) {
|
||||||
|
return ErrInvalidPayload
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case BINARY:
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case CLOSE:
|
||||||
|
// incomplete code
|
||||||
|
if m.Size == 1 {
|
||||||
|
return ErrInvalidCloseStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalid utf-8 reason
|
||||||
|
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
|
||||||
|
return ErrInvalidPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalid code
|
||||||
|
if m.Size >= 2 {
|
||||||
|
c := binary.BigEndian.Uint16(m.Data[0:2])
|
||||||
|
|
||||||
|
if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 {
|
||||||
|
return ErrInvalidCloseStatus
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case PING:
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case PONG:
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return ErrInvalidOpCode
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -19,7 +19,7 @@ func TestSimpleMessageReading(t *testing.T) {
|
||||||
"must fail on unmasked frame",
|
"must fail on unmasked frame",
|
||||||
[]byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f},
|
[]byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f},
|
||||||
Message{},
|
Message{},
|
||||||
UnmaskedFrameErr,
|
ErrUnmaskedFrame,
|
||||||
},
|
},
|
||||||
{ // FIN ; TEXT ; hello
|
{ // FIN ; TEXT ; hello
|
||||||
"simple hello text message",
|
"simple hello text message",
|
||||||
|
@ -175,8 +175,8 @@ func TestReadEOF(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.unmaskedError && err != UnmaskedFrameErr {
|
if tc.unmaskedError && err != ErrUnmaskedFrame {
|
||||||
t.Errorf("Expected UnmaskedFrameError, got %v", err)
|
t.Errorf("Expected UnmaskedFrameor, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if got.Size != 0x00 {
|
if got.Size != 0x00 {
|
||||||
|
|
Loading…
Reference in New Issue