package websocket import ( "encoding/binary" "fmt" "io" "unicode/utf8" ) var ( // ErrUnmaskedFrame error ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame") // ErrTooLongControlFrame error ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long") // ErrInvalidFragment error ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation") // ErrUnexpectedContinuation error ErrUnexpectedContinuation = fmt.Errorf("Received unexpected continuation frame") // ErrInvalidSize error ErrInvalidSize = fmt.Errorf("Received invalid payload size") // ErrInvalidPayload error ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload") // ErrInvalidCloseStatus error ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status") // ErrInvalidOpCode error ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode") // ErrReservedBits error ErrReservedBits = fmt.Errorf("Received reserved bits") // ErrCloseFrame error ErrCloseFrame = fmt.Errorf("Received close Frame") ) // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask const maximumHeaderSize = 1 + 1 + 8 + 4 const maxWriteChunk = 0x7fff // MessageError lists websocket close statuses type MessageError uint16 const ( // None used when there is no error None MessageError = 0 // Normal error Normal MessageError = 1000 // GoingAway error GoingAway MessageError = 1001 // ProtocolError error ProtocolError MessageError = 1002 // UnacceptableOpCode error UnacceptableOpCode MessageError = 1003 // InvalidPayload error InvalidPayload MessageError = 1007 // utf8 // MessageTooLarge error MessageTooLarge MessageError = 1009 ) // MessageType lists websocket message types type MessageType byte const ( // Continuation message type Continuation MessageType = 0x00 // Text message type Text MessageType = 0x01 // Binary message type Binary MessageType = 0x02 // Close message type Close MessageType = 0x08 // Ping message type Ping MessageType = 0x09 // Pong message type Pong MessageType = 0x0a ) // Message is a websocket message type Message struct { Final bool Type MessageType Size uint Data []byte } // receive reads a message form reader func readMessage(reader io.Reader) (*Message, error) { var err error var tmpBuf []byte var mask []byte var cursor int m := new(Message) /* (2) Byte 1: FIN and OpCode */ tmpBuf = make([]byte, 1) err = readBytes(reader, tmpBuf) if err != nil { return m, err } // check reserved bits if tmpBuf[0]&0x70 != 0 { return m, ErrReservedBits } m.Final = bool(tmpBuf[0]&0x80 == 0x80) m.Type = MessageType(tmpBuf[0] & 0x0f) /* (3) Byte 2: Mask and Length[0] */ tmpBuf = make([]byte, 1) err = readBytes(reader, tmpBuf) if err != nil { return m, err } // if mask, byte array not nil if tmpBuf[0]&0x80 == 0x80 { mask = make([]byte, 0) } // payload length m.Size = uint(tmpBuf[0] & 0x7f) /* (4) Extended payload */ if m.Size == 127 { tmpBuf = make([]byte, 8) err := readBytes(reader, tmpBuf) if err != nil { return m, err } m.Size = uint(binary.BigEndian.Uint64(tmpBuf)) } else if m.Size == 126 { tmpBuf = make([]byte, 2) err := readBytes(reader, tmpBuf) if err != nil { return m, err } m.Size = uint(binary.BigEndian.Uint16(tmpBuf)) } /* (5) Masking key */ if mask != nil { tmpBuf = make([]byte, 4) err := readBytes(reader, tmpBuf) if err != nil { return m, err } mask = make([]byte, 4) copy(mask, tmpBuf) } /* (6) Read payload by chunks */ m.Data = make([]byte, int(m.Size)) cursor = 0 // {1} While we have data to read // for uint(cursor) < m.Size { // {2} Try to read (at least 1 byte) // nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1) if err != nil { return m, err } // {3} Unmask data // if mask != nil { for i, l := cursor, cursor+nbread; i < l; i++ { mi := i % 4 // mask index m.Data[i] = m.Data[i] ^ mask[mi] } } // {4} Update cursor // cursor += nbread } // return error if unmasked frame // we have to fully read it for read buffer to be clean err = nil if mask == nil { err = ErrUnmaskedFrame } return m, err } // Send sends a frame over a socket func (m Message) Send(writer io.Writer) error { header := make([]byte, 0, maximumHeaderSize) // fix size if uint(len(m.Data)) <= m.Size { m.Size = uint(len(m.Data)) } /* (1) Byte 0 : FIN + opcode */ var final byte = 0x80 if !m.Final { final = 0 } header = append(header, final|byte(m.Type)) /* (2) Get payload length */ if m.Size < 126 { // simple header = append(header, byte(m.Size)) } else if m.Size <= 0xffff { // extended: 16 bits header = append(header, 126) buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(m.Size)) header = append(header, buf...) } else if m.Size <= 0xffffffffffffffff { // extended: 64 bits header = append(header, 127) buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, uint64(m.Size)) header = append(header, buf...) } /* (3) Build write buffer */ writeBuf := make([]byte, 0, len(header)+int(m.Size)) writeBuf = append(writeBuf, header...) writeBuf = append(writeBuf, m.Data[0:m.Size]...) /* (4) Send over socket by chunks */ toWrite := len(header) + int(m.Size) cursor := 0 for cursor < toWrite { maxBoundary := cursor + maxWriteChunk if maxBoundary > toWrite { maxBoundary = toWrite } // Try to wrote (at max 1024 bytes) // nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary]) if err != nil { return err } // Update cursor // cursor += nbwritten } return nil } // Check for message errors with: // (m) the current message // (fragment) whether there is a fragment in construction // returns the message error func (m *Message) check(fragment bool) error { /* (1) Invalid first fragment (not TEXT nor BINARY) */ if !m.Final && !fragment && m.Type != Text && m.Type != Binary { return ErrInvalidFragment } /* (2) Waiting fragment but received standalone frame */ if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong { return ErrInvalidFragment } /* (3) Control frame too long */ if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) { return ErrTooLongControlFrame } switch m.Type { case Continuation: // unexpected continuation if !fragment { return ErrUnexpectedContinuation } return nil case Text: if m.Final && !utf8.Valid(m.Data) { return ErrInvalidPayload } return nil case Binary: return nil case Close: // incomplete code if m.Size == 1 { return ErrInvalidCloseStatus } // invalid utf-8 reason if m.Size > 2 && !utf8.Valid(m.Data[2:]) { return ErrInvalidPayload } // invalid code if m.Size >= 2 { c := binary.BigEndian.Uint16(m.Data[0:2]) if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 { return ErrInvalidCloseStatus } } return ErrCloseFrame case Ping: return nil case Pong: return nil default: return ErrInvalidOpCode } return nil } // readBytes reads from a reader into a byte array // until the byte length is fully filled with data // loops while there is no error // // It manages connections which chunks data func readBytes(reader io.Reader, buffer []byte) error { var cur, len int = 0, len(buffer) // try to read until the full size is read for cur < len { nbread, err := reader.Read(buffer[cur:]) if err != nil { return err } cur += nbread } return nil }