diff --git a/client.go b/client.go index 3a91a2a..b9827e4 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "git.xdrm.io/go/ws/internal/http/upgrade/request" + "git.xdrm.io/go/ws/internal/http/upgrade" ) // Represents a client socket utility (reader, writer, ..) @@ -41,13 +41,13 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli /* (1) Manage UPGRADE request ---------------------------------------------------------*/ - /* (1) Parse request */ - req, _ := request.Parse(s) + // 1. Parse request + req, _ := upgrade.Parse(s) - /* (3) Build response */ + // 3. Build response res := req.BuildResponse() - /* (4) Write into socket */ + // 4. Write into socket _, err := res.Send(s) if err != nil { return nil, fmt.Errorf("Upgrade write error: %s", err) @@ -55,16 +55,16 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli if res.GetStatusCode() != 101 { s.Close() - return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode()) + return nil, fmt.Errorf("Upgrade error (HTTP %d)", res.GetStatusCode()) } /* (2) Initialise client ---------------------------------------------------------*/ - /* (1) Get upgrade data */ + // 1. Get upgrade data clientURI := req.GetURI() clientProtocol := res.GetProtocol() - /* (2) Initialise client */ + // 2. Initialise client cli := &client{ io: clientIO{ sock: s, @@ -74,7 +74,7 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli iface: &Client{ Protocol: string(clientProtocol), - Arguments: [][]string{[]string{clientURI}}, + Arguments: [][]string{{clientURI}}, }, ch: clientChannelSet{ @@ -85,20 +85,20 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli /* (3) Find controller by URI ---------------------------------------------------------*/ - /* (1) Try to find one */ + // 1. Try to find one controller, arguments := ctl.Match(clientURI) - /* (2) If nothing found -> error */ + // 2. If nothing found -> error if controller == nil { - return nil, fmt.Errorf("No controller found, no default controller set\n") + return nil, fmt.Errorf("No controller found, no default controller set") } - /* (3) Copy arguments */ + // 3. Copy arguments cli.iface.Arguments = arguments /* (4) Launch client routines ---------------------------------------------------------*/ - /* (1) Launch client controller */ + // 1. Launch client controller go controller.Fun( cli.iface, // pass the client cli.ch.receive, // the receiver @@ -106,10 +106,10 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli serverCh.broadcast, // broadcast sender ) - /* (2) Launch message reader */ + // 2. Launch message reader go clientReader(cli) - /* (3) Launc writer */ + // 3. Launc writer go clientWriter(cli) return cli, nil @@ -127,13 +127,13 @@ func clientReader(c *client) { for { - /* (1) if currently closing -> exit */ + // 1. if currently closing -> exit if c.io.closing { fmt.Printf("[reader] killed because closing") break } - /* (2) Parse message */ + // 2. Parse message msg, err := readMessage(c.io.reader) if err == ErrUnmaskedFrame || err == ErrReservedBits { @@ -143,7 +143,7 @@ func clientReader(c *client) { break } - /* (3) Fail on invalid message */ + // 3. Fail on invalid message msgErr := msg.check(frag != nil) if msgErr != nil { @@ -182,7 +182,7 @@ func clientReader(c *client) { } - /* (4) Ping <-> Pong */ + // 4. Ping <-> Pong if msg.Type == Ping && c.io.writing { msg.Final = true msg.Type = Pong @@ -190,7 +190,7 @@ func clientReader(c *client) { continue } - /* (5) Store first fragment */ + // 5. Store first fragment if frag == nil && !msg.Final { frag = &Message{ Type: msg.Type, @@ -201,7 +201,7 @@ func clientReader(c *client) { continue } - /* (6) Store fragments */ + // 6. Store fragments if frag != nil { frag.Final = msg.Final frag.Size += msg.Size @@ -226,7 +226,7 @@ func clientReader(c *client) { } - /* (7) Dispatch to receiver */ + // 7. Dispatch to receiver if msg.Type == Text || msg.Type == Binary { c.ch.receive <- *msg } @@ -236,7 +236,7 @@ func clientReader(c *client) { close(c.ch.receive) c.io.reading.Done() - /* (8) close channel (if not already done) */ + // 8. close channel (if not already done) // fmt.Printf("[reader] end\n") c.close(closeStatus, clientAck) @@ -250,10 +250,10 @@ func clientWriter(c *client) { for msg := range c.ch.send { - /* (2) Send message */ + // 2. Send message err := msg.Send(c.io.sock) - /* (3) Fail on error */ + // 3. Fail on error if err != nil { fmt.Printf(" [writer] %s\n", err) c.io.writing = false @@ -264,7 +264,7 @@ func clientWriter(c *client) { c.io.writing = false - /* (4) close channel (if not already done) */ + // 4. close channel (if not already done) // fmt.Printf("[writer] end\n") c.close(Normal, true) @@ -276,7 +276,7 @@ func clientWriter(c *client) { // then delete client func (c *client) close(status MessageError, clientACK bool) { - /* (1) Fail if already closing */ + // 1. Fail if already closing alreadyClosing := false c.io.closingMu.Lock() alreadyClosing = c.io.closing @@ -287,18 +287,18 @@ func (c *client) close(status MessageError, clientACK bool) { return } - /* (2) kill writer' if still running */ + // 2. kill writer' if still running if c.io.writing { close(c.ch.send) } - /* (3) kill reader if still running */ + // 3. kill reader if still running c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1)) c.io.reading.Wait() if status != None { - /* (3) Build message */ + // 3. Build message msg := &Message{ Final: true, Type: Close, @@ -307,7 +307,7 @@ func (c *client) close(status MessageError, clientACK bool) { } binary.BigEndian.PutUint16(msg.Data, uint16(status)) - /* (4) Send message */ + // 4. Send message msg.Send(c.io.sock) // if err != nil { // fmt.Printf("[close] send error (%s0\n", err) @@ -315,7 +315,7 @@ func (c *client) close(status MessageError, clientACK bool) { } - /* (2) Wait for client CLOSE if needed */ + // 2. Wait for client CLOSE if needed if clientACK { c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond)) @@ -334,11 +334,11 @@ func (c *client) close(status MessageError, clientACK bool) { } - /* (3) Close socket */ + // 3. Close socket c.io.sock.Close() // fmt.Printf("[close] socket closed\n") - /* (4) Unregister */ + // 4. Unregister c.io.kill <- c return diff --git a/cmd/iface/main.go b/cmd/iface/main.go index 83d59d0..6f39b04 100644 --- a/cmd/iface/main.go +++ b/cmd/iface/main.go @@ -11,10 +11,10 @@ func main() { startTime := time.Now().UnixNano() - /* (1) Bind WebSocket server */ + // 1. Bind WebSocket server serv := ws.CreateServer("0.0.0.0", 4444) - /* (2) Bind default controller */ + // 2. Bind default controller serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { defer func() { @@ -33,7 +33,7 @@ func main() { }) - /* (3) Bind to URI */ + // 3. Bind to URI err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { fmt.Printf("[uri] connected\n") @@ -52,7 +52,7 @@ func main() { panic(err) } - /* (4) Launch the server */ + // 4. Launch the server err = serv.Launch() if err != nil { fmt.Printf("[ERROR] %s\n", err) diff --git a/controller.go b/controller.go index 4f8635e..e0f9bfd 100644 --- a/controller.go +++ b/controller.go @@ -30,10 +30,10 @@ type ControllerSet struct { // also it returns the matching string patterns func (s *ControllerSet) Match(uri string) (*Controller, [][]string) { - /* (1) Initialise argument list */ + // 1. Initialise argument list arguments := [][]string{{uri}} - /* (2) Try each controller */ + // 2. Try each controller for _, c := range s.URI { /* 1. If matches */ @@ -52,12 +52,12 @@ func (s *ControllerSet) Match(uri string) (*Controller, [][]string) { } - /* (3) If no controller found -> set default controller */ + // 3. If no controller found -> set default controller if s.Def != nil { return s.Def, arguments } - /* (4) If default is NIL, return empty controller */ + // 4. If default is NIL, return empty controller return nil, arguments } diff --git a/internal/http/reader/reader.go b/internal/http/reader/reader.go index badd424..1bf181b 100644 --- a/internal/http/reader/reader.go +++ b/internal/http/reader/reader.go @@ -35,16 +35,16 @@ func NewReader(r io.Reader) *ChunkReader { // Read reads a chunk, err is io.EOF when done func (r *ChunkReader) Read() ([]byte, error) { - /* (1) If already ended */ + // 1. If already ended if r.isEnded { return nil, io.EOF } - /* (2) Read line */ + // 2. Read line var line []byte line, err := r.reader.ReadSlice('\n') - /* (3) manage errors */ + // 3. manage errors if err == io.EOF { err = io.ErrUnexpectedEOF } @@ -57,10 +57,10 @@ func (r *ChunkReader) Read() ([]byte, error) { return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength) } - /* (4) Trim */ + // 4. Trim line = removeTrailingSpace(line) - /* (5) Manage ending line */ + // 5. Manage ending line if len(line) == 0 { r.isEnded = true return line, io.EOF diff --git a/internal/http/upgrade/request/errors.go b/internal/http/upgrade/errors.go similarity index 98% rename from internal/http/upgrade/request/errors.go rename to internal/http/upgrade/errors.go index 7a31351..befb526 100644 --- a/internal/http/upgrade/request/errors.go +++ b/internal/http/upgrade/errors.go @@ -1,4 +1,4 @@ -package request +package upgrade import ( "fmt" diff --git a/internal/http/upgrade/header.go b/internal/http/upgrade/header.go new file mode 100644 index 0000000..1e5f06b --- /dev/null +++ b/internal/http/upgrade/header.go @@ -0,0 +1,74 @@ +package upgrade + +import ( + "bytes" + "fmt" + "strings" +) + +// HeaderType represents all 'valid' HTTP request headers +type HeaderType uint8 + +// header types +const ( + Unknown HeaderType = iota + Host + Upgrade + Connection + Origin + WSKey + WSProtocol + WSExtensions + WSVersion +) + +// HeaderValue represents a unique or multiple header value(s) +type HeaderValue [][]byte + +// Header represents the data of a HTTP request header +type Header struct { + Name HeaderType + Values HeaderValue +} + +// ReadHeader tries to parse an HTTP header from a byte array +func ReadHeader(b []byte) (*Header, error) { + + // 1. Split by ':' + parts := bytes.Split(b, []byte(": ")) + + if len(parts) != 2 { + return nil, fmt.Errorf("Invalid HTTP header format '%s'", b) + } + + // 2. Create instance + inst := &Header{} + + // 3. Check for header name + switch strings.ToLower(string(parts[0])) { + case "host": + inst.Name = Host + case "upgrade": + inst.Name = Upgrade + case "connection": + inst.Name = Connection + case "origin": + inst.Name = Origin + case "sec-websocket-key": + inst.Name = WSKey + case "sec-websocket-protocol": + inst.Name = WSProtocol + case "sec-websocket-extensions": + inst.Name = WSExtensions + case "sec-websocket-version": + inst.Name = WSVersion + default: + inst.Name = Unknown + } + + // 4. Split values + inst.Values = bytes.Split(parts[1], []byte(", ")) + + return inst, nil + +} diff --git a/internal/http/upgrade/request/header_check.go b/internal/http/upgrade/header_check.go similarity index 76% rename from internal/http/upgrade/request/header_check.go rename to internal/http/upgrade/header_check.go index ad57ec1..0db6f94 100644 --- a/internal/http/upgrade/request/header_check.go +++ b/internal/http/upgrade/header_check.go @@ -1,16 +1,13 @@ -package request +package upgrade import ( "fmt" "strconv" "strings" - - "git.xdrm.io/go/ws/internal/http/upgrade/request/parser/header" - "git.xdrm.io/go/ws/internal/http/upgrade/response" ) // checkHost checks and extracts the Host header -func (r *T) extractHostPort(bb header.HeaderValue) error { +func (r *Request) extractHostPort(bb HeaderValue) error { if len(bb) != 1 { return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))} @@ -32,7 +29,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error { // extract port readPort, err := strconv.ParseUint(split[1], 10, 16) if err != nil { - r.code = response.BadRequest + r.code = BadRequest return &InvalidRequest{"Host", "cannot read port"} } @@ -42,7 +39,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error { if len(r.origin) > 0 { if err != nil { err = r.checkOriginPolicy() - r.code = response.Forbidden + r.code = Forbidden return &InvalidOriginPolicy{r.host, r.origin, err} } } @@ -52,7 +49,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error { } // checkOrigin checks the Origin Header -func (r *T) extractOrigin(bb header.HeaderValue) error { +func (r *Request) extractOrigin(bb HeaderValue) error { // bypass if bypassOriginPolicy { @@ -60,7 +57,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error { } if len(bb) != 1 { - r.code = response.Forbidden + r.code = Forbidden return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))} } @@ -70,7 +67,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error { if len(r.host) > 0 { err := r.checkOriginPolicy() if err != nil { - r.code = response.Forbidden + r.code = Forbidden return &InvalidOriginPolicy{r.host, r.origin, err} } } @@ -80,7 +77,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error { } // checkOriginPolicy origin policy based on 'host' value -func (r *T) checkOriginPolicy() error { +func (r *Request) checkOriginPolicy() error { // TODO: Origin policy, for now BYPASS r.validPolicy = true return nil @@ -88,7 +85,7 @@ func (r *T) checkOriginPolicy() error { // checkConnection checks the 'Connection' header // it MUST contain 'Upgrade' -func (r *T) checkConnection(bb header.HeaderValue) error { +func (r *Request) checkConnection(bb HeaderValue) error { for _, b := range bb { @@ -99,17 +96,17 @@ func (r *T) checkConnection(bb header.HeaderValue) error { } - r.code = response.BadRequest + r.code = BadRequest return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"} } // checkUpgrade checks the 'Upgrade' header // it MUST be 'websocket' -func (r *T) checkUpgrade(bb header.HeaderValue) error { +func (r *Request) checkUpgrade(bb HeaderValue) error { if len(bb) != 1 { - r.code = response.BadRequest + r.code = BadRequest return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} } @@ -118,17 +115,17 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error { return nil } - r.code = response.BadRequest + r.code = BadRequest return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])} } // checkVersion checks the 'Sec-WebSocket-Version' header // it MUST be '13' -func (r *T) checkVersion(bb header.HeaderValue) error { +func (r *Request) checkVersion(bb HeaderValue) error { if len(bb) != 1 || string(bb[0]) != "13" { - r.code = response.UpgradeRequired + r.code = UpgradeRequired return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])} } @@ -139,10 +136,10 @@ func (r *T) checkVersion(bb header.HeaderValue) error { // extractKey extracts the 'Sec-WebSocket-Key' header // it MUST be 24 bytes (base64) -func (r *T) extractKey(bb header.HeaderValue) error { +func (r *Request) extractKey(bb HeaderValue) error { if len(bb) != 1 || len(bb[0]) != 24 { - r.code = response.BadRequest + r.code = BadRequest return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))} } @@ -154,7 +151,7 @@ func (r *T) extractKey(bb header.HeaderValue) error { // extractProtocols extracts the 'Sec-WebSocket-Protocol' header // it can contain multiple values -func (r *T) extractProtocols(bb header.HeaderValue) error { +func (r *Request) extractProtocols(bb HeaderValue) error { r.protocols = bb diff --git a/internal/http/upgrade/request/request-line.go b/internal/http/upgrade/line.go similarity index 66% rename from internal/http/upgrade/request/request-line.go rename to internal/http/upgrade/line.go index 839d763..a463220 100644 --- a/internal/http/upgrade/request/request-line.go +++ b/internal/http/upgrade/line.go @@ -1,4 +1,4 @@ -package request +package upgrade import ( "bytes" @@ -6,50 +6,38 @@ import ( "regexp" ) -// httpMethod represents available http methods -type httpMethod byte - -const ( - OPTIONS httpMethod = iota - GET - HEAD - POST - PUT - DELETE -) - -// RequestLine represents the HTTP Request line +// Line represents the HTTP Request line // defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1 -type RequestLine struct { - method httpMethod +type Line struct { + method Method uri string version byte } -// parseRequestLine parses the first HTTP request line -func (r *RequestLine) Parse(b []byte) error { +// Parse parses the first HTTP request line +func (r *Line) Parse(b []byte) error { - /* (1) Split by ' ' */ + // 1. Split by ' ' parts := bytes.Split(b, []byte(" ")) - /* (2) Fail when missing parts */ + // 2. Fail when missing parts if len(parts) != 3 { return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts)) } - /* (3) Extract HTTP method */ + // 3. Extract HTTP method err := r.extractHttpMethod(parts[0]) if err != nil { return err } - /* (4) Extract URI */ + // 4. Extract URI err = r.extractURI(parts[1]) if err != nil { return err } - /* (5) Extract version */ + // 5. Extract version err = r.extractHttpVersion(parts[2]) if err != nil { return err @@ -60,19 +48,19 @@ func (r *RequestLine) Parse(b []byte) error { } // GetURI returns the actual URI -func (r RequestLine) GetURI() string { +func (r Line) GetURI() string { return r.uri } // extractHttpMethod extracts the HTTP method from a []byte // and checks for errors // allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE -func (r *RequestLine) extractHttpMethod(b []byte) error { +func (r *Line) extractHttpMethod(b []byte) error { switch string(b) { // case "OPTIONS": r.method = OPTIONS case "GET": - r.method = GET + r.method = Get // case "HEAD": r.method = HEAD // case "POST": r.method = POST // case "PUT": r.method = PUT @@ -87,15 +75,15 @@ func (r *RequestLine) extractHttpMethod(b []byte) error { // extractURI extracts the URI from a []byte and checks for errors // allowed format: /([^/]/)*/? -func (r *RequestLine) extractURI(b []byte) error { +func (r *Line) extractURI(b []byte) error { - /* (1) Check format */ + // 1. Check format checker := regexp.MustCompile("^(?:/[^/]+)*/?$") if !checker.Match(b) { return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b) } - /* (2) Store */ + // 2. Store r.uri = string(b) return nil @@ -104,26 +92,26 @@ func (r *RequestLine) extractURI(b []byte) error { // extractHttpVersion extracts the version and checks for errors // allowed format: [1-9] or [1.9].[0-9] -func (r *RequestLine) extractHttpVersion(b []byte) error { +func (r *Line) extractHttpVersion(b []byte) error { - /* (1) Extract version parts */ + // 1. Extract version parts extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`) if !extractor.Match(b) { return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b) } - /* (2) Extract version number */ + // 2. Extract version number matches := extractor.FindSubmatch(b) var version byte = matches[1][0] - '0' - /* (3) Extract subversion (if exists) */ + // 3. Extract subversion (if exists) var subVersion byte = 0 if len(matches[2]) > 0 { subVersion = matches[2][0] - '0' } - /* (4) Store version (x 10 to fit uint8) */ + // 4. Store version (x 10 to fit uint8) r.version = version*10 + subVersion return nil diff --git a/internal/http/upgrade/methods.go b/internal/http/upgrade/methods.go new file mode 100644 index 0000000..6b791d6 --- /dev/null +++ b/internal/http/upgrade/methods.go @@ -0,0 +1,14 @@ +package upgrade + +// Method represents available http methods +type Method uint8 + +// http methods +const ( + Options Method = iota + Get + Head + Post + Put + Delete +) diff --git a/internal/http/upgrade/request.go b/internal/http/upgrade/request.go new file mode 100644 index 0000000..1a7c0f8 --- /dev/null +++ b/internal/http/upgrade/request.go @@ -0,0 +1,216 @@ +package upgrade + +import ( + "fmt" + "io" + + "git.xdrm.io/go/ws/internal/http/reader" +) + +// If origin is required +const bypassOriginPolicy = true + +// Request represents an HTTP Upgrade request +type Request struct { + first bool // whether the first line has been read (GET uri HTTP/version) + + // status code + code StatusCode + + // request line + request Line + + // data to check origin (depends of reading order) + host string + port uint16 // 0 if not set + origin string + validPolicy bool + + // ws data + key []byte + protocols [][]byte + + // required fields check + hasConnection bool + hasUpgrade bool + hasVersion bool +} + +// Parse builds an upgrade HTTP request +// from a reader (typically bufio.NewRead of the socket) +func Parse(r io.Reader) (request *Request, err error) { + + req := &Request{ + code: 500, + } + + /* (1) Parse request + ---------------------------------------------------------*/ + // 1. Get chunk reader + cr := reader.NewReader(r) + if err != nil { + return req, fmt.Errorf("Error while creating chunk reader: %s", err) + } + + // 2. Parse header line by line + for { + + line, err := cr.Read() + if err == io.EOF { + break + } + + if err != nil { + return req, err + } + + err = req.parseHeader(line) + + if err != nil { + return req, err + } + + } + + // 3. Check completion + err = req.isComplete() + if err != nil { + req.code = BadRequest + return req, err + } + + req.code = SwitchingProtocols + return req, nil + +} + +// StatusCode returns the status current +func (r Request) StatusCode() StatusCode { + return r.code +} + +// BuildResponse builds a response from the request +func (r *Request) BuildResponse() *Response { + + inst := &Response{} + + // 1. Copy code + inst.SetStatusCode(r.code) + + // 2. Set Protocol + if len(r.protocols) > 0 { + inst.SetProtocol(r.protocols[0]) + } + + // 4. Process key + inst.ProcessKey(r.key) + + return inst +} + +// GetURI returns the actual URI +func (r Request) GetURI() string { + return r.request.GetURI() +} + +// parseHeader parses any http request line +// (header and request-line) +func (r *Request) parseHeader(b []byte) error { + + /* (1) First line -> GET {uri} HTTP/{version} + ---------------------------------------------------------*/ + if !r.first { + + err := r.request.Parse(b) + + if err != nil { + r.code = BadRequest + return &InvalidRequest{"Request-Line", err.Error()} + } + + r.first = true + return nil + + } + + /* (2) Other lines -> Header-Name: Header-Value + ---------------------------------------------------------*/ + // 1. Try to parse header + head, err := ReadHeader(b) + if err != nil { + r.code = BadRequest + return fmt.Errorf("Error parsing header: %s", err) + } + + // 2. Manage header + switch head.Name { + case Host: + err = r.extractHostPort(head.Values) + case Origin: + err = r.extractOrigin(head.Values) + case Upgrade: + err = r.checkUpgrade(head.Values) + case Connection: + err = r.checkConnection(head.Values) + case WSVersion: + err = r.checkVersion(head.Values) + case WSKey: + err = r.extractKey(head.Values) + case WSProtocol: + err = r.extractProtocols(head.Values) + + default: + return nil + } + + // dispatch error + if err != nil { + return err + } + + return nil + +} + +// isComplete returns whether the Upgrade Request +// is complete (no missing required item) +func (r Request) isComplete() error { + + // 1. Request-Line + if !r.first { + return &IncompleteRequest{"Request-Line"} + } + + // 2. Host + if len(r.host) == 0 { + return &IncompleteRequest{"Host"} + } + + // 3. Origin + if !bypassOriginPolicy && len(r.origin) == 0 { + return &IncompleteRequest{"Origin"} + } + + // 4. Connection + if !r.hasConnection { + return &IncompleteRequest{"Connection"} + } + + // 5. Upgrade + if !r.hasUpgrade { + return &IncompleteRequest{"Upgrade"} + } + + // 6. Sec-WebSocket-Version + if !r.hasVersion { + return &IncompleteRequest{"Sec-WebSocket-Version"} + } + + // 7. Sec-WebSocket-Key + if len(r.key) < 1 { + return &IncompleteRequest{"Sec-WebSocket-Key"} + } + + return nil + +} diff --git a/internal/http/upgrade/request/parser/header/public.go b/internal/http/upgrade/request/parser/header/public.go deleted file mode 100644 index c00b6d0..0000000 --- a/internal/http/upgrade/request/parser/header/public.go +++ /dev/null @@ -1,41 +0,0 @@ -package header - -import ( - // "regexp" - "fmt" - "strings" - "bytes" -) - -// parse tries to return a 'T' (httpHeader) from a byte array -func Parse(b []byte) (*T, error) { - - /* (1) Split by ':' */ - parts := bytes.Split(b, []byte(": ")) - - if len(parts) != 2 { - return nil, fmt.Errorf("Invalid HTTP header format '%s'", b) - } - - /* (2) Create instance */ - inst := new(T) - - /* (3) Check for header name */ - switch strings.ToLower(string(parts[0])) { - case "host": inst.Name = HOST - case "upgrade": inst.Name = UPGRADE - case "connection": inst.Name = CONNECTION - case "origin": inst.Name = ORIGIN - case "sec-websocket-key": inst.Name = WSKEY - case "sec-websocket-protocol": inst.Name = WSPROTOCOL - case "sec-websocket-extensions": inst.Name = WSEXTENSIONS - case "sec-websocket-version": inst.Name = WSVERSION - default: inst.Name = UNKNOWN - } - - /* (4) Split values */ - inst.Values = bytes.Split(parts[1], []byte(", ")) - - return inst, nil - -} \ No newline at end of file diff --git a/internal/http/upgrade/request/parser/header/types.go b/internal/http/upgrade/request/parser/header/types.go deleted file mode 100644 index 994e7df..0000000 --- a/internal/http/upgrade/request/parser/header/types.go +++ /dev/null @@ -1,25 +0,0 @@ -package header - -// HeaderType represents all 'valid' HTTP request headers -type HeaderType byte -const ( - UNKNOWN HeaderType = iota - HOST - UPGRADE - CONNECTION - ORIGIN - WSKEY - WSPROTOCOL - WSEXTENSIONS - WSVERSION -) - -// HeaderValue represents a unique or multiple header value(s) -type HeaderValue [][]byte - - -// T represents the data of a HTTP request header -type T struct{ - Name HeaderType - Values HeaderValue -} \ No newline at end of file diff --git a/internal/http/upgrade/request/private.go b/internal/http/upgrade/request/private.go deleted file mode 100644 index e8021ff..0000000 --- a/internal/http/upgrade/request/private.go +++ /dev/null @@ -1,110 +0,0 @@ -package request - -import ( - "fmt" - - "git.xdrm.io/go/ws/internal/http/upgrade/request/parser/header" - "git.xdrm.io/go/ws/internal/http/upgrade/response" -) - -// parseHeader parses any http request line -// (header and request-line) -func (r *T) parseHeader(b []byte) error { - - /* (1) First line -> GET {uri} HTTP/{version} - ---------------------------------------------------------*/ - if !r.first { - - err := r.request.Parse(b) - - if err != nil { - r.code = response.BadRequest - return &InvalidRequest{"Request-Line", err.Error()} - } - - r.first = true - return nil - - } - - /* (2) Other lines -> Header-Name: Header-Value - ---------------------------------------------------------*/ - /* (1) Try to parse header */ - head, err := header.Parse(b) - if err != nil { - r.code = response.BadRequest - return fmt.Errorf("Error parsing header: %s", err) - } - - /* (2) Manage header */ - switch head.Name { - case header.HOST: - err = r.extractHostPort(head.Values) - case header.ORIGIN: - err = r.extractOrigin(head.Values) - case header.UPGRADE: - err = r.checkUpgrade(head.Values) - case header.CONNECTION: - err = r.checkConnection(head.Values) - case header.WSVERSION: - err = r.checkVersion(head.Values) - case header.WSKEY: - err = r.extractKey(head.Values) - case header.WSPROTOCOL: - err = r.extractProtocols(head.Values) - - default: - return nil - } - - // dispatch error - if err != nil { - return err - } - - return nil - -} - -// isComplete returns whether the Upgrade Request -// is complete (no missing required item) -func (r T) isComplete() error { - - /* (1) Request-Line */ - if !r.first { - return &IncompleteRequest{"Request-Line"} - } - - /* (2) Host */ - if len(r.host) == 0 { - return &IncompleteRequest{"Host"} - } - - /* (3) Origin */ - if !bypassOriginPolicy && len(r.origin) == 0 { - return &IncompleteRequest{"Origin"} - } - - /* (4) Connection */ - if !r.hasConnection { - return &IncompleteRequest{"Connection"} - } - - /* (5) Upgrade */ - if !r.hasUpgrade { - return &IncompleteRequest{"Upgrade"} - } - - /* (6) Sec-WebSocket-Version */ - if !r.hasVersion { - return &IncompleteRequest{"Sec-WebSocket-Version"} - } - - /* (7) Sec-WebSocket-Key */ - if len(r.key) < 1 { - return &IncompleteRequest{"Sec-WebSocket-Key"} - } - - return nil - -} diff --git a/internal/http/upgrade/request/public.go b/internal/http/upgrade/request/public.go deleted file mode 100644 index c8e8b34..0000000 --- a/internal/http/upgrade/request/public.go +++ /dev/null @@ -1,85 +0,0 @@ -package request - -import ( - "fmt" - "io" - - "git.xdrm.io/go/ws/internal/http/reader" - "git.xdrm.io/go/ws/internal/http/upgrade/response" -) - -// Parse builds an upgrade HTTP request -// from a reader (typically bufio.NewRead of the socket) -func Parse(r io.Reader) (request *T, err error) { - - req := new(T) - req.code = 500 - - /* (1) Parse request - ---------------------------------------------------------*/ - /* (1) Get chunk reader */ - cr := reader.NewReader(r) - if err != nil { - return req, fmt.Errorf("Error while creating chunk reader: %s", err) - } - - /* (2) Parse header line by line */ - for { - - line, err := cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return req, err - } - - err = req.parseHeader(line) - - if err != nil { - return req, err - } - - } - - /* (3) Check completion */ - err = req.isComplete() - if err != nil { - req.code = response.BadRequest - return req, err - } - - req.code = response.SwitchingProtocols - return req, nil - -} - -// StatusCode returns the status current -func (r T) StatusCode() response.StatusCode { - return r.code -} - -// BuildResponse builds a response.T from the request -func (r *T) BuildResponse() *response.T { - - inst := new(response.T) - - /* (1) Copy code */ - inst.SetStatusCode(r.code) - - /* (2) Set Protocol */ - if len(r.protocols) > 0 { - inst.SetProtocol(r.protocols[0]) - } - - /* (4) Process key */ - inst.ProcessKey(r.key) - - return inst -} - -// GetURI returns the actual URI -func (r T) GetURI() string { - return r.request.GetURI() -} diff --git a/internal/http/upgrade/request/types.go b/internal/http/upgrade/request/types.go deleted file mode 100644 index 282ef91..0000000 --- a/internal/http/upgrade/request/types.go +++ /dev/null @@ -1,32 +0,0 @@ -package request - -import "git.xdrm.io/go/ws/internal/http/upgrade/response" - -// If origin is required -const bypassOriginPolicy = true - -// T represents an HTTP Upgrade request -type T struct { - first bool // whether the first line has been read (GET uri HTTP/version) - - // status code - code response.StatusCode - - // request line - request RequestLine - - // data to check origin (depends of reading order) - host string - port uint16 // 0 if not set - origin string - validPolicy bool - - // ws data - key []byte - protocols [][]byte - - // required fields check - hasConnection bool - hasUpgrade bool - hasVersion bool -} diff --git a/internal/http/upgrade/request/request_test.go b/internal/http/upgrade/request_test.go similarity index 94% rename from internal/http/upgrade/request/request_test.go rename to internal/http/upgrade/request_test.go index ec78c5d..3d12b30 100644 --- a/internal/http/upgrade/request/request_test.go +++ b/internal/http/upgrade/request_test.go @@ -1,4 +1,4 @@ -package request +package upgrade import ( "bytes" @@ -6,13 +6,13 @@ import ( "testing" ) -// /* (1) Parse request */ +// // 1. Parse request // req, _ := request.Parse(s) -// /* (3) Build response */ +// // 3. Build response // res := req.BuildResponse() -// /* (4) Write into socket */ +// // 4. Write into socket // _, err := res.Send(s) // if err != nil { // return nil, fmt.Errorf("Upgrade write error: %s", err) @@ -25,7 +25,7 @@ import ( func TestEOFSocket(t *testing.T) { - socket := new(bytes.Buffer) + socket := &bytes.Buffer{} _, err := Parse(socket) @@ -39,7 +39,7 @@ func TestEOFSocket(t *testing.T) { func TestInvalidRequestLine(t *testing.T) { - socket := new(bytes.Buffer) + socket := &bytes.Buffer{} cases := []struct { Reqline string HasError bool @@ -113,7 +113,7 @@ func TestInvalidHost(t *testing.T) { requestLine := []byte("GET / HTTP/1.1\r\n") - socket := new(bytes.Buffer) + socket := &bytes.Buffer{} cases := []struct { Host string HasError bool diff --git a/internal/http/upgrade/response/public.go b/internal/http/upgrade/response.go similarity index 50% rename from internal/http/upgrade/response/public.go rename to internal/http/upgrade/response.go index 8152ff9..cd0b44d 100644 --- a/internal/http/upgrade/response/public.go +++ b/internal/http/upgrade/response.go @@ -1,4 +1,4 @@ -package response +package upgrade import ( "crypto/sha1" @@ -7,19 +7,35 @@ import ( "io" ) +// HTTPVersion constant +const HTTPVersion = "1.1" + +// UsedWSVersion constant websocket version +const UsedWSVersion = 13 + +// WSSalt constant websocket salt +const WSSalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +// Response represents an HTTP Upgrade Response +type Response struct { + code StatusCode // status code + accept []byte // processed from Sec-WebSocket-Key + protocol []byte // set from Sec-WebSocket-Protocol or none if not received +} + // SetStatusCode sets the status code -func (r *T) SetStatusCode(sc StatusCode) { +func (r *Response) SetStatusCode(sc StatusCode) { r.code = sc } // SetProtocol sets the protocols -func (r *T) SetProtocol(p []byte) { +func (r *Response) SetProtocol(p []byte) { r.protocol = p } // ProcessKey processes the accept token according // to the rfc from the Sec-WebSocket-Key -func (r *T) ProcessKey(k []byte) { +func (r *Response) ProcessKey(k []byte) { // do nothing for empty key if k == nil || len(k) == 0 { @@ -27,40 +43,40 @@ func (r *T) ProcessKey(k []byte) { return } - /* (1) Concat with constant salt */ - mix := append(k, WSSalt...) + // 1. Concat with constant salt + mix := append(k, []byte(WSSalt)...) - /* (2) Hash with sha1 algorithm */ + // 2. Hash with sha1 algorithm digest := sha1.Sum(mix) - /* (3) Base64 encode it */ + // 3. Base64 encode it r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size])) } // Send sends the response through an io.Writer // typically a socket -func (r T) Send(w io.Writer) (int, error) { +func (r Response) Send(w io.Writer) (int, error) { - /* (1) Build response line */ - responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HttpVersion, r.code, r.code.String()) + // 1. Build response line + responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HTTPVersion, r.code, r.code) - /* (2) Build headers */ + // 2. Build headers optionalProtocol := "" if len(r.protocol) > 0 { optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol) } - headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", WSVersion, optionalProtocol) + headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", UsedWSVersion, optionalProtocol) if r.accept != nil { headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept) } headers = fmt.Sprintf("%s\r\n", headers) - /* (3) Build all */ + // 3. Build all raw := []byte(fmt.Sprintf("%s%s", responseLine, headers)) - /* (4) Write */ + // 4. Write written, err := w.Write(raw) return written, err @@ -68,11 +84,11 @@ func (r T) Send(w io.Writer) (int, error) { } // GetProtocol returns the choosen protocol if set, else nil -func (r T) GetProtocol() []byte { +func (r Response) GetProtocol() []byte { return r.protocol } // GetStatusCode returns the response status code -func (r T) GetStatusCode() StatusCode { +func (r Response) GetStatusCode() StatusCode { return r.code } diff --git a/internal/http/upgrade/response/types.go b/internal/http/upgrade/response/types.go deleted file mode 100644 index 6784ea8..0000000 --- a/internal/http/upgrade/response/types.go +++ /dev/null @@ -1,15 +0,0 @@ -package response - -// Constant -const HttpVersion = "1.1" -const WSVersion = 13 - -var WSSalt []byte = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -// T represents an HTTP Upgrade Response -type T struct { - code StatusCode // status code - accept []byte // processed from Sec-WebSocket-Key - protocol []byte // set from Sec-WebSocket-Protocol or none if not received - -} diff --git a/internal/http/upgrade/response/status_code.go b/internal/http/upgrade/status_code.go similarity index 97% rename from internal/http/upgrade/response/status_code.go rename to internal/http/upgrade/status_code.go index 39344e7..aed18be 100644 --- a/internal/http/upgrade/response/status_code.go +++ b/internal/http/upgrade/status_code.go @@ -1,9 +1,9 @@ -package response +package upgrade // StatusCode maps the status codes (and description) type StatusCode uint16 -var ( +const ( // SwitchingProtocols - handshake success SwitchingProtocols StatusCode = 101 // BadRequest - missing/malformed headers diff --git a/internal/uri/parser/private.go b/internal/uri/parser/private.go index e2a22c1..7daa62c 100644 --- a/internal/uri/parser/private.go +++ b/internal/uri/parser/private.go @@ -9,45 +9,45 @@ import ( // from a pattern string func buildScheme(ss []string) (Scheme, error) { - /* (1) Build scheme */ + // 1. Build scheme sch := make(Scheme, 0, maxMatch) for _, s := range ss { - /* (2) ignore empty */ + // 2. ignore empty if len(s) == 0 { continue } - m := new(matcher) + m := &matcher{} switch s { - /* (3) Card: 0, N */ + // 3. Card: 0, N case "**": m.req = false m.mul = true sch = append(sch, m) - /* (4) Card: 1, N */ + // 4. Card: 1, N case "..": m.req = true m.mul = true sch = append(sch, m) - /* (5) Card: 0, 1 */ + // 5. Card: 0, 1 case "*": m.req = false m.mul = false sch = append(sch, m) - /* (6) Card: 1 */ + // 6. Card: 1 case ".": m.req = true m.mul = false sch = append(sch, m) - /* (7) Card: 1, literal string */ + // 7. Card: 1, literal string default: m.req = true m.mul = false @@ -64,16 +64,16 @@ func buildScheme(ss []string) (Scheme, error) { // optimise optimised the scheme for further parsing func (s Scheme) optimise() (Scheme, error) { - /* (1) Nothing to do if only 1 element */ + // 1. Nothing to do if only 1 element if len(s) <= 1 { return s, nil } - /* (2) Init reshifted scheme */ + // 2. Init reshifted scheme rshift := make(Scheme, 0, maxMatch) rshift = append(rshift, s[0]) - /* (2) Iterate over matchers */ + // 2. Iterate over matchers for p, i, l := 0, 1, len(s); i < l; i++ { pre, cur := s[p], s[i] @@ -106,11 +106,11 @@ func (s Scheme) optimise() (Scheme, error) { // it returns a cleared uri, without STRING data func (s Scheme) matchString(uri string) (string, bool) { - /* (1) Initialise variables */ + // 1. Initialise variables clr := uri // contains cleared input string minOff := 0 // minimum offset - /* (2) Iterate over strings */ + // 2. Iterate over strings for _, m := range s { ls := len(m.pat) @@ -147,12 +147,12 @@ func (s Scheme) matchString(uri string) (string, bool) { } - /* (3) If exists, remove trailing '/' */ + // 3. If exists, remove trailing '/' if clr[len(clr)-1] == '/' { clr = clr[:len(clr)-1] } - /* (4) If exists, remove trailing '\a' */ + // 4. If exists, remove trailing '\a' if clr[len(clr)-1] == '\a' { clr = clr[:len(clr)-1] } @@ -166,7 +166,7 @@ func (s Scheme) matchString(uri string) (string, bool) { // + it sets the matchers buffers for later extraction func (s Scheme) matchWildcards(clear string) bool { - /* (1) Extract wildcards (ref) */ + // 1. Extract wildcards (ref) wildcards := make(Scheme, 0, maxMatch) for _, m := range s { @@ -176,15 +176,15 @@ func (s Scheme) matchWildcards(clear string) bool { } } - /* (2) If no wildcards -> match */ + // 2. If no wildcards -> match if len(wildcards) == 0 { return true } - /* (3) Break uri by '\a' characters */ + // 3. Break uri by '\a' characters matches := strings.Split(clear, "\a")[1:] - /* (4) Iterate over matches */ + // 4. Iterate over matches for n, match := range matches { // {1} If no more matcher // @@ -210,7 +210,7 @@ func (s Scheme) matchWildcards(clear string) bool { } - /* (5) Match */ + // 5. Match return true } diff --git a/internal/uri/parser/public.go b/internal/uri/parser/public.go index 9726ed6..e564ddc 100644 --- a/internal/uri/parser/public.go +++ b/internal/uri/parser/public.go @@ -8,15 +8,15 @@ import ( // Build builds an URI scheme from a pattern string func Build(s string) (*Scheme, error) { - /* (1) Manage '/' at the start */ + // 1. Manage '/' at the start if len(s) < 1 || s[0] != '/' { return nil, fmt.Errorf("URI must begin with '/'") } - /* (2) Split by '/' */ + // 2. Split by '/' parts := strings.Split(s, "/") - /* (3) Max exceeded */ + // 3. Max exceeded if len(parts)-2 > maxMatch { for i, p := range parts { fmt.Printf("%d: '%s'\n", i, p) @@ -24,13 +24,13 @@ func Build(s string) (*Scheme, error) { return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts)) } - /* (4) Build for each part */ + // 4. Build for each part sch, err := buildScheme(parts) if err != nil { return nil, err } - /* (5) Optimise structure */ + // 5. Optimise structure opti, err := sch.optimise() if err != nil { return nil, err @@ -43,18 +43,18 @@ func Build(s string) (*Scheme, error) { // Match returns if the given URI is matched by the scheme func (s Scheme) Match(str string) bool { - /* (1) Nothing -> match all */ + // 1. Nothing -> match all if len(s) == 0 { return true } - /* (2) Check for string match */ + // 2. Check for string match clearURI, match := s.matchString(str) if !match { return false } - /* (3) Check for non-string match (wildcards) */ + // 3. Check for non-string match (wildcards) match = s.matchWildcards(clearURI) if !match { return false @@ -66,12 +66,12 @@ func (s Scheme) Match(str string) bool { // GetMatch returns the indexed match (excluding string matchers) func (s Scheme) GetMatch(n uint8) ([]string, error) { - /* (1) Index out of range */ + // 1. Index out of range if n > uint8(len(s)) { return nil, fmt.Errorf("Index out of range") } - /* (2) Iterate to find index (exclude strings) */ + // 2. Iterate to find index (exclude strings) ni := -1 for _, m := range s { @@ -90,7 +90,7 @@ func (s Scheme) GetMatch(n uint8) ([]string, error) { } - /* (3) If nothing found -> return empty set */ + // 3. If nothing found -> return empty set return nil, fmt.Errorf("Index out of range (max: %d)", ni) } diff --git a/message.go b/message.go index d117342..9ddc43e 100644 --- a/message.go +++ b/message.go @@ -2,32 +2,36 @@ package websocket import ( "encoding/binary" - "fmt" "io" "unicode/utf8" ) -var ( +// constant error +type constErr string + +func (c constErr) Error() string { return string(c) } + +const ( // ErrUnmaskedFrame error - ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame") + ErrUnmaskedFrame = constErr("Received unmasked frame") // ErrTooLongControlFrame error - ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long") + ErrTooLongControlFrame = constErr("Received a control frame that is fragmented or too long") // ErrInvalidFragment error - ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation") + ErrInvalidFragment = constErr("Received invalid fragmentation") // ErrUnexpectedContinuation error - ErrUnexpectedContinuation = fmt.Errorf("Received unexpected continuation frame") + ErrUnexpectedContinuation = constErr("Received unexpected continuation frame") // ErrInvalidSize error - ErrInvalidSize = fmt.Errorf("Received invalid payload size") + ErrInvalidSize = constErr("Received invalid payload size") // ErrInvalidPayload error - ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload") + ErrInvalidPayload = constErr("Received invalid utf8 payload") // ErrInvalidCloseStatus error - ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status") + ErrInvalidCloseStatus = constErr("Received invalid close status") // ErrInvalidOpCode error - ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode") + ErrInvalidOpCode = constErr("Received invalid OpCode") // ErrReservedBits error - ErrReservedBits = fmt.Errorf("Received reserved bits") + ErrReservedBits = constErr("Received reserved bits") // ErrCloseFrame error - ErrCloseFrame = fmt.Errorf("Received close Frame") + ErrCloseFrame = constErr("Received close Frame") ) // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask @@ -88,9 +92,9 @@ func readMessage(reader io.Reader) (*Message, error) { var mask []byte var cursor int - m := new(Message) + m := &Message{} - /* (2) Byte 1: FIN and OpCode */ + // 2. Byte 1: FIN and OpCode tmpBuf = make([]byte, 1) err = readBytes(reader, tmpBuf) if err != nil { @@ -105,7 +109,7 @@ func readMessage(reader io.Reader) (*Message, error) { m.Final = bool(tmpBuf[0]&0x80 == 0x80) m.Type = MessageType(tmpBuf[0] & 0x0f) - /* (3) Byte 2: Mask and Length[0] */ + // 3. Byte 2: Mask and Length[0] tmpBuf = make([]byte, 1) err = readBytes(reader, tmpBuf) if err != nil { @@ -120,7 +124,7 @@ func readMessage(reader io.Reader) (*Message, error) { // payload length m.Size = uint(tmpBuf[0] & 0x7f) - /* (4) Extended payload */ + // 4. Extended payload if m.Size == 127 { tmpBuf = make([]byte, 8) @@ -143,7 +147,7 @@ func readMessage(reader io.Reader) (*Message, error) { } - /* (5) Masking key */ + // 5. Masking key if mask != nil { tmpBuf = make([]byte, 4) @@ -157,7 +161,7 @@ func readMessage(reader io.Reader) (*Message, error) { } - /* (6) Read payload by chunks */ + // 6. Read payload by chunks m.Data = make([]byte, int(m.Size)) cursor = 0 @@ -207,14 +211,14 @@ func (m Message) Send(writer io.Writer) error { m.Size = uint(len(m.Data)) } - /* (1) Byte 0 : FIN + opcode */ + // 1. Byte 0 : FIN + opcode var final byte = 0x80 if !m.Final { final = 0 } header = append(header, final|byte(m.Type)) - /* (2) Get payload length */ + // 2. Get payload length if m.Size < 126 { // simple header = append(header, byte(m.Size)) @@ -237,12 +241,12 @@ func (m Message) Send(writer io.Writer) error { } - /* (3) Build write buffer */ + // 3. Build write buffer writeBuf := make([]byte, 0, len(header)+int(m.Size)) writeBuf = append(writeBuf, header...) writeBuf = append(writeBuf, m.Data[0:m.Size]...) - /* (4) Send over socket by chunks */ + // 4. Send over socket by chunks toWrite := len(header) + int(m.Size) cursor := 0 for cursor < toWrite { @@ -272,17 +276,17 @@ func (m Message) Send(writer io.Writer) error { // returns the message error func (m *Message) check(fragment bool) error { - /* (1) Invalid first fragment (not TEXT nor BINARY) */ + // 1. Invalid first fragment (not TEXT nor BINARY) if !m.Final && !fragment && m.Type != Text && m.Type != Binary { return ErrInvalidFragment } - /* (2) Waiting fragment but received standalone frame */ + // 2. Waiting fragment but received standalone frame if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong { return ErrInvalidFragment } - /* (3) Control frame too long */ + // 3. Control frame too long if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) { return ErrTooLongControlFrame } @@ -335,8 +339,6 @@ func (m *Message) check(fragment bool) error { return ErrInvalidOpCode } - - return nil } // readBytes reads from a reader into a byte array diff --git a/message_test.go b/message_test.go index 1657a23..0f0441e 100644 --- a/message_test.go +++ b/message_test.go @@ -267,7 +267,7 @@ func TestSimpleMessageSending(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { - writer := new(bytes.Buffer) + writer := &bytes.Buffer{} err := tc.Base.Send(writer) diff --git a/server.go b/server.go index 141c9a3..7bc47d3 100644 --- a/server.go +++ b/server.go @@ -65,13 +65,13 @@ func (s *Server) BindDefault(f ControllerFunc) { // Bind a controller to an URI scheme func (s *Server) Bind(uri string, f ControllerFunc) error { - /* (1) Build URI parser */ + // 1. Build URI parser uriScheme, err := parser.Build(uri) if err != nil { return fmt.Errorf("Cannot build URI: %s", err) } - /* (2) Create controller */ + // 2. Create controller s.ctl.URI = append(s.ctl.URI, &Controller{ URI: uriScheme, Fun: f, @@ -88,10 +88,10 @@ func (s *Server) Launch() error { /* (1) Listen socket ---------------------------------------------------------*/ - /* (1) Build full url */ + // 1. Build full url url := fmt.Sprintf("%s:%d", s.addr, s.port) - /* (2) Bind socket to listen */ + // 2. Bind socket to listen s.sock, err = net.Listen("tcp", url) if err != nil { return fmt.Errorf("Listen socket: %s", err) @@ -101,14 +101,14 @@ func (s *Server) Launch() error { fmt.Printf("+ listening on %s\n", url) - /* (3) Launch scheduler */ + // 3. Launch scheduler go s.scheduler() /* (2) For each incoming connection (client) ---------------------------------------------------------*/ for { - /* (1) Wait for client */ + // 1. Wait for client sock, err := s.sock.Accept() if err != nil { break @@ -116,14 +116,14 @@ func (s *Server) Launch() error { go func() { - /* (2) Try to create client */ + // 2. Try to create client cli, err := buildClient(sock, s.ctl, s.ch) if err != nil { fmt.Printf(" - %s\n", err) return } - /* (3) Register client */ + // 3. Register client s.ch.register <- cli }() @@ -141,15 +141,15 @@ func (s *Server) scheduler() { select { - /* (1) Create client */ + // 1. Create client case client := <-s.ch.register: s.clients[client.io.sock] = client - /* (2) Remove client */ + // 2. Remove client case client := <-s.ch.unregister: delete(s.clients, client.io.sock) - /* (3) Broadcast */ + // 3. Broadcast case msg := <-s.ch.broadcast: for _, c := range s.clients { c.ch.send <- msg