Compare commits

...

5 Commits

30 changed files with 1196 additions and 1461 deletions

179
client.go
View File

@ -4,10 +4,11 @@ import (
"bufio" "bufio"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"git.xdrm.io/go/websocket/internal/http/upgrade/request"
"net" "net"
"sync" "sync"
"time" "time"
"git.xdrm.io/go/ws/internal/http/upgrade"
) )
// Represents a client socket utility (reader, writer, ..) // Represents a client socket utility (reader, writer, ..)
@ -35,36 +36,27 @@ type client struct {
status MessageError // close status ; 0 = nothing ; else -> must close status MessageError // close status ; 0 = nothing ; else -> must close
} }
// Create creates a new client // newClient creates a new client
func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) { func newClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) {
req := &upgrade.Request{}
_, err := req.ReadFrom(s)
if err != nil {
return nil, fmt.Errorf("request read: %w", err)
}
/* (1) Manage UPGRADE request
---------------------------------------------------------*/
/* (1) Parse request */
req, _ := request.Parse(s)
/* (3) Build response */
res := req.BuildResponse() res := req.BuildResponse()
/* (4) Write into socket */ _, err = res.WriteTo(s)
_, err := res.Send(s)
if err != nil { if err != nil {
return nil, fmt.Errorf("Upgrade write error: %s", err) return nil, fmt.Errorf("upgrade write: %w", err)
} }
if res.GetStatusCode() != 101 { if res.StatusCode != 101 {
s.Close() s.Close()
return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode()) return nil, fmt.Errorf("upgrade failed (HTTP %d)", res.StatusCode)
} }
/* (2) Initialise client var cli = &client{
---------------------------------------------------------*/
/* (1) Get upgrade data */
clientURI := req.GetURI()
clientProtocol := res.GetProtocol()
/* (2) Initialise client */
cli := &client{
io: clientIO{ io: clientIO{
sock: s, sock: s,
reader: bufio.NewReader(s), reader: bufio.NewReader(s),
@ -72,8 +64,8 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
}, },
iface: &Client{ iface: &Client{
Protocol: string(clientProtocol), Protocol: string(res.Protocol),
Arguments: [][]string{[]string{clientURI}}, Arguments: [][]string{{req.URI()}},
}, },
ch: clientChannelSet{ ch: clientChannelSet{
@ -82,67 +74,54 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
}, },
} }
/* (3) Find controller by URI // find controller by URI
---------------------------------------------------------*/ controller, arguments := ctl.Match(req.URI())
/* (1) Try to find one */
controller, arguments := ctl.Match(clientURI)
/* (2) If nothing found -> error */
if controller == nil { if controller == nil {
return nil, fmt.Errorf("No controller found, no default controller set\n") return nil, fmt.Errorf("no controller found, no default controller set")
} }
/* (3) Copy arguments */ // copy args
cli.iface.Arguments = arguments cli.iface.Arguments = arguments
/* (4) Launch client routines
---------------------------------------------------------*/
/* (1) Launch client controller */
go controller.Fun( go controller.Fun(
cli.iface, // pass the client cli.iface, // pass the client
cli.ch.receive, // the receiver cli.ch.receive, // the receiver
cli.ch.send, // the sender cli.ch.send, // the sender
serverCh.broadcast, // broadcast sender serverCh.broadcast, // broadcast sender
) )
/* (2) Launch message reader */
go clientReader(cli) go clientReader(cli)
/* (3) Launc writer */
go clientWriter(cli) go clientWriter(cli)
return cli, nil return cli, nil
} }
// reader reads and parses messages from the buffer // clientReader reads and parses messages from the buffer
func clientReader(c *client) { func clientReader(c *client) {
var frag *Message var (
frag *Message
closeStatus := NORMAL closeStatus = Normal
clientAck := true clientAck = true
)
c.io.reading.Add(1) c.io.reading.Add(1)
for { for {
// currently closing -> exit
/* (1) if currently closing -> exit */
if c.io.closing { if c.io.closing {
fmt.Printf("[reader] killed because closing") fmt.Printf("[reader] killed because closing")
break break
} }
/* (2) Parse message */ // Parse message
msg, err := readMessage(c.io.reader) var msg = &Message{}
_, err := msg.ReadFrom(c.io.reader)
if err == ErrUnmaskedFrame || err == ErrReservedBits { if err == ErrUnmaskedFrame || err == ErrReservedBits {
closeStatus = PROTOCOL_ERR closeStatus = ProtocolError
} }
if err != nil { if err != nil {
break break
} }
/* (3) Fail on invalid message */ // invalid message
msgErr := msg.check(frag != nil) msgErr := msg.check(frag != nil)
if msgErr != nil { if msgErr != nil {
@ -150,27 +129,27 @@ func clientReader(c *client) {
switch msgErr { switch msgErr {
// Fail // fail
case ErrUnexpectedContinuation: case ErrUnexpectedContinuation:
closeStatus = NONE closeStatus = None
clientAck = false clientAck = false
mustClose = true mustClose = true
// proper close // proper close
case CloseFrame: case ErrCloseFrame:
closeStatus = NORMAL closeStatus = Normal
clientAck = true clientAck = true
mustClose = true mustClose = true
// invalid payload proper close // invalid payload proper close
case ErrInvalidPayload: case ErrInvalidPayload:
closeStatus = INVALID_PAYLOAD closeStatus = InvalidPayload
clientAck = true clientAck = true
mustClose = true mustClose = true
// any other error -> protocol error // any other error -> protocol error
default: default:
closeStatus = PROTOCOL_ERR closeStatus = ProtocolError
clientAck = true clientAck = true
mustClose = true mustClose = true
} }
@ -181,15 +160,15 @@ func clientReader(c *client) {
} }
/* (4) Ping <-> Pong */ // ping <-> Pong
if msg.Type == PING && c.io.writing { if msg.Type == Ping && c.io.writing {
msg.Final = true msg.Final = true
msg.Type = PONG msg.Type = Pong
c.ch.send <- *msg c.ch.send <- *msg
continue continue
} }
/* (5) Store first fragment */ // store first fragment
if frag == nil && !msg.Final { if frag == nil && !msg.Final {
frag = &Message{ frag = &Message{
Type: msg.Type, Type: msg.Type,
@ -200,7 +179,7 @@ func clientReader(c *client) {
continue continue
} }
/* (6) Store fragments */ // store fragments
if frag != nil { if frag != nil {
frag.Final = msg.Final frag.Final = msg.Final
frag.Size += msg.Size frag.Size += msg.Size
@ -213,10 +192,10 @@ func clientReader(c *client) {
// check message errors // check message errors
fragErr := frag.check(false) fragErr := frag.check(false)
if fragErr == ErrInvalidPayload { if fragErr == ErrInvalidPayload {
closeStatus = INVALID_PAYLOAD closeStatus = InvalidPayload
break break
} else if fragErr != nil { } else if fragErr != nil {
closeStatus = PROTOCOL_ERR closeStatus = ProtocolError
break break
} }
@ -225,8 +204,8 @@ func clientReader(c *client) {
} }
/* (7) Dispatch to receiver */ // dispatch to receiver
if msg.Type == TEXT || msg.Type == BINARY { if msg.Type == Text || msg.Type == Binary {
c.ch.receive <- *msg c.ch.receive <- *msg
} }
@ -235,111 +214,79 @@ func clientReader(c *client) {
close(c.ch.receive) close(c.ch.receive)
c.io.reading.Done() c.io.reading.Done()
/* (8) close channel (if not already done) */ // close channel (if not already done)
// fmt.Printf("[reader] end\n") // fmt.Printf("[reader] end\n")
c.close(closeStatus, clientAck) c.close(closeStatus, clientAck)
} }
// writer writes into websocket // clientWriter writes to the websocket connection and is triggered by
// and is triggered by client.ch.send channel // client.ch.send channel
func clientWriter(c *client) { func clientWriter(c *client) {
c.io.writing = true // if channel still exists c.io.writing = true // if channel still exists
for msg := range c.ch.send { for msg := range c.ch.send {
_, err := msg.WriteTo(c.io.sock)
/* (2) Send message */
err := msg.Send(c.io.sock)
/* (3) Fail on error */
if err != nil { if err != nil {
fmt.Printf(" [writer] %s\n", err) fmt.Printf(" [writer] %s\n", err)
c.io.writing = false c.io.writing = false
break break
} }
} }
c.io.writing = false c.io.writing = false
/* (4) close channel (if not already done) */ // close channel (if not already done)
// fmt.Printf("[writer] end\n") // fmt.Printf("[writer] end\n")
c.close(NORMAL, true) c.close(Normal, true)
} }
// closes the connection // close the connection
// send CLOSE frame is 'status' is not NONE // send CLOSE frame is 'status' is not NONE
// wait for the next message (CLOSE acknowledge) if 'clientACK' // wait for the next message (CLOSE acknowledge) if 'clientACK'
// then delete client // then delete client
func (c *client) close(status MessageError, clientACK bool) { func (c *client) close(status MessageError, clientACK bool) {
// fail if already closing
/* (1) Fail if already closing */
alreadyClosing := false alreadyClosing := false
c.io.closingMu.Lock() c.io.closingMu.Lock()
alreadyClosing = c.io.closing alreadyClosing = c.io.closing
c.io.closing = true c.io.closing = true
c.io.closingMu.Unlock() c.io.closingMu.Unlock()
if alreadyClosing { if alreadyClosing {
return return
} }
/* (2) kill writer' if still running */ // kill writer' if still running
if c.io.writing { if c.io.writing {
close(c.ch.send) close(c.ch.send)
} }
/* (3) kill reader if still running */ // kill reader if still running
c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1)) c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
c.io.reading.Wait() c.io.reading.Wait()
if status != NONE { if status != None {
/* (3) Build message */
msg := &Message{ msg := &Message{
Final: true, Final: true,
Type: CLOSE, Type: Close,
Size: 2, Size: 2,
Data: make([]byte, 2), Data: make([]byte, 2),
} }
binary.BigEndian.PutUint16(msg.Data, uint16(status)) binary.BigEndian.PutUint16(msg.Data, uint16(status))
/* (4) Send message */ msg.WriteTo(c.io.sock)
msg.Send(c.io.sock)
// if err != nil {
// fmt.Printf("[close] send error (%s0\n", err)
// }
} }
/* (2) Wait for client CLOSE if needed */ // wait for client CLOSE if needed
if clientACK { if clientACK {
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond)) c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
var tmpMsg = &Message{}
/* Wait for message */ tmpMsg.ReadFrom(c.io.reader)
readMessage(c.io.reader)
// if err != nil || msg.Type != CLOSE {
// if err == nil {
// fmt.Printf("[close] received OpCode = %d\n", msg.Type)
// } else {
// fmt.Printf("[close] read error (%v)\n", err)
// }
// }
// fmt.Printf("[close] received ACK\n")
} }
/* (3) Close socket */
c.io.sock.Close() c.io.sock.Close()
// fmt.Printf("[close] socket closed\n") // fmt.Printf("[close] socket closed\n")
/* (4) Unregister */
c.io.kill <- c c.io.kill <- c
return
} }

View File

@ -1,21 +1,21 @@
package iface package main
import ( import (
"fmt" "fmt"
ws "git.xdrm.io/go/websocket"
"time" "time"
ws "git.xdrm.io/go/ws"
) )
func main() { func main() {
startTime := time.Now().UnixNano() startTime := time.Now().UnixNano()
/* (1) Bind WebSocket server */ // creqte WebSocket server
serv := ws.CreateServer("0.0.0.0", 4444) serv := ws.NewServer("0.0.0.0", 4444)
/* (2) Bind default controller */ // bind default controller
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
defer func() { defer func() {
if recover() != nil { if recover() != nil {
fmt.Printf("*** PANIC\n") fmt.Printf("*** PANIC\n")
@ -23,35 +23,28 @@ func main() {
}() }()
for msg := range receiver { for msg := range receiver {
// if receive message -> send it back // if receive message -> send it back
sender <- msg sender <- msg
// close(sender) // close(sender)
} }
}) })
/* (3) Bind to URI */ // bnd to URI
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
fmt.Printf("[uri] connected\n") fmt.Printf("[uri] connected\n")
for msg := range receiver { for msg := range receiver {
fmt.Printf("[uri] received '%s'\n", msg.Data) fmt.Printf("[uri] received '%s'\n", msg.Data)
sender <- msg sender <- msg
} }
fmt.Printf("[uri] unexpectedly closed\n") fmt.Printf("[uri] unexpectedly closed\n")
}) })
if err != nil { if err != nil {
panic(err) panic(err)
} }
/* (4) Launch the server */ // launch the server
err = serv.Launch() err = serv.Launch()
if err != nil { if err != nil {
fmt.Printf("[ERROR] %s\n", err) fmt.Printf("[ERROR] %s\n", err)
@ -59,5 +52,4 @@ func main() {
} }
fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3) fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3)
} }

View File

@ -1,63 +1,48 @@
package websocket package websocket
import ( import "git.xdrm.io/go/ws/internal/uri"
"git.xdrm.io/go/websocket/internal/uri/parser"
)
// Represents available information about a client // Client contains available information about a client
type Client struct { type Client struct {
Protocol string // choosen protocol (Sec-WebSocket-Protocol) Protocol string // choosen protocol (Sec-WebSocket-Protocol)
Arguments [][]string // URI parameters, index 0 is full URI, then matching groups Arguments [][]string // URI parameters, index 0 is full URI, then matching groups
Store interface{} // store (for client implementation-specific data) Store interface{} // store (for client implementation-specific data)
} }
// Represents a websocket controller callback function // ControllerFunc is a websocket controller callback function
type ControllerFunc func(*Client, <-chan Message, chan<- Message, chan<- Message) type ControllerFunc func(*Client, <-chan Message, chan<- Message, chan<- Message)
// Represents a websocket controller // Controller is a websocket controller
type Controller struct { type Controller struct {
URI *parser.Scheme // uri scheme URI *uri.Scheme // uri scheme
Fun ControllerFunc // controller function Fun ControllerFunc // controller function
} }
// Represents a controller set // ControllerSet contains a set of controllers
type ControllerSet struct { type ControllerSet struct {
Def *Controller // default controller Def *Controller // default controller
Uri []*Controller // uri controllers URI []*Controller // uri controllers
} }
// Match finds a controller for a given URI // Match finds a controller for a given URI
// also it returns the matching string patterns // also it returns the matching string patterns
func (s *ControllerSet) Match(uri string) (*Controller, [][]string) { func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
arguments := [][]string{{uri}}
/* (1) Initialise argument list */ for _, c := range s.URI {
arguments := [][]string{[]string{uri}}
/* (2) Try each controller */
for _, c := range s.Uri {
/* 1. If matches */
if c.URI.Match(uri) { if c.URI.Match(uri) {
/* Extract matches */
match := c.URI.GetAllMatch() match := c.URI.GetAllMatch()
/* Add them to the 'arg' attribute */
arguments = append(arguments, match...) arguments = append(arguments, match...)
/* Mark that we have a controller */
return c, arguments return c, arguments
}
} }
} // fallback to default
/* (3) If no controller found -> set default controller */
if s.Def != nil { if s.Def != nil {
return s.Def, arguments return s.Def, arguments
} }
/* (4) If default is NIL, return empty controller */ // no default
return nil, arguments return nil, arguments
} }

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.xdrm.io/go/ws
go 1.16

View File

@ -6,48 +6,43 @@ package reader
// the golang standard library // the golang standard library
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
"bufio"
) )
// Maximum line length // Maximum line length
var maxLineLength = 4096 const maxLineLength = 4096
// Chunk reader // ChunkReader struct
type chunkReader struct { type ChunkReader struct {
reader *bufio.Reader // the reader reader *bufio.Reader // the reader
isEnded bool // If we are done (2 consecutive CRLF) isEnded bool // If we are done (2 consecutive CRLF)
} }
// NewReader creates a new reader
// New creates a new reader func NewReader(r io.Reader) *ChunkReader {
func NewReader(r io.Reader) (reader *chunkReader) {
br, ok := r.(*bufio.Reader) br, ok := r.(*bufio.Reader)
if !ok { if !ok {
br = bufio.NewReader(r) br = bufio.NewReader(r)
} }
return &chunkReader{reader: br} return &ChunkReader{reader: br}
} }
// Read reads a chunk, io.EOF when done
func (r *ChunkReader) Read() ([]byte, error) {
// Read reads a chunk, err is io.EOF when done // already ended
func (r *chunkReader) Read() ([]byte, error){
/* (1) If already ended */
if r.isEnded { if r.isEnded {
return nil, io.EOF return nil, io.EOF
} }
/* (2) Read line */ // read line
var line []byte var line []byte
line, err := r.reader.ReadSlice('\n') line, err := r.reader.ReadSlice('\n')
/* (3) manage errors */
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
@ -60,10 +55,8 @@ func (r *chunkReader) Read() ([]byte, error){
return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength) return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength)
} }
/* (4) Trim */ line = trimSpaces(line)
line = removeTrailingSpace(line)
/* (5) Manage ending line */
if len(line) == 0 { if len(line) == 0 {
r.isEnded = true r.isEnded = true
return line, io.EOF return line, io.EOF
@ -73,15 +66,13 @@ func (r *chunkReader) Read() ([]byte, error){
} }
func trimSpaces(b []byte) []byte {
for len(b) > 0 && isSpaceChar(b[len(b)-1]) {
func removeTrailingSpace(b []byte) []byte{
for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
b = b[:len(b)-1] b = b[:len(b)-1]
} }
return b return b
} }
func isASCIISpace(b byte) bool { func isSpaceChar(b byte) bool {
return b == ' ' || b == '\t' || b == '\r' || b == '\n' return b == ' ' || b == '\t' || b == '\r' || b == '\n'
} }

View File

@ -0,0 +1,35 @@
package upgrade
import (
"fmt"
)
// ErrInvalidRequest for invalid requests
// - multiple-value if only 1 expected
type ErrInvalidRequest struct {
Field string
Reason string
}
func (err ErrInvalidRequest) Error() string {
return fmt.Sprintf("invalid field '%s': %s", err.Field, err.Reason)
}
// ErrIncompleteRequest when mandatory request fields are missing (request-line or headers)
// it contains the missing field as a string
type ErrIncompleteRequest string
func (err ErrIncompleteRequest) Error() string {
return fmt.Sprintf("incomplete request, '%s' is invalid or missing", string(err))
}
// ErrInvalidOriginPolicy when a request has a violated origin policy
type ErrInvalidOriginPolicy struct {
Host string
Origin string
err error
}
func (err ErrInvalidOriginPolicy) Error() string {
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
}

View File

@ -0,0 +1,74 @@
package upgrade
import (
"bytes"
"fmt"
"strings"
)
// HeaderType represents all 'valid' HTTP request headers
type HeaderType uint8
// header types
const (
Unknown HeaderType = iota
Host
Upgrade
Connection
Origin
WSKey
WSProtocol
WSExtensions
WSVersion
)
// HeaderValue represents a unique or multiple header value(s)
type HeaderValue [][]byte
// Header represents the data of a HTTP request header
type Header struct {
Name HeaderType
Values HeaderValue
}
// ReadHeader tries to parse an HTTP header from a byte array
func ReadHeader(b []byte) (*Header, error) {
// 1. Split by ':'
parts := bytes.Split(b, []byte(": "))
if len(parts) != 2 {
return nil, fmt.Errorf("invalid HTTP header format '%s'", b)
}
// 2. Create instance
inst := &Header{}
// 3. Check for header name
switch strings.ToLower(string(parts[0])) {
case "host":
inst.Name = Host
case "upgrade":
inst.Name = Upgrade
case "connection":
inst.Name = Connection
case "origin":
inst.Name = Origin
case "sec-websocket-key":
inst.Name = WSKey
case "sec-websocket-protocol":
inst.Name = WSProtocol
case "sec-websocket-extensions":
inst.Name = WSExtensions
case "sec-websocket-version":
inst.Name = WSVersion
default:
inst.Name = Unknown
}
// 4. Split values
inst.Values = bytes.Split(parts[1], []byte(", "))
return inst, nil
}

View File

@ -1,22 +1,20 @@
package request package upgrade
import ( import (
"fmt" "fmt"
"git.xdrm.io/go/websocket/internal/http/upgrade/request/parser/header"
"git.xdrm.io/go/websocket/internal/http/upgrade/response"
"strconv" "strconv"
"strings" "strings"
) )
// checkHost checks and extracts the Host header // checkHost checks and extracts the Host header
func (r *T) extractHostPort(bb header.HeaderValue) error { func (r *Request) extractHostPort(bb HeaderValue) error {
if len(bb) != 1 { if len(bb) != 1 {
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))} return &ErrInvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
if len(bb[0]) <= 3 { if len(bb[0]) <= 3 {
return &InvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))} return &ErrInvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))}
} }
split := strings.Split(string(bb[0]), ":") split := strings.Split(string(bb[0]), ":")
@ -31,8 +29,8 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
// extract port // extract port
readPort, err := strconv.ParseUint(split[1], 10, 16) readPort, err := strconv.ParseUint(split[1], 10, 16)
if err != nil { if err != nil {
r.code = response.BAD_REQUEST r.statusCode = StatusBadRequest
return &InvalidRequest{"Host", "cannot read port"} return &ErrInvalidRequest{"Host", "cannot read port"}
} }
r.port = uint16(readPort) r.port = uint16(readPort)
@ -41,8 +39,8 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
if len(r.origin) > 0 { if len(r.origin) > 0 {
if err != nil { if err != nil {
err = r.checkOriginPolicy() err = r.checkOriginPolicy()
r.code = response.FORBIDDEN r.statusCode = StatusForbidden
return &InvalidOriginPolicy{r.host, r.origin, err} return &ErrInvalidOriginPolicy{r.host, r.origin, err}
} }
} }
@ -51,7 +49,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
} }
// checkOrigin checks the Origin Header // checkOrigin checks the Origin Header
func (r *T) extractOrigin(bb header.HeaderValue) error { func (r *Request) extractOrigin(bb HeaderValue) error {
// bypass // bypass
if bypassOriginPolicy { if bypassOriginPolicy {
@ -59,8 +57,8 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
} }
if len(bb) != 1 { if len(bb) != 1 {
r.code = response.FORBIDDEN r.statusCode = StatusForbidden
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))} return &ErrInvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
r.origin = string(bb[0]) r.origin = string(bb[0])
@ -69,8 +67,8 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
if len(r.host) > 0 { if len(r.host) > 0 {
err := r.checkOriginPolicy() err := r.checkOriginPolicy()
if err != nil { if err != nil {
r.code = response.FORBIDDEN r.statusCode = StatusForbidden
return &InvalidOriginPolicy{r.host, r.origin, err} return &ErrInvalidOriginPolicy{r.host, r.origin, err}
} }
} }
@ -79,7 +77,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
} }
// checkOriginPolicy origin policy based on 'host' value // checkOriginPolicy origin policy based on 'host' value
func (r *T) checkOriginPolicy() error { func (r *Request) checkOriginPolicy() error {
// TODO: Origin policy, for now BYPASS // TODO: Origin policy, for now BYPASS
r.validPolicy = true r.validPolicy = true
return nil return nil
@ -87,7 +85,7 @@ func (r *T) checkOriginPolicy() error {
// checkConnection checks the 'Connection' header // checkConnection checks the 'Connection' header
// it MUST contain 'Upgrade' // it MUST contain 'Upgrade'
func (r *T) checkConnection(bb header.HeaderValue) error { func (r *Request) checkConnection(bb HeaderValue) error {
for _, b := range bb { for _, b := range bb {
@ -98,18 +96,18 @@ func (r *T) checkConnection(bb header.HeaderValue) error {
} }
r.code = response.BAD_REQUEST r.statusCode = StatusBadRequest
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"} return &ErrInvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
} }
// checkUpgrade checks the 'Upgrade' header // checkUpgrade checks the 'Upgrade' header
// it MUST be 'websocket' // it MUST be 'websocket'
func (r *T) checkUpgrade(bb header.HeaderValue) error { func (r *Request) checkUpgrade(bb HeaderValue) error {
if len(bb) != 1 { if len(bb) != 1 {
r.code = response.BAD_REQUEST r.statusCode = StatusBadRequest
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
if strings.ToLower(string(bb[0])) == "websocket" { if strings.ToLower(string(bb[0])) == "websocket" {
@ -117,18 +115,18 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
return nil return nil
} }
r.code = response.BAD_REQUEST r.statusCode = StatusBadRequest
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])} return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
} }
// checkVersion checks the 'Sec-WebSocket-Version' header // checkVersion checks the 'Sec-WebSocket-Version' header
// it MUST be '13' // it MUST be '13'
func (r *T) checkVersion(bb header.HeaderValue) error { func (r *Request) checkVersion(bb HeaderValue) error {
if len(bb) != 1 || string(bb[0]) != "13" { if len(bb) != 1 || string(bb[0]) != "13" {
r.code = response.UPGRADE_REQUIRED r.statusCode = StatusUpgradeRequired
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])} return &ErrInvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
} }
r.hasVersion = true r.hasVersion = true
@ -138,11 +136,11 @@ func (r *T) checkVersion(bb header.HeaderValue) error {
// extractKey extracts the 'Sec-WebSocket-Key' header // extractKey extracts the 'Sec-WebSocket-Key' header
// it MUST be 24 bytes (base64) // it MUST be 24 bytes (base64)
func (r *T) extractKey(bb header.HeaderValue) error { func (r *Request) extractKey(bb HeaderValue) error {
if len(bb) != 1 || len(bb[0]) != 24 { if len(bb) != 1 || len(bb[0]) != 24 {
r.code = response.BAD_REQUEST r.statusCode = StatusBadRequest
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))} return &ErrInvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
} }
r.key = bb[0] r.key = bb[0]
@ -153,7 +151,7 @@ func (r *T) extractKey(bb header.HeaderValue) error {
// extractProtocols extracts the 'Sec-WebSocket-Protocol' header // extractProtocols extracts the 'Sec-WebSocket-Protocol' header
// it can contain multiple values // it can contain multiple values
func (r *T) extractProtocols(bb header.HeaderValue) error { func (r *Request) extractProtocols(bb HeaderValue) error {
r.protocols = bb r.protocols = bb

View File

@ -0,0 +1,189 @@
package upgrade
import (
"fmt"
"io"
"git.xdrm.io/go/ws/internal/http/reader"
)
// whether origin is required
const bypassOriginPolicy = true
// Request represents an HTTP Upgrade request
type Request struct {
// whether the first line has been read (GET uri HTTP/version)
first bool
statusCode StatusCode
requestLine RequestLine
// data to check origin (depends on reading order)
host string
port uint16 // 0 if not set
origin string
validPolicy bool
// websocket specific
key []byte
protocols [][]byte
// mandatory fields to check
hasConnection bool
hasUpgrade bool
hasVersion bool
}
// ReadFrom reads an upgrade HTTP request ; typically from bufio.NewRead of the
// socket
//
// implements io.ReaderFrom
func (req *Request) ReadFrom(r io.Reader) (int64, error) {
var read int64
// reset request
req.statusCode = 500
// parse header line by line
var cr = reader.NewReader(r)
for {
line, err := cr.Read()
read += int64(len(line))
if err == io.EOF {
break
}
if err != nil {
return read, err
}
err = req.parseHeader(line)
if err != nil {
return read, err
}
}
err := req.isComplete()
if err != nil {
req.statusCode = StatusBadRequest
return read, err
}
req.statusCode = StatusSwitchingProtocols
return read, nil
}
// StatusCode returns the status current
func (req Request) StatusCode() StatusCode {
return req.statusCode
}
// BuildResponse builds a response from the request
func (req *Request) BuildResponse() *Response {
res := &Response{
StatusCode: req.statusCode,
Protocol: nil,
}
if len(req.protocols) > 0 {
res.Protocol = req.protocols[0]
}
res.ProcessKey(req.key)
return res
}
// URI returns the actual URI
func (req Request) URI() string {
return req.requestLine.URI()
}
// parseHeader parses any http request line
// (header and request-line)
func (req *Request) parseHeader(b []byte) error {
// first line -> GET {uri} HTTP/{version}
if !req.first {
_, err := req.requestLine.Read(b)
if err != nil {
req.statusCode = StatusBadRequest
return &ErrInvalidRequest{"Request-Line", err.Error()}
}
req.first = true
return nil
}
// other lines -> Header-Name: Header-Value
head, err := ReadHeader(b)
if err != nil {
req.statusCode = StatusBadRequest
return fmt.Errorf("parse header: %w", err)
}
// 2. Manage header
switch head.Name {
case Host:
err = req.extractHostPort(head.Values)
case Origin:
err = req.extractOrigin(head.Values)
case Upgrade:
err = req.checkUpgrade(head.Values)
case Connection:
err = req.checkConnection(head.Values)
case WSVersion:
err = req.checkVersion(head.Values)
case WSKey:
err = req.extractKey(head.Values)
case WSProtocol:
err = req.extractProtocols(head.Values)
default:
return nil
}
// dispatch error
if err != nil {
return err
}
return nil
}
// isComplete returns whether the Upgrade Request
// is complete (no required field missing)
// returns nil on success
func (req Request) isComplete() error {
if !req.first {
return ErrIncompleteRequest("Request-Line")
}
if len(req.host) == 0 {
return ErrIncompleteRequest("Host")
}
if !bypassOriginPolicy && len(req.origin) == 0 {
return ErrIncompleteRequest("Origin")
}
if !req.hasConnection {
return ErrIncompleteRequest("Connection")
}
if !req.hasUpgrade {
return ErrIncompleteRequest("Upgrade")
}
if !req.hasVersion {
return ErrIncompleteRequest("Sec-WebSocket-Version")
}
if len(req.key) < 1 {
return ErrIncompleteRequest("Sec-WebSocket-Key")
}
return nil
}

View File

@ -1,36 +0,0 @@
package request
import (
"fmt"
)
// invalid request
// - multiple-value if only 1 expected
type InvalidRequest struct {
Field string
Reason string
}
func (err InvalidRequest) Error() string {
return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason)
}
// Request misses fields (request-line or headers)
type IncompleteRequest struct {
MissingField string
}
func (err IncompleteRequest) Error() string {
return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField)
}
// Request has a violated origin policy
type InvalidOriginPolicy struct {
Host string
Origin string
err error
}
func (err InvalidOriginPolicy) Error() string {
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
}

View File

@ -1,41 +0,0 @@
package header
import (
// "regexp"
"fmt"
"strings"
"bytes"
)
// parse tries to return a 'T' (httpHeader) from a byte array
func Parse(b []byte) (*T, error) {
/* (1) Split by ':' */
parts := bytes.Split(b, []byte(": "))
if len(parts) != 2 {
return nil, fmt.Errorf("Invalid HTTP header format '%s'", b)
}
/* (2) Create instance */
inst := new(T)
/* (3) Check for header name */
switch strings.ToLower(string(parts[0])) {
case "host": inst.Name = HOST
case "upgrade": inst.Name = UPGRADE
case "connection": inst.Name = CONNECTION
case "origin": inst.Name = ORIGIN
case "sec-websocket-key": inst.Name = WSKEY
case "sec-websocket-protocol": inst.Name = WSPROTOCOL
case "sec-websocket-extensions": inst.Name = WSEXTENSIONS
case "sec-websocket-version": inst.Name = WSVERSION
default: inst.Name = UNKNOWN
}
/* (4) Split values */
inst.Values = bytes.Split(parts[1], []byte(", "))
return inst, nil
}

View File

@ -1,25 +0,0 @@
package header
// HeaderType represents all 'valid' HTTP request headers
type HeaderType byte
const (
UNKNOWN HeaderType = iota
HOST
UPGRADE
CONNECTION
ORIGIN
WSKEY
WSPROTOCOL
WSEXTENSIONS
WSVERSION
)
// HeaderValue represents a unique or multiple header value(s)
type HeaderValue [][]byte
// T represents the data of a HTTP request header
type T struct{
Name HeaderType
Values HeaderValue
}

View File

@ -1,109 +0,0 @@
package request
import (
"fmt"
"git.xdrm.io/go/websocket/internal/http/upgrade/request/parser/header"
"git.xdrm.io/go/websocket/internal/http/upgrade/response"
)
// parseHeader parses any http request line
// (header and request-line)
func (r *T) parseHeader(b []byte) error {
/* (1) First line -> GET {uri} HTTP/{version}
---------------------------------------------------------*/
if !r.first {
err := r.request.Parse(b)
if err != nil {
r.code = response.BAD_REQUEST
return &InvalidRequest{"Request-Line", err.Error()}
}
r.first = true
return nil
}
/* (2) Other lines -> Header-Name: Header-Value
---------------------------------------------------------*/
/* (1) Try to parse header */
head, err := header.Parse(b)
if err != nil {
r.code = response.BAD_REQUEST
return fmt.Errorf("Error parsing header: %s", err)
}
/* (2) Manage header */
switch head.Name {
case header.HOST:
err = r.extractHostPort(head.Values)
case header.ORIGIN:
err = r.extractOrigin(head.Values)
case header.UPGRADE:
err = r.checkUpgrade(head.Values)
case header.CONNECTION:
err = r.checkConnection(head.Values)
case header.WSVERSION:
err = r.checkVersion(head.Values)
case header.WSKEY:
err = r.extractKey(head.Values)
case header.WSPROTOCOL:
err = r.extractProtocols(head.Values)
default:
return nil
}
// dispatch error
if err != nil {
return err
}
return nil
}
// isComplete returns whether the Upgrade Request
// is complete (no missing required item)
func (r T) isComplete() error {
/* (1) Request-Line */
if !r.first {
return &IncompleteRequest{"Request-Line"}
}
/* (2) Host */
if len(r.host) == 0 {
return &IncompleteRequest{"Host"}
}
/* (3) Origin */
if !bypassOriginPolicy && len(r.origin) == 0 {
return &IncompleteRequest{"Origin"}
}
/* (4) Connection */
if !r.hasConnection {
return &IncompleteRequest{"Connection"}
}
/* (5) Upgrade */
if !r.hasUpgrade {
return &IncompleteRequest{"Upgrade"}
}
/* (6) Sec-WebSocket-Version */
if !r.hasVersion {
return &IncompleteRequest{"Sec-WebSocket-Version"}
}
/* (7) Sec-WebSocket-Key */
if len(r.key) < 1 {
return &IncompleteRequest{"Sec-WebSocket-Key"}
}
return nil
}

View File

@ -1,84 +0,0 @@
package request
import (
"fmt"
"git.xdrm.io/go/websocket/internal/http/reader"
"git.xdrm.io/go/websocket/internal/http/upgrade/response"
"io"
)
// Parse builds an upgrade HTTP request
// from a reader (typically bufio.NewRead of the socket)
func Parse(r io.Reader) (request *T, err error) {
req := new(T)
req.code = 500
/* (1) Parse request
---------------------------------------------------------*/
/* (1) Get chunk reader */
cr := reader.NewReader(r)
if err != nil {
return req, fmt.Errorf("Error while creating chunk reader: %s", err)
}
/* (2) Parse header line by line */
for {
line, err := cr.Read()
if err == io.EOF {
break
}
if err != nil {
return req, err
}
err = req.parseHeader(line)
if err != nil {
return req, err
}
}
/* (3) Check completion */
err = req.isComplete()
if err != nil {
req.code = response.BAD_REQUEST
return req, err
}
req.code = response.SWITCHING_PROTOCOLS
return req, nil
}
// StatusCode returns the status current
func (r T) StatusCode() response.StatusCode {
return r.code
}
// BuildResponse builds a response.T from the request
func (r *T) BuildResponse() *response.T {
inst := new(response.T)
/* (1) Copy code */
inst.SetStatusCode(r.code)
/* (2) Set Protocol */
if len(r.protocols) > 0 {
inst.SetProtocol(r.protocols[0])
}
/* (4) Process key */
inst.ProcessKey(r.key)
return inst
}
// GetURI returns the actual URI
func (r T) GetURI() string {
return r.request.GetURI()
}

View File

@ -1,130 +0,0 @@
package request
import (
"bytes"
"fmt"
"regexp"
)
// httpMethod represents available http methods
type httpMethod byte
const (
OPTIONS httpMethod = iota
GET
HEAD
POST
PUT
DELETE
)
// RequestLine represents the HTTP Request line
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
type RequestLine struct {
method httpMethod
uri string
version byte
}
// parseRequestLine parses the first HTTP request line
func (r *RequestLine) Parse(b []byte) error {
/* (1) Split by ' ' */
parts := bytes.Split(b, []byte(" "))
/* (2) Fail when missing parts */
if len(parts) != 3 {
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
}
/* (3) Extract HTTP method */
err := r.extractHttpMethod(parts[0])
if err != nil {
return err
}
/* (4) Extract URI */
err = r.extractURI(parts[1])
if err != nil {
return err
}
/* (5) Extract version */
err = r.extractHttpVersion(parts[2])
if err != nil {
return err
}
return nil
}
// GetURI returns the actual URI
func (r RequestLine) GetURI() string {
return r.uri
}
// extractHttpMethod extracts the HTTP method from a []byte
// and checks for errors
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
func (r *RequestLine) extractHttpMethod(b []byte) error {
switch string(b) {
// case "OPTIONS": r.method = OPTIONS
case "GET":
r.method = GET
// case "HEAD": r.method = HEAD
// case "POST": r.method = POST
// case "PUT": r.method = PUT
// case "DELETE": r.method = DELETE
default:
return fmt.Errorf("Invalid HTTP method '%s', expected 'GET'", b)
}
return nil
}
// extractURI extracts the URI from a []byte and checks for errors
// allowed format: /([^/]/)*/?
func (r *RequestLine) extractURI(b []byte) error {
/* (1) Check format */
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
if !checker.Match(b) {
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
}
/* (2) Store */
r.uri = string(b)
return nil
}
// extractHttpVersion extracts the version and checks for errors
// allowed format: [1-9] or [1.9].[0-9]
func (r *RequestLine) extractHttpVersion(b []byte) error {
/* (1) Extract version parts */
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
if !extractor.Match(b) {
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
}
/* (2) Extract version number */
matches := extractor.FindSubmatch(b)
var version byte = matches[1][0] - '0'
/* (3) Extract subversion (if exists) */
var subVersion byte = 0
if len(matches[2]) > 0 {
subVersion = matches[2][0] - '0'
}
/* (4) Store version (x 10 to fit uint8) */
r.version = version*10 + subVersion
return nil
}

View File

@ -1,32 +0,0 @@
package request
import "git.xdrm.io/go/websocket/internal/http/upgrade/response"
// If origin is required
const bypassOriginPolicy = true
// T represents an HTTP Upgrade request
type T struct {
first bool // whether the first line has been read (GET uri HTTP/version)
// status code
code response.StatusCode
// request line
request RequestLine
// data to check origin (depends of reading order)
host string
port uint16 // 0 if not set
origin string
validPolicy bool
// ws data
key []byte
protocols [][]byte
// required fields check
hasConnection bool
hasUpgrade bool
hasVersion bool
}

View File

@ -0,0 +1,94 @@
package upgrade
import (
"bytes"
"fmt"
"regexp"
)
// RequestLine represents the HTTP Request line
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
type RequestLine struct {
uri string
version byte
}
// Read an HTTP request line from a byte array
//
// implements io.Reader
func (rl *RequestLine) Read(b []byte) (int, error) {
var read = len(b)
// split by spaces
parts := bytes.Split(b, []byte(" "))
if len(parts) != 3 {
return read, fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
}
err := rl.extractHttpMethod(parts[0])
if err != nil {
return read, err
}
err = rl.extractURI(parts[1])
if err != nil {
return read, err
}
err = rl.extractHttpVersion(parts[2])
if err != nil {
return read, err
}
return read, nil
}
// URI of the request line
func (rl RequestLine) URI() string {
return rl.uri
}
// extractHttpMethod extracts the HTTP method from a []byte
// and checks for errors
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
func (rl *RequestLine) extractHttpMethod(b []byte) error {
if string(b) != "GET" {
return fmt.Errorf("invalid HTTP method '%s', expected 'GET'", b)
}
return nil
}
// extractURI extracts the URI from a []byte and checks for errors
// allowed format: /([^/]/)*/?
func (rl *RequestLine) extractURI(b []byte) error {
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
if !checker.Match(b) {
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
}
rl.uri = string(b)
return nil
}
// extractHttpVersion extracts the version and checks for errors
// allowed format: [1-9] or [1.9].[0-9]
func (rl *RequestLine) extractHttpVersion(b []byte) error {
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
if !extractor.Match(b) {
return fmt.Errorf("invalid HTTP version, expected INT or INT.INT, got '%s'", b)
}
matches := extractor.FindSubmatch(b)
var version byte = matches[1][0] - '0'
var subversion byte = 0
if len(matches[2]) > 0 {
subversion = matches[2][0] - '0'
}
rl.version = version*10 + subversion
return nil
}

View File

@ -1,4 +1,4 @@
package request package upgrade
import ( import (
"bytes" "bytes"
@ -6,40 +6,21 @@ import (
"testing" "testing"
) )
// /* (1) Parse request */
// req, _ := request.Parse(s)
// /* (3) Build response */
// res := req.BuildResponse()
// /* (4) Write into socket */
// _, err := res.Send(s)
// if err != nil {
// return nil, fmt.Errorf("Upgrade write error: %s", err)
// }
// if res.GetStatusCode() != 101 {
// s.Close()
// return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
// }
func TestEOFSocket(t *testing.T) { func TestEOFSocket(t *testing.T) {
var (
socket = &bytes.Buffer{}
req = &Request{}
)
socket := new(bytes.Buffer) _, err := req.ReadFrom(socket)
if err != io.ErrUnexpectedEOF {
_, err := Parse(socket) t.Fatalf("unexpected error <%v> expected <%v>", err, io.ErrUnexpectedEOF)
if err == nil {
t.Fatalf("Empty socket expected EOF, got no error")
} else if err != io.ErrUnexpectedEOF {
t.Fatalf("Empty socket expected EOF, got '%s'", err)
} }
} }
func TestInvalidRequestLine(t *testing.T) { func TestInvalidRequestLine(t *testing.T) {
socket := new(bytes.Buffer) socket := &bytes.Buffer{}
cases := []struct { cases := []struct {
Reqline string Reqline string
HasError bool HasError bool
@ -76,15 +57,13 @@ func TestInvalidRequestLine(t *testing.T) {
socket.Write([]byte(tc.Reqline)) socket.Write([]byte(tc.Reqline))
socket.Write([]byte("\r\n\r\n")) socket.Write([]byte("\r\n\r\n"))
_, err := Parse(socket) var req = &Request{}
_, err := req.ReadFrom(socket)
if !tc.HasError { if !tc.HasError {
// no error -> ok
if err == nil { if err == nil {
continue continue
// error for the end of the request -> ok // error for the end of the request -> ok
} else if _, ok := err.(*IncompleteRequest); ok { } else if _, ok := err.(ErrIncompleteRequest); ok {
continue continue
} }
@ -97,7 +76,7 @@ func TestInvalidRequestLine(t *testing.T) {
continue continue
} }
ir, ok := err.(*InvalidRequest) ir, ok := err.(*ErrInvalidRequest)
// not InvalidRequest err -> error // not InvalidRequest err -> error
if !ok || ir.Field != "Request-Line" { if !ok || ir.Field != "Request-Line" {
@ -113,7 +92,7 @@ func TestInvalidHost(t *testing.T) {
requestLine := []byte("GET / HTTP/1.1\r\n") requestLine := []byte("GET / HTTP/1.1\r\n")
socket := new(bytes.Buffer) socket := &bytes.Buffer{}
cases := []struct { cases := []struct {
Host string Host string
HasError bool HasError bool
@ -148,15 +127,15 @@ func TestInvalidHost(t *testing.T) {
socket.Write([]byte(tc.Host)) socket.Write([]byte(tc.Host))
socket.Write([]byte("\r\n\r\n")) socket.Write([]byte("\r\n\r\n"))
_, err := Parse(socket) var req = &Request{}
_, err := req.ReadFrom(socket)
if !tc.HasError { if !tc.HasError {
// no error -> ok // no error -> ok
if err == nil { if err == nil {
continue continue
// error for the end of the request -> ok // error for the end of the request -> ok
} else if _, ok := err.(*IncompleteRequest); ok { } else if _, ok := err.(ErrIncompleteRequest); ok {
continue continue
} }
@ -170,7 +149,7 @@ func TestInvalidHost(t *testing.T) {
} }
// check if InvalidRequest // check if InvalidRequest
ir, ok := err.(*InvalidRequest) ir, ok := err.(ErrInvalidRequest)
// not InvalidRequest err -> error // not InvalidRequest err -> error
if ok && ir.Field != "Host" { if ok && ir.Field != "Host" {

View File

@ -0,0 +1,64 @@
package upgrade
import (
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
)
// constants
const (
httpVersion = "1.1"
wsVersion = 13
keySalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
)
// Response is an HTTP Upgrade Response
type Response struct {
StatusCode StatusCode
// Sec-WebSocket-Protocol or nil if missing
Protocol []byte
// processed from Sec-WebSocket-Key
key []byte
}
// ProcessKey processes the accept token according
// to the rfc from the Sec-WebSocket-Key
func (r *Response) ProcessKey(k []byte) {
// ignore empty key
if k == nil || len(k) < 1 {
return
}
// concat with constant salt
salted := append(k, []byte(keySalt)...)
// hash with sha1
digest := sha1.Sum(salted)
// base64 encode
r.key = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
}
// WriteTo writes the response; typically in a socket
//
// implements io.WriterTo
func (r Response) WriteTo(w io.Writer) (int64, error) {
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", httpVersion, r.StatusCode, r.StatusCode)
optionalProtocol := ""
if len(r.Protocol) > 0 {
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.Protocol)
}
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", wsVersion, optionalProtocol)
if r.key != nil {
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.key)
}
headers = fmt.Sprintf("%s\r\n", headers)
combined := []byte(fmt.Sprintf("%s%s", responseLine, headers))
written, err := w.Write(combined)
return int64(written), err
}

View File

@ -1,78 +0,0 @@
package response
import (
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
)
// SetStatusCode sets the status code
func (r *T) SetStatusCode(sc StatusCode) {
r.code = sc
}
// SetProtocols sets the protocols
func (r *T) SetProtocol(p []byte) {
r.protocol = p
}
// ProcessKey processes the accept token according
// to the rfc from the Sec-WebSocket-Key
func (r *T) ProcessKey(k []byte) {
// do nothing for empty key
if k == nil || len(k) == 0 {
r.accept = nil
return
}
/* (1) Concat with constant salt */
mix := append(k, WSSalt...)
/* (2) Hash with sha1 algorithm */
digest := sha1.Sum(mix)
/* (3) Base64 encode it */
r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
}
// Send sends the response through an io.Writer
// typically a socket
func (r T) Send(w io.Writer) (int, error) {
/* (1) Build response line */
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HttpVersion, r.code, r.code.Message())
/* (2) Build headers */
optionalProtocol := ""
if len(r.protocol) > 0 {
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol)
}
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", WSVersion, optionalProtocol)
if r.accept != nil {
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept)
}
headers = fmt.Sprintf("%s\r\n", headers)
/* (3) Build all */
raw := []byte(fmt.Sprintf("%s%s", responseLine, headers))
/* (4) Write */
written, err := w.Write(raw)
return written, err
}
// GetProtocol returns the choosen protocol if set, else nil
func (r T) GetProtocol() []byte {
return r.protocol
}
// GetStatusCode returns the response status code
func (r T) GetStatusCode() StatusCode {
return r.code
}

View File

@ -1,31 +0,0 @@
package response
// StatusCode maps the status codes (and description)
type StatusCode uint16
var SWITCHING_PROTOCOLS StatusCode = 101 // handshake success
var BAD_REQUEST StatusCode = 400 // missing/malformed headers
var FORBIDDEN StatusCode = 403 // invalid origin policy, TLS required
var UPGRADE_REQUIRED StatusCode = 426 // invalid WS version
var NOT_FOUND StatusCode = 404 // unserved or invalid URI
var INTERNAL StatusCode = 500 // custom error
func (sc StatusCode) Message() string {
switch sc {
case SWITCHING_PROTOCOLS:
return "Switching Protocols"
case BAD_REQUEST:
return "Bad Request"
case FORBIDDEN:
return "Forbidden"
case UPGRADE_REQUIRED:
return "Upgrade Required"
case NOT_FOUND:
return "Not Found"
case INTERNAL:
return "Internal Server Error"
default:
return "Unknown Status Code"
}
}

View File

@ -1,15 +0,0 @@
package response
// Constant
const HttpVersion = "1.1"
const WSVersion = 13
var WSSalt []byte = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
// T represents an HTTP Upgrade Response
type T struct {
code StatusCode // status code
accept []byte // processed from Sec-WebSocket-Key
protocol []byte // set from Sec-WebSocket-Protocol or none if not received
}

View File

@ -0,0 +1,39 @@
package upgrade
// StatusCode maps HTTP status codes (and description)
type StatusCode int
const (
// StatusSwitchingProtocols - handshake success
StatusSwitchingProtocols StatusCode = 101
// StatusBadRequest - missing/malformed headers
StatusBadRequest StatusCode = 400
// StatusForbidden - invalid origin policy, TLS required
StatusForbidden StatusCode = 403
// StatusUpgradeRequired - invalid WS version
StatusUpgradeRequired StatusCode = 426
// StatusNotFound - unserved or invalid URI
StatusNotFound StatusCode = 404
// StatusInternal - custom error
StatusInternal StatusCode = 500
)
// String implements the Stringer interface
func (sc StatusCode) String() string {
switch sc {
case StatusSwitchingProtocols:
return "Switching Protocols"
case StatusBadRequest:
return "Bad Request"
case StatusForbidden:
return "Forbidden"
case StatusUpgradeRequired:
return "Upgrade Required"
case StatusNotFound:
return "Not Found"
case StatusInternal:
return "Internal Server Error"
default:
return "Unknown Status Code"
}
}

319
internal/uri/parser.go Normal file
View File

@ -0,0 +1,319 @@
package uri
import (
"fmt"
"strings"
)
// === WILDCARDS ===
//
// The star '*' -> matches 0 or 1 slash-bounded string
// The multi star '**' -> matches 0 or more slash-separated strings
// The dot '.' -> matches 1 slash-bounded string
// The multi dot '..' -> matches 1 or more slash-separated strings
//
// === SCHEME POLICY ===
//
// - The last '/' is optional
// - Any '**' at the very end will match anything that starts with the given prefix
//
// === LIMITATIONS ==
//
// - A scheme must begin with '/'
// - A scheme cannot contain something else than a STRING or WILDCARD between 2 '/' separators
// - A scheme STRING cannot contain the symbols '/' as a character
// - A scheme STRING containing '*' or '.' characters will be treating as STRING only
// - A maximum of 16 slash-separated matchers (STRING or WILDCARD) are allowed
const maxMatch = 16
// Represents an URI matcher
type matcher struct {
pat string // pattern to match (empty if wildcard)
req bool // whether it is required
mul bool // whether multiple matches are allowed
buf []string // matched content (when matching)
}
// Scheme represents an URI scheme
type Scheme []*matcher
// FromString builds an URI scheme from a string pattern
func FromString(s string) (*Scheme, error) {
// handle '/' at the start
if len(s) < 1 || s[0] != '/' {
return nil, fmt.Errorf("invalid URI; must start with '/'")
}
parts := strings.Split(s, "/")
// check max match size
if len(parts)-2 > maxMatch {
for i, p := range parts {
fmt.Printf("%d: '%s'\n", i, p)
}
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
}
sch, err := buildScheme(parts)
if err != nil {
return nil, err
}
opti, err := sch.optimise()
if err != nil {
return nil, err
}
return &opti, nil
}
// Match returns whether the given URI is matched by the scheme
func (s Scheme) Match(uri string) bool {
if len(s) == 0 {
return true
}
// check for string match
clearURI, match := s.matchString(uri)
if !match {
return false
}
// check for non-string match (wildcards)
return s.matchWildcards(clearURI)
}
// GetMatch returns the indexed match (excluding string matchers)
func (s Scheme) GetMatch(n uint8) ([]string, error) {
if n > uint8(len(s)) {
return nil, fmt.Errorf("index out of range")
}
// iterate to find index (exclude strings)
matches := -1
for _, m := range s {
if len(m.pat) > 0 {
continue
}
matches++
// expected index -> return matches
if uint8(matches) == n {
return m.buf, nil
}
}
// nothing found -> return empty set
return nil, fmt.Errorf("index out of range (max: %d)", matches)
}
// GetAllMatch returns all the indexed match (excluding string matchers)
func (s Scheme) GetAllMatch() [][]string {
match := make([][]string, 0, len(s))
for _, m := range s {
if len(m.pat) > 0 {
continue
}
match = append(match, m.buf)
}
return match
}
// buildScheme builds a 'basic' scheme
// from a pattern string
func buildScheme(ss []string) (Scheme, error) {
sch := make(Scheme, 0, maxMatch)
for _, s := range ss {
if len(s) == 0 {
continue
}
m := &matcher{}
switch s {
// card: 0, N
case "**":
m.req = false
m.mul = true
sch = append(sch, m)
// card: 1, N
case "..":
m.req = true
m.mul = true
sch = append(sch, m)
// card: 0, 1
case "*":
m.req = false
m.mul = false
sch = append(sch, m)
// card: 1
case ".":
m.req = true
m.mul = false
sch = append(sch, m)
// card: 1, literal string
default:
m.req = true
m.mul = false
m.pat = fmt.Sprintf("/%s", s)
sch = append(sch, m)
}
}
return sch, nil
}
// optimise optimised the scheme for further parsing
func (s Scheme) optimise() (Scheme, error) {
if len(s) <= 1 {
return s, nil
}
// init reshifted scheme
rshift := make(Scheme, 0, maxMatch)
rshift = append(rshift, s[0])
// iterate over matchers
for p, i, l := 0, 1, len(s); i < l; i++ {
pre, cur := s[p], s[i]
// merge: 2 following literals
if len(pre.pat) > 0 && len(cur.pat) > 0 {
// merge strings into previous
pre.pat = fmt.Sprintf("%s%s", pre.pat, cur.pat)
// delete current
s[i] = nil
}
// increment previous (only if current is not nul)
if s[i] != nil {
rshift = append(rshift, s[i])
p = i
}
}
return rshift, nil
}
// matchString checks the STRING matchers from an URI
// - returns a boolean : false when not matching, true eitherway
// - returns a cleared uri, without STRING data
func (s Scheme) matchString(uri string) (string, bool) {
var (
clearedInput = uri
minOffset = 0
)
for _, m := range s {
ls := len(m.pat)
// ignore no STRING match
if ls == 0 {
continue
}
// get offset in URI (else -1)
off := strings.Index(clearedInput, m.pat)
if off < 0 {
return "", false
}
// fail on invalid offset range
if off < minOffset {
return "", false
}
// check for trailing '/'
hasSlash := 0
if off+ls < len(clearedInput) && clearedInput[off+ls] == '/' {
hasSlash = 1
}
// remove the current string (+trailing slash) from the URI
beg, end := clearedInput[:off], clearedInput[off+ls+hasSlash:]
clearedInput = fmt.Sprintf("%s\a/%s", beg, end) // separate matches with a '\a' character
// update offset range
// +2 slash separators
// -1 because strings begin with 1 slash already
minOffset = len(beg) + 2 - 1
}
// if exists, remove trailing '/'
if clearedInput[len(clearedInput)-1] == '/' {
clearedInput = clearedInput[:len(clearedInput)-1]
}
// if exists, remove trailing '\a'
if clearedInput[len(clearedInput)-1] == '\a' {
clearedInput = clearedInput[:len(clearedInput)-1]
}
return clearedInput, true
}
// matchWildcards check the WILCARDS (non-string) matchers from
// a cleared URI. it returns if the string matches
// + it sets the matchers buffers for later extraction
func (s Scheme) matchWildcards(clear string) bool {
// extract wildcards (ref)
wildcards := make(Scheme, 0, maxMatch)
for _, m := range s {
if len(m.pat) == 0 {
m.buf = nil // flush buffers
wildcards = append(wildcards, m)
}
}
if len(wildcards) == 0 {
return true
}
// break uri by '\a' characters
matches := strings.Split(clear, "\a")[1:]
for n, match := range matches {
// no more matcher
if n >= len(wildcards) {
return false
}
// from index 1 because it begins with '/'
data := strings.Split(match, "/")[1:]
// missing required
if wildcards[n].req && len(data) < 1 {
return false
}
// if not multi but got multi
if !wildcards[n].mul && len(data) > 1 {
return false
}
wildcards[n].buf = data
}
return true
}

View File

@ -1,216 +0,0 @@
package parser
import (
"fmt"
"strings"
)
// buildScheme builds a 'basic' scheme
// from a pattern string
func buildScheme(ss []string) (Scheme, error) {
/* (1) Build scheme */
sch := make(Scheme, 0, maxMatch)
for _, s := range ss {
/* (2) ignore empty */
if len(s) == 0 {
continue
}
m := new(matcher)
switch s {
/* (3) Card: 0, N */
case "**":
m.req = false
m.mul = true
sch = append(sch, m)
/* (4) Card: 1, N */
case "..":
m.req = true
m.mul = true
sch = append(sch, m)
/* (5) Card: 0, 1 */
case "*":
m.req = false
m.mul = false
sch = append(sch, m)
/* (6) Card: 1 */
case ".":
m.req = true
m.mul = false
sch = append(sch, m)
/* (7) Card: 1, literal string */
default:
m.req = true
m.mul = false
m.pat = fmt.Sprintf("/%s", s)
sch = append(sch, m)
}
}
return sch, nil
}
// optimise optimised the scheme for further parsing
func (s Scheme) optimise() (Scheme, error) {
/* (1) Nothing to do if only 1 element */
if len(s) <= 1 {
return s, nil
}
/* (2) Init reshifted scheme */
rshift := make(Scheme, 0, maxMatch)
rshift = append(rshift, s[0])
/* (2) Iterate over matchers */
for p, i, l := 0, 1, len(s); i < l; i++ {
pre, cur := s[p], s[i]
/* Merge: 2 following literals */
if len(pre.pat) > 0 && len(cur.pat) > 0 {
// merge strings into previous
pre.pat = fmt.Sprintf("%s%s", pre.pat, cur.pat)
// delete current
s[i] = nil
}
// increment previous (only if current is not nul)
if s[i] != nil {
rshift = append(rshift, s[i])
p = i
}
}
return rshift, nil
}
// matchString checks the STRING matchers from an URI
// it returns a boolean : false when not matching, true eitherway
// it returns a cleared uri, without STRING data
func (s Scheme) matchString(uri string) (string, bool) {
/* (1) Initialise variables */
clr := uri // contains cleared input string
minOff := 0 // minimum offset
/* (2) Iterate over strings */
for _, m := range s {
ls := len(m.pat)
// {1} If not STRING matcher -> ignore //
if ls == 0 {
continue
}
// {2} Get offset in URI (else -1) //
off := strings.Index(clr, m.pat)
if off < 0 {
return "", false
}
// {3} Fail on invalid offset range //
if off < minOff {
return "", false
}
// {4} Check for trailing '/' //
hasSlash := 0
if off+ls < len(clr) && clr[off+ls] == '/' {
hasSlash = 1
}
// {5} Remove the current string (+trailing slash) from the URI //
beg, end := clr[:off], clr[off+ls+hasSlash:]
clr = fmt.Sprintf("%s\a/%s", beg, end) // separate matches by '\a' character
// {6} Update offset range //
minOff = len(beg) + 2 - 1 // +2 slash separators
// -1 because strings begin with 1 slash already
}
/* (3) If exists, remove trailing '/' */
if clr[len(clr)-1] == '/' {
clr = clr[:len(clr)-1]
}
/* (4) If exists, remove trailing '\a' */
if clr[len(clr)-1] == '\a' {
clr = clr[:len(clr)-1]
}
return clr, true
}
// matchWildcards check the WILCARDS (non-string) matchers from
// a cleared URI. it returns if the string matches
// + it sets the matchers buffers for later extraction
func (s Scheme) matchWildcards(clear string) bool {
/* (1) Extract wildcards (ref) */
wildcards := make(Scheme, 0, maxMatch)
for _, m := range s {
if len(m.pat) == 0 {
m.buf = nil // flush buffers
wildcards = append(wildcards, m)
}
}
/* (2) If no wildcards -> match */
if len(wildcards) == 0 {
return true
}
/* (3) Break uri by '\a' characters */
matches := strings.Split(clear, "\a")[1:]
/* (4) Iterate over matches */
for n, match := range matches {
// {1} If no more matcher //
if n >= len(wildcards) {
return false
}
// {2} Split by '/' //
data := strings.Split(match, "/")[1:] // from index 1 because it begins with '/'
// {3} If required and missing //
if wildcards[n].req && len(data) < 1 {
return false
}
// {4} If not multi but got multi //
if !wildcards[n].mul && len(data) > 1 {
return false
}
// {5} Store data into matcher //
wildcards[n].buf = data
}
/* (5) Match */
return true
}

View File

@ -1,116 +0,0 @@
package parser
import (
"fmt"
"strings"
)
// Build builds an URI scheme from a pattern string
func Build(s string) (*Scheme, error) {
/* (1) Manage '/' at the start */
if len(s) < 1 || s[0] != '/' {
return nil, fmt.Errorf("URI must begin with '/'")
}
/* (2) Split by '/' */
parts := strings.Split(s, "/")
/* (3) Max exceeded */
if len(parts)-2 > maxMatch {
for i, p := range parts {
fmt.Printf("%d: '%s'\n", i, p)
}
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
}
/* (4) Build for each part */
sch, err := buildScheme(parts)
if err != nil {
return nil, err
}
/* (5) Optimise structure */
opti, err := sch.optimise()
if err != nil {
return nil, err
}
return &opti, nil
}
// Match returns if the given URI is matched by the scheme
func (s Scheme) Match(str string) bool {
/* (1) Nothing -> match all */
if len(s) == 0 {
return true
}
/* (2) Check for string match */
clearURI, match := s.matchString(str)
if !match {
return false
}
/* (3) Check for non-string match (wildcards) */
match = s.matchWildcards(clearURI)
if !match {
return false
}
return true
}
// GetMatch returns the indexed match (excluding string matchers)
func (s Scheme) GetMatch(n uint8) ([]string, error) {
/* (1) Index out of range */
if n > uint8(len(s)) {
return nil, fmt.Errorf("Index out of range")
}
/* (2) Iterate to find index (exclude strings) */
ni := -1
for _, m := range s {
// ignore strings
if len(m.pat) > 0 {
continue
}
// increment match counter : ni
ni++
// if expected index -> return matches
if uint8(ni) == n {
return m.buf, nil
}
}
/* (3) If nothing found -> return empty set */
return nil, fmt.Errorf("Index out of range (max: %d)", ni)
}
// GetAllMatch returns all the indexed match (excluding string matchers)
func (s Scheme) GetAllMatch() [][]string {
match := make([][]string, 0, len(s))
for _, m := range s {
// ignore strings
if len(m.pat) > 0 {
continue
}
match = append(match, m.buf)
}
return match
}

View File

@ -1,35 +0,0 @@
package parser
// === WILDCARDS ===
//
// The star '*' -> matches 0 or 1 slash-bounded string
// The multi star '**' -> matches 0 or more slash-separated strings
// The dot '.' -> matches 1 slash-bounded string
// The multi dot '..' -> matches 1 or more slash-separated strings
//
// === SCHEME POLICY ===
//
// - The last '/' is optional
// - Any '**' at the very end will match anything that starts with the given prefix
//
// === LIMITATIONS ==
//
// - A scheme must begin with '/'
// - A scheme cannot contain something else than a STRING or WILDCARD between 2 '/' separators
// - A scheme STRING cannot contain the symbols '/' as a character
// - A scheme STRING containing '*' or '.' characters will be treating as STRING only
// - A maximum of 16 slash-separated matchers (STRING or WILDCARD) are allowed
const maxMatch = 16
// Represents an URI matcher
type matcher struct {
pat string // pattern to match (empty if wildcard)
req bool // whether it is required
mul bool // whether multiple matches are allowed
buf []string // matched content (when matching)
}
// Represents an URI scheme
type Scheme []*matcher

View File

@ -2,52 +2,81 @@ package websocket
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"unicode/utf8" "unicode/utf8"
) )
var ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame") // constant error
var ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long") type constErr string
var ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation")
var ErrUnexpectedContinuation = fmt.Errorf("Received unexpected continuation frame") func (c constErr) Error() string { return string(c) }
var ErrInvalidSize = fmt.Errorf("Received invalid payload size")
var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload") const (
var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status") // ErrUnmaskedFrame error
var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode") ErrUnmaskedFrame constErr = "Received unmasked frame"
var ErrReservedBits = fmt.Errorf("Received reserved bits") // ErrTooLongControlFrame error
var CloseFrame = fmt.Errorf("Received close Frame") ErrTooLongControlFrame constErr = "Received a control frame that is fragmented or too long"
// ErrInvalidFragment error
ErrInvalidFragment constErr = "Received invalid fragmentation"
// ErrUnexpectedContinuation error
ErrUnexpectedContinuation constErr = "Received unexpected continuation frame"
// ErrInvalidSize error
ErrInvalidSize constErr = "Received invalid payload size"
// ErrInvalidPayload error
ErrInvalidPayload constErr = "Received invalid utf8 payload"
// ErrInvalidCloseStatus error
ErrInvalidCloseStatus constErr = "Received invalid close status"
// ErrInvalidOpCode error
ErrInvalidOpCode constErr = "Received invalid OpCode"
// ErrReservedBits error
ErrReservedBits constErr = "Received reserved bits"
// ErrCloseFrame error
ErrCloseFrame constErr = "Received close Frame"
)
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
const maximumHeaderSize = 1 + 1 + 8 + 4 const maximumHeaderSize = 1 + 1 + 8 + 4
const maxWriteChunk = 0x7fff const maxWriteChunk = 0x7fff
// Lists websocket close status // MessageError lists websocket close statuses
type MessageError uint16 type MessageError uint16
const ( const (
NONE MessageError = 0 // None used when there is no error
NORMAL MessageError = 1000 None MessageError = 0
GOING_AWAY MessageError = 1001 // Normal error
PROTOCOL_ERR MessageError = 1002 Normal MessageError = 1000
UNACCEPTABLE_OPCODE MessageError = 1003 // GoingAway error
INVALID_PAYLOAD MessageError = 1007 // utf8 GoingAway MessageError = 1001
MESSAGE_TOO_LARGE MessageError = 1009 // ProtocolError error
ProtocolError MessageError = 1002
// UnacceptableOpCode error
UnacceptableOpCode MessageError = 1003
// InvalidPayload error
InvalidPayload MessageError = 1007 // utf8
// MessageTooLarge error
MessageTooLarge MessageError = 1009
) )
// Lists websocket message types // MessageType lists websocket message types
type MessageType byte type MessageType byte
const ( const (
CONTINUATION MessageType = 0x00 // Continuation message type
TEXT MessageType = 0x01 Continuation MessageType = 0x00
BINARY MessageType = 0x02 // Text message type
CLOSE MessageType = 0x08 Text MessageType = 0x01
PING MessageType = 0x09 // Binary message type
PONG MessageType = 0x0a Binary MessageType = 0x02
// Close message type
Close MessageType = 0x08
// Ping message type
Ping MessageType = 0x09
// Pong message type
Pong MessageType = 0x0a
) )
// Represents a websocket message // Message is a websocket message
type Message struct { type Message struct {
Final bool Final bool
Type MessageType Type MessageType
@ -55,36 +84,40 @@ type Message struct {
Data []byte Data []byte
} }
// receive reads a message form reader // ReadFrom reads a message from a reader
func readMessage(reader io.Reader) (*Message, error) { //
// implements io.ReaderFrom
func (m *Message) ReadFrom(reader io.Reader) (int64, error) {
var (
read int64
err error
tmpBuf []byte
mask []byte
cursor int
)
var err error // byte 1: FIN and OpCode
var tmpBuf []byte
var mask []byte
var cursor int
m := new(Message)
/* (2) Byte 1: FIN and OpCode */
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
read += int64(len(tmpBuf))
err = readBytes(reader, tmpBuf) err = readBytes(reader, tmpBuf)
if err != nil { if err != nil {
return m, err return read, err
} }
// check reserved bits // check reserved bits
if tmpBuf[0]&0x70 != 0 { if tmpBuf[0]&0x70 != 0 {
return m, ErrReservedBits return read, ErrReservedBits
} }
m.Final = bool(tmpBuf[0]&0x80 == 0x80) m.Final = bool(tmpBuf[0]&0x80 == 0x80)
m.Type = MessageType(tmpBuf[0] & 0x0f) m.Type = MessageType(tmpBuf[0] & 0x0f)
/* (3) Byte 2: Mask and Length[0] */ // byte 2: mask and length[0]
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
read += int64(len(tmpBuf))
err = readBytes(reader, tmpBuf) err = readBytes(reader, tmpBuf)
if err != nil { if err != nil {
return m, err return read, err
} }
// if mask, byte array not nil // if mask, byte array not nil
@ -95,71 +128,63 @@ func readMessage(reader io.Reader) (*Message, error) {
// payload length // payload length
m.Size = uint(tmpBuf[0] & 0x7f) m.Size = uint(tmpBuf[0] & 0x7f)
/* (4) Extended payload */ // extended payload
if m.Size == 127 { if m.Size == 127 {
tmpBuf = make([]byte, 8) tmpBuf = make([]byte, 8)
read += int64(len(tmpBuf))
err := readBytes(reader, tmpBuf) err := readBytes(reader, tmpBuf)
if err != nil { if err != nil {
return m, err return read, err
} }
m.Size = uint(binary.BigEndian.Uint64(tmpBuf)) m.Size = uint(binary.BigEndian.Uint64(tmpBuf))
} else if m.Size == 126 { } else if m.Size == 126 {
tmpBuf = make([]byte, 2) tmpBuf = make([]byte, 2)
read += int64(len(tmpBuf))
err := readBytes(reader, tmpBuf) err := readBytes(reader, tmpBuf)
if err != nil { if err != nil {
return m, err return read, err
} }
m.Size = uint(binary.BigEndian.Uint16(tmpBuf)) m.Size = uint(binary.BigEndian.Uint16(tmpBuf))
} }
/* (5) Masking key */ // masking key
if mask != nil { if mask != nil {
tmpBuf = make([]byte, 4) tmpBuf = make([]byte, 4)
read += int64(len(tmpBuf))
err := readBytes(reader, tmpBuf) err := readBytes(reader, tmpBuf)
if err != nil { if err != nil {
return m, err return read, err
} }
mask = make([]byte, 4) mask = make([]byte, 4)
copy(mask, tmpBuf) copy(mask, tmpBuf)
} }
/* (6) Read payload by chunks */ // read payload by chunks
m.Data = make([]byte, int(m.Size)) m.Data = make([]byte, int(m.Size))
cursor = 0 cursor = 0
// {1} While we have data to read // // while data to read
for uint(cursor) < m.Size { for uint(cursor) < m.Size {
// {2} Try to read (at least 1 byte) // // try to read (at least 1 byte)
nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1) nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1)
if err != nil { if err != nil {
return m, err return read + int64(cursor) + int64(nbread), err
} }
// {3} Unmask data // // unmask data //
if mask != nil { if mask != nil {
for i, l := cursor, cursor+nbread; i < l; i++ { for i, l := cursor, cursor+nbread; i < l; i++ {
mi := i % 4 // mask index mi := i % 4 // mask index
m.Data[i] = m.Data[i] ^ mask[mi] m.Data[i] = m.Data[i] ^ mask[mi]
} }
} }
// {4} Update cursor //
cursor += nbread cursor += nbread
} }
read += int64(cursor)
// return error if unmasked frame // return error if unmasked frame
// we have to fully read it for read buffer to be clean // we have to fully read it for read buffer to be clean
@ -168,13 +193,14 @@ func readMessage(reader io.Reader) (*Message, error) {
err = ErrUnmaskedFrame err = ErrUnmaskedFrame
} }
return m, err return read, err
} }
// Send sends a frame over a socket // WriteTo writes a message frame over a socket
func (m Message) Send(writer io.Writer) error { //
// implements io.WriterTo
func (m Message) WriteTo(writer io.Writer) (int64, error) {
header := make([]byte, 0, maximumHeaderSize) header := make([]byte, 0, maximumHeaderSize)
// fix size // fix size
@ -182,20 +208,18 @@ func (m Message) Send(writer io.Writer) error {
m.Size = uint(len(m.Data)) m.Size = uint(len(m.Data))
} }
/* (1) Byte 0 : FIN + opcode */ // byte 0 : FIN + opcode
var final byte = 0x80 var final byte = 0x80
if !m.Final { if !m.Final {
final = 0 final = 0
} }
header = append(header, final|byte(m.Type)) header = append(header, final|byte(m.Type))
/* (2) Get payload length */ // get payload length
if m.Size < 126 { // simple if m.Size < 126 { // simple
header = append(header, byte(m.Size)) header = append(header, byte(m.Size))
} else if m.Size <= 0xffff { // extended: 16 bits } else if m.Size <= 0xffff { // extended: 16 bits
header = append(header, 126) header = append(header, 126)
buf := make([]byte, 2) buf := make([]byte, 2)
@ -203,7 +227,6 @@ func (m Message) Send(writer io.Writer) error {
header = append(header, buf...) header = append(header, buf...)
} else if m.Size <= 0xffffffffffffffff { // extended: 64 bits } else if m.Size <= 0xffffffffffffffff { // extended: 64 bits
header = append(header, 127) header = append(header, 127)
buf := make([]byte, 8) buf := make([]byte, 8)
@ -212,16 +235,15 @@ func (m Message) Send(writer io.Writer) error {
} }
/* (3) Build write buffer */ // build write buffer
writeBuf := make([]byte, 0, len(header)+int(m.Size)) writeBuf := make([]byte, 0, len(header)+int(m.Size))
writeBuf = append(writeBuf, header...) writeBuf = append(writeBuf, header...)
writeBuf = append(writeBuf, m.Data[0:m.Size]...) writeBuf = append(writeBuf, m.Data[0:m.Size]...)
/* (4) Send over socket by chunks */ // write by chunks
toWrite := len(header) + int(m.Size) toWrite := len(header) + int(m.Size)
cursor := 0 cursor := 0
for cursor < toWrite { for cursor < toWrite {
maxBoundary := cursor + maxWriteChunk maxBoundary := cursor + maxWriteChunk
if maxBoundary > toWrite { if maxBoundary > toWrite {
maxBoundary = toWrite maxBoundary = toWrite
@ -230,56 +252,54 @@ func (m Message) Send(writer io.Writer) error {
// Try to wrote (at max 1024 bytes) // // Try to wrote (at max 1024 bytes) //
nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary]) nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary])
if err != nil { if err != nil {
return err return int64(nbwritten), err
} }
// Update cursor // // Update cursor //
cursor += nbwritten cursor += nbwritten
}
return int64(cursor), nil
} }
return nil // check for message errors with:
}
// Check for message errors with:
// (m) the current message // (m) the current message
// (fragment) whether there is a fragment in construction // (fragment) whether there is a fragment in construction
// returns the message error // returns the message error
func (m *Message) check(fragment bool) error { func (m *Message) check(fragment bool) error {
/* (1) Invalid first fragment (not TEXT nor BINARY) */ // invalid first fragment (not TEXT nor BINARY)
if !m.Final && !fragment && m.Type != TEXT && m.Type != BINARY { if !m.Final && !fragment && m.Type != Text && m.Type != Binary {
return ErrInvalidFragment return ErrInvalidFragment
} }
/* (2) Waiting fragment but received standalone frame */ // waiting fragment but received standalone frame
if fragment && m.Type != CONTINUATION && m.Type != CLOSE && m.Type != PING && m.Type != PONG { if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong {
return ErrInvalidFragment return ErrInvalidFragment
} }
/* (3) Control frame too long */ // control frame too long
if (m.Type == CLOSE || m.Type == PING || m.Type == PONG) && (m.Size > 125 || !m.Final) { if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) {
return ErrTooLongControlFrame return ErrTooLongControlFrame
} }
switch m.Type { switch m.Type {
case CONTINUATION: case Continuation:
// unexpected continuation // unexpected continuation
if !fragment { if !fragment {
return ErrUnexpectedContinuation return ErrUnexpectedContinuation
} }
return nil return nil
case TEXT: case Text:
if m.Final && !utf8.Valid(m.Data) { if m.Final && !utf8.Valid(m.Data) {
return ErrInvalidPayload return ErrInvalidPayload
} }
return nil return nil
case BINARY: case Binary:
return nil return nil
case CLOSE: case Close:
// incomplete code // incomplete code
if m.Size == 1 { if m.Size == 1 {
return ErrInvalidCloseStatus return ErrInvalidCloseStatus
@ -298,20 +318,18 @@ func (m *Message) check(fragment bool) error {
return ErrInvalidCloseStatus return ErrInvalidCloseStatus
} }
} }
return CloseFrame return ErrCloseFrame
case PING: case Ping:
return nil return nil
case PONG: case Pong:
return nil return nil
default: default:
return ErrInvalidOpCode return ErrInvalidOpCode
} }
return nil
} }
// readBytes reads from a reader into a byte array // readBytes reads from a reader into a byte array
@ -320,20 +338,19 @@ func (m *Message) check(fragment bool) error {
// //
// It manages connections which chunks data // It manages connections which chunks data
func readBytes(reader io.Reader, buffer []byte) error { func readBytes(reader io.Reader, buffer []byte) error {
var (
cur = 0
len = len(buffer)
)
var cur, len int = 0, len(buffer) // read until the full size is read
// try to read until the full size is read
for cur < len { for cur < len {
nbread, err := reader.Read(buffer[cur:]) nbread, err := reader.Read(buffer[cur:])
if err != nil { if err != nil {
return err return err
} }
cur += nbread cur += nbread
} }
return nil return nil
} }

View File

@ -41,25 +41,25 @@ func TestSimpleMessageReading(t *testing.T) {
{ // FIN ; TEXT ; hello { // FIN ; TEXT ; hello
"simple hello text message", "simple hello text message",
[]byte{0x81, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f}, []byte{0x81, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
Message{true, TEXT, 5, []byte("hello")}, Message{true, Text, 5, []byte("hello")},
nil, nil,
}, },
{ // FIN ; BINARY ; hello { // FIN ; BINARY ; hello
"simple hello binary message", "simple hello binary message",
[]byte{0x82, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f}, []byte{0x82, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
Message{true, BINARY, 5, []byte("hello")}, Message{true, Binary, 5, []byte("hello")},
nil, nil,
}, },
{ // FIN ; BINARY ; test unmasking { // FIN ; BINARY ; test unmasking
"unmasking test", "unmasking test",
[]byte{0x82, 0x88, 0x01, 0x02, 0x03, 0x04, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80}, []byte{0x82, 0x88, 0x01, 0x02, 0x03, 0x04, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80},
Message{true, BINARY, 8, []byte{0x11, 0x22, 0x33, 0x44, 0x51, 0x62, 0x73, 0x84}}, Message{true, Binary, 8, []byte{0x11, 0x22, 0x33, 0x44, 0x51, 0x62, 0x73, 0x84}},
nil, nil,
}, },
{ // FIN=0 ; TEXT ; { // FIN=0 ; TEXT ;
"non final frame", "non final frame",
[]byte{0x01, 0x82, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02}, []byte{0x01, 0x82, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02},
Message{false, TEXT, 2, []byte{0x01, 0x02}}, Message{false, Text, 2, []byte{0x01, 0x02}},
nil, nil,
}, },
} }
@ -67,11 +67,12 @@ func TestSimpleMessageReading(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
var (
reader = bytes.NewBuffer(tc.ReadBuffer)
msg = &Message{}
)
reader := bytes.NewBuffer(tc.ReadBuffer) _, err := msg.ReadFrom(reader)
got, err := readMessage(reader)
if err != tc.Err { if err != tc.Err {
t.Errorf("Expected %v error, got %v", tc.Err, err) t.Errorf("Expected %v error, got %v", tc.Err, err)
} }
@ -82,23 +83,23 @@ func TestSimpleMessageReading(t *testing.T) {
} }
// check FIN // check FIN
if got.Final != tc.Expected.Final { if msg.Final != tc.Expected.Final {
t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, got.Final) t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, msg.Final)
} }
// check OpCode // check OpCode
if got.Type != tc.Expected.Type { if msg.Type != tc.Expected.Type {
t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, got.Type) t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, msg.Type)
} }
// check Size // check Size
if got.Size != tc.Expected.Size { if msg.Size != tc.Expected.Size {
t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, got.Size) t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, msg.Size)
} }
// check Data // check Data
if string(got.Data) != string(tc.Expected.Data) { if string(msg.Data) != string(tc.Expected.Data) {
t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, got.Data) t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, msg.Data)
} }
}) })
@ -177,17 +178,15 @@ func TestReadEOF(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
var (
reader := bytes.NewBuffer(tc.ReadBuffer) reader = bytes.NewBuffer(tc.ReadBuffer)
msg = &Message{}
got, err := readMessage(reader) )
_, err := msg.ReadFrom(reader)
if tc.eof { if tc.eof {
if err != io.EOF { if err != io.EOF {
t.Errorf("Expected EOF, got %v", err) t.Fatalf("Expected EOF, got %v", err)
} }
return return
} }
@ -195,8 +194,8 @@ func TestReadEOF(t *testing.T) {
t.Errorf("Expected UnmaskedFrameor, got %v", err) t.Errorf("Expected UnmaskedFrameor, got %v", err)
} }
if got.Size != 0x00 { if msg.Size != 0x00 {
t.Errorf("Expected a size of 0, got %d", got.Size) t.Errorf("Expected a size of 0, got %d", msg.Size)
} }
}) })
@ -222,43 +221,43 @@ func TestSimpleMessageSending(t *testing.T) {
}{ }{
{ {
"simple hello text message", "simple hello text message",
Message{true, TEXT, 5, []byte("hello")}, Message{true, Text, 5, []byte("hello")},
[]byte{0x81, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f}, []byte{0x81, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
}, { }, {
"simple hello binary message", "simple hello binary message",
Message{true, BINARY, 5, []byte("hello")}, Message{true, Binary, 5, []byte("hello")},
[]byte{0x82, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f}, []byte{0x82, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
}, { }, {
"other simple binary message", "other simple binary message",
Message{true, BINARY, 8, []byte{0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80}}, Message{true, Binary, 8, []byte{0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80}},
[]byte{0x82, 0x08, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80}, []byte{0x82, 0x08, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80},
}, { }, {
"non final frame", "non final frame",
Message{false, TEXT, 2, []byte{0x01, 0x02}}, Message{false, Text, 2, []byte{0x01, 0x02}},
[]byte{0x01, 0x02, 0x01, 0x02}, []byte{0x01, 0x02, 0x01, 0x02},
}, { }, {
"125 > normal length", "125 > normal length",
Message{true, TEXT, uint(len(m4b1)), m4b1}, Message{true, Text, uint(len(m4b1)), m4b1},
append([]byte{0x81, 0x7e - 1}, m4b1...), append([]byte{0x81, 0x7e - 1}, m4b1...),
}, { }, {
"126 > extended 16 bits length", "126 > extended 16 bits length",
Message{true, TEXT, uint(len(m4b2)), m4b2}, Message{true, Text, uint(len(m4b2)), m4b2},
append([]byte{0x81, 126, 0x00, 0x7e}, m4b2...), append([]byte{0x81, 126, 0x00, 0x7e}, m4b2...),
}, { }, {
"127 > extended 16 bits length", "127 > extended 16 bits length",
Message{true, TEXT, uint(len(m4b3)), m4b3}, Message{true, Text, uint(len(m4b3)), m4b3},
append([]byte{0x81, 126, 0x00, 0x7e + 1}, m4b3...), append([]byte{0x81, 126, 0x00, 0x7e + 1}, m4b3...),
}, { }, {
"fffe > extended 16 bits length", "fffe > extended 16 bits length",
Message{true, TEXT, uint(len(m16b1)), m16b1}, Message{true, Text, uint(len(m16b1)), m16b1},
append([]byte{0x81, 126, 0xff, 0xfe}, m16b1...), append([]byte{0x81, 126, 0xff, 0xfe}, m16b1...),
}, { }, {
"ffff > extended 16 bits length", "ffff > extended 16 bits length",
Message{true, TEXT, uint(len(m16b2)), m16b2}, Message{true, Text, uint(len(m16b2)), m16b2},
append([]byte{0x81, 126, 0xff, 0xff}, m16b2...), append([]byte{0x81, 126, 0xff, 0xff}, m16b2...),
}, { }, {
"10000 > extended 64 bits length", "10000 > extended 64 bits length",
Message{true, TEXT, uint(len(m16b3)), m16b3}, Message{true, Text, uint(len(m16b3)), m16b3},
append([]byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, m16b3...), append([]byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, m16b3...),
}, },
} }
@ -267,10 +266,9 @@ func TestSimpleMessageSending(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
writer := new(bytes.Buffer) writer := &bytes.Buffer{}
err := tc.Base.Send(writer)
_, err := tc.Base.WriteTo(writer)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v", err) t.Errorf("expected no error, got %v", err)
return return
@ -305,22 +303,22 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CONTINUATION must fail", "CONTINUATION must fail",
Message{false, CONTINUATION, 0, []byte{}}, false, ErrInvalidFragment, Message{false, Continuation, 0, []byte{}}, false, ErrInvalidFragment,
}, { }, {
"TEXT must not fail", "TEXT must not fail",
Message{false, TEXT, 0, []byte{}}, false, nil, Message{false, Text, 0, []byte{}}, false, nil,
}, { }, {
"BINARY must not fail", "BINARY must not fail",
Message{false, BINARY, 0, []byte{}}, false, nil, Message{false, Binary, 0, []byte{}}, false, nil,
}, { }, {
"CLOSE must fail", "CLOSE must fail",
Message{false, CLOSE, 0, []byte{}}, false, ErrInvalidFragment, Message{false, Close, 0, []byte{}}, false, ErrInvalidFragment,
}, { }, {
"PING must fail", "PING must fail",
Message{false, PING, 0, []byte{}}, false, ErrInvalidFragment, Message{false, Ping, 0, []byte{}}, false, ErrInvalidFragment,
}, { }, {
"PONG must fail", "PONG must fail",
Message{false, PONG, 0, []byte{}}, false, ErrInvalidFragment, Message{false, Pong, 0, []byte{}}, false, ErrInvalidFragment,
}, },
}, },
}, { }, {
@ -328,22 +326,22 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CONTINUATION must not fail", "CONTINUATION must not fail",
Message{true, CONTINUATION, 0, []byte{}}, true, nil, Message{true, Continuation, 0, []byte{}}, true, nil,
}, { }, {
"TEXT must fail", "TEXT must fail",
Message{true, TEXT, 0, []byte{}}, true, ErrInvalidFragment, Message{true, Text, 0, []byte{}}, true, ErrInvalidFragment,
}, { }, {
"BINARY must fail", "BINARY must fail",
Message{true, BINARY, 0, []byte{}}, true, ErrInvalidFragment, Message{true, Binary, 0, []byte{}}, true, ErrInvalidFragment,
}, { }, {
"CLOSE must not fail", "CLOSE must not fail",
Message{true, CLOSE, 0, []byte{}}, true, CloseFrame, Message{true, Close, 0, []byte{}}, true, ErrCloseFrame,
}, { }, {
"PING must not fail", "PING must not fail",
Message{true, PING, 0, []byte{}}, true, nil, Message{true, Ping, 0, []byte{}}, true, nil,
}, { }, {
"PONG must not fail", "PONG must not fail",
Message{true, PONG, 0, []byte{}}, true, nil, Message{true, Pong, 0, []byte{}}, true, nil,
}, },
}, },
}, { }, {
@ -351,13 +349,13 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE must not fail", "CLOSE must not fail",
Message{true, CLOSE, 125, []byte{0x03, 0xe8, 0}}, false, CloseFrame, Message{true, Close, 125, []byte{0x03, 0xe8, 0}}, false, ErrCloseFrame,
}, { }, {
"PING must not fail", "PING must not fail",
Message{true, PING, 125, []byte{}}, false, nil, Message{true, Ping, 125, []byte{}}, false, nil,
}, { }, {
"PONG must not fail", "PONG must not fail",
Message{true, PONG, 125, []byte{}}, false, nil, Message{true, Pong, 125, []byte{}}, false, nil,
}, },
}, },
}, { }, {
@ -365,13 +363,13 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE must fail", "CLOSE must fail",
Message{true, CLOSE, 126, []byte{0x03, 0xe8, 0}}, false, ErrTooLongControlFrame, Message{true, Close, 126, []byte{0x03, 0xe8, 0}}, false, ErrTooLongControlFrame,
}, { }, {
"PING must fail", "PING must fail",
Message{true, PING, 126, []byte{}}, false, ErrTooLongControlFrame, Message{true, Ping, 126, []byte{}}, false, ErrTooLongControlFrame,
}, { }, {
"PONG must fail", "PONG must fail",
Message{true, PONG, 126, []byte{}}, false, ErrTooLongControlFrame, Message{true, Pong, 126, []byte{}}, false, ErrTooLongControlFrame,
}, },
}, },
}, { }, {
@ -379,13 +377,13 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE must fail", "CLOSE must fail",
Message{false, CLOSE, 126, []byte{0x03, 0xe8, 0}}, false, ErrInvalidFragment, Message{false, Close, 126, []byte{0x03, 0xe8, 0}}, false, ErrInvalidFragment,
}, { }, {
"PING must fail", "PING must fail",
Message{false, PING, 126, []byte{}}, false, ErrInvalidFragment, Message{false, Ping, 126, []byte{}}, false, ErrInvalidFragment,
}, { }, {
"PONG must fail", "PONG must fail",
Message{false, PONG, 126, []byte{}}, false, ErrInvalidFragment, Message{false, Pong, 126, []byte{}}, false, ErrInvalidFragment,
}, },
}, },
}, { }, {
@ -393,10 +391,10 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"no waiting fragment final", "no waiting fragment final",
Message{false, CONTINUATION, 126, nil}, false, ErrInvalidFragment, Message{false, Continuation, 126, nil}, false, ErrInvalidFragment,
}, { }, {
"no waiting fragment non-final", "no waiting fragment non-final",
Message{true, CONTINUATION, 126, nil}, false, ErrUnexpectedContinuation, Message{true, Continuation, 126, nil}, false, ErrUnexpectedContinuation,
}, },
}, },
}, { }, {
@ -404,23 +402,23 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE valid reason", "CLOSE valid reason",
Message{true, CLOSE, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0xa1}}, false, CloseFrame, Message{true, Close, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0xa1}}, false, ErrCloseFrame,
}, { }, {
"CLOSE invalid reason byte 2", "CLOSE invalid reason byte 2",
Message{true, CLOSE, 5, []byte{0x03, 0xe8, 0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload, Message{true, Close, 5, []byte{0x03, 0xe8, 0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload,
}, { }, {
"CLOSE invalid reason byte 3", "CLOSE invalid reason byte 3",
Message{true, CLOSE, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0x28}}, false, ErrInvalidPayload, Message{true, Close, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0x28}}, false, ErrInvalidPayload,
}, },
{ {
"TEXT valid reason", "TEXT valid reason",
Message{true, TEXT, 3, []byte{0xe2, 0x82, 0xa1}}, false, nil, Message{true, Text, 3, []byte{0xe2, 0x82, 0xa1}}, false, nil,
}, { }, {
"TEXT invalid reason byte 2", "TEXT invalid reason byte 2",
Message{true, TEXT, 3, []byte{0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload, Message{true, Text, 3, []byte{0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload,
}, { }, {
"TEXT invalid reason byte 3", "TEXT invalid reason byte 3",
Message{true, TEXT, 3, []byte{0xe2, 0x82, 0x28}}, false, ErrInvalidPayload, Message{true, Text, 3, []byte{0xe2, 0x82, 0x28}}, false, ErrInvalidPayload,
}, },
}, },
}, { }, {
@ -428,82 +426,82 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE only 1 byte", "CLOSE only 1 byte",
Message{true, CLOSE, 1, []byte{0x03}}, false, ErrInvalidCloseStatus, Message{true, Close, 1, []byte{0x03}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1000", "valid CLOSE status 1000",
Message{true, CLOSE, 2, []byte{0x03, 0xe8}}, false, CloseFrame, Message{true, Close, 2, []byte{0x03, 0xe8}}, false, ErrCloseFrame,
}, { }, {
"invalid CLOSE status 999 under 1000", "invalid CLOSE status 999 under 1000",
Message{true, CLOSE, 2, []byte{0x03, 0xe7}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xe7}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1001", "valid CLOSE status 1001",
Message{true, CLOSE, 2, []byte{0x03, 0xe9}}, false, CloseFrame, Message{true, Close, 2, []byte{0x03, 0xe9}}, false, ErrCloseFrame,
}, { }, {
"valid CLOSE status 1002", "valid CLOSE status 1002",
Message{true, CLOSE, 2, []byte{0x03, 0xea}}, false, CloseFrame, Message{true, Close, 2, []byte{0x03, 0xea}}, false, ErrCloseFrame,
}, { }, {
"valid CLOSE status 1003", "valid CLOSE status 1003",
Message{true, CLOSE, 2, []byte{0x03, 0xeb}}, false, CloseFrame, Message{true, Close, 2, []byte{0x03, 0xeb}}, false, ErrCloseFrame,
}, { }, {
"invalid CLOSE status 1004", "invalid CLOSE status 1004",
Message{true, CLOSE, 2, []byte{0x03, 0xec}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xec}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1005", "invalid CLOSE status 1005",
Message{true, CLOSE, 2, []byte{0x03, 0xed}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xed}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1006", "invalid CLOSE status 1006",
Message{true, CLOSE, 2, []byte{0x03, 0xee}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xee}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1007", "valid CLOSE status 1007",
Message{true, CLOSE, 2, []byte{0x03, 0xef}}, false, CloseFrame, Message{true, Close, 2, []byte{0x03, 0xef}}, false, ErrCloseFrame,
}, { }, {
"valid CLOSE status 1011", "valid CLOSE status 1011",
Message{true, CLOSE, 2, []byte{0x03, 0xf3}}, false, CloseFrame, Message{true, Close, 2, []byte{0x03, 0xf3}}, false, ErrCloseFrame,
}, { }, {
"invalid CLOSE status 1012", "invalid CLOSE status 1012",
Message{true, CLOSE, 2, []byte{0x03, 0xf4}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xf4}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1013", "invalid CLOSE status 1013",
Message{true, CLOSE, 2, []byte{0x03, 0xf5}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xf5}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1014", "invalid CLOSE status 1014",
Message{true, CLOSE, 2, []byte{0x03, 0xf6}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xf6}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1015", "invalid CLOSE status 1015",
Message{true, CLOSE, 2, []byte{0x03, 0xf7}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xf7}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1016", "invalid CLOSE status 1016",
Message{true, CLOSE, 2, []byte{0x03, 0xf8}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x03, 0xf8}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1017", "valid CLOSE status 1017",
Message{true, CLOSE, 2, []byte{0x03, 0xf9}}, false, CloseFrame, Message{true, Close, 2, []byte{0x03, 0xf9}}, false, ErrCloseFrame,
}, { }, {
"valid CLOSE status 1099", "valid CLOSE status 1099",
Message{true, CLOSE, 2, []byte{0x04, 0x4b}}, false, CloseFrame, Message{true, Close, 2, []byte{0x04, 0x4b}}, false, ErrCloseFrame,
}, { }, {
"invalid CLOSE status 1100", "invalid CLOSE status 1100",
Message{true, CLOSE, 2, []byte{0x04, 0x4c}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x04, 0x4c}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1101", "valid CLOSE status 1101",
Message{true, CLOSE, 2, []byte{0x04, 0x4d}}, false, CloseFrame, Message{true, Close, 2, []byte{0x04, 0x4d}}, false, ErrCloseFrame,
}, { }, {
"valid CLOSE status 1999", "valid CLOSE status 1999",
Message{true, CLOSE, 2, []byte{0x07, 0xcf}}, false, CloseFrame, Message{true, Close, 2, []byte{0x07, 0xcf}}, false, ErrCloseFrame,
}, { }, {
"invalid CLOSE status 2000", "invalid CLOSE status 2000",
Message{true, CLOSE, 2, []byte{0x07, 0xd0}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x07, 0xd0}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 2001", "valid CLOSE status 2001",
Message{true, CLOSE, 2, []byte{0x07, 0xd1}}, false, CloseFrame, Message{true, Close, 2, []byte{0x07, 0xd1}}, false, ErrCloseFrame,
}, { }, {
"valid CLOSE status 2998", "valid CLOSE status 2998",
Message{true, CLOSE, 2, []byte{0x0b, 0xb6}}, false, CloseFrame, Message{true, Close, 2, []byte{0x0b, 0xb6}}, false, ErrCloseFrame,
}, { }, {
"invalid CLOSE status 2999", "invalid CLOSE status 2999",
Message{true, CLOSE, 2, []byte{0x0b, 0xb7}}, false, ErrInvalidCloseStatus, Message{true, Close, 2, []byte{0x0b, 0xb7}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 3000", "valid CLOSE status 3000",
Message{true, CLOSE, 2, []byte{0x0b, 0xb8}}, false, CloseFrame, Message{true, Close, 2, []byte{0x0b, 0xb8}}, false, ErrCloseFrame,
}, },
}, },
}, { }, {
@ -517,7 +515,7 @@ func TestMessageCheck(t *testing.T) {
{"5", Message{true, 5, 0, []byte{}}, false, ErrInvalidOpCode}, {"5", Message{true, 5, 0, []byte{}}, false, ErrInvalidOpCode},
{"6", Message{true, 6, 0, []byte{}}, false, ErrInvalidOpCode}, {"6", Message{true, 6, 0, []byte{}}, false, ErrInvalidOpCode},
{"7", Message{true, 7, 0, []byte{}}, false, ErrInvalidOpCode}, {"7", Message{true, 7, 0, []byte{}}, false, ErrInvalidOpCode},
{"8", Message{true, 8, 0, []byte{}}, false, CloseFrame}, {"8", Message{true, 8, 0, []byte{}}, false, ErrCloseFrame},
{"9", Message{true, 9, 0, []byte{}}, false, nil}, {"9", Message{true, 9, 0, []byte{}}, false, nil},
{"10", Message{true, 10, 0, []byte{}}, false, nil}, {"10", Message{true, 10, 0, []byte{}}, false, nil},
{"11", Message{true, 11, 0, []byte{}}, false, ErrInvalidOpCode}, {"11", Message{true, 11, 0, []byte{}}, false, ErrInvalidOpCode},

View File

@ -2,18 +2,19 @@ package websocket
import ( import (
"fmt" "fmt"
"git.xdrm.io/go/websocket/internal/uri/parser"
"net" "net"
"git.xdrm.io/go/ws/internal/uri"
) )
// Represents all channels that need a server // All channels that a server features
type serverChannelSet struct { type serverChannelSet struct {
register chan *client register chan *client
unregister chan *client unregister chan *client
broadcast chan Message broadcast chan Message
} }
// Represents a websocket server // Server is a websocket server
type Server struct { type Server struct {
sock net.Listener // listen socket sock net.Listener // listen socket
addr []byte // server listening ip/host addr []byte // server listening ip/host
@ -26,9 +27,8 @@ type Server struct {
ch serverChannelSet ch serverChannelSet
} }
// CreateServer creates a server for a specific HOST and PORT // NewServer creates a server
func CreateServer(host string, port uint16) *Server { func NewServer(host string, port uint16) *Server {
return &Server{ return &Server{
addr: []byte(host), addr: []byte(host),
port: port, port: port,
@ -37,7 +37,7 @@ func CreateServer(host string, port uint16) *Server {
ctl: ControllerSet{ ctl: ControllerSet{
Def: nil, Def: nil,
Uri: make([]*Controller, 0), URI: make([]*Controller, 0),
}, },
ch: serverChannelSet{ ch: serverChannelSet{
@ -46,125 +46,84 @@ func CreateServer(host string, port uint16) *Server {
broadcast: make(chan Message, 1), broadcast: make(chan Message, 1),
}, },
} }
} }
// BindDefault binds a default controller // BindDefault binds a default controller
// it will be called if the URI does not // it will be called if the URI does not
// match another controller // match another controller
func (s *Server) BindDefault(f ControllerFunc) { func (s *Server) BindDefault(f ControllerFunc) {
s.ctl.Def = &Controller{ s.ctl.Def = &Controller{
URI: nil, URI: nil,
Fun: f, Fun: f,
} }
} }
// Bind binds a controller to an URI scheme // Bind a controller to an URI scheme
func (s *Server) Bind(uri string, f ControllerFunc) error { func (s *Server) Bind(uriStr string, f ControllerFunc) error {
uriScheme, err := uri.FromString(uriStr)
/* (1) Build URI parser */
uriScheme, err := parser.Build(uri)
if err != nil { if err != nil {
return fmt.Errorf("Cannot build URI: %s", err) return fmt.Errorf("cannot build URI: %w", err)
} }
/* (2) Create controller */ s.ctl.URI = append(s.ctl.URI, &Controller{
s.ctl.Uri = append(s.ctl.Uri, &Controller{
URI: uriScheme, URI: uriScheme,
Fun: f, Fun: f,
}) })
return nil return nil
} }
// Launch launches the websocket server // Launch the websocket server
func (s *Server) Launch() error { func (s *Server) Launch() error {
var (
err error
url = fmt.Sprintf("%s:%d", s.addr, s.port)
)
var err error
/* (1) Listen socket
---------------------------------------------------------*/
/* (1) Build full url */
url := fmt.Sprintf("%s:%d", s.addr, s.port)
/* (2) Bind socket to listen */
s.sock, err = net.Listen("tcp", url) s.sock, err = net.Listen("tcp", url)
if err != nil { if err != nil {
return fmt.Errorf("Listen socket: %s", err) return fmt.Errorf("listen: %w", err)
} }
defer s.sock.Close() defer s.sock.Close()
fmt.Printf("+ listening on %s\n", url) fmt.Printf("+ listening on %s\n", url)
go s.schedule()
/* (3) Launch scheduler */
go s.scheduler()
/* (2) For each incoming connection (client)
---------------------------------------------------------*/
for { for {
/* (1) Wait for client */
sock, err := s.sock.Accept() sock, err := s.sock.Accept()
if err != nil { if err != nil {
break break
} }
go func() { go func() {
cli, err := newClient(sock, s.ctl, s.ch)
/* (2) Try to create client */
cli, err := buildClient(sock, s.ctl, s.ch)
if err != nil { if err != nil {
fmt.Printf(" - %s\n", err) fmt.Printf(" - %s\n", err)
return return
} }
/* (3) Register client */
s.ch.register <- cli s.ch.register <- cli
}() }()
} }
return nil return nil
} }
// Scheduler schedules clients registration and broadcast // schedule client registration and broadcast
func (s *Server) scheduler() { func (s *Server) schedule() {
for { for {
select { select {
/* (1) New client */
case client := <-s.ch.register: case client := <-s.ch.register:
// fmt.Printf(" + client\n")
s.clients[client.io.sock] = client s.clients[client.io.sock] = client
/* (2) New client */
case client := <-s.ch.unregister: case client := <-s.ch.unregister:
// fmt.Printf(" - client\n")
delete(s.clients, client.io.sock) delete(s.clients, client.io.sock)
/* (3) Broadcast */
case msg := <-s.ch.broadcast: case msg := <-s.ch.broadcast:
fmt.Printf(" + broadcast\n")
for _, c := range s.clients { for _, c := range s.clients {
c.ch.send <- msg c.ch.send <- msg
} }
} }
} }
fmt.Printf("+ server stopped\n")
} }