diff --git a/client.go b/client.go index 1443557..034da9d 100644 --- a/client.go +++ b/client.go @@ -36,36 +36,27 @@ type client struct { status MessageError // close status ; 0 = nothing ; else -> must close } -// Create creates a new client -func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) { +// newClient creates a new client +func newClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) { + req := &upgrade.Request{} + _, err := req.ReadFrom(s) + if err != nil { + return nil, fmt.Errorf("request read: %w", err) + } - /* (1) Manage UPGRADE request - ---------------------------------------------------------*/ - // 1. Parse request - req, _ := upgrade.Parse(s) - - // 3. Build response res := req.BuildResponse() - // 4. Write into socket - _, err := res.Send(s) + _, err = res.WriteTo(s) if err != nil { return nil, fmt.Errorf("upgrade write: %w", err) } - if res.GetStatusCode() != 101 { + if res.StatusCode != 101 { s.Close() - return nil, fmt.Errorf("upgrade failed (HTTP %d)", res.GetStatusCode()) + return nil, fmt.Errorf("upgrade failed (HTTP %d)", res.StatusCode) } - /* (2) Initialise client - ---------------------------------------------------------*/ - // 1. Get upgrade data - clientURI := req.GetURI() - clientProtocol := res.GetProtocol() - - // 2. Initialise client - cli := &client{ + var cli = &client{ io: clientIO{ sock: s, reader: bufio.NewReader(s), @@ -73,8 +64,8 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli }, iface: &Client{ - Protocol: string(clientProtocol), - Arguments: [][]string{{clientURI}}, + Protocol: string(res.Protocol), + Arguments: [][]string{{req.URI()}}, }, ch: clientChannelSet{ @@ -83,59 +74,46 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli }, } - /* (3) Find controller by URI - ---------------------------------------------------------*/ - // 1. Try to find one - controller, arguments := ctl.Match(clientURI) - - // 2. If nothing found -> error + // find controller by URI + controller, arguments := ctl.Match(req.URI()) if controller == nil { - return nil, fmt.Errorf("No controller found, no default controller set") + return nil, fmt.Errorf("no controller found, no default controller set") } - // 3. Copy arguments + // copy args cli.iface.Arguments = arguments - /* (4) Launch client routines - ---------------------------------------------------------*/ - // 1. Launch client controller go controller.Fun( cli.iface, // pass the client cli.ch.receive, // the receiver cli.ch.send, // the sender serverCh.broadcast, // broadcast sender ) - - // 2. Launch message reader go clientReader(cli) - - // 3. Launc writer go clientWriter(cli) - return cli, nil - } -// reader reads and parses messages from the buffer +// clientReader reads and parses messages from the buffer func clientReader(c *client) { - var frag *Message - - closeStatus := Normal - clientAck := true + var ( + frag *Message + closeStatus = Normal + clientAck = true + ) c.io.reading.Add(1) for { - - // 1. if currently closing -> exit + // currently closing -> exit if c.io.closing { fmt.Printf("[reader] killed because closing") break } - // 2. Parse message - msg, err := readMessage(c.io.reader) - + // Parse message + var msg = &Message{} + _, err := msg.ReadFrom(c.io.reader) if err == ErrUnmaskedFrame || err == ErrReservedBits { closeStatus = ProtocolError } @@ -143,7 +121,7 @@ func clientReader(c *client) { break } - // 3. Fail on invalid message + // invalid message msgErr := msg.check(frag != nil) if msgErr != nil { @@ -151,7 +129,7 @@ func clientReader(c *client) { switch msgErr { - // Fail + // fail case ErrUnexpectedContinuation: closeStatus = None clientAck = false @@ -182,7 +160,7 @@ func clientReader(c *client) { } - // 4. Ping <-> Pong + // ping <-> Pong if msg.Type == Ping && c.io.writing { msg.Final = true msg.Type = Pong @@ -190,7 +168,7 @@ func clientReader(c *client) { continue } - // 5. Store first fragment + // store first fragment if frag == nil && !msg.Final { frag = &Message{ Type: msg.Type, @@ -201,7 +179,7 @@ func clientReader(c *client) { continue } - // 6. Store fragments + // store fragments if frag != nil { frag.Final = msg.Final frag.Size += msg.Size @@ -226,7 +204,7 @@ func clientReader(c *client) { } - // 7. Dispatch to receiver + // dispatch to receiver if msg.Type == Text || msg.Type == Binary { c.ch.receive <- *msg } @@ -236,69 +214,59 @@ func clientReader(c *client) { close(c.ch.receive) c.io.reading.Done() - // 8. close channel (if not already done) + // close channel (if not already done) // fmt.Printf("[reader] end\n") c.close(closeStatus, clientAck) } -// writer writes into websocket -// and is triggered by client.ch.send channel +// clientWriter writes to the websocket connection and is triggered by +// client.ch.send channel func clientWriter(c *client) { - c.io.writing = true // if channel still exists for msg := range c.ch.send { - - // 2. Send message - err := msg.Send(c.io.sock) - - // 3. Fail on error + _, err := msg.WriteTo(c.io.sock) if err != nil { fmt.Printf(" [writer] %s\n", err) c.io.writing = false break } - } c.io.writing = false - // 4. close channel (if not already done) + // close channel (if not already done) // fmt.Printf("[writer] end\n") c.close(Normal, true) } -// closes the connection +// close the connection // send CLOSE frame is 'status' is not NONE // wait for the next message (CLOSE acknowledge) if 'clientACK' // then delete client func (c *client) close(status MessageError, clientACK bool) { - - // 1. Fail if already closing + // fail if already closing alreadyClosing := false c.io.closingMu.Lock() alreadyClosing = c.io.closing c.io.closing = true c.io.closingMu.Unlock() - if alreadyClosing { return } - // 2. kill writer' if still running + // kill writer' if still running if c.io.writing { close(c.ch.send) } - // 3. kill reader if still running + // kill reader if still running c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1)) c.io.reading.Wait() if status != None { - - // 3. Build message msg := &Message{ Final: true, Type: Close, @@ -307,40 +275,18 @@ func (c *client) close(status MessageError, clientACK bool) { } binary.BigEndian.PutUint16(msg.Data, uint16(status)) - // 4. Send message - msg.Send(c.io.sock) - // if err != nil { - // fmt.Printf("[close] send error (%s0\n", err) - // } - + msg.WriteTo(c.io.sock) } - // 2. Wait for client CLOSE if needed + // wait for client CLOSE if needed if clientACK { - c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond)) - - /* Wait for message */ - readMessage(c.io.reader) - // if err != nil || msg.Type != CLOSE { - // if err == nil { - // fmt.Printf("[close] received OpCode = %d\n", msg.Type) - // } else { - // fmt.Printf("[close] read error (%v)\n", err) - // } - // } - - // fmt.Printf("[close] received ACK\n") - + var tmpMsg = &Message{} + tmpMsg.ReadFrom(c.io.reader) } - // 3. Close socket c.io.sock.Close() // fmt.Printf("[close] socket closed\n") - // 4. Unregister c.io.kill <- c - - return - } diff --git a/cmd/iface/main.go b/cmd/iface/main.go index 6f39b04..cfae389 100644 --- a/cmd/iface/main.go +++ b/cmd/iface/main.go @@ -1,4 +1,4 @@ -package iface +package main import ( "fmt" @@ -11,12 +11,11 @@ func main() { startTime := time.Now().UnixNano() - // 1. Bind WebSocket server - serv := ws.CreateServer("0.0.0.0", 4444) + // creqte WebSocket server + serv := ws.NewServer("0.0.0.0", 4444) - // 2. Bind default controller + // bind default controller serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { - defer func() { if recover() != nil { fmt.Printf("*** PANIC\n") @@ -24,35 +23,28 @@ func main() { }() for msg := range receiver { - // if receive message -> send it back sender <- msg // close(sender) - } - }) - // 3. Bind to URI + // bnd to URI err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { - fmt.Printf("[uri] connected\n") for msg := range receiver { - fmt.Printf("[uri] received '%s'\n", msg.Data) sender <- msg } - fmt.Printf("[uri] unexpectedly closed\n") - }) if err != nil { panic(err) } - // 4. Launch the server + // launch the server err = serv.Launch() if err != nil { fmt.Printf("[ERROR] %s\n", err) @@ -60,5 +52,4 @@ func main() { } fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3) - } diff --git a/controller.go b/controller.go index f2cc3d1..6ced304 100644 --- a/controller.go +++ b/controller.go @@ -18,7 +18,7 @@ type Controller struct { Fun ControllerFunc // controller function } -// ControllerSet is set of controllers +// ControllerSet contains a set of controllers type ControllerSet struct { Def *Controller // default controller URI []*Controller // uri controllers @@ -27,35 +27,22 @@ type ControllerSet struct { // Match finds a controller for a given URI // also it returns the matching string patterns func (s *ControllerSet) Match(uri string) (*Controller, [][]string) { - - // 1. Initialise argument list arguments := [][]string{{uri}} - // 2. Try each controller for _, c := range s.URI { - - /* 1. If matches */ if c.URI.Match(uri) { - - /* Extract matches */ match := c.URI.GetAllMatch() - - /* Add them to the 'arg' attribute */ arguments = append(arguments, match...) - - /* Mark that we have a controller */ return c, arguments - } - } - // 3. If no controller found -> set default controller + // fallback to default if s.Def != nil { return s.Def, arguments } - // 4. If default is NIL, return empty controller + // no default return nil, arguments } diff --git a/internal/http/reader/reader.go b/internal/http/reader/reader.go index 1bf181b..ff9d48c 100644 --- a/internal/http/reader/reader.go +++ b/internal/http/reader/reader.go @@ -12,7 +12,7 @@ import ( ) // Maximum line length -var maxLineLength = 4096 +const maxLineLength = 4096 // ChunkReader struct type ChunkReader struct { @@ -32,19 +32,17 @@ func NewReader(r io.Reader) *ChunkReader { } -// Read reads a chunk, err is io.EOF when done +// Read reads a chunk, io.EOF when done func (r *ChunkReader) Read() ([]byte, error) { - - // 1. If already ended + // already ended if r.isEnded { return nil, io.EOF } - // 2. Read line + // read line var line []byte line, err := r.reader.ReadSlice('\n') - // 3. manage errors if err == io.EOF { err = io.ErrUnexpectedEOF } @@ -57,10 +55,8 @@ func (r *ChunkReader) Read() ([]byte, error) { return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength) } - // 4. Trim - line = removeTrailingSpace(line) + line = trimSpaces(line) - // 5. Manage ending line if len(line) == 0 { r.isEnded = true return line, io.EOF @@ -70,13 +66,13 @@ func (r *ChunkReader) Read() ([]byte, error) { } -func removeTrailingSpace(b []byte) []byte { - for len(b) > 0 && isASCIISpace(b[len(b)-1]) { +func trimSpaces(b []byte) []byte { + for len(b) > 0 && isSpaceChar(b[len(b)-1]) { b = b[:len(b)-1] } return b } -func isASCIISpace(b byte) bool { +func isSpaceChar(b byte) bool { return b == ' ' || b == '\t' || b == '\r' || b == '\n' } diff --git a/internal/http/upgrade/errors.go b/internal/http/upgrade/errors.go index befb526..7812350 100644 --- a/internal/http/upgrade/errors.go +++ b/internal/http/upgrade/errors.go @@ -4,33 +4,32 @@ import ( "fmt" ) -// invalid request +// ErrInvalidRequest for invalid requests // - multiple-value if only 1 expected -type InvalidRequest struct { +type ErrInvalidRequest struct { Field string Reason string } -func (err InvalidRequest) Error() string { - return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason) +func (err ErrInvalidRequest) Error() string { + return fmt.Sprintf("invalid field '%s': %s", err.Field, err.Reason) } -// Request misses fields (request-line or headers) -type IncompleteRequest struct { - MissingField string +// ErrIncompleteRequest when mandatory request fields are missing (request-line or headers) +// it contains the missing field as a string +type ErrIncompleteRequest string + +func (err ErrIncompleteRequest) Error() string { + return fmt.Sprintf("incomplete request, '%s' is invalid or missing", string(err)) } -func (err IncompleteRequest) Error() string { - return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField) -} - -// Request has a violated origin policy -type InvalidOriginPolicy struct { +// ErrInvalidOriginPolicy when a request has a violated origin policy +type ErrInvalidOriginPolicy struct { Host string Origin string err error } -func (err InvalidOriginPolicy) Error() string { +func (err ErrInvalidOriginPolicy) Error() string { return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err) } diff --git a/internal/http/upgrade/header_check.go b/internal/http/upgrade/header_check.go index 0db6f94..f5d9644 100644 --- a/internal/http/upgrade/header_check.go +++ b/internal/http/upgrade/header_check.go @@ -10,11 +10,11 @@ import ( func (r *Request) extractHostPort(bb HeaderValue) error { if len(bb) != 1 { - return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))} + return &ErrInvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))} } if len(bb[0]) <= 3 { - return &InvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))} + return &ErrInvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))} } split := strings.Split(string(bb[0]), ":") @@ -29,8 +29,8 @@ func (r *Request) extractHostPort(bb HeaderValue) error { // extract port readPort, err := strconv.ParseUint(split[1], 10, 16) if err != nil { - r.code = BadRequest - return &InvalidRequest{"Host", "cannot read port"} + r.statusCode = StatusBadRequest + return &ErrInvalidRequest{"Host", "cannot read port"} } r.port = uint16(readPort) @@ -39,8 +39,8 @@ func (r *Request) extractHostPort(bb HeaderValue) error { if len(r.origin) > 0 { if err != nil { err = r.checkOriginPolicy() - r.code = Forbidden - return &InvalidOriginPolicy{r.host, r.origin, err} + r.statusCode = StatusForbidden + return &ErrInvalidOriginPolicy{r.host, r.origin, err} } } @@ -57,8 +57,8 @@ func (r *Request) extractOrigin(bb HeaderValue) error { } if len(bb) != 1 { - r.code = Forbidden - return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))} + r.statusCode = StatusForbidden + return &ErrInvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))} } r.origin = string(bb[0]) @@ -67,8 +67,8 @@ func (r *Request) extractOrigin(bb HeaderValue) error { if len(r.host) > 0 { err := r.checkOriginPolicy() if err != nil { - r.code = Forbidden - return &InvalidOriginPolicy{r.host, r.origin, err} + r.statusCode = StatusForbidden + return &ErrInvalidOriginPolicy{r.host, r.origin, err} } } @@ -96,8 +96,8 @@ func (r *Request) checkConnection(bb HeaderValue) error { } - r.code = BadRequest - return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"} + r.statusCode = StatusBadRequest + return &ErrInvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"} } @@ -106,8 +106,8 @@ func (r *Request) checkConnection(bb HeaderValue) error { func (r *Request) checkUpgrade(bb HeaderValue) error { if len(bb) != 1 { - r.code = BadRequest - return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} + r.statusCode = StatusBadRequest + return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} } if strings.ToLower(string(bb[0])) == "websocket" { @@ -115,8 +115,8 @@ func (r *Request) checkUpgrade(bb HeaderValue) error { return nil } - r.code = BadRequest - return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])} + r.statusCode = StatusBadRequest + return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])} } @@ -125,8 +125,8 @@ func (r *Request) checkUpgrade(bb HeaderValue) error { func (r *Request) checkVersion(bb HeaderValue) error { if len(bb) != 1 || string(bb[0]) != "13" { - r.code = UpgradeRequired - return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])} + r.statusCode = StatusUpgradeRequired + return &ErrInvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])} } r.hasVersion = true @@ -139,8 +139,8 @@ func (r *Request) checkVersion(bb HeaderValue) error { func (r *Request) extractKey(bb HeaderValue) error { if len(bb) != 1 || len(bb[0]) != 24 { - r.code = BadRequest - return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))} + r.statusCode = StatusBadRequest + return &ErrInvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))} } r.key = bb[0] diff --git a/internal/http/upgrade/line.go b/internal/http/upgrade/line.go deleted file mode 100644 index 842b4a1..0000000 --- a/internal/http/upgrade/line.go +++ /dev/null @@ -1,118 +0,0 @@ -package upgrade - -import ( - "bytes" - "fmt" - "regexp" -) - -// Line represents the HTTP Request line -// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1 -type Line struct { - method Method - uri string - version byte -} - -// Parse parses the first HTTP request line -func (r *Line) Parse(b []byte) error { - - // 1. Split by ' ' - parts := bytes.Split(b, []byte(" ")) - - // 2. Fail when missing parts - if len(parts) != 3 { - return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts)) - } - - // 3. Extract HTTP method - err := r.extractHttpMethod(parts[0]) - if err != nil { - return err - } - - // 4. Extract URI - err = r.extractURI(parts[1]) - if err != nil { - return err - } - - // 5. Extract version - err = r.extractHttpVersion(parts[2]) - if err != nil { - return err - } - - return nil - -} - -// GetURI returns the actual URI -func (r Line) GetURI() string { - return r.uri -} - -// extractHttpMethod extracts the HTTP method from a []byte -// and checks for errors -// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE -func (r *Line) extractHttpMethod(b []byte) error { - - switch string(b) { - // case "OPTIONS": r.method = OPTIONS - case "GET": - r.method = Get - // case "HEAD": r.method = HEAD - // case "POST": r.method = POST - // case "PUT": r.method = PUT - // case "DELETE": r.method = DELETE - - default: - return fmt.Errorf("invalid HTTP method '%s', expected 'GET'", b) - } - - return nil -} - -// extractURI extracts the URI from a []byte and checks for errors -// allowed format: /([^/]/)*/? -func (r *Line) extractURI(b []byte) error { - - // 1. Check format - checker := regexp.MustCompile("^(?:/[^/]+)*/?$") - if !checker.Match(b) { - return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b) - } - - // 2. Store - r.uri = string(b) - - return nil - -} - -// extractHttpVersion extracts the version and checks for errors -// allowed format: [1-9] or [1.9].[0-9] -func (r *Line) extractHttpVersion(b []byte) error { - - // 1. Extract version parts - extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`) - - if !extractor.Match(b) { - return fmt.Errorf("invalid HTTP version, expected INT or INT.INT, got '%s'", b) - } - - // 2. Extract version number - matches := extractor.FindSubmatch(b) - var version byte = matches[1][0] - '0' - - // 3. Extract subversion (if exists) - var subVersion byte = 0 - if len(matches[2]) > 0 { - subVersion = matches[2][0] - '0' - } - - // 4. Store version (x 10 to fit uint8) - r.version = version*10 + subVersion - - return nil -} diff --git a/internal/http/upgrade/methods.go b/internal/http/upgrade/methods.go deleted file mode 100644 index 6b791d6..0000000 --- a/internal/http/upgrade/methods.go +++ /dev/null @@ -1,14 +0,0 @@ -package upgrade - -// Method represents available http methods -type Method uint8 - -// http methods -const ( - Options Method = iota - Get - Head - Post - Put - Delete -) diff --git a/internal/http/upgrade/request.go b/internal/http/upgrade/request.go index 45c1a5f..2524eb3 100644 --- a/internal/http/upgrade/request.go +++ b/internal/http/upgrade/request.go @@ -7,157 +7,139 @@ import ( "git.xdrm.io/go/ws/internal/http/reader" ) -// If origin is required +// whether origin is required const bypassOriginPolicy = true // Request represents an HTTP Upgrade request type Request struct { - first bool // whether the first line has been read (GET uri HTTP/version) + // whether the first line has been read (GET uri HTTP/version) + first bool + statusCode StatusCode + requestLine RequestLine - // status code - code StatusCode - - // request line - request Line - - // data to check origin (depends of reading order) + // data to check origin (depends on reading order) host string port uint16 // 0 if not set origin string validPolicy bool - // ws data + // websocket specific key []byte protocols [][]byte - // required fields check + // mandatory fields to check hasConnection bool hasUpgrade bool hasVersion bool } -// Parse builds an upgrade HTTP request -// from a reader (typically bufio.NewRead of the socket) -func Parse(r io.Reader) (request *Request, err error) { +// ReadFrom reads an upgrade HTTP request ; typically from bufio.NewRead of the +// socket +// +// implements io.ReaderFrom +func (req *Request) ReadFrom(r io.Reader) (int64, error) { + var read int64 - req := &Request{ - code: 500, - } + // reset request + req.statusCode = 500 - /* (1) Parse request - ---------------------------------------------------------*/ - // 1. Get chunk reader - cr := reader.NewReader(r) - if err != nil { - return req, fmt.Errorf("create chunk reader: %w", err) - } - - // 2. Parse header line by line + // parse header line by line + var cr = reader.NewReader(r) for { - line, err := cr.Read() + read += int64(len(line)) if err == io.EOF { break } if err != nil { - return req, err + return read, err } err = req.parseHeader(line) if err != nil { - return req, err + return read, err } - } - // 3. Check completion - err = req.isComplete() + err := req.isComplete() if err != nil { - req.code = BadRequest - return req, err + req.statusCode = StatusBadRequest + return read, err } - req.code = SwitchingProtocols - return req, nil - + req.statusCode = StatusSwitchingProtocols + return read, nil } // StatusCode returns the status current -func (r Request) StatusCode() StatusCode { - return r.code +func (req Request) StatusCode() StatusCode { + return req.statusCode } // BuildResponse builds a response from the request -func (r *Request) BuildResponse() *Response { +func (req *Request) BuildResponse() *Response { - inst := &Response{} - - // 1. Copy code - inst.SetStatusCode(r.code) - - // 2. Set Protocol - if len(r.protocols) > 0 { - inst.SetProtocol(r.protocols[0]) + res := &Response{ + StatusCode: req.statusCode, + Protocol: nil, } - // 4. Process key - inst.ProcessKey(r.key) + if len(req.protocols) > 0 { + res.Protocol = req.protocols[0] + } - return inst + res.ProcessKey(req.key) + + return res } -// GetURI returns the actual URI -func (r Request) GetURI() string { - return r.request.GetURI() +// URI returns the actual URI +func (req Request) URI() string { + return req.requestLine.URI() } // parseHeader parses any http request line // (header and request-line) -func (r *Request) parseHeader(b []byte) error { - - /* (1) First line -> GET {uri} HTTP/{version} - ---------------------------------------------------------*/ - if !r.first { - - err := r.request.Parse(b) +func (req *Request) parseHeader(b []byte) error { + // first line -> GET {uri} HTTP/{version} + if !req.first { + _, err := req.requestLine.Read(b) if err != nil { - r.code = BadRequest - return &InvalidRequest{"Request-Line", err.Error()} + req.statusCode = StatusBadRequest + return &ErrInvalidRequest{"Request-Line", err.Error()} } - r.first = true + req.first = true return nil } - /* (2) Other lines -> Header-Name: Header-Value - ---------------------------------------------------------*/ - // 1. Try to parse header + // other lines -> Header-Name: Header-Value head, err := ReadHeader(b) if err != nil { - r.code = BadRequest + req.statusCode = StatusBadRequest return fmt.Errorf("parse header: %w", err) } // 2. Manage header switch head.Name { case Host: - err = r.extractHostPort(head.Values) + err = req.extractHostPort(head.Values) case Origin: - err = r.extractOrigin(head.Values) + err = req.extractOrigin(head.Values) case Upgrade: - err = r.checkUpgrade(head.Values) + err = req.checkUpgrade(head.Values) case Connection: - err = r.checkConnection(head.Values) + err = req.checkConnection(head.Values) case WSVersion: - err = r.checkVersion(head.Values) + err = req.checkVersion(head.Values) case WSKey: - err = r.extractKey(head.Values) + err = req.extractKey(head.Values) case WSProtocol: - err = r.extractProtocols(head.Values) + err = req.extractProtocols(head.Values) default: return nil @@ -169,48 +151,39 @@ func (r *Request) parseHeader(b []byte) error { } return nil - } // isComplete returns whether the Upgrade Request -// is complete (no missing required item) -func (r Request) isComplete() error { - - // 1. Request-Line - if !r.first { - return &IncompleteRequest{"Request-Line"} +// is complete (no required field missing) +// returns nil on success +func (req Request) isComplete() error { + if !req.first { + return ErrIncompleteRequest("Request-Line") } - // 2. Host - if len(r.host) == 0 { - return &IncompleteRequest{"Host"} + if len(req.host) == 0 { + return ErrIncompleteRequest("Host") } - // 3. Origin - if !bypassOriginPolicy && len(r.origin) == 0 { - return &IncompleteRequest{"Origin"} + if !bypassOriginPolicy && len(req.origin) == 0 { + return ErrIncompleteRequest("Origin") } - // 4. Connection - if !r.hasConnection { - return &IncompleteRequest{"Connection"} + if !req.hasConnection { + return ErrIncompleteRequest("Connection") } - // 5. Upgrade - if !r.hasUpgrade { - return &IncompleteRequest{"Upgrade"} + if !req.hasUpgrade { + return ErrIncompleteRequest("Upgrade") } - // 6. Sec-WebSocket-Version - if !r.hasVersion { - return &IncompleteRequest{"Sec-WebSocket-Version"} + if !req.hasVersion { + return ErrIncompleteRequest("Sec-WebSocket-Version") } - // 7. Sec-WebSocket-Key - if len(r.key) < 1 { - return &IncompleteRequest{"Sec-WebSocket-Key"} + if len(req.key) < 1 { + return ErrIncompleteRequest("Sec-WebSocket-Key") } return nil - } diff --git a/internal/http/upgrade/request_line.go b/internal/http/upgrade/request_line.go new file mode 100644 index 0000000..fe8d1cc --- /dev/null +++ b/internal/http/upgrade/request_line.go @@ -0,0 +1,94 @@ +package upgrade + +import ( + "bytes" + "fmt" + "regexp" +) + +// RequestLine represents the HTTP Request line +// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1 +type RequestLine struct { + uri string + version byte +} + +// Read an HTTP request line from a byte array +// +// implements io.Reader +func (rl *RequestLine) Read(b []byte) (int, error) { + var read = len(b) + + // split by spaces + parts := bytes.Split(b, []byte(" ")) + + if len(parts) != 3 { + return read, fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts)) + } + + err := rl.extractHttpMethod(parts[0]) + if err != nil { + return read, err + } + + err = rl.extractURI(parts[1]) + if err != nil { + return read, err + } + + err = rl.extractHttpVersion(parts[2]) + if err != nil { + return read, err + } + + return read, nil + +} + +// URI of the request line +func (rl RequestLine) URI() string { + return rl.uri +} + +// extractHttpMethod extracts the HTTP method from a []byte +// and checks for errors +// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE +func (rl *RequestLine) extractHttpMethod(b []byte) error { + if string(b) != "GET" { + return fmt.Errorf("invalid HTTP method '%s', expected 'GET'", b) + } + return nil +} + +// extractURI extracts the URI from a []byte and checks for errors +// allowed format: /([^/]/)*/? +func (rl *RequestLine) extractURI(b []byte) error { + checker := regexp.MustCompile("^(?:/[^/]+)*/?$") + if !checker.Match(b) { + return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b) + } + rl.uri = string(b) + return nil + +} + +// extractHttpVersion extracts the version and checks for errors +// allowed format: [1-9] or [1.9].[0-9] +func (rl *RequestLine) extractHttpVersion(b []byte) error { + extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`) + + if !extractor.Match(b) { + return fmt.Errorf("invalid HTTP version, expected INT or INT.INT, got '%s'", b) + } + matches := extractor.FindSubmatch(b) + + var version byte = matches[1][0] - '0' + + var subversion byte = 0 + if len(matches[2]) > 0 { + subversion = matches[2][0] - '0' + } + + rl.version = version*10 + subversion + return nil +} diff --git a/internal/http/upgrade/request_test.go b/internal/http/upgrade/request_test.go index 7af15fa..adf4933 100644 --- a/internal/http/upgrade/request_test.go +++ b/internal/http/upgrade/request_test.go @@ -7,17 +7,15 @@ import ( ) func TestEOFSocket(t *testing.T) { + var ( + socket = &bytes.Buffer{} + req = &Request{} + ) - socket := &bytes.Buffer{} - - _, err := Parse(socket) - - if err == nil { - t.Fatalf("Empty socket expected EOF, got no error") - } else if err != io.ErrUnexpectedEOF { - t.Fatalf("Empty socket expected EOF, got '%s'", err) + _, err := req.ReadFrom(socket) + if err != io.ErrUnexpectedEOF { + t.Fatalf("unexpected error <%v> expected <%v>", err, io.ErrUnexpectedEOF) } - } func TestInvalidRequestLine(t *testing.T) { @@ -59,15 +57,13 @@ func TestInvalidRequestLine(t *testing.T) { socket.Write([]byte(tc.Reqline)) socket.Write([]byte("\r\n\r\n")) - _, err := Parse(socket) - + var req = &Request{} + _, err := req.ReadFrom(socket) if !tc.HasError { - - // no error -> ok if err == nil { continue // error for the end of the request -> ok - } else if _, ok := err.(*IncompleteRequest); ok { + } else if _, ok := err.(ErrIncompleteRequest); ok { continue } @@ -80,7 +76,7 @@ func TestInvalidRequestLine(t *testing.T) { continue } - ir, ok := err.(*InvalidRequest) + ir, ok := err.(*ErrInvalidRequest) // not InvalidRequest err -> error if !ok || ir.Field != "Request-Line" { @@ -131,15 +127,15 @@ func TestInvalidHost(t *testing.T) { socket.Write([]byte(tc.Host)) socket.Write([]byte("\r\n\r\n")) - _, err := Parse(socket) - + var req = &Request{} + _, err := req.ReadFrom(socket) if !tc.HasError { // no error -> ok if err == nil { continue // error for the end of the request -> ok - } else if _, ok := err.(*IncompleteRequest); ok { + } else if _, ok := err.(ErrIncompleteRequest); ok { continue } @@ -153,7 +149,7 @@ func TestInvalidHost(t *testing.T) { } // check if InvalidRequest - ir, ok := err.(*InvalidRequest) + ir, ok := err.(ErrInvalidRequest) // not InvalidRequest err -> error if ok && ir.Field != "Host" { diff --git a/internal/http/upgrade/response.go b/internal/http/upgrade/response.go index cd0b44d..e85d5d1 100644 --- a/internal/http/upgrade/response.go +++ b/internal/http/upgrade/response.go @@ -7,88 +7,58 @@ import ( "io" ) -// HTTPVersion constant -const HTTPVersion = "1.1" +// constants +const ( + httpVersion = "1.1" + wsVersion = 13 + keySalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +) -// UsedWSVersion constant websocket version -const UsedWSVersion = 13 - -// WSSalt constant websocket salt -const WSSalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - -// Response represents an HTTP Upgrade Response +// Response is an HTTP Upgrade Response type Response struct { - code StatusCode // status code - accept []byte // processed from Sec-WebSocket-Key - protocol []byte // set from Sec-WebSocket-Protocol or none if not received -} - -// SetStatusCode sets the status code -func (r *Response) SetStatusCode(sc StatusCode) { - r.code = sc -} - -// SetProtocol sets the protocols -func (r *Response) SetProtocol(p []byte) { - r.protocol = p + StatusCode StatusCode + // Sec-WebSocket-Protocol or nil if missing + Protocol []byte + // processed from Sec-WebSocket-Key + key []byte } // ProcessKey processes the accept token according // to the rfc from the Sec-WebSocket-Key func (r *Response) ProcessKey(k []byte) { - - // do nothing for empty key - if k == nil || len(k) == 0 { - r.accept = nil + // ignore empty key + if k == nil || len(k) < 1 { return } - // 1. Concat with constant salt - mix := append(k, []byte(WSSalt)...) - - // 2. Hash with sha1 algorithm - digest := sha1.Sum(mix) - - // 3. Base64 encode it - r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size])) - + // concat with constant salt + salted := append(k, []byte(keySalt)...) + // hash with sha1 + digest := sha1.Sum(salted) + // base64 encode + r.key = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size])) } -// Send sends the response through an io.Writer -// typically a socket -func (r Response) Send(w io.Writer) (int, error) { +// WriteTo writes the response; typically in a socket +// +// implements io.WriterTo +func (r Response) WriteTo(w io.Writer) (int64, error) { + responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", httpVersion, r.StatusCode, r.StatusCode) - // 1. Build response line - responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HTTPVersion, r.code, r.code) - - // 2. Build headers optionalProtocol := "" - if len(r.protocol) > 0 { - optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol) + if len(r.Protocol) > 0 { + optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.Protocol) } - headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", UsedWSVersion, optionalProtocol) - if r.accept != nil { - headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept) + headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", wsVersion, optionalProtocol) + if r.key != nil { + headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.key) } headers = fmt.Sprintf("%s\r\n", headers) - // 3. Build all - raw := []byte(fmt.Sprintf("%s%s", responseLine, headers)) + combined := []byte(fmt.Sprintf("%s%s", responseLine, headers)) - // 4. Write - written, err := w.Write(raw) - - return written, err + written, err := w.Write(combined) + return int64(written), err } - -// GetProtocol returns the choosen protocol if set, else nil -func (r Response) GetProtocol() []byte { - return r.protocol -} - -// GetStatusCode returns the response status code -func (r Response) GetStatusCode() StatusCode { - return r.code -} diff --git a/internal/http/upgrade/status_code.go b/internal/http/upgrade/status_code.go index aed18be..c64845c 100644 --- a/internal/http/upgrade/status_code.go +++ b/internal/http/upgrade/status_code.go @@ -1,37 +1,37 @@ package upgrade -// StatusCode maps the status codes (and description) -type StatusCode uint16 +// StatusCode maps HTTP status codes (and description) +type StatusCode int const ( - // SwitchingProtocols - handshake success - SwitchingProtocols StatusCode = 101 - // BadRequest - missing/malformed headers - BadRequest StatusCode = 400 - // Forbidden - invalid origin policy, TLS required - Forbidden StatusCode = 403 - // UpgradeRequired - invalid WS version - UpgradeRequired StatusCode = 426 - // NotFound - unserved or invalid URI - NotFound StatusCode = 404 - // Internal - custom error - Internal StatusCode = 500 + // StatusSwitchingProtocols - handshake success + StatusSwitchingProtocols StatusCode = 101 + // StatusBadRequest - missing/malformed headers + StatusBadRequest StatusCode = 400 + // StatusForbidden - invalid origin policy, TLS required + StatusForbidden StatusCode = 403 + // StatusUpgradeRequired - invalid WS version + StatusUpgradeRequired StatusCode = 426 + // StatusNotFound - unserved or invalid URI + StatusNotFound StatusCode = 404 + // StatusInternal - custom error + StatusInternal StatusCode = 500 ) // String implements the Stringer interface func (sc StatusCode) String() string { switch sc { - case SwitchingProtocols: + case StatusSwitchingProtocols: return "Switching Protocols" - case BadRequest: + case StatusBadRequest: return "Bad Request" - case Forbidden: + case StatusForbidden: return "Forbidden" - case UpgradeRequired: + case StatusUpgradeRequired: return "Upgrade Required" - case NotFound: + case StatusNotFound: return "Not Found" - case Internal: + case StatusInternal: return "Internal Server Error" default: return "Unknown Status Code" diff --git a/internal/uri/parser.go b/internal/uri/parser.go index e581c23..7a5ba1d 100644 --- a/internal/uri/parser.go +++ b/internal/uri/parser.go @@ -39,18 +39,16 @@ type matcher struct { // Scheme represents an URI scheme type Scheme []*matcher -// FromString builds an URI scheme from a pattern string +// FromString builds an URI scheme from a string pattern func FromString(s string) (*Scheme, error) { - - // 1. Manage '/' at the start + // handle '/' at the start if len(s) < 1 || s[0] != '/' { return nil, fmt.Errorf("invalid URI; must start with '/'") } - // 2. Split by '/' parts := strings.Split(s, "/") - // 3. Max exceeded + // check max match size if len(parts)-2 > maxMatch { for i, p := range parts { fmt.Printf("%d: '%s'\n", i, p) @@ -58,13 +56,11 @@ func FromString(s string) (*Scheme, error) { return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts)) } - // 4. Build for each part sch, err := buildScheme(parts) if err != nil { return nil, err } - // 5. Optimise structure opti, err := sch.optimise() if err != nil { return nil, err @@ -74,91 +70,67 @@ func FromString(s string) (*Scheme, error) { } -// Match returns if the given URI is matched by the scheme -func (s Scheme) Match(str string) bool { - - // 1. Nothing -> match all +// Match returns whether the given URI is matched by the scheme +func (s Scheme) Match(uri string) bool { if len(s) == 0 { return true } - // 2. Check for string match - clearURI, match := s.matchString(str) + // check for string match + clearURI, match := s.matchString(uri) if !match { return false } - // 3. Check for non-string match (wildcards) - match = s.matchWildcards(clearURI) - if !match { - return false - } - - return true + // check for non-string match (wildcards) + return s.matchWildcards(clearURI) } // GetMatch returns the indexed match (excluding string matchers) func (s Scheme) GetMatch(n uint8) ([]string, error) { - - // 1. Index out of range if n > uint8(len(s)) { return nil, fmt.Errorf("index out of range") } - // 2. Iterate to find index (exclude strings) - ni := -1 + // iterate to find index (exclude strings) + matches := -1 for _, m := range s { - - // ignore strings if len(m.pat) > 0 { continue } - // increment match counter : ni - ni++ + matches++ - // if expected index -> return matches - if uint8(ni) == n { + // expected index -> return matches + if uint8(matches) == n { return m.buf, nil } - } - // 3. If nothing found -> return empty set - return nil, fmt.Errorf("index out of range (max: %d)", ni) + // nothing found -> return empty set + return nil, fmt.Errorf("index out of range (max: %d)", matches) } // GetAllMatch returns all the indexed match (excluding string matchers) func (s Scheme) GetAllMatch() [][]string { - match := make([][]string, 0, len(s)) for _, m := range s { - - // ignore strings if len(m.pat) > 0 { continue } - match = append(match, m.buf) - } - return match - } // buildScheme builds a 'basic' scheme // from a pattern string func buildScheme(ss []string) (Scheme, error) { - - // 1. Build scheme sch := make(Scheme, 0, maxMatch) for _, s := range ss { - - // 2. ignore empty if len(s) == 0 { continue } @@ -167,31 +139,31 @@ func buildScheme(ss []string) (Scheme, error) { switch s { - // 3. Card: 0, N + // card: 0, N case "**": m.req = false m.mul = true sch = append(sch, m) - // 4. Card: 1, N + // card: 1, N case "..": m.req = true m.mul = true sch = append(sch, m) - // 5. Card: 0, 1 + // card: 0, 1 case "*": m.req = false m.mul = false sch = append(sch, m) - // 6. Card: 1 + // card: 1 case ".": m.req = true m.mul = false sch = append(sch, m) - // 7. Card: 1, literal string + // card: 1, literal string default: m.req = true m.mul = false @@ -207,30 +179,26 @@ func buildScheme(ss []string) (Scheme, error) { // optimise optimised the scheme for further parsing func (s Scheme) optimise() (Scheme, error) { - - // 1. Nothing to do if only 1 element if len(s) <= 1 { return s, nil } - // 2. Init reshifted scheme + // init reshifted scheme rshift := make(Scheme, 0, maxMatch) rshift = append(rshift, s[0]) - // 2. Iterate over matchers + // iterate over matchers for p, i, l := 0, 1, len(s); i < l; i++ { pre, cur := s[p], s[i] - /* Merge: 2 following literals */ + // merge: 2 following literals if len(pre.pat) > 0 && len(cur.pat) > 0 { - // merge strings into previous pre.pat = fmt.Sprintf("%s%s", pre.pat, cur.pat) // delete current s[i] = nil - } // increment previous (only if current is not nul) @@ -242,67 +210,65 @@ func (s Scheme) optimise() (Scheme, error) { } return rshift, nil - } // matchString checks the STRING matchers from an URI -// it returns a boolean : false when not matching, true eitherway -// it returns a cleared uri, without STRING data +// - returns a boolean : false when not matching, true eitherway +// - returns a cleared uri, without STRING data func (s Scheme) matchString(uri string) (string, bool) { - // 1. Initialise variables - clr := uri // contains cleared input string - minOff := 0 // minimum offset + var ( + clearedInput = uri + minOffset = 0 + ) - // 2. Iterate over strings for _, m := range s { - ls := len(m.pat) - // {1} If not STRING matcher -> ignore // + // ignore no STRING match if ls == 0 { continue } - // {2} Get offset in URI (else -1) // - off := strings.Index(clr, m.pat) + // get offset in URI (else -1) + off := strings.Index(clearedInput, m.pat) if off < 0 { return "", false } - // {3} Fail on invalid offset range // - if off < minOff { + // fail on invalid offset range + if off < minOffset { return "", false } - // {4} Check for trailing '/' // + // check for trailing '/' hasSlash := 0 - if off+ls < len(clr) && clr[off+ls] == '/' { + if off+ls < len(clearedInput) && clearedInput[off+ls] == '/' { hasSlash = 1 } - // {5} Remove the current string (+trailing slash) from the URI // - beg, end := clr[:off], clr[off+ls+hasSlash:] - clr = fmt.Sprintf("%s\a/%s", beg, end) // separate matches by '\a' character + // remove the current string (+trailing slash) from the URI + beg, end := clearedInput[:off], clearedInput[off+ls+hasSlash:] + clearedInput = fmt.Sprintf("%s\a/%s", beg, end) // separate matches with a '\a' character - // {6} Update offset range // - minOff = len(beg) + 2 - 1 // +2 slash separators + // update offset range + // +2 slash separators // -1 because strings begin with 1 slash already + minOffset = len(beg) + 2 - 1 } - // 3. If exists, remove trailing '/' - if clr[len(clr)-1] == '/' { - clr = clr[:len(clr)-1] + // if exists, remove trailing '/' + if clearedInput[len(clearedInput)-1] == '/' { + clearedInput = clearedInput[:len(clearedInput)-1] } - // 4. If exists, remove trailing '\a' - if clr[len(clr)-1] == '\a' { - clr = clr[:len(clr)-1] + // if exists, remove trailing '\a' + if clearedInput[len(clearedInput)-1] == '\a' { + clearedInput = clearedInput[:len(clearedInput)-1] } - return clr, true - + return clearedInput, true } // matchWildcards check the WILCARDS (non-string) matchers from @@ -310,7 +276,7 @@ func (s Scheme) matchString(uri string) (string, bool) { // + it sets the matchers buffers for later extraction func (s Scheme) matchWildcards(clear string) bool { - // 1. Extract wildcards (ref) + // extract wildcards (ref) wildcards := make(Scheme, 0, maxMatch) for _, m := range s { @@ -320,41 +286,34 @@ func (s Scheme) matchWildcards(clear string) bool { } } - // 2. If no wildcards -> match if len(wildcards) == 0 { return true } - // 3. Break uri by '\a' characters + // break uri by '\a' characters matches := strings.Split(clear, "\a")[1:] - // 4. Iterate over matches for n, match := range matches { - - // {1} If no more matcher // + // no more matcher if n >= len(wildcards) { return false } - // {2} Split by '/' // - data := strings.Split(match, "/")[1:] // from index 1 because it begins with '/' + // from index 1 because it begins with '/' + data := strings.Split(match, "/")[1:] - // {3} If required and missing // + // missing required if wildcards[n].req && len(data) < 1 { return false } - // {4} If not multi but got multi // + // if not multi but got multi if !wildcards[n].mul && len(data) > 1 { return false } - // {5} Store data into matcher // wildcards[n].buf = data - } - // 5. Match return true - } diff --git a/message.go b/message.go index f00f757..dafc0f8 100644 --- a/message.go +++ b/message.go @@ -84,36 +84,40 @@ type Message struct { Data []byte } -// receive reads a message form reader -func readMessage(reader io.Reader) (*Message, error) { +// ReadFrom reads a message from a reader +// +// implements io.ReaderFrom +func (m *Message) ReadFrom(reader io.Reader) (int64, error) { + var ( + read int64 + err error + tmpBuf []byte + mask []byte + cursor int + ) - var err error - var tmpBuf []byte - var mask []byte - var cursor int - - m := &Message{} - - // 2. Byte 1: FIN and OpCode + // byte 1: FIN and OpCode tmpBuf = make([]byte, 1) + read += int64(len(tmpBuf)) err = readBytes(reader, tmpBuf) if err != nil { - return m, err + return read, err } // check reserved bits if tmpBuf[0]&0x70 != 0 { - return m, ErrReservedBits + return read, ErrReservedBits } m.Final = bool(tmpBuf[0]&0x80 == 0x80) m.Type = MessageType(tmpBuf[0] & 0x0f) - // 3. Byte 2: Mask and Length[0] + // byte 2: mask and length[0] tmpBuf = make([]byte, 1) + read += int64(len(tmpBuf)) err = readBytes(reader, tmpBuf) if err != nil { - return m, err + return read, err } // if mask, byte array not nil @@ -124,71 +128,63 @@ func readMessage(reader io.Reader) (*Message, error) { // payload length m.Size = uint(tmpBuf[0] & 0x7f) - // 4. Extended payload + // extended payload if m.Size == 127 { - tmpBuf = make([]byte, 8) + read += int64(len(tmpBuf)) err := readBytes(reader, tmpBuf) if err != nil { - return m, err + return read, err } - m.Size = uint(binary.BigEndian.Uint64(tmpBuf)) } else if m.Size == 126 { - tmpBuf = make([]byte, 2) + read += int64(len(tmpBuf)) err := readBytes(reader, tmpBuf) if err != nil { - return m, err + return read, err } - m.Size = uint(binary.BigEndian.Uint16(tmpBuf)) - } - // 5. Masking key + // masking key if mask != nil { - tmpBuf = make([]byte, 4) + read += int64(len(tmpBuf)) err := readBytes(reader, tmpBuf) if err != nil { - return m, err + return read, err } - mask = make([]byte, 4) copy(mask, tmpBuf) - } - // 6. Read payload by chunks + // read payload by chunks m.Data = make([]byte, int(m.Size)) cursor = 0 - // {1} While we have data to read // + // while data to read for uint(cursor) < m.Size { - // {2} Try to read (at least 1 byte) // + // try to read (at least 1 byte) nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1) if err != nil { - return m, err + return read + int64(cursor) + int64(nbread), err } - // {3} Unmask data // + // unmask data // if mask != nil { for i, l := cursor, cursor+nbread; i < l; i++ { - mi := i % 4 // mask index m.Data[i] = m.Data[i] ^ mask[mi] - } } - // {4} Update cursor // cursor += nbread - } + read += int64(cursor) // return error if unmasked frame // we have to fully read it for read buffer to be clean @@ -197,13 +193,14 @@ func readMessage(reader io.Reader) (*Message, error) { err = ErrUnmaskedFrame } - return m, err + return read, err } -// Send sends a frame over a socket -func (m Message) Send(writer io.Writer) error { - +// WriteTo writes a message frame over a socket +// +// implements io.WriterTo +func (m Message) WriteTo(writer io.Writer) (int64, error) { header := make([]byte, 0, maximumHeaderSize) // fix size @@ -211,20 +208,18 @@ func (m Message) Send(writer io.Writer) error { m.Size = uint(len(m.Data)) } - // 1. Byte 0 : FIN + opcode + // byte 0 : FIN + opcode var final byte = 0x80 if !m.Final { final = 0 } header = append(header, final|byte(m.Type)) - // 2. Get payload length + // get payload length if m.Size < 126 { // simple - header = append(header, byte(m.Size)) } else if m.Size <= 0xffff { // extended: 16 bits - header = append(header, 126) buf := make([]byte, 2) @@ -232,7 +227,6 @@ func (m Message) Send(writer io.Writer) error { header = append(header, buf...) } else if m.Size <= 0xffffffffffffffff { // extended: 64 bits - header = append(header, 127) buf := make([]byte, 8) @@ -241,16 +235,15 @@ func (m Message) Send(writer io.Writer) error { } - // 3. Build write buffer + // build write buffer writeBuf := make([]byte, 0, len(header)+int(m.Size)) writeBuf = append(writeBuf, header...) writeBuf = append(writeBuf, m.Data[0:m.Size]...) - // 4. Send over socket by chunks + // write by chunks toWrite := len(header) + int(m.Size) cursor := 0 for cursor < toWrite { - maxBoundary := cursor + maxWriteChunk if maxBoundary > toWrite { maxBoundary = toWrite @@ -259,34 +252,32 @@ func (m Message) Send(writer io.Writer) error { // Try to wrote (at max 1024 bytes) // nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary]) if err != nil { - return err + return int64(nbwritten), err } // Update cursor // cursor += nbwritten - } - - return nil + return int64(cursor), nil } -// Check for message errors with: +// check for message errors with: // (m) the current message // (fragment) whether there is a fragment in construction // returns the message error func (m *Message) check(fragment bool) error { - // 1. Invalid first fragment (not TEXT nor BINARY) + // invalid first fragment (not TEXT nor BINARY) if !m.Final && !fragment && m.Type != Text && m.Type != Binary { return ErrInvalidFragment } - // 2. Waiting fragment but received standalone frame + // waiting fragment but received standalone frame if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong { return ErrInvalidFragment } - // 3. Control frame too long + // control frame too long if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) { return ErrTooLongControlFrame } @@ -347,20 +338,19 @@ func (m *Message) check(fragment bool) error { // // It manages connections which chunks data func readBytes(reader io.Reader, buffer []byte) error { + var ( + cur = 0 + len = len(buffer) + ) - var cur, len int = 0, len(buffer) - - // try to read until the full size is read + // read until the full size is read for cur < len { - nbread, err := reader.Read(buffer[cur:]) if err != nil { return err } - cur += nbread } - return nil } diff --git a/message_test.go b/message_test.go index 0f0441e..b84cf70 100644 --- a/message_test.go +++ b/message_test.go @@ -67,11 +67,12 @@ func TestSimpleMessageReading(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { + var ( + reader = bytes.NewBuffer(tc.ReadBuffer) + msg = &Message{} + ) - reader := bytes.NewBuffer(tc.ReadBuffer) - - got, err := readMessage(reader) - + _, err := msg.ReadFrom(reader) if err != tc.Err { t.Errorf("Expected %v error, got %v", tc.Err, err) } @@ -82,23 +83,23 @@ func TestSimpleMessageReading(t *testing.T) { } // check FIN - if got.Final != tc.Expected.Final { - t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, got.Final) + if msg.Final != tc.Expected.Final { + t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, msg.Final) } // check OpCode - if got.Type != tc.Expected.Type { - t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, got.Type) + if msg.Type != tc.Expected.Type { + t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, msg.Type) } // check Size - if got.Size != tc.Expected.Size { - t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, got.Size) + if msg.Size != tc.Expected.Size { + t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, msg.Size) } // check Data - if string(got.Data) != string(tc.Expected.Data) { - t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, got.Data) + if string(msg.Data) != string(tc.Expected.Data) { + t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, msg.Data) } }) @@ -177,17 +178,15 @@ func TestReadEOF(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { - - reader := bytes.NewBuffer(tc.ReadBuffer) - - got, err := readMessage(reader) - + var ( + reader = bytes.NewBuffer(tc.ReadBuffer) + msg = &Message{} + ) + _, err := msg.ReadFrom(reader) if tc.eof { - if err != io.EOF { - t.Errorf("Expected EOF, got %v", err) + t.Fatalf("Expected EOF, got %v", err) } - return } @@ -195,8 +194,8 @@ func TestReadEOF(t *testing.T) { t.Errorf("Expected UnmaskedFrameor, got %v", err) } - if got.Size != 0x00 { - t.Errorf("Expected a size of 0, got %d", got.Size) + if msg.Size != 0x00 { + t.Errorf("Expected a size of 0, got %d", msg.Size) } }) @@ -269,8 +268,7 @@ func TestSimpleMessageSending(t *testing.T) { writer := &bytes.Buffer{} - err := tc.Base.Send(writer) - + _, err := tc.Base.WriteTo(writer) if err != nil { t.Errorf("expected no error, got %v", err) return diff --git a/server.go b/server.go index f9a43a6..ac77702 100644 --- a/server.go +++ b/server.go @@ -27,9 +27,8 @@ type Server struct { ch serverChannelSet } -// CreateServer for a specific HOST and PORT -func CreateServer(host string, port uint16) *Server { - +// NewServer creates a server +func NewServer(host string, port uint16) *Server { return &Server{ addr: []byte(host), port: port, @@ -47,116 +46,84 @@ func CreateServer(host string, port uint16) *Server { broadcast: make(chan Message, 1), }, } - } // BindDefault binds a default controller // it will be called if the URI does not // match another controller func (s *Server) BindDefault(f ControllerFunc) { - s.ctl.Def = &Controller{ URI: nil, Fun: f, } - } // Bind a controller to an URI scheme func (s *Server) Bind(uriStr string, f ControllerFunc) error { - - // 1. Build URI parser uriScheme, err := uri.FromString(uriStr) if err != nil { return fmt.Errorf("cannot build URI: %w", err) } - // 2. Create controller s.ctl.URI = append(s.ctl.URI, &Controller{ URI: uriScheme, Fun: f, }) - return nil } // Launch the websocket server func (s *Server) Launch() error { + var ( + err error + url = fmt.Sprintf("%s:%d", s.addr, s.port) + ) - var err error - - /* (1) Listen socket - ---------------------------------------------------------*/ - // 1. Build full url - url := fmt.Sprintf("%s:%d", s.addr, s.port) - - // 2. Bind socket to listen s.sock, err = net.Listen("tcp", url) if err != nil { return fmt.Errorf("listen: %w", err) } - defer s.sock.Close() fmt.Printf("+ listening on %s\n", url) + go s.schedule() - // 3. Launch scheduler - go s.scheduler() - - /* (2) For each incoming connection (client) - ---------------------------------------------------------*/ for { - - // 1. Wait for client sock, err := s.sock.Accept() if err != nil { break } go func() { - - // 2. Try to create client - cli, err := buildClient(sock, s.ctl, s.ch) + cli, err := newClient(sock, s.ctl, s.ch) if err != nil { fmt.Printf(" - %s\n", err) return } - // 3. Register client s.ch.register <- cli - }() - } - return nil - } -// Scheduler schedules clients registration and broadcast -func (s *Server) scheduler() { - +// schedule client registration and broadcast +func (s *Server) schedule() { for { - select { - // 1. Create client case client := <-s.ch.register: s.clients[client.io.sock] = client - // 2. Remove client case client := <-s.ch.unregister: delete(s.clients, client.io.sock) - // 3. Broadcast case msg := <-s.ch.broadcast: for _, c := range s.clients { c.ch.send <- msg } } - } - }