ws/message.go

357 lines
7.7 KiB
Go

package websocket
import (
"encoding/binary"
"io"
"unicode/utf8"
)
// constant error
type constErr string
func (c constErr) Error() string { return string(c) }
const (
// ErrUnmaskedFrame error
ErrUnmaskedFrame constErr = "Received unmasked frame"
// ErrTooLongControlFrame error
ErrTooLongControlFrame constErr = "Received a control frame that is fragmented or too long"
// ErrInvalidFragment error
ErrInvalidFragment constErr = "Received invalid fragmentation"
// ErrUnexpectedContinuation error
ErrUnexpectedContinuation constErr = "Received unexpected continuation frame"
// ErrInvalidSize error
ErrInvalidSize constErr = "Received invalid payload size"
// ErrInvalidPayload error
ErrInvalidPayload constErr = "Received invalid utf8 payload"
// ErrInvalidCloseStatus error
ErrInvalidCloseStatus constErr = "Received invalid close status"
// ErrInvalidOpCode error
ErrInvalidOpCode constErr = "Received invalid OpCode"
// ErrReservedBits error
ErrReservedBits constErr = "Received reserved bits"
// ErrCloseFrame error
ErrCloseFrame constErr = "Received close Frame"
)
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
const maximumHeaderSize = 1 + 1 + 8 + 4
const maxWriteChunk = 0x7fff
// MessageError lists websocket close statuses
type MessageError uint16
const (
// None used when there is no error
None MessageError = 0
// Normal error
Normal MessageError = 1000
// GoingAway error
GoingAway MessageError = 1001
// ProtocolError error
ProtocolError MessageError = 1002
// UnacceptableOpCode error
UnacceptableOpCode MessageError = 1003
// InvalidPayload error
InvalidPayload MessageError = 1007 // utf8
// MessageTooLarge error
MessageTooLarge MessageError = 1009
)
// MessageType lists websocket message types
type MessageType byte
const (
// Continuation message type
Continuation MessageType = 0x00
// Text message type
Text MessageType = 0x01
// Binary message type
Binary MessageType = 0x02
// Close message type
Close MessageType = 0x08
// Ping message type
Ping MessageType = 0x09
// Pong message type
Pong MessageType = 0x0a
)
// Message is a websocket message
type Message struct {
Final bool
Type MessageType
Size uint
Data []byte
}
// ReadFrom reads a message from a reader
//
// implements io.ReaderFrom
func (m *Message) ReadFrom(reader io.Reader) (int64, error) {
var (
read int64
err error
tmpBuf []byte
mask []byte
cursor int
)
// byte 1: FIN and OpCode
tmpBuf = make([]byte, 1)
read += int64(len(tmpBuf))
err = readBytes(reader, tmpBuf)
if err != nil {
return read, err
}
// check reserved bits
if tmpBuf[0]&0x70 != 0 {
return read, ErrReservedBits
}
m.Final = bool(tmpBuf[0]&0x80 == 0x80)
m.Type = MessageType(tmpBuf[0] & 0x0f)
// byte 2: mask and length[0]
tmpBuf = make([]byte, 1)
read += int64(len(tmpBuf))
err = readBytes(reader, tmpBuf)
if err != nil {
return read, err
}
// if mask, byte array not nil
if tmpBuf[0]&0x80 == 0x80 {
mask = make([]byte, 0)
}
// payload length
m.Size = uint(tmpBuf[0] & 0x7f)
// extended payload
if m.Size == 127 {
tmpBuf = make([]byte, 8)
read += int64(len(tmpBuf))
err := readBytes(reader, tmpBuf)
if err != nil {
return read, err
}
m.Size = uint(binary.BigEndian.Uint64(tmpBuf))
} else if m.Size == 126 {
tmpBuf = make([]byte, 2)
read += int64(len(tmpBuf))
err := readBytes(reader, tmpBuf)
if err != nil {
return read, err
}
m.Size = uint(binary.BigEndian.Uint16(tmpBuf))
}
// masking key
if mask != nil {
tmpBuf = make([]byte, 4)
read += int64(len(tmpBuf))
err := readBytes(reader, tmpBuf)
if err != nil {
return read, err
}
mask = make([]byte, 4)
copy(mask, tmpBuf)
}
// read payload by chunks
m.Data = make([]byte, int(m.Size))
cursor = 0
// while data to read
for uint(cursor) < m.Size {
// try to read (at least 1 byte)
nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1)
if err != nil {
return read + int64(cursor) + int64(nbread), err
}
// unmask data //
if mask != nil {
for i, l := cursor, cursor+nbread; i < l; i++ {
mi := i % 4 // mask index
m.Data[i] = m.Data[i] ^ mask[mi]
}
}
cursor += nbread
}
read += int64(cursor)
// return error if unmasked frame
// we have to fully read it for read buffer to be clean
err = nil
if mask == nil {
err = ErrUnmaskedFrame
}
return read, err
}
// WriteTo writes a message frame over a socket
//
// implements io.WriterTo
func (m Message) WriteTo(writer io.Writer) (int64, error) {
header := make([]byte, 0, maximumHeaderSize)
// fix size
if uint(len(m.Data)) <= m.Size {
m.Size = uint(len(m.Data))
}
// byte 0 : FIN + opcode
var final byte = 0x80
if !m.Final {
final = 0
}
header = append(header, final|byte(m.Type))
// get payload length
if m.Size < 126 { // simple
header = append(header, byte(m.Size))
} else if m.Size <= 0xffff { // extended: 16 bits
header = append(header, 126)
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(m.Size))
header = append(header, buf...)
} else if m.Size <= 0xffffffffffffffff { // extended: 64 bits
header = append(header, 127)
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(m.Size))
header = append(header, buf...)
}
// build write buffer
writeBuf := make([]byte, 0, len(header)+int(m.Size))
writeBuf = append(writeBuf, header...)
writeBuf = append(writeBuf, m.Data[0:m.Size]...)
// write by chunks
toWrite := len(header) + int(m.Size)
cursor := 0
for cursor < toWrite {
maxBoundary := cursor + maxWriteChunk
if maxBoundary > toWrite {
maxBoundary = toWrite
}
// Try to wrote (at max 1024 bytes) //
nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary])
if err != nil {
return int64(nbwritten), err
}
// Update cursor //
cursor += nbwritten
}
return int64(cursor), nil
}
// check for message errors with:
// (m) the current message
// (fragment) whether there is a fragment in construction
// returns the message error
func (m *Message) check(fragment bool) error {
// invalid first fragment (not TEXT nor BINARY)
if !m.Final && !fragment && m.Type != Text && m.Type != Binary {
return ErrInvalidFragment
}
// waiting fragment but received standalone frame
if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong {
return ErrInvalidFragment
}
// control frame too long
if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) {
return ErrTooLongControlFrame
}
switch m.Type {
case Continuation:
// unexpected continuation
if !fragment {
return ErrUnexpectedContinuation
}
return nil
case Text:
if m.Final && !utf8.Valid(m.Data) {
return ErrInvalidPayload
}
return nil
case Binary:
return nil
case Close:
// incomplete code
if m.Size == 1 {
return ErrInvalidCloseStatus
}
// invalid utf-8 reason
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
return ErrInvalidPayload
}
// invalid code
if m.Size >= 2 {
c := binary.BigEndian.Uint16(m.Data[0:2])
if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 {
return ErrInvalidCloseStatus
}
}
return ErrCloseFrame
case Ping:
return nil
case Pong:
return nil
default:
return ErrInvalidOpCode
}
}
// readBytes reads from a reader into a byte array
// until the byte length is fully filled with data
// loops while there is no error
//
// It manages connections which chunks data
func readBytes(reader io.Reader, buffer []byte) error {
var (
cur = 0
len = len(buffer)
)
// read until the full size is read
for cur < len {
nbread, err := reader.Read(buffer[cur:])
if err != nil {
return err
}
cur += nbread
}
return nil
}