refactor: http/upgrade normalise files

This commit is contained in:
Adrien Marquès 2021-05-14 17:19:02 +02:00
parent db52cfd28f
commit 6c47dbc38f
Signed by: xdrm-brackets
GPG Key ID: D75243CA236D825E
24 changed files with 507 additions and 508 deletions

View File

@ -8,7 +8,7 @@ import (
"sync" "sync"
"time" "time"
"git.xdrm.io/go/ws/internal/http/upgrade/request" "git.xdrm.io/go/ws/internal/http/upgrade"
) )
// Represents a client socket utility (reader, writer, ..) // Represents a client socket utility (reader, writer, ..)
@ -41,13 +41,13 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
/* (1) Manage UPGRADE request /* (1) Manage UPGRADE request
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Parse request */ // 1. Parse request
req, _ := request.Parse(s) req, _ := upgrade.Parse(s)
/* (3) Build response */ // 3. Build response
res := req.BuildResponse() res := req.BuildResponse()
/* (4) Write into socket */ // 4. Write into socket
_, err := res.Send(s) _, err := res.Send(s)
if err != nil { if err != nil {
return nil, fmt.Errorf("Upgrade write error: %s", err) return nil, fmt.Errorf("Upgrade write error: %s", err)
@ -55,16 +55,16 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
if res.GetStatusCode() != 101 { if res.GetStatusCode() != 101 {
s.Close() s.Close()
return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode()) return nil, fmt.Errorf("Upgrade error (HTTP %d)", res.GetStatusCode())
} }
/* (2) Initialise client /* (2) Initialise client
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Get upgrade data */ // 1. Get upgrade data
clientURI := req.GetURI() clientURI := req.GetURI()
clientProtocol := res.GetProtocol() clientProtocol := res.GetProtocol()
/* (2) Initialise client */ // 2. Initialise client
cli := &client{ cli := &client{
io: clientIO{ io: clientIO{
sock: s, sock: s,
@ -74,7 +74,7 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
iface: &Client{ iface: &Client{
Protocol: string(clientProtocol), Protocol: string(clientProtocol),
Arguments: [][]string{[]string{clientURI}}, Arguments: [][]string{{clientURI}},
}, },
ch: clientChannelSet{ ch: clientChannelSet{
@ -85,20 +85,20 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
/* (3) Find controller by URI /* (3) Find controller by URI
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Try to find one */ // 1. Try to find one
controller, arguments := ctl.Match(clientURI) controller, arguments := ctl.Match(clientURI)
/* (2) If nothing found -> error */ // 2. If nothing found -> error
if controller == nil { if controller == nil {
return nil, fmt.Errorf("No controller found, no default controller set\n") return nil, fmt.Errorf("No controller found, no default controller set")
} }
/* (3) Copy arguments */ // 3. Copy arguments
cli.iface.Arguments = arguments cli.iface.Arguments = arguments
/* (4) Launch client routines /* (4) Launch client routines
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Launch client controller */ // 1. Launch client controller
go controller.Fun( go controller.Fun(
cli.iface, // pass the client cli.iface, // pass the client
cli.ch.receive, // the receiver cli.ch.receive, // the receiver
@ -106,10 +106,10 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
serverCh.broadcast, // broadcast sender serverCh.broadcast, // broadcast sender
) )
/* (2) Launch message reader */ // 2. Launch message reader
go clientReader(cli) go clientReader(cli)
/* (3) Launc writer */ // 3. Launc writer
go clientWriter(cli) go clientWriter(cli)
return cli, nil return cli, nil
@ -127,13 +127,13 @@ func clientReader(c *client) {
for { for {
/* (1) if currently closing -> exit */ // 1. if currently closing -> exit
if c.io.closing { if c.io.closing {
fmt.Printf("[reader] killed because closing") fmt.Printf("[reader] killed because closing")
break break
} }
/* (2) Parse message */ // 2. Parse message
msg, err := readMessage(c.io.reader) msg, err := readMessage(c.io.reader)
if err == ErrUnmaskedFrame || err == ErrReservedBits { if err == ErrUnmaskedFrame || err == ErrReservedBits {
@ -143,7 +143,7 @@ func clientReader(c *client) {
break break
} }
/* (3) Fail on invalid message */ // 3. Fail on invalid message
msgErr := msg.check(frag != nil) msgErr := msg.check(frag != nil)
if msgErr != nil { if msgErr != nil {
@ -182,7 +182,7 @@ func clientReader(c *client) {
} }
/* (4) Ping <-> Pong */ // 4. Ping <-> Pong
if msg.Type == Ping && c.io.writing { if msg.Type == Ping && c.io.writing {
msg.Final = true msg.Final = true
msg.Type = Pong msg.Type = Pong
@ -190,7 +190,7 @@ func clientReader(c *client) {
continue continue
} }
/* (5) Store first fragment */ // 5. Store first fragment
if frag == nil && !msg.Final { if frag == nil && !msg.Final {
frag = &Message{ frag = &Message{
Type: msg.Type, Type: msg.Type,
@ -201,7 +201,7 @@ func clientReader(c *client) {
continue continue
} }
/* (6) Store fragments */ // 6. Store fragments
if frag != nil { if frag != nil {
frag.Final = msg.Final frag.Final = msg.Final
frag.Size += msg.Size frag.Size += msg.Size
@ -226,7 +226,7 @@ func clientReader(c *client) {
} }
/* (7) Dispatch to receiver */ // 7. Dispatch to receiver
if msg.Type == Text || msg.Type == Binary { if msg.Type == Text || msg.Type == Binary {
c.ch.receive <- *msg c.ch.receive <- *msg
} }
@ -236,7 +236,7 @@ func clientReader(c *client) {
close(c.ch.receive) close(c.ch.receive)
c.io.reading.Done() c.io.reading.Done()
/* (8) close channel (if not already done) */ // 8. close channel (if not already done)
// fmt.Printf("[reader] end\n") // fmt.Printf("[reader] end\n")
c.close(closeStatus, clientAck) c.close(closeStatus, clientAck)
@ -250,10 +250,10 @@ func clientWriter(c *client) {
for msg := range c.ch.send { for msg := range c.ch.send {
/* (2) Send message */ // 2. Send message
err := msg.Send(c.io.sock) err := msg.Send(c.io.sock)
/* (3) Fail on error */ // 3. Fail on error
if err != nil { if err != nil {
fmt.Printf(" [writer] %s\n", err) fmt.Printf(" [writer] %s\n", err)
c.io.writing = false c.io.writing = false
@ -264,7 +264,7 @@ func clientWriter(c *client) {
c.io.writing = false c.io.writing = false
/* (4) close channel (if not already done) */ // 4. close channel (if not already done)
// fmt.Printf("[writer] end\n") // fmt.Printf("[writer] end\n")
c.close(Normal, true) c.close(Normal, true)
@ -276,7 +276,7 @@ func clientWriter(c *client) {
// then delete client // then delete client
func (c *client) close(status MessageError, clientACK bool) { func (c *client) close(status MessageError, clientACK bool) {
/* (1) Fail if already closing */ // 1. Fail if already closing
alreadyClosing := false alreadyClosing := false
c.io.closingMu.Lock() c.io.closingMu.Lock()
alreadyClosing = c.io.closing alreadyClosing = c.io.closing
@ -287,18 +287,18 @@ func (c *client) close(status MessageError, clientACK bool) {
return return
} }
/* (2) kill writer' if still running */ // 2. kill writer' if still running
if c.io.writing { if c.io.writing {
close(c.ch.send) close(c.ch.send)
} }
/* (3) kill reader if still running */ // 3. kill reader if still running
c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1)) c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
c.io.reading.Wait() c.io.reading.Wait()
if status != None { if status != None {
/* (3) Build message */ // 3. Build message
msg := &Message{ msg := &Message{
Final: true, Final: true,
Type: Close, Type: Close,
@ -307,7 +307,7 @@ func (c *client) close(status MessageError, clientACK bool) {
} }
binary.BigEndian.PutUint16(msg.Data, uint16(status)) binary.BigEndian.PutUint16(msg.Data, uint16(status))
/* (4) Send message */ // 4. Send message
msg.Send(c.io.sock) msg.Send(c.io.sock)
// if err != nil { // if err != nil {
// fmt.Printf("[close] send error (%s0\n", err) // fmt.Printf("[close] send error (%s0\n", err)
@ -315,7 +315,7 @@ func (c *client) close(status MessageError, clientACK bool) {
} }
/* (2) Wait for client CLOSE if needed */ // 2. Wait for client CLOSE if needed
if clientACK { if clientACK {
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond)) c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
@ -334,11 +334,11 @@ func (c *client) close(status MessageError, clientACK bool) {
} }
/* (3) Close socket */ // 3. Close socket
c.io.sock.Close() c.io.sock.Close()
// fmt.Printf("[close] socket closed\n") // fmt.Printf("[close] socket closed\n")
/* (4) Unregister */ // 4. Unregister
c.io.kill <- c c.io.kill <- c
return return

View File

@ -11,10 +11,10 @@ func main() {
startTime := time.Now().UnixNano() startTime := time.Now().UnixNano()
/* (1) Bind WebSocket server */ // 1. Bind WebSocket server
serv := ws.CreateServer("0.0.0.0", 4444) serv := ws.CreateServer("0.0.0.0", 4444)
/* (2) Bind default controller */ // 2. Bind default controller
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
defer func() { defer func() {
@ -33,7 +33,7 @@ func main() {
}) })
/* (3) Bind to URI */ // 3. Bind to URI
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) { err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
fmt.Printf("[uri] connected\n") fmt.Printf("[uri] connected\n")
@ -52,7 +52,7 @@ func main() {
panic(err) panic(err)
} }
/* (4) Launch the server */ // 4. Launch the server
err = serv.Launch() err = serv.Launch()
if err != nil { if err != nil {
fmt.Printf("[ERROR] %s\n", err) fmt.Printf("[ERROR] %s\n", err)

View File

@ -30,10 +30,10 @@ type ControllerSet struct {
// also it returns the matching string patterns // also it returns the matching string patterns
func (s *ControllerSet) Match(uri string) (*Controller, [][]string) { func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
/* (1) Initialise argument list */ // 1. Initialise argument list
arguments := [][]string{{uri}} arguments := [][]string{{uri}}
/* (2) Try each controller */ // 2. Try each controller
for _, c := range s.URI { for _, c := range s.URI {
/* 1. If matches */ /* 1. If matches */
@ -52,12 +52,12 @@ func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
} }
/* (3) If no controller found -> set default controller */ // 3. If no controller found -> set default controller
if s.Def != nil { if s.Def != nil {
return s.Def, arguments return s.Def, arguments
} }
/* (4) If default is NIL, return empty controller */ // 4. If default is NIL, return empty controller
return nil, arguments return nil, arguments
} }

View File

@ -35,16 +35,16 @@ func NewReader(r io.Reader) *ChunkReader {
// Read reads a chunk, err is io.EOF when done // Read reads a chunk, err is io.EOF when done
func (r *ChunkReader) Read() ([]byte, error) { func (r *ChunkReader) Read() ([]byte, error) {
/* (1) If already ended */ // 1. If already ended
if r.isEnded { if r.isEnded {
return nil, io.EOF return nil, io.EOF
} }
/* (2) Read line */ // 2. Read line
var line []byte var line []byte
line, err := r.reader.ReadSlice('\n') line, err := r.reader.ReadSlice('\n')
/* (3) manage errors */ // 3. manage errors
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
@ -57,10 +57,10 @@ func (r *ChunkReader) Read() ([]byte, error) {
return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength) return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength)
} }
/* (4) Trim */ // 4. Trim
line = removeTrailingSpace(line) line = removeTrailingSpace(line)
/* (5) Manage ending line */ // 5. Manage ending line
if len(line) == 0 { if len(line) == 0 {
r.isEnded = true r.isEnded = true
return line, io.EOF return line, io.EOF

View File

@ -1,4 +1,4 @@
package request package upgrade
import ( import (
"fmt" "fmt"

View File

@ -0,0 +1,74 @@
package upgrade
import (
"bytes"
"fmt"
"strings"
)
// HeaderType represents all 'valid' HTTP request headers
type HeaderType uint8
// header types
const (
Unknown HeaderType = iota
Host
Upgrade
Connection
Origin
WSKey
WSProtocol
WSExtensions
WSVersion
)
// HeaderValue represents a unique or multiple header value(s)
type HeaderValue [][]byte
// Header represents the data of a HTTP request header
type Header struct {
Name HeaderType
Values HeaderValue
}
// ReadHeader tries to parse an HTTP header from a byte array
func ReadHeader(b []byte) (*Header, error) {
// 1. Split by ':'
parts := bytes.Split(b, []byte(": "))
if len(parts) != 2 {
return nil, fmt.Errorf("Invalid HTTP header format '%s'", b)
}
// 2. Create instance
inst := &Header{}
// 3. Check for header name
switch strings.ToLower(string(parts[0])) {
case "host":
inst.Name = Host
case "upgrade":
inst.Name = Upgrade
case "connection":
inst.Name = Connection
case "origin":
inst.Name = Origin
case "sec-websocket-key":
inst.Name = WSKey
case "sec-websocket-protocol":
inst.Name = WSProtocol
case "sec-websocket-extensions":
inst.Name = WSExtensions
case "sec-websocket-version":
inst.Name = WSVersion
default:
inst.Name = Unknown
}
// 4. Split values
inst.Values = bytes.Split(parts[1], []byte(", "))
return inst, nil
}

View File

@ -1,16 +1,13 @@
package request package upgrade
import ( import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"git.xdrm.io/go/ws/internal/http/upgrade/request/parser/header"
"git.xdrm.io/go/ws/internal/http/upgrade/response"
) )
// checkHost checks and extracts the Host header // checkHost checks and extracts the Host header
func (r *T) extractHostPort(bb header.HeaderValue) error { func (r *Request) extractHostPort(bb HeaderValue) error {
if len(bb) != 1 { if len(bb) != 1 {
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))} return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
@ -32,7 +29,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
// extract port // extract port
readPort, err := strconv.ParseUint(split[1], 10, 16) readPort, err := strconv.ParseUint(split[1], 10, 16)
if err != nil { if err != nil {
r.code = response.BadRequest r.code = BadRequest
return &InvalidRequest{"Host", "cannot read port"} return &InvalidRequest{"Host", "cannot read port"}
} }
@ -42,7 +39,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
if len(r.origin) > 0 { if len(r.origin) > 0 {
if err != nil { if err != nil {
err = r.checkOriginPolicy() err = r.checkOriginPolicy()
r.code = response.Forbidden r.code = Forbidden
return &InvalidOriginPolicy{r.host, r.origin, err} return &InvalidOriginPolicy{r.host, r.origin, err}
} }
} }
@ -52,7 +49,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
} }
// checkOrigin checks the Origin Header // checkOrigin checks the Origin Header
func (r *T) extractOrigin(bb header.HeaderValue) error { func (r *Request) extractOrigin(bb HeaderValue) error {
// bypass // bypass
if bypassOriginPolicy { if bypassOriginPolicy {
@ -60,7 +57,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
} }
if len(bb) != 1 { if len(bb) != 1 {
r.code = response.Forbidden r.code = Forbidden
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))} return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
@ -70,7 +67,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
if len(r.host) > 0 { if len(r.host) > 0 {
err := r.checkOriginPolicy() err := r.checkOriginPolicy()
if err != nil { if err != nil {
r.code = response.Forbidden r.code = Forbidden
return &InvalidOriginPolicy{r.host, r.origin, err} return &InvalidOriginPolicy{r.host, r.origin, err}
} }
} }
@ -80,7 +77,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
} }
// checkOriginPolicy origin policy based on 'host' value // checkOriginPolicy origin policy based on 'host' value
func (r *T) checkOriginPolicy() error { func (r *Request) checkOriginPolicy() error {
// TODO: Origin policy, for now BYPASS // TODO: Origin policy, for now BYPASS
r.validPolicy = true r.validPolicy = true
return nil return nil
@ -88,7 +85,7 @@ func (r *T) checkOriginPolicy() error {
// checkConnection checks the 'Connection' header // checkConnection checks the 'Connection' header
// it MUST contain 'Upgrade' // it MUST contain 'Upgrade'
func (r *T) checkConnection(bb header.HeaderValue) error { func (r *Request) checkConnection(bb HeaderValue) error {
for _, b := range bb { for _, b := range bb {
@ -99,17 +96,17 @@ func (r *T) checkConnection(bb header.HeaderValue) error {
} }
r.code = response.BadRequest r.code = BadRequest
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"} return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
} }
// checkUpgrade checks the 'Upgrade' header // checkUpgrade checks the 'Upgrade' header
// it MUST be 'websocket' // it MUST be 'websocket'
func (r *T) checkUpgrade(bb header.HeaderValue) error { func (r *Request) checkUpgrade(bb HeaderValue) error {
if len(bb) != 1 { if len(bb) != 1 {
r.code = response.BadRequest r.code = BadRequest
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
@ -118,17 +115,17 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
return nil return nil
} }
r.code = response.BadRequest r.code = BadRequest
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])} return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
} }
// checkVersion checks the 'Sec-WebSocket-Version' header // checkVersion checks the 'Sec-WebSocket-Version' header
// it MUST be '13' // it MUST be '13'
func (r *T) checkVersion(bb header.HeaderValue) error { func (r *Request) checkVersion(bb HeaderValue) error {
if len(bb) != 1 || string(bb[0]) != "13" { if len(bb) != 1 || string(bb[0]) != "13" {
r.code = response.UpgradeRequired r.code = UpgradeRequired
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])} return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
} }
@ -139,10 +136,10 @@ func (r *T) checkVersion(bb header.HeaderValue) error {
// extractKey extracts the 'Sec-WebSocket-Key' header // extractKey extracts the 'Sec-WebSocket-Key' header
// it MUST be 24 bytes (base64) // it MUST be 24 bytes (base64)
func (r *T) extractKey(bb header.HeaderValue) error { func (r *Request) extractKey(bb HeaderValue) error {
if len(bb) != 1 || len(bb[0]) != 24 { if len(bb) != 1 || len(bb[0]) != 24 {
r.code = response.BadRequest r.code = BadRequest
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))} return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
} }
@ -154,7 +151,7 @@ func (r *T) extractKey(bb header.HeaderValue) error {
// extractProtocols extracts the 'Sec-WebSocket-Protocol' header // extractProtocols extracts the 'Sec-WebSocket-Protocol' header
// it can contain multiple values // it can contain multiple values
func (r *T) extractProtocols(bb header.HeaderValue) error { func (r *Request) extractProtocols(bb HeaderValue) error {
r.protocols = bb r.protocols = bb

View File

@ -1,4 +1,4 @@
package request package upgrade
import ( import (
"bytes" "bytes"
@ -6,50 +6,38 @@ import (
"regexp" "regexp"
) )
// httpMethod represents available http methods // Line represents the HTTP Request line
type httpMethod byte
const (
OPTIONS httpMethod = iota
GET
HEAD
POST
PUT
DELETE
)
// RequestLine represents the HTTP Request line
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1 // defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
type RequestLine struct { type Line struct {
method httpMethod method Method
uri string uri string
version byte version byte
} }
// parseRequestLine parses the first HTTP request line // Parse parses the first HTTP request line
func (r *RequestLine) Parse(b []byte) error { func (r *Line) Parse(b []byte) error {
/* (1) Split by ' ' */ // 1. Split by ' '
parts := bytes.Split(b, []byte(" ")) parts := bytes.Split(b, []byte(" "))
/* (2) Fail when missing parts */ // 2. Fail when missing parts
if len(parts) != 3 { if len(parts) != 3 {
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts)) return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
} }
/* (3) Extract HTTP method */ // 3. Extract HTTP method
err := r.extractHttpMethod(parts[0]) err := r.extractHttpMethod(parts[0])
if err != nil { if err != nil {
return err return err
} }
/* (4) Extract URI */ // 4. Extract URI
err = r.extractURI(parts[1]) err = r.extractURI(parts[1])
if err != nil { if err != nil {
return err return err
} }
/* (5) Extract version */ // 5. Extract version
err = r.extractHttpVersion(parts[2]) err = r.extractHttpVersion(parts[2])
if err != nil { if err != nil {
return err return err
@ -60,19 +48,19 @@ func (r *RequestLine) Parse(b []byte) error {
} }
// GetURI returns the actual URI // GetURI returns the actual URI
func (r RequestLine) GetURI() string { func (r Line) GetURI() string {
return r.uri return r.uri
} }
// extractHttpMethod extracts the HTTP method from a []byte // extractHttpMethod extracts the HTTP method from a []byte
// and checks for errors // and checks for errors
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE // allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
func (r *RequestLine) extractHttpMethod(b []byte) error { func (r *Line) extractHttpMethod(b []byte) error {
switch string(b) { switch string(b) {
// case "OPTIONS": r.method = OPTIONS // case "OPTIONS": r.method = OPTIONS
case "GET": case "GET":
r.method = GET r.method = Get
// case "HEAD": r.method = HEAD // case "HEAD": r.method = HEAD
// case "POST": r.method = POST // case "POST": r.method = POST
// case "PUT": r.method = PUT // case "PUT": r.method = PUT
@ -87,15 +75,15 @@ func (r *RequestLine) extractHttpMethod(b []byte) error {
// extractURI extracts the URI from a []byte and checks for errors // extractURI extracts the URI from a []byte and checks for errors
// allowed format: /([^/]/)*/? // allowed format: /([^/]/)*/?
func (r *RequestLine) extractURI(b []byte) error { func (r *Line) extractURI(b []byte) error {
/* (1) Check format */ // 1. Check format
checker := regexp.MustCompile("^(?:/[^/]+)*/?$") checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
if !checker.Match(b) { if !checker.Match(b) {
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b) return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
} }
/* (2) Store */ // 2. Store
r.uri = string(b) r.uri = string(b)
return nil return nil
@ -104,26 +92,26 @@ func (r *RequestLine) extractURI(b []byte) error {
// extractHttpVersion extracts the version and checks for errors // extractHttpVersion extracts the version and checks for errors
// allowed format: [1-9] or [1.9].[0-9] // allowed format: [1-9] or [1.9].[0-9]
func (r *RequestLine) extractHttpVersion(b []byte) error { func (r *Line) extractHttpVersion(b []byte) error {
/* (1) Extract version parts */ // 1. Extract version parts
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`) extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
if !extractor.Match(b) { if !extractor.Match(b) {
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b) return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
} }
/* (2) Extract version number */ // 2. Extract version number
matches := extractor.FindSubmatch(b) matches := extractor.FindSubmatch(b)
var version byte = matches[1][0] - '0' var version byte = matches[1][0] - '0'
/* (3) Extract subversion (if exists) */ // 3. Extract subversion (if exists)
var subVersion byte = 0 var subVersion byte = 0
if len(matches[2]) > 0 { if len(matches[2]) > 0 {
subVersion = matches[2][0] - '0' subVersion = matches[2][0] - '0'
} }
/* (4) Store version (x 10 to fit uint8) */ // 4. Store version (x 10 to fit uint8)
r.version = version*10 + subVersion r.version = version*10 + subVersion
return nil return nil

View File

@ -0,0 +1,14 @@
package upgrade
// Method represents available http methods
type Method uint8
// http methods
const (
Options Method = iota
Get
Head
Post
Put
Delete
)

View File

@ -0,0 +1,216 @@
package upgrade
import (
"fmt"
"io"
"git.xdrm.io/go/ws/internal/http/reader"
)
// If origin is required
const bypassOriginPolicy = true
// Request represents an HTTP Upgrade request
type Request struct {
first bool // whether the first line has been read (GET uri HTTP/version)
// status code
code StatusCode
// request line
request Line
// data to check origin (depends of reading order)
host string
port uint16 // 0 if not set
origin string
validPolicy bool
// ws data
key []byte
protocols [][]byte
// required fields check
hasConnection bool
hasUpgrade bool
hasVersion bool
}
// Parse builds an upgrade HTTP request
// from a reader (typically bufio.NewRead of the socket)
func Parse(r io.Reader) (request *Request, err error) {
req := &Request{
code: 500,
}
/* (1) Parse request
---------------------------------------------------------*/
// 1. Get chunk reader
cr := reader.NewReader(r)
if err != nil {
return req, fmt.Errorf("Error while creating chunk reader: %s", err)
}
// 2. Parse header line by line
for {
line, err := cr.Read()
if err == io.EOF {
break
}
if err != nil {
return req, err
}
err = req.parseHeader(line)
if err != nil {
return req, err
}
}
// 3. Check completion
err = req.isComplete()
if err != nil {
req.code = BadRequest
return req, err
}
req.code = SwitchingProtocols
return req, nil
}
// StatusCode returns the status current
func (r Request) StatusCode() StatusCode {
return r.code
}
// BuildResponse builds a response from the request
func (r *Request) BuildResponse() *Response {
inst := &Response{}
// 1. Copy code
inst.SetStatusCode(r.code)
// 2. Set Protocol
if len(r.protocols) > 0 {
inst.SetProtocol(r.protocols[0])
}
// 4. Process key
inst.ProcessKey(r.key)
return inst
}
// GetURI returns the actual URI
func (r Request) GetURI() string {
return r.request.GetURI()
}
// parseHeader parses any http request line
// (header and request-line)
func (r *Request) parseHeader(b []byte) error {
/* (1) First line -> GET {uri} HTTP/{version}
---------------------------------------------------------*/
if !r.first {
err := r.request.Parse(b)
if err != nil {
r.code = BadRequest
return &InvalidRequest{"Request-Line", err.Error()}
}
r.first = true
return nil
}
/* (2) Other lines -> Header-Name: Header-Value
---------------------------------------------------------*/
// 1. Try to parse header
head, err := ReadHeader(b)
if err != nil {
r.code = BadRequest
return fmt.Errorf("Error parsing header: %s", err)
}
// 2. Manage header
switch head.Name {
case Host:
err = r.extractHostPort(head.Values)
case Origin:
err = r.extractOrigin(head.Values)
case Upgrade:
err = r.checkUpgrade(head.Values)
case Connection:
err = r.checkConnection(head.Values)
case WSVersion:
err = r.checkVersion(head.Values)
case WSKey:
err = r.extractKey(head.Values)
case WSProtocol:
err = r.extractProtocols(head.Values)
default:
return nil
}
// dispatch error
if err != nil {
return err
}
return nil
}
// isComplete returns whether the Upgrade Request
// is complete (no missing required item)
func (r Request) isComplete() error {
// 1. Request-Line
if !r.first {
return &IncompleteRequest{"Request-Line"}
}
// 2. Host
if len(r.host) == 0 {
return &IncompleteRequest{"Host"}
}
// 3. Origin
if !bypassOriginPolicy && len(r.origin) == 0 {
return &IncompleteRequest{"Origin"}
}
// 4. Connection
if !r.hasConnection {
return &IncompleteRequest{"Connection"}
}
// 5. Upgrade
if !r.hasUpgrade {
return &IncompleteRequest{"Upgrade"}
}
// 6. Sec-WebSocket-Version
if !r.hasVersion {
return &IncompleteRequest{"Sec-WebSocket-Version"}
}
// 7. Sec-WebSocket-Key
if len(r.key) < 1 {
return &IncompleteRequest{"Sec-WebSocket-Key"}
}
return nil
}

View File

@ -1,41 +0,0 @@
package header
import (
// "regexp"
"fmt"
"strings"
"bytes"
)
// parse tries to return a 'T' (httpHeader) from a byte array
func Parse(b []byte) (*T, error) {
/* (1) Split by ':' */
parts := bytes.Split(b, []byte(": "))
if len(parts) != 2 {
return nil, fmt.Errorf("Invalid HTTP header format '%s'", b)
}
/* (2) Create instance */
inst := new(T)
/* (3) Check for header name */
switch strings.ToLower(string(parts[0])) {
case "host": inst.Name = HOST
case "upgrade": inst.Name = UPGRADE
case "connection": inst.Name = CONNECTION
case "origin": inst.Name = ORIGIN
case "sec-websocket-key": inst.Name = WSKEY
case "sec-websocket-protocol": inst.Name = WSPROTOCOL
case "sec-websocket-extensions": inst.Name = WSEXTENSIONS
case "sec-websocket-version": inst.Name = WSVERSION
default: inst.Name = UNKNOWN
}
/* (4) Split values */
inst.Values = bytes.Split(parts[1], []byte(", "))
return inst, nil
}

View File

@ -1,25 +0,0 @@
package header
// HeaderType represents all 'valid' HTTP request headers
type HeaderType byte
const (
UNKNOWN HeaderType = iota
HOST
UPGRADE
CONNECTION
ORIGIN
WSKEY
WSPROTOCOL
WSEXTENSIONS
WSVERSION
)
// HeaderValue represents a unique or multiple header value(s)
type HeaderValue [][]byte
// T represents the data of a HTTP request header
type T struct{
Name HeaderType
Values HeaderValue
}

View File

@ -1,110 +0,0 @@
package request
import (
"fmt"
"git.xdrm.io/go/ws/internal/http/upgrade/request/parser/header"
"git.xdrm.io/go/ws/internal/http/upgrade/response"
)
// parseHeader parses any http request line
// (header and request-line)
func (r *T) parseHeader(b []byte) error {
/* (1) First line -> GET {uri} HTTP/{version}
---------------------------------------------------------*/
if !r.first {
err := r.request.Parse(b)
if err != nil {
r.code = response.BadRequest
return &InvalidRequest{"Request-Line", err.Error()}
}
r.first = true
return nil
}
/* (2) Other lines -> Header-Name: Header-Value
---------------------------------------------------------*/
/* (1) Try to parse header */
head, err := header.Parse(b)
if err != nil {
r.code = response.BadRequest
return fmt.Errorf("Error parsing header: %s", err)
}
/* (2) Manage header */
switch head.Name {
case header.HOST:
err = r.extractHostPort(head.Values)
case header.ORIGIN:
err = r.extractOrigin(head.Values)
case header.UPGRADE:
err = r.checkUpgrade(head.Values)
case header.CONNECTION:
err = r.checkConnection(head.Values)
case header.WSVERSION:
err = r.checkVersion(head.Values)
case header.WSKEY:
err = r.extractKey(head.Values)
case header.WSPROTOCOL:
err = r.extractProtocols(head.Values)
default:
return nil
}
// dispatch error
if err != nil {
return err
}
return nil
}
// isComplete returns whether the Upgrade Request
// is complete (no missing required item)
func (r T) isComplete() error {
/* (1) Request-Line */
if !r.first {
return &IncompleteRequest{"Request-Line"}
}
/* (2) Host */
if len(r.host) == 0 {
return &IncompleteRequest{"Host"}
}
/* (3) Origin */
if !bypassOriginPolicy && len(r.origin) == 0 {
return &IncompleteRequest{"Origin"}
}
/* (4) Connection */
if !r.hasConnection {
return &IncompleteRequest{"Connection"}
}
/* (5) Upgrade */
if !r.hasUpgrade {
return &IncompleteRequest{"Upgrade"}
}
/* (6) Sec-WebSocket-Version */
if !r.hasVersion {
return &IncompleteRequest{"Sec-WebSocket-Version"}
}
/* (7) Sec-WebSocket-Key */
if len(r.key) < 1 {
return &IncompleteRequest{"Sec-WebSocket-Key"}
}
return nil
}

View File

@ -1,85 +0,0 @@
package request
import (
"fmt"
"io"
"git.xdrm.io/go/ws/internal/http/reader"
"git.xdrm.io/go/ws/internal/http/upgrade/response"
)
// Parse builds an upgrade HTTP request
// from a reader (typically bufio.NewRead of the socket)
func Parse(r io.Reader) (request *T, err error) {
req := new(T)
req.code = 500
/* (1) Parse request
---------------------------------------------------------*/
/* (1) Get chunk reader */
cr := reader.NewReader(r)
if err != nil {
return req, fmt.Errorf("Error while creating chunk reader: %s", err)
}
/* (2) Parse header line by line */
for {
line, err := cr.Read()
if err == io.EOF {
break
}
if err != nil {
return req, err
}
err = req.parseHeader(line)
if err != nil {
return req, err
}
}
/* (3) Check completion */
err = req.isComplete()
if err != nil {
req.code = response.BadRequest
return req, err
}
req.code = response.SwitchingProtocols
return req, nil
}
// StatusCode returns the status current
func (r T) StatusCode() response.StatusCode {
return r.code
}
// BuildResponse builds a response.T from the request
func (r *T) BuildResponse() *response.T {
inst := new(response.T)
/* (1) Copy code */
inst.SetStatusCode(r.code)
/* (2) Set Protocol */
if len(r.protocols) > 0 {
inst.SetProtocol(r.protocols[0])
}
/* (4) Process key */
inst.ProcessKey(r.key)
return inst
}
// GetURI returns the actual URI
func (r T) GetURI() string {
return r.request.GetURI()
}

View File

@ -1,32 +0,0 @@
package request
import "git.xdrm.io/go/ws/internal/http/upgrade/response"
// If origin is required
const bypassOriginPolicy = true
// T represents an HTTP Upgrade request
type T struct {
first bool // whether the first line has been read (GET uri HTTP/version)
// status code
code response.StatusCode
// request line
request RequestLine
// data to check origin (depends of reading order)
host string
port uint16 // 0 if not set
origin string
validPolicy bool
// ws data
key []byte
protocols [][]byte
// required fields check
hasConnection bool
hasUpgrade bool
hasVersion bool
}

View File

@ -1,4 +1,4 @@
package request package upgrade
import ( import (
"bytes" "bytes"
@ -6,13 +6,13 @@ import (
"testing" "testing"
) )
// /* (1) Parse request */ // // 1. Parse request
// req, _ := request.Parse(s) // req, _ := request.Parse(s)
// /* (3) Build response */ // // 3. Build response
// res := req.BuildResponse() // res := req.BuildResponse()
// /* (4) Write into socket */ // // 4. Write into socket
// _, err := res.Send(s) // _, err := res.Send(s)
// if err != nil { // if err != nil {
// return nil, fmt.Errorf("Upgrade write error: %s", err) // return nil, fmt.Errorf("Upgrade write error: %s", err)
@ -25,7 +25,7 @@ import (
func TestEOFSocket(t *testing.T) { func TestEOFSocket(t *testing.T) {
socket := new(bytes.Buffer) socket := &bytes.Buffer{}
_, err := Parse(socket) _, err := Parse(socket)
@ -39,7 +39,7 @@ func TestEOFSocket(t *testing.T) {
func TestInvalidRequestLine(t *testing.T) { func TestInvalidRequestLine(t *testing.T) {
socket := new(bytes.Buffer) socket := &bytes.Buffer{}
cases := []struct { cases := []struct {
Reqline string Reqline string
HasError bool HasError bool
@ -113,7 +113,7 @@ func TestInvalidHost(t *testing.T) {
requestLine := []byte("GET / HTTP/1.1\r\n") requestLine := []byte("GET / HTTP/1.1\r\n")
socket := new(bytes.Buffer) socket := &bytes.Buffer{}
cases := []struct { cases := []struct {
Host string Host string
HasError bool HasError bool

View File

@ -1,4 +1,4 @@
package response package upgrade
import ( import (
"crypto/sha1" "crypto/sha1"
@ -7,19 +7,35 @@ import (
"io" "io"
) )
// HTTPVersion constant
const HTTPVersion = "1.1"
// UsedWSVersion constant websocket version
const UsedWSVersion = 13
// WSSalt constant websocket salt
const WSSalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
// Response represents an HTTP Upgrade Response
type Response struct {
code StatusCode // status code
accept []byte // processed from Sec-WebSocket-Key
protocol []byte // set from Sec-WebSocket-Protocol or none if not received
}
// SetStatusCode sets the status code // SetStatusCode sets the status code
func (r *T) SetStatusCode(sc StatusCode) { func (r *Response) SetStatusCode(sc StatusCode) {
r.code = sc r.code = sc
} }
// SetProtocol sets the protocols // SetProtocol sets the protocols
func (r *T) SetProtocol(p []byte) { func (r *Response) SetProtocol(p []byte) {
r.protocol = p r.protocol = p
} }
// ProcessKey processes the accept token according // ProcessKey processes the accept token according
// to the rfc from the Sec-WebSocket-Key // to the rfc from the Sec-WebSocket-Key
func (r *T) ProcessKey(k []byte) { func (r *Response) ProcessKey(k []byte) {
// do nothing for empty key // do nothing for empty key
if k == nil || len(k) == 0 { if k == nil || len(k) == 0 {
@ -27,40 +43,40 @@ func (r *T) ProcessKey(k []byte) {
return return
} }
/* (1) Concat with constant salt */ // 1. Concat with constant salt
mix := append(k, WSSalt...) mix := append(k, []byte(WSSalt)...)
/* (2) Hash with sha1 algorithm */ // 2. Hash with sha1 algorithm
digest := sha1.Sum(mix) digest := sha1.Sum(mix)
/* (3) Base64 encode it */ // 3. Base64 encode it
r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size])) r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
} }
// Send sends the response through an io.Writer // Send sends the response through an io.Writer
// typically a socket // typically a socket
func (r T) Send(w io.Writer) (int, error) { func (r Response) Send(w io.Writer) (int, error) {
/* (1) Build response line */ // 1. Build response line
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HttpVersion, r.code, r.code.String()) responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HTTPVersion, r.code, r.code)
/* (2) Build headers */ // 2. Build headers
optionalProtocol := "" optionalProtocol := ""
if len(r.protocol) > 0 { if len(r.protocol) > 0 {
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol) optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol)
} }
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", WSVersion, optionalProtocol) headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", UsedWSVersion, optionalProtocol)
if r.accept != nil { if r.accept != nil {
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept) headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept)
} }
headers = fmt.Sprintf("%s\r\n", headers) headers = fmt.Sprintf("%s\r\n", headers)
/* (3) Build all */ // 3. Build all
raw := []byte(fmt.Sprintf("%s%s", responseLine, headers)) raw := []byte(fmt.Sprintf("%s%s", responseLine, headers))
/* (4) Write */ // 4. Write
written, err := w.Write(raw) written, err := w.Write(raw)
return written, err return written, err
@ -68,11 +84,11 @@ func (r T) Send(w io.Writer) (int, error) {
} }
// GetProtocol returns the choosen protocol if set, else nil // GetProtocol returns the choosen protocol if set, else nil
func (r T) GetProtocol() []byte { func (r Response) GetProtocol() []byte {
return r.protocol return r.protocol
} }
// GetStatusCode returns the response status code // GetStatusCode returns the response status code
func (r T) GetStatusCode() StatusCode { func (r Response) GetStatusCode() StatusCode {
return r.code return r.code
} }

View File

@ -1,15 +0,0 @@
package response
// Constant
const HttpVersion = "1.1"
const WSVersion = 13
var WSSalt []byte = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
// T represents an HTTP Upgrade Response
type T struct {
code StatusCode // status code
accept []byte // processed from Sec-WebSocket-Key
protocol []byte // set from Sec-WebSocket-Protocol or none if not received
}

View File

@ -1,9 +1,9 @@
package response package upgrade
// StatusCode maps the status codes (and description) // StatusCode maps the status codes (and description)
type StatusCode uint16 type StatusCode uint16
var ( const (
// SwitchingProtocols - handshake success // SwitchingProtocols - handshake success
SwitchingProtocols StatusCode = 101 SwitchingProtocols StatusCode = 101
// BadRequest - missing/malformed headers // BadRequest - missing/malformed headers

View File

@ -9,45 +9,45 @@ import (
// from a pattern string // from a pattern string
func buildScheme(ss []string) (Scheme, error) { func buildScheme(ss []string) (Scheme, error) {
/* (1) Build scheme */ // 1. Build scheme
sch := make(Scheme, 0, maxMatch) sch := make(Scheme, 0, maxMatch)
for _, s := range ss { for _, s := range ss {
/* (2) ignore empty */ // 2. ignore empty
if len(s) == 0 { if len(s) == 0 {
continue continue
} }
m := new(matcher) m := &matcher{}
switch s { switch s {
/* (3) Card: 0, N */ // 3. Card: 0, N
case "**": case "**":
m.req = false m.req = false
m.mul = true m.mul = true
sch = append(sch, m) sch = append(sch, m)
/* (4) Card: 1, N */ // 4. Card: 1, N
case "..": case "..":
m.req = true m.req = true
m.mul = true m.mul = true
sch = append(sch, m) sch = append(sch, m)
/* (5) Card: 0, 1 */ // 5. Card: 0, 1
case "*": case "*":
m.req = false m.req = false
m.mul = false m.mul = false
sch = append(sch, m) sch = append(sch, m)
/* (6) Card: 1 */ // 6. Card: 1
case ".": case ".":
m.req = true m.req = true
m.mul = false m.mul = false
sch = append(sch, m) sch = append(sch, m)
/* (7) Card: 1, literal string */ // 7. Card: 1, literal string
default: default:
m.req = true m.req = true
m.mul = false m.mul = false
@ -64,16 +64,16 @@ func buildScheme(ss []string) (Scheme, error) {
// optimise optimised the scheme for further parsing // optimise optimised the scheme for further parsing
func (s Scheme) optimise() (Scheme, error) { func (s Scheme) optimise() (Scheme, error) {
/* (1) Nothing to do if only 1 element */ // 1. Nothing to do if only 1 element
if len(s) <= 1 { if len(s) <= 1 {
return s, nil return s, nil
} }
/* (2) Init reshifted scheme */ // 2. Init reshifted scheme
rshift := make(Scheme, 0, maxMatch) rshift := make(Scheme, 0, maxMatch)
rshift = append(rshift, s[0]) rshift = append(rshift, s[0])
/* (2) Iterate over matchers */ // 2. Iterate over matchers
for p, i, l := 0, 1, len(s); i < l; i++ { for p, i, l := 0, 1, len(s); i < l; i++ {
pre, cur := s[p], s[i] pre, cur := s[p], s[i]
@ -106,11 +106,11 @@ func (s Scheme) optimise() (Scheme, error) {
// it returns a cleared uri, without STRING data // it returns a cleared uri, without STRING data
func (s Scheme) matchString(uri string) (string, bool) { func (s Scheme) matchString(uri string) (string, bool) {
/* (1) Initialise variables */ // 1. Initialise variables
clr := uri // contains cleared input string clr := uri // contains cleared input string
minOff := 0 // minimum offset minOff := 0 // minimum offset
/* (2) Iterate over strings */ // 2. Iterate over strings
for _, m := range s { for _, m := range s {
ls := len(m.pat) ls := len(m.pat)
@ -147,12 +147,12 @@ func (s Scheme) matchString(uri string) (string, bool) {
} }
/* (3) If exists, remove trailing '/' */ // 3. If exists, remove trailing '/'
if clr[len(clr)-1] == '/' { if clr[len(clr)-1] == '/' {
clr = clr[:len(clr)-1] clr = clr[:len(clr)-1]
} }
/* (4) If exists, remove trailing '\a' */ // 4. If exists, remove trailing '\a'
if clr[len(clr)-1] == '\a' { if clr[len(clr)-1] == '\a' {
clr = clr[:len(clr)-1] clr = clr[:len(clr)-1]
} }
@ -166,7 +166,7 @@ func (s Scheme) matchString(uri string) (string, bool) {
// + it sets the matchers buffers for later extraction // + it sets the matchers buffers for later extraction
func (s Scheme) matchWildcards(clear string) bool { func (s Scheme) matchWildcards(clear string) bool {
/* (1) Extract wildcards (ref) */ // 1. Extract wildcards (ref)
wildcards := make(Scheme, 0, maxMatch) wildcards := make(Scheme, 0, maxMatch)
for _, m := range s { for _, m := range s {
@ -176,15 +176,15 @@ func (s Scheme) matchWildcards(clear string) bool {
} }
} }
/* (2) If no wildcards -> match */ // 2. If no wildcards -> match
if len(wildcards) == 0 { if len(wildcards) == 0 {
return true return true
} }
/* (3) Break uri by '\a' characters */ // 3. Break uri by '\a' characters
matches := strings.Split(clear, "\a")[1:] matches := strings.Split(clear, "\a")[1:]
/* (4) Iterate over matches */ // 4. Iterate over matches
for n, match := range matches { for n, match := range matches {
// {1} If no more matcher // // {1} If no more matcher //
@ -210,7 +210,7 @@ func (s Scheme) matchWildcards(clear string) bool {
} }
/* (5) Match */ // 5. Match
return true return true
} }

View File

@ -8,15 +8,15 @@ import (
// Build builds an URI scheme from a pattern string // Build builds an URI scheme from a pattern string
func Build(s string) (*Scheme, error) { func Build(s string) (*Scheme, error) {
/* (1) Manage '/' at the start */ // 1. Manage '/' at the start
if len(s) < 1 || s[0] != '/' { if len(s) < 1 || s[0] != '/' {
return nil, fmt.Errorf("URI must begin with '/'") return nil, fmt.Errorf("URI must begin with '/'")
} }
/* (2) Split by '/' */ // 2. Split by '/'
parts := strings.Split(s, "/") parts := strings.Split(s, "/")
/* (3) Max exceeded */ // 3. Max exceeded
if len(parts)-2 > maxMatch { if len(parts)-2 > maxMatch {
for i, p := range parts { for i, p := range parts {
fmt.Printf("%d: '%s'\n", i, p) fmt.Printf("%d: '%s'\n", i, p)
@ -24,13 +24,13 @@ func Build(s string) (*Scheme, error) {
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts)) return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
} }
/* (4) Build for each part */ // 4. Build for each part
sch, err := buildScheme(parts) sch, err := buildScheme(parts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
/* (5) Optimise structure */ // 5. Optimise structure
opti, err := sch.optimise() opti, err := sch.optimise()
if err != nil { if err != nil {
return nil, err return nil, err
@ -43,18 +43,18 @@ func Build(s string) (*Scheme, error) {
// Match returns if the given URI is matched by the scheme // Match returns if the given URI is matched by the scheme
func (s Scheme) Match(str string) bool { func (s Scheme) Match(str string) bool {
/* (1) Nothing -> match all */ // 1. Nothing -> match all
if len(s) == 0 { if len(s) == 0 {
return true return true
} }
/* (2) Check for string match */ // 2. Check for string match
clearURI, match := s.matchString(str) clearURI, match := s.matchString(str)
if !match { if !match {
return false return false
} }
/* (3) Check for non-string match (wildcards) */ // 3. Check for non-string match (wildcards)
match = s.matchWildcards(clearURI) match = s.matchWildcards(clearURI)
if !match { if !match {
return false return false
@ -66,12 +66,12 @@ func (s Scheme) Match(str string) bool {
// GetMatch returns the indexed match (excluding string matchers) // GetMatch returns the indexed match (excluding string matchers)
func (s Scheme) GetMatch(n uint8) ([]string, error) { func (s Scheme) GetMatch(n uint8) ([]string, error) {
/* (1) Index out of range */ // 1. Index out of range
if n > uint8(len(s)) { if n > uint8(len(s)) {
return nil, fmt.Errorf("Index out of range") return nil, fmt.Errorf("Index out of range")
} }
/* (2) Iterate to find index (exclude strings) */ // 2. Iterate to find index (exclude strings)
ni := -1 ni := -1
for _, m := range s { for _, m := range s {
@ -90,7 +90,7 @@ func (s Scheme) GetMatch(n uint8) ([]string, error) {
} }
/* (3) If nothing found -> return empty set */ // 3. If nothing found -> return empty set
return nil, fmt.Errorf("Index out of range (max: %d)", ni) return nil, fmt.Errorf("Index out of range (max: %d)", ni)
} }

View File

@ -2,32 +2,36 @@ package websocket
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"unicode/utf8" "unicode/utf8"
) )
var ( // constant error
type constErr string
func (c constErr) Error() string { return string(c) }
const (
// ErrUnmaskedFrame error // ErrUnmaskedFrame error
ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame") ErrUnmaskedFrame = constErr("Received unmasked frame")
// ErrTooLongControlFrame error // ErrTooLongControlFrame error
ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long") ErrTooLongControlFrame = constErr("Received a control frame that is fragmented or too long")
// ErrInvalidFragment error // ErrInvalidFragment error
ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation") ErrInvalidFragment = constErr("Received invalid fragmentation")
// ErrUnexpectedContinuation error // ErrUnexpectedContinuation error
ErrUnexpectedContinuation = fmt.Errorf("Received unexpected continuation frame") ErrUnexpectedContinuation = constErr("Received unexpected continuation frame")
// ErrInvalidSize error // ErrInvalidSize error
ErrInvalidSize = fmt.Errorf("Received invalid payload size") ErrInvalidSize = constErr("Received invalid payload size")
// ErrInvalidPayload error // ErrInvalidPayload error
ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload") ErrInvalidPayload = constErr("Received invalid utf8 payload")
// ErrInvalidCloseStatus error // ErrInvalidCloseStatus error
ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status") ErrInvalidCloseStatus = constErr("Received invalid close status")
// ErrInvalidOpCode error // ErrInvalidOpCode error
ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode") ErrInvalidOpCode = constErr("Received invalid OpCode")
// ErrReservedBits error // ErrReservedBits error
ErrReservedBits = fmt.Errorf("Received reserved bits") ErrReservedBits = constErr("Received reserved bits")
// ErrCloseFrame error // ErrCloseFrame error
ErrCloseFrame = fmt.Errorf("Received close Frame") ErrCloseFrame = constErr("Received close Frame")
) )
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
@ -88,9 +92,9 @@ func readMessage(reader io.Reader) (*Message, error) {
var mask []byte var mask []byte
var cursor int var cursor int
m := new(Message) m := &Message{}
/* (2) Byte 1: FIN and OpCode */ // 2. Byte 1: FIN and OpCode
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
err = readBytes(reader, tmpBuf) err = readBytes(reader, tmpBuf)
if err != nil { if err != nil {
@ -105,7 +109,7 @@ func readMessage(reader io.Reader) (*Message, error) {
m.Final = bool(tmpBuf[0]&0x80 == 0x80) m.Final = bool(tmpBuf[0]&0x80 == 0x80)
m.Type = MessageType(tmpBuf[0] & 0x0f) m.Type = MessageType(tmpBuf[0] & 0x0f)
/* (3) Byte 2: Mask and Length[0] */ // 3. Byte 2: Mask and Length[0]
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
err = readBytes(reader, tmpBuf) err = readBytes(reader, tmpBuf)
if err != nil { if err != nil {
@ -120,7 +124,7 @@ func readMessage(reader io.Reader) (*Message, error) {
// payload length // payload length
m.Size = uint(tmpBuf[0] & 0x7f) m.Size = uint(tmpBuf[0] & 0x7f)
/* (4) Extended payload */ // 4. Extended payload
if m.Size == 127 { if m.Size == 127 {
tmpBuf = make([]byte, 8) tmpBuf = make([]byte, 8)
@ -143,7 +147,7 @@ func readMessage(reader io.Reader) (*Message, error) {
} }
/* (5) Masking key */ // 5. Masking key
if mask != nil { if mask != nil {
tmpBuf = make([]byte, 4) tmpBuf = make([]byte, 4)
@ -157,7 +161,7 @@ func readMessage(reader io.Reader) (*Message, error) {
} }
/* (6) Read payload by chunks */ // 6. Read payload by chunks
m.Data = make([]byte, int(m.Size)) m.Data = make([]byte, int(m.Size))
cursor = 0 cursor = 0
@ -207,14 +211,14 @@ func (m Message) Send(writer io.Writer) error {
m.Size = uint(len(m.Data)) m.Size = uint(len(m.Data))
} }
/* (1) Byte 0 : FIN + opcode */ // 1. Byte 0 : FIN + opcode
var final byte = 0x80 var final byte = 0x80
if !m.Final { if !m.Final {
final = 0 final = 0
} }
header = append(header, final|byte(m.Type)) header = append(header, final|byte(m.Type))
/* (2) Get payload length */ // 2. Get payload length
if m.Size < 126 { // simple if m.Size < 126 { // simple
header = append(header, byte(m.Size)) header = append(header, byte(m.Size))
@ -237,12 +241,12 @@ func (m Message) Send(writer io.Writer) error {
} }
/* (3) Build write buffer */ // 3. Build write buffer
writeBuf := make([]byte, 0, len(header)+int(m.Size)) writeBuf := make([]byte, 0, len(header)+int(m.Size))
writeBuf = append(writeBuf, header...) writeBuf = append(writeBuf, header...)
writeBuf = append(writeBuf, m.Data[0:m.Size]...) writeBuf = append(writeBuf, m.Data[0:m.Size]...)
/* (4) Send over socket by chunks */ // 4. Send over socket by chunks
toWrite := len(header) + int(m.Size) toWrite := len(header) + int(m.Size)
cursor := 0 cursor := 0
for cursor < toWrite { for cursor < toWrite {
@ -272,17 +276,17 @@ func (m Message) Send(writer io.Writer) error {
// returns the message error // returns the message error
func (m *Message) check(fragment bool) error { func (m *Message) check(fragment bool) error {
/* (1) Invalid first fragment (not TEXT nor BINARY) */ // 1. Invalid first fragment (not TEXT nor BINARY)
if !m.Final && !fragment && m.Type != Text && m.Type != Binary { if !m.Final && !fragment && m.Type != Text && m.Type != Binary {
return ErrInvalidFragment return ErrInvalidFragment
} }
/* (2) Waiting fragment but received standalone frame */ // 2. Waiting fragment but received standalone frame
if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong { if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong {
return ErrInvalidFragment return ErrInvalidFragment
} }
/* (3) Control frame too long */ // 3. Control frame too long
if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) { if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) {
return ErrTooLongControlFrame return ErrTooLongControlFrame
} }
@ -335,8 +339,6 @@ func (m *Message) check(fragment bool) error {
return ErrInvalidOpCode return ErrInvalidOpCode
} }
return nil
} }
// readBytes reads from a reader into a byte array // readBytes reads from a reader into a byte array

View File

@ -267,7 +267,7 @@ func TestSimpleMessageSending(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
writer := new(bytes.Buffer) writer := &bytes.Buffer{}
err := tc.Base.Send(writer) err := tc.Base.Send(writer)

View File

@ -65,13 +65,13 @@ func (s *Server) BindDefault(f ControllerFunc) {
// Bind a controller to an URI scheme // Bind a controller to an URI scheme
func (s *Server) Bind(uri string, f ControllerFunc) error { func (s *Server) Bind(uri string, f ControllerFunc) error {
/* (1) Build URI parser */ // 1. Build URI parser
uriScheme, err := parser.Build(uri) uriScheme, err := parser.Build(uri)
if err != nil { if err != nil {
return fmt.Errorf("Cannot build URI: %s", err) return fmt.Errorf("Cannot build URI: %s", err)
} }
/* (2) Create controller */ // 2. Create controller
s.ctl.URI = append(s.ctl.URI, &Controller{ s.ctl.URI = append(s.ctl.URI, &Controller{
URI: uriScheme, URI: uriScheme,
Fun: f, Fun: f,
@ -88,10 +88,10 @@ func (s *Server) Launch() error {
/* (1) Listen socket /* (1) Listen socket
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Build full url */ // 1. Build full url
url := fmt.Sprintf("%s:%d", s.addr, s.port) url := fmt.Sprintf("%s:%d", s.addr, s.port)
/* (2) Bind socket to listen */ // 2. Bind socket to listen
s.sock, err = net.Listen("tcp", url) s.sock, err = net.Listen("tcp", url)
if err != nil { if err != nil {
return fmt.Errorf("Listen socket: %s", err) return fmt.Errorf("Listen socket: %s", err)
@ -101,14 +101,14 @@ func (s *Server) Launch() error {
fmt.Printf("+ listening on %s\n", url) fmt.Printf("+ listening on %s\n", url)
/* (3) Launch scheduler */ // 3. Launch scheduler
go s.scheduler() go s.scheduler()
/* (2) For each incoming connection (client) /* (2) For each incoming connection (client)
---------------------------------------------------------*/ ---------------------------------------------------------*/
for { for {
/* (1) Wait for client */ // 1. Wait for client
sock, err := s.sock.Accept() sock, err := s.sock.Accept()
if err != nil { if err != nil {
break break
@ -116,14 +116,14 @@ func (s *Server) Launch() error {
go func() { go func() {
/* (2) Try to create client */ // 2. Try to create client
cli, err := buildClient(sock, s.ctl, s.ch) cli, err := buildClient(sock, s.ctl, s.ch)
if err != nil { if err != nil {
fmt.Printf(" - %s\n", err) fmt.Printf(" - %s\n", err)
return return
} }
/* (3) Register client */ // 3. Register client
s.ch.register <- cli s.ch.register <- cli
}() }()
@ -141,15 +141,15 @@ func (s *Server) scheduler() {
select { select {
/* (1) Create client */ // 1. Create client
case client := <-s.ch.register: case client := <-s.ch.register:
s.clients[client.io.sock] = client s.clients[client.io.sock] = client
/* (2) Remove client */ // 2. Remove client
case client := <-s.ch.unregister: case client := <-s.ch.unregister:
delete(s.clients, client.io.sock) delete(s.clients, client.io.sock)
/* (3) Broadcast */ // 3. Broadcast
case msg := <-s.ch.broadcast: case msg := <-s.ch.broadcast:
for _, c := range s.clients { for _, c := range s.clients {
c.ch.send <- msg c.ch.send <- msg