package websocket import ( "encoding/binary" "io" "unicode/utf8" ) // constant error type constErr string func (c constErr) Error() string { return string(c) } const ( // ErrUnmaskedFrame error ErrUnmaskedFrame constErr = "Received unmasked frame" // ErrTooLongControlFrame error ErrTooLongControlFrame constErr = "Received a control frame that is fragmented or too long" // ErrInvalidFragment error ErrInvalidFragment constErr = "Received invalid fragmentation" // ErrUnexpectedContinuation error ErrUnexpectedContinuation constErr = "Received unexpected continuation frame" // ErrInvalidSize error ErrInvalidSize constErr = "Received invalid payload size" // ErrInvalidPayload error ErrInvalidPayload constErr = "Received invalid utf8 payload" // ErrInvalidCloseStatus error ErrInvalidCloseStatus constErr = "Received invalid close status" // ErrInvalidOpCode error ErrInvalidOpCode constErr = "Received invalid OpCode" // ErrReservedBits error ErrReservedBits constErr = "Received reserved bits" // ErrCloseFrame error ErrCloseFrame constErr = "Received close Frame" ) // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask const maximumHeaderSize = 1 + 1 + 8 + 4 const maxWriteChunk = 0x7fff // MessageError lists websocket close statuses type MessageError uint16 const ( // None used when there is no error None MessageError = 0 // Normal error Normal MessageError = 1000 // GoingAway error GoingAway MessageError = 1001 // ProtocolError error ProtocolError MessageError = 1002 // UnacceptableOpCode error UnacceptableOpCode MessageError = 1003 // InvalidPayload error InvalidPayload MessageError = 1007 // utf8 // MessageTooLarge error MessageTooLarge MessageError = 1009 ) // MessageType lists websocket message types type MessageType byte const ( // Continuation message type Continuation MessageType = 0x00 // Text message type Text MessageType = 0x01 // Binary message type Binary MessageType = 0x02 // Close message type Close MessageType = 0x08 // Ping message type Ping MessageType = 0x09 // Pong message type Pong MessageType = 0x0a ) // Message is a websocket message type Message struct { Final bool Type MessageType Size uint Data []byte } // ReadFrom reads a message from a reader // // implements io.ReaderFrom func (m *Message) ReadFrom(reader io.Reader) (int64, error) { var ( read int64 err error tmpBuf []byte mask []byte cursor int ) // byte 1: FIN and OpCode tmpBuf = make([]byte, 1) read += int64(len(tmpBuf)) err = readBytes(reader, tmpBuf) if err != nil { return read, err } // check reserved bits if tmpBuf[0]&0x70 != 0 { return read, ErrReservedBits } m.Final = bool(tmpBuf[0]&0x80 == 0x80) m.Type = MessageType(tmpBuf[0] & 0x0f) // byte 2: mask and length[0] tmpBuf = make([]byte, 1) read += int64(len(tmpBuf)) err = readBytes(reader, tmpBuf) if err != nil { return read, err } // if mask, byte array not nil if tmpBuf[0]&0x80 == 0x80 { mask = make([]byte, 0) } // payload length m.Size = uint(tmpBuf[0] & 0x7f) // extended payload if m.Size == 127 { tmpBuf = make([]byte, 8) read += int64(len(tmpBuf)) err := readBytes(reader, tmpBuf) if err != nil { return read, err } m.Size = uint(binary.BigEndian.Uint64(tmpBuf)) } else if m.Size == 126 { tmpBuf = make([]byte, 2) read += int64(len(tmpBuf)) err := readBytes(reader, tmpBuf) if err != nil { return read, err } m.Size = uint(binary.BigEndian.Uint16(tmpBuf)) } // masking key if mask != nil { tmpBuf = make([]byte, 4) read += int64(len(tmpBuf)) err := readBytes(reader, tmpBuf) if err != nil { return read, err } mask = make([]byte, 4) copy(mask, tmpBuf) } // read payload by chunks m.Data = make([]byte, int(m.Size)) cursor = 0 // while data to read for uint(cursor) < m.Size { // try to read (at least 1 byte) nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1) if err != nil { return read + int64(cursor) + int64(nbread), err } // unmask data // if mask != nil { for i, l := cursor, cursor+nbread; i < l; i++ { mi := i % 4 // mask index m.Data[i] = m.Data[i] ^ mask[mi] } } cursor += nbread } read += int64(cursor) // return error if unmasked frame // we have to fully read it for read buffer to be clean err = nil if mask == nil { err = ErrUnmaskedFrame } return read, err } // WriteTo writes a message frame over a socket // // implements io.WriterTo func (m Message) WriteTo(writer io.Writer) (int64, error) { header := make([]byte, 0, maximumHeaderSize) // fix size if uint(len(m.Data)) <= m.Size { m.Size = uint(len(m.Data)) } // byte 0 : FIN + opcode var final byte = 0x80 if !m.Final { final = 0 } header = append(header, final|byte(m.Type)) // get payload length if m.Size < 126 { // simple header = append(header, byte(m.Size)) } else if m.Size <= 0xffff { // extended: 16 bits header = append(header, 126) buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(m.Size)) header = append(header, buf...) } else if m.Size <= 0xffffffffffffffff { // extended: 64 bits header = append(header, 127) buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, uint64(m.Size)) header = append(header, buf...) } // build write buffer writeBuf := make([]byte, 0, len(header)+int(m.Size)) writeBuf = append(writeBuf, header...) writeBuf = append(writeBuf, m.Data[0:m.Size]...) // write by chunks toWrite := len(header) + int(m.Size) cursor := 0 for cursor < toWrite { maxBoundary := cursor + maxWriteChunk if maxBoundary > toWrite { maxBoundary = toWrite } // Try to wrote (at max 1024 bytes) // nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary]) if err != nil { return int64(nbwritten), err } // Update cursor // cursor += nbwritten } return int64(cursor), nil } // check for message errors with: // (m) the current message // (fragment) whether there is a fragment in construction // returns the message error func (m *Message) check(fragment bool) error { // invalid first fragment (not TEXT nor BINARY) if !m.Final && !fragment && m.Type != Text && m.Type != Binary { return ErrInvalidFragment } // waiting fragment but received standalone frame if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong { return ErrInvalidFragment } // control frame too long if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) { return ErrTooLongControlFrame } switch m.Type { case Continuation: // unexpected continuation if !fragment { return ErrUnexpectedContinuation } return nil case Text: if m.Final && !utf8.Valid(m.Data) { return ErrInvalidPayload } return nil case Binary: return nil case Close: // incomplete code if m.Size == 1 { return ErrInvalidCloseStatus } // invalid utf-8 reason if m.Size > 2 && !utf8.Valid(m.Data[2:]) { return ErrInvalidPayload } // invalid code if m.Size >= 2 { c := binary.BigEndian.Uint16(m.Data[0:2]) if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 { return ErrInvalidCloseStatus } } return ErrCloseFrame case Ping: return nil case Pong: return nil default: return ErrInvalidOpCode } } // readBytes reads from a reader into a byte array // until the byte length is fully filled with data // loops while there is no error // // It manages connections which chunks data func readBytes(reader io.Reader, buffer []byte) error { var ( cur = 0 len = len(buffer) ) // read until the full size is read for cur < len { nbread, err := reader.Read(buffer[cur:]) if err != nil { return err } cur += nbread } return nil }