refactor: semantic and idiomatic renames, fixes API changes
This commit is contained in:
parent
9e69b66289
commit
e7c84eddf8
146
client.go
146
client.go
|
@ -36,36 +36,27 @@ type client struct {
|
|||
status MessageError // close status ; 0 = nothing ; else -> must close
|
||||
}
|
||||
|
||||
// Create creates a new client
|
||||
func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) {
|
||||
// newClient creates a new client
|
||||
func newClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) {
|
||||
req := &upgrade.Request{}
|
||||
_, err := req.ReadFrom(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request read: %w", err)
|
||||
}
|
||||
|
||||
/* (1) Manage UPGRADE request
|
||||
---------------------------------------------------------*/
|
||||
// 1. Parse request
|
||||
req, _ := upgrade.Parse(s)
|
||||
|
||||
// 3. Build response
|
||||
res := req.BuildResponse()
|
||||
|
||||
// 4. Write into socket
|
||||
_, err := res.Send(s)
|
||||
_, err = res.WriteTo(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upgrade write: %w", err)
|
||||
}
|
||||
|
||||
if res.GetStatusCode() != 101 {
|
||||
if res.StatusCode != 101 {
|
||||
s.Close()
|
||||
return nil, fmt.Errorf("upgrade failed (HTTP %d)", res.GetStatusCode())
|
||||
return nil, fmt.Errorf("upgrade failed (HTTP %d)", res.StatusCode)
|
||||
}
|
||||
|
||||
/* (2) Initialise client
|
||||
---------------------------------------------------------*/
|
||||
// 1. Get upgrade data
|
||||
clientURI := req.GetURI()
|
||||
clientProtocol := res.GetProtocol()
|
||||
|
||||
// 2. Initialise client
|
||||
cli := &client{
|
||||
var cli = &client{
|
||||
io: clientIO{
|
||||
sock: s,
|
||||
reader: bufio.NewReader(s),
|
||||
|
@ -73,8 +64,8 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
|||
},
|
||||
|
||||
iface: &Client{
|
||||
Protocol: string(clientProtocol),
|
||||
Arguments: [][]string{{clientURI}},
|
||||
Protocol: string(res.Protocol),
|
||||
Arguments: [][]string{{req.URI()}},
|
||||
},
|
||||
|
||||
ch: clientChannelSet{
|
||||
|
@ -83,59 +74,46 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
|||
},
|
||||
}
|
||||
|
||||
/* (3) Find controller by URI
|
||||
---------------------------------------------------------*/
|
||||
// 1. Try to find one
|
||||
controller, arguments := ctl.Match(clientURI)
|
||||
|
||||
// 2. If nothing found -> error
|
||||
// find controller by URI
|
||||
controller, arguments := ctl.Match(req.URI())
|
||||
if controller == nil {
|
||||
return nil, fmt.Errorf("No controller found, no default controller set")
|
||||
return nil, fmt.Errorf("no controller found, no default controller set")
|
||||
}
|
||||
|
||||
// 3. Copy arguments
|
||||
// copy args
|
||||
cli.iface.Arguments = arguments
|
||||
|
||||
/* (4) Launch client routines
|
||||
---------------------------------------------------------*/
|
||||
// 1. Launch client controller
|
||||
go controller.Fun(
|
||||
cli.iface, // pass the client
|
||||
cli.ch.receive, // the receiver
|
||||
cli.ch.send, // the sender
|
||||
serverCh.broadcast, // broadcast sender
|
||||
)
|
||||
|
||||
// 2. Launch message reader
|
||||
go clientReader(cli)
|
||||
|
||||
// 3. Launc writer
|
||||
go clientWriter(cli)
|
||||
|
||||
return cli, nil
|
||||
|
||||
}
|
||||
|
||||
// reader reads and parses messages from the buffer
|
||||
// clientReader reads and parses messages from the buffer
|
||||
func clientReader(c *client) {
|
||||
var frag *Message
|
||||
|
||||
closeStatus := Normal
|
||||
clientAck := true
|
||||
var (
|
||||
frag *Message
|
||||
closeStatus = Normal
|
||||
clientAck = true
|
||||
)
|
||||
|
||||
c.io.reading.Add(1)
|
||||
|
||||
for {
|
||||
|
||||
// 1. if currently closing -> exit
|
||||
// currently closing -> exit
|
||||
if c.io.closing {
|
||||
fmt.Printf("[reader] killed because closing")
|
||||
break
|
||||
}
|
||||
|
||||
// 2. Parse message
|
||||
msg, err := readMessage(c.io.reader)
|
||||
|
||||
// Parse message
|
||||
var msg = &Message{}
|
||||
_, err := msg.ReadFrom(c.io.reader)
|
||||
if err == ErrUnmaskedFrame || err == ErrReservedBits {
|
||||
closeStatus = ProtocolError
|
||||
}
|
||||
|
@ -143,7 +121,7 @@ func clientReader(c *client) {
|
|||
break
|
||||
}
|
||||
|
||||
// 3. Fail on invalid message
|
||||
// invalid message
|
||||
msgErr := msg.check(frag != nil)
|
||||
if msgErr != nil {
|
||||
|
||||
|
@ -151,7 +129,7 @@ func clientReader(c *client) {
|
|||
|
||||
switch msgErr {
|
||||
|
||||
// Fail
|
||||
// fail
|
||||
case ErrUnexpectedContinuation:
|
||||
closeStatus = None
|
||||
clientAck = false
|
||||
|
@ -182,7 +160,7 @@ func clientReader(c *client) {
|
|||
|
||||
}
|
||||
|
||||
// 4. Ping <-> Pong
|
||||
// ping <-> Pong
|
||||
if msg.Type == Ping && c.io.writing {
|
||||
msg.Final = true
|
||||
msg.Type = Pong
|
||||
|
@ -190,7 +168,7 @@ func clientReader(c *client) {
|
|||
continue
|
||||
}
|
||||
|
||||
// 5. Store first fragment
|
||||
// store first fragment
|
||||
if frag == nil && !msg.Final {
|
||||
frag = &Message{
|
||||
Type: msg.Type,
|
||||
|
@ -201,7 +179,7 @@ func clientReader(c *client) {
|
|||
continue
|
||||
}
|
||||
|
||||
// 6. Store fragments
|
||||
// store fragments
|
||||
if frag != nil {
|
||||
frag.Final = msg.Final
|
||||
frag.Size += msg.Size
|
||||
|
@ -226,7 +204,7 @@ func clientReader(c *client) {
|
|||
|
||||
}
|
||||
|
||||
// 7. Dispatch to receiver
|
||||
// dispatch to receiver
|
||||
if msg.Type == Text || msg.Type == Binary {
|
||||
c.ch.receive <- *msg
|
||||
}
|
||||
|
@ -236,69 +214,59 @@ func clientReader(c *client) {
|
|||
close(c.ch.receive)
|
||||
c.io.reading.Done()
|
||||
|
||||
// 8. close channel (if not already done)
|
||||
// close channel (if not already done)
|
||||
// fmt.Printf("[reader] end\n")
|
||||
c.close(closeStatus, clientAck)
|
||||
|
||||
}
|
||||
|
||||
// writer writes into websocket
|
||||
// and is triggered by client.ch.send channel
|
||||
// clientWriter writes to the websocket connection and is triggered by
|
||||
// client.ch.send channel
|
||||
func clientWriter(c *client) {
|
||||
|
||||
c.io.writing = true // if channel still exists
|
||||
|
||||
for msg := range c.ch.send {
|
||||
|
||||
// 2. Send message
|
||||
err := msg.Send(c.io.sock)
|
||||
|
||||
// 3. Fail on error
|
||||
_, err := msg.WriteTo(c.io.sock)
|
||||
if err != nil {
|
||||
fmt.Printf(" [writer] %s\n", err)
|
||||
c.io.writing = false
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
c.io.writing = false
|
||||
|
||||
// 4. close channel (if not already done)
|
||||
// close channel (if not already done)
|
||||
// fmt.Printf("[writer] end\n")
|
||||
c.close(Normal, true)
|
||||
|
||||
}
|
||||
|
||||
// closes the connection
|
||||
// close the connection
|
||||
// send CLOSE frame is 'status' is not NONE
|
||||
// wait for the next message (CLOSE acknowledge) if 'clientACK'
|
||||
// then delete client
|
||||
func (c *client) close(status MessageError, clientACK bool) {
|
||||
|
||||
// 1. Fail if already closing
|
||||
// fail if already closing
|
||||
alreadyClosing := false
|
||||
c.io.closingMu.Lock()
|
||||
alreadyClosing = c.io.closing
|
||||
c.io.closing = true
|
||||
c.io.closingMu.Unlock()
|
||||
|
||||
if alreadyClosing {
|
||||
return
|
||||
}
|
||||
|
||||
// 2. kill writer' if still running
|
||||
// kill writer' if still running
|
||||
if c.io.writing {
|
||||
close(c.ch.send)
|
||||
}
|
||||
|
||||
// 3. kill reader if still running
|
||||
// kill reader if still running
|
||||
c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
|
||||
c.io.reading.Wait()
|
||||
|
||||
if status != None {
|
||||
|
||||
// 3. Build message
|
||||
msg := &Message{
|
||||
Final: true,
|
||||
Type: Close,
|
||||
|
@ -307,40 +275,18 @@ func (c *client) close(status MessageError, clientACK bool) {
|
|||
}
|
||||
binary.BigEndian.PutUint16(msg.Data, uint16(status))
|
||||
|
||||
// 4. Send message
|
||||
msg.Send(c.io.sock)
|
||||
// if err != nil {
|
||||
// fmt.Printf("[close] send error (%s0\n", err)
|
||||
// }
|
||||
|
||||
msg.WriteTo(c.io.sock)
|
||||
}
|
||||
|
||||
// 2. Wait for client CLOSE if needed
|
||||
// wait for client CLOSE if needed
|
||||
if clientACK {
|
||||
|
||||
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
|
||||
/* Wait for message */
|
||||
readMessage(c.io.reader)
|
||||
// if err != nil || msg.Type != CLOSE {
|
||||
// if err == nil {
|
||||
// fmt.Printf("[close] received OpCode = %d\n", msg.Type)
|
||||
// } else {
|
||||
// fmt.Printf("[close] read error (%v)\n", err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// fmt.Printf("[close] received ACK\n")
|
||||
|
||||
var tmpMsg = &Message{}
|
||||
tmpMsg.ReadFrom(c.io.reader)
|
||||
}
|
||||
|
||||
// 3. Close socket
|
||||
c.io.sock.Close()
|
||||
// fmt.Printf("[close] socket closed\n")
|
||||
|
||||
// 4. Unregister
|
||||
c.io.kill <- c
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package iface
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -11,12 +11,11 @@ func main() {
|
|||
|
||||
startTime := time.Now().UnixNano()
|
||||
|
||||
// 1. Bind WebSocket server
|
||||
serv := ws.CreateServer("0.0.0.0", 4444)
|
||||
// creqte WebSocket server
|
||||
serv := ws.NewServer("0.0.0.0", 4444)
|
||||
|
||||
// 2. Bind default controller
|
||||
// bind default controller
|
||||
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
||||
|
||||
defer func() {
|
||||
if recover() != nil {
|
||||
fmt.Printf("*** PANIC\n")
|
||||
|
@ -24,35 +23,28 @@ func main() {
|
|||
}()
|
||||
|
||||
for msg := range receiver {
|
||||
|
||||
// if receive message -> send it back
|
||||
sender <- msg
|
||||
// close(sender)
|
||||
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// 3. Bind to URI
|
||||
// bnd to URI
|
||||
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
||||
|
||||
fmt.Printf("[uri] connected\n")
|
||||
|
||||
for msg := range receiver {
|
||||
|
||||
fmt.Printf("[uri] received '%s'\n", msg.Data)
|
||||
sender <- msg
|
||||
|
||||
}
|
||||
|
||||
fmt.Printf("[uri] unexpectedly closed\n")
|
||||
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 4. Launch the server
|
||||
// launch the server
|
||||
err = serv.Launch()
|
||||
if err != nil {
|
||||
fmt.Printf("[ERROR] %s\n", err)
|
||||
|
@ -60,5 +52,4 @@ func main() {
|
|||
}
|
||||
|
||||
fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3)
|
||||
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ type Controller struct {
|
|||
Fun ControllerFunc // controller function
|
||||
}
|
||||
|
||||
// ControllerSet is set of controllers
|
||||
// ControllerSet contains a set of controllers
|
||||
type ControllerSet struct {
|
||||
Def *Controller // default controller
|
||||
URI []*Controller // uri controllers
|
||||
|
@ -27,35 +27,22 @@ type ControllerSet struct {
|
|||
// Match finds a controller for a given URI
|
||||
// also it returns the matching string patterns
|
||||
func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
|
||||
|
||||
// 1. Initialise argument list
|
||||
arguments := [][]string{{uri}}
|
||||
|
||||
// 2. Try each controller
|
||||
for _, c := range s.URI {
|
||||
|
||||
/* 1. If matches */
|
||||
if c.URI.Match(uri) {
|
||||
|
||||
/* Extract matches */
|
||||
match := c.URI.GetAllMatch()
|
||||
|
||||
/* Add them to the 'arg' attribute */
|
||||
arguments = append(arguments, match...)
|
||||
|
||||
/* Mark that we have a controller */
|
||||
return c, arguments
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 3. If no controller found -> set default controller
|
||||
// fallback to default
|
||||
if s.Def != nil {
|
||||
return s.Def, arguments
|
||||
}
|
||||
|
||||
// 4. If default is NIL, return empty controller
|
||||
// no default
|
||||
return nil, arguments
|
||||
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
// Maximum line length
|
||||
var maxLineLength = 4096
|
||||
const maxLineLength = 4096
|
||||
|
||||
// ChunkReader struct
|
||||
type ChunkReader struct {
|
||||
|
@ -32,19 +32,17 @@ func NewReader(r io.Reader) *ChunkReader {
|
|||
|
||||
}
|
||||
|
||||
// Read reads a chunk, err is io.EOF when done
|
||||
// Read reads a chunk, io.EOF when done
|
||||
func (r *ChunkReader) Read() ([]byte, error) {
|
||||
|
||||
// 1. If already ended
|
||||
// already ended
|
||||
if r.isEnded {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// 2. Read line
|
||||
// read line
|
||||
var line []byte
|
||||
line, err := r.reader.ReadSlice('\n')
|
||||
|
||||
// 3. manage errors
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
|
@ -57,10 +55,8 @@ func (r *ChunkReader) Read() ([]byte, error) {
|
|||
return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength)
|
||||
}
|
||||
|
||||
// 4. Trim
|
||||
line = removeTrailingSpace(line)
|
||||
line = trimSpaces(line)
|
||||
|
||||
// 5. Manage ending line
|
||||
if len(line) == 0 {
|
||||
r.isEnded = true
|
||||
return line, io.EOF
|
||||
|
@ -70,13 +66,13 @@ func (r *ChunkReader) Read() ([]byte, error) {
|
|||
|
||||
}
|
||||
|
||||
func removeTrailingSpace(b []byte) []byte {
|
||||
for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
|
||||
func trimSpaces(b []byte) []byte {
|
||||
for len(b) > 0 && isSpaceChar(b[len(b)-1]) {
|
||||
b = b[:len(b)-1]
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func isASCIISpace(b byte) bool {
|
||||
func isSpaceChar(b byte) bool {
|
||||
return b == ' ' || b == '\t' || b == '\r' || b == '\n'
|
||||
}
|
||||
|
|
|
@ -4,33 +4,32 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
// invalid request
|
||||
// ErrInvalidRequest for invalid requests
|
||||
// - multiple-value if only 1 expected
|
||||
type InvalidRequest struct {
|
||||
type ErrInvalidRequest struct {
|
||||
Field string
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (err InvalidRequest) Error() string {
|
||||
return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason)
|
||||
func (err ErrInvalidRequest) Error() string {
|
||||
return fmt.Sprintf("invalid field '%s': %s", err.Field, err.Reason)
|
||||
}
|
||||
|
||||
// Request misses fields (request-line or headers)
|
||||
type IncompleteRequest struct {
|
||||
MissingField string
|
||||
// ErrIncompleteRequest when mandatory request fields are missing (request-line or headers)
|
||||
// it contains the missing field as a string
|
||||
type ErrIncompleteRequest string
|
||||
|
||||
func (err ErrIncompleteRequest) Error() string {
|
||||
return fmt.Sprintf("incomplete request, '%s' is invalid or missing", string(err))
|
||||
}
|
||||
|
||||
func (err IncompleteRequest) Error() string {
|
||||
return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField)
|
||||
}
|
||||
|
||||
// Request has a violated origin policy
|
||||
type InvalidOriginPolicy struct {
|
||||
// ErrInvalidOriginPolicy when a request has a violated origin policy
|
||||
type ErrInvalidOriginPolicy struct {
|
||||
Host string
|
||||
Origin string
|
||||
err error
|
||||
}
|
||||
|
||||
func (err InvalidOriginPolicy) Error() string {
|
||||
func (err ErrInvalidOriginPolicy) Error() string {
|
||||
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
|
||||
}
|
||||
|
|
|
@ -10,11 +10,11 @@ import (
|
|||
func (r *Request) extractHostPort(bb HeaderValue) error {
|
||||
|
||||
if len(bb) != 1 {
|
||||
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
return &ErrInvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
}
|
||||
|
||||
if len(bb[0]) <= 3 {
|
||||
return &InvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))}
|
||||
return &ErrInvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))}
|
||||
}
|
||||
|
||||
split := strings.Split(string(bb[0]), ":")
|
||||
|
@ -29,8 +29,8 @@ func (r *Request) extractHostPort(bb HeaderValue) error {
|
|||
// extract port
|
||||
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||
if err != nil {
|
||||
r.code = BadRequest
|
||||
return &InvalidRequest{"Host", "cannot read port"}
|
||||
r.statusCode = StatusBadRequest
|
||||
return &ErrInvalidRequest{"Host", "cannot read port"}
|
||||
}
|
||||
|
||||
r.port = uint16(readPort)
|
||||
|
@ -39,8 +39,8 @@ func (r *Request) extractHostPort(bb HeaderValue) error {
|
|||
if len(r.origin) > 0 {
|
||||
if err != nil {
|
||||
err = r.checkOriginPolicy()
|
||||
r.code = Forbidden
|
||||
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||
r.statusCode = StatusForbidden
|
||||
return &ErrInvalidOriginPolicy{r.host, r.origin, err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,8 +57,8 @@ func (r *Request) extractOrigin(bb HeaderValue) error {
|
|||
}
|
||||
|
||||
if len(bb) != 1 {
|
||||
r.code = Forbidden
|
||||
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
r.statusCode = StatusForbidden
|
||||
return &ErrInvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
}
|
||||
|
||||
r.origin = string(bb[0])
|
||||
|
@ -67,8 +67,8 @@ func (r *Request) extractOrigin(bb HeaderValue) error {
|
|||
if len(r.host) > 0 {
|
||||
err := r.checkOriginPolicy()
|
||||
if err != nil {
|
||||
r.code = Forbidden
|
||||
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||
r.statusCode = StatusForbidden
|
||||
return &ErrInvalidOriginPolicy{r.host, r.origin, err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,8 +96,8 @@ func (r *Request) checkConnection(bb HeaderValue) error {
|
|||
|
||||
}
|
||||
|
||||
r.code = BadRequest
|
||||
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
||||
r.statusCode = StatusBadRequest
|
||||
return &ErrInvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
||||
|
||||
}
|
||||
|
||||
|
@ -106,8 +106,8 @@ func (r *Request) checkConnection(bb HeaderValue) error {
|
|||
func (r *Request) checkUpgrade(bb HeaderValue) error {
|
||||
|
||||
if len(bb) != 1 {
|
||||
r.code = BadRequest
|
||||
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
r.statusCode = StatusBadRequest
|
||||
return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
}
|
||||
|
||||
if strings.ToLower(string(bb[0])) == "websocket" {
|
||||
|
@ -115,8 +115,8 @@ func (r *Request) checkUpgrade(bb HeaderValue) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
r.code = BadRequest
|
||||
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
||||
r.statusCode = StatusBadRequest
|
||||
return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
||||
|
||||
}
|
||||
|
||||
|
@ -125,8 +125,8 @@ func (r *Request) checkUpgrade(bb HeaderValue) error {
|
|||
func (r *Request) checkVersion(bb HeaderValue) error {
|
||||
|
||||
if len(bb) != 1 || string(bb[0]) != "13" {
|
||||
r.code = UpgradeRequired
|
||||
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
||||
r.statusCode = StatusUpgradeRequired
|
||||
return &ErrInvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
||||
}
|
||||
|
||||
r.hasVersion = true
|
||||
|
@ -139,8 +139,8 @@ func (r *Request) checkVersion(bb HeaderValue) error {
|
|||
func (r *Request) extractKey(bb HeaderValue) error {
|
||||
|
||||
if len(bb) != 1 || len(bb[0]) != 24 {
|
||||
r.code = BadRequest
|
||||
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
||||
r.statusCode = StatusBadRequest
|
||||
return &ErrInvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
||||
}
|
||||
|
||||
r.key = bb[0]
|
||||
|
|
|
@ -1,118 +0,0 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// Line represents the HTTP Request line
|
||||
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
||||
type Line struct {
|
||||
method Method
|
||||
uri string
|
||||
version byte
|
||||
}
|
||||
|
||||
// Parse parses the first HTTP request line
|
||||
func (r *Line) Parse(b []byte) error {
|
||||
|
||||
// 1. Split by ' '
|
||||
parts := bytes.Split(b, []byte(" "))
|
||||
|
||||
// 2. Fail when missing parts
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
||||
}
|
||||
|
||||
// 3. Extract HTTP method
|
||||
err := r.extractHttpMethod(parts[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. Extract URI
|
||||
err = r.extractURI(parts[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 5. Extract version
|
||||
err = r.extractHttpVersion(parts[2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// GetURI returns the actual URI
|
||||
func (r Line) GetURI() string {
|
||||
return r.uri
|
||||
}
|
||||
|
||||
// extractHttpMethod extracts the HTTP method from a []byte
|
||||
// and checks for errors
|
||||
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
||||
func (r *Line) extractHttpMethod(b []byte) error {
|
||||
|
||||
switch string(b) {
|
||||
// case "OPTIONS": r.method = OPTIONS
|
||||
case "GET":
|
||||
r.method = Get
|
||||
// case "HEAD": r.method = HEAD
|
||||
// case "POST": r.method = POST
|
||||
// case "PUT": r.method = PUT
|
||||
// case "DELETE": r.method = DELETE
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid HTTP method '%s', expected 'GET'", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractURI extracts the URI from a []byte and checks for errors
|
||||
// allowed format: /([^/]/)*/?
|
||||
func (r *Line) extractURI(b []byte) error {
|
||||
|
||||
// 1. Check format
|
||||
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
||||
if !checker.Match(b) {
|
||||
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
||||
}
|
||||
|
||||
// 2. Store
|
||||
r.uri = string(b)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// extractHttpVersion extracts the version and checks for errors
|
||||
// allowed format: [1-9] or [1.9].[0-9]
|
||||
func (r *Line) extractHttpVersion(b []byte) error {
|
||||
|
||||
// 1. Extract version parts
|
||||
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
|
||||
|
||||
if !extractor.Match(b) {
|
||||
return fmt.Errorf("invalid HTTP version, expected INT or INT.INT, got '%s'", b)
|
||||
}
|
||||
|
||||
// 2. Extract version number
|
||||
matches := extractor.FindSubmatch(b)
|
||||
var version byte = matches[1][0] - '0'
|
||||
|
||||
// 3. Extract subversion (if exists)
|
||||
var subVersion byte = 0
|
||||
if len(matches[2]) > 0 {
|
||||
subVersion = matches[2][0] - '0'
|
||||
}
|
||||
|
||||
// 4. Store version (x 10 to fit uint8)
|
||||
r.version = version*10 + subVersion
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
package upgrade
|
||||
|
||||
// Method represents available http methods
|
||||
type Method uint8
|
||||
|
||||
// http methods
|
||||
const (
|
||||
Options Method = iota
|
||||
Get
|
||||
Head
|
||||
Post
|
||||
Put
|
||||
Delete
|
||||
)
|
|
@ -7,157 +7,139 @@ import (
|
|||
"git.xdrm.io/go/ws/internal/http/reader"
|
||||
)
|
||||
|
||||
// If origin is required
|
||||
// whether origin is required
|
||||
const bypassOriginPolicy = true
|
||||
|
||||
// Request represents an HTTP Upgrade request
|
||||
type Request struct {
|
||||
first bool // whether the first line has been read (GET uri HTTP/version)
|
||||
// whether the first line has been read (GET uri HTTP/version)
|
||||
first bool
|
||||
statusCode StatusCode
|
||||
requestLine RequestLine
|
||||
|
||||
// status code
|
||||
code StatusCode
|
||||
|
||||
// request line
|
||||
request Line
|
||||
|
||||
// data to check origin (depends of reading order)
|
||||
// data to check origin (depends on reading order)
|
||||
host string
|
||||
port uint16 // 0 if not set
|
||||
origin string
|
||||
validPolicy bool
|
||||
|
||||
// ws data
|
||||
// websocket specific
|
||||
key []byte
|
||||
protocols [][]byte
|
||||
|
||||
// required fields check
|
||||
// mandatory fields to check
|
||||
hasConnection bool
|
||||
hasUpgrade bool
|
||||
hasVersion bool
|
||||
}
|
||||
|
||||
// Parse builds an upgrade HTTP request
|
||||
// from a reader (typically bufio.NewRead of the socket)
|
||||
func Parse(r io.Reader) (request *Request, err error) {
|
||||
// ReadFrom reads an upgrade HTTP request ; typically from bufio.NewRead of the
|
||||
// socket
|
||||
//
|
||||
// implements io.ReaderFrom
|
||||
func (req *Request) ReadFrom(r io.Reader) (int64, error) {
|
||||
var read int64
|
||||
|
||||
req := &Request{
|
||||
code: 500,
|
||||
}
|
||||
// reset request
|
||||
req.statusCode = 500
|
||||
|
||||
/* (1) Parse request
|
||||
---------------------------------------------------------*/
|
||||
// 1. Get chunk reader
|
||||
cr := reader.NewReader(r)
|
||||
if err != nil {
|
||||
return req, fmt.Errorf("create chunk reader: %w", err)
|
||||
}
|
||||
|
||||
// 2. Parse header line by line
|
||||
// parse header line by line
|
||||
var cr = reader.NewReader(r)
|
||||
for {
|
||||
|
||||
line, err := cr.Read()
|
||||
read += int64(len(line))
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return req, err
|
||||
return read, err
|
||||
}
|
||||
|
||||
err = req.parseHeader(line)
|
||||
|
||||
if err != nil {
|
||||
return req, err
|
||||
return read, err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 3. Check completion
|
||||
err = req.isComplete()
|
||||
err := req.isComplete()
|
||||
if err != nil {
|
||||
req.code = BadRequest
|
||||
return req, err
|
||||
req.statusCode = StatusBadRequest
|
||||
return read, err
|
||||
}
|
||||
|
||||
req.code = SwitchingProtocols
|
||||
return req, nil
|
||||
|
||||
req.statusCode = StatusSwitchingProtocols
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// StatusCode returns the status current
|
||||
func (r Request) StatusCode() StatusCode {
|
||||
return r.code
|
||||
func (req Request) StatusCode() StatusCode {
|
||||
return req.statusCode
|
||||
}
|
||||
|
||||
// BuildResponse builds a response from the request
|
||||
func (r *Request) BuildResponse() *Response {
|
||||
func (req *Request) BuildResponse() *Response {
|
||||
|
||||
inst := &Response{}
|
||||
|
||||
// 1. Copy code
|
||||
inst.SetStatusCode(r.code)
|
||||
|
||||
// 2. Set Protocol
|
||||
if len(r.protocols) > 0 {
|
||||
inst.SetProtocol(r.protocols[0])
|
||||
res := &Response{
|
||||
StatusCode: req.statusCode,
|
||||
Protocol: nil,
|
||||
}
|
||||
|
||||
// 4. Process key
|
||||
inst.ProcessKey(r.key)
|
||||
if len(req.protocols) > 0 {
|
||||
res.Protocol = req.protocols[0]
|
||||
}
|
||||
|
||||
return inst
|
||||
res.ProcessKey(req.key)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// GetURI returns the actual URI
|
||||
func (r Request) GetURI() string {
|
||||
return r.request.GetURI()
|
||||
// URI returns the actual URI
|
||||
func (req Request) URI() string {
|
||||
return req.requestLine.URI()
|
||||
}
|
||||
|
||||
// parseHeader parses any http request line
|
||||
// (header and request-line)
|
||||
func (r *Request) parseHeader(b []byte) error {
|
||||
|
||||
/* (1) First line -> GET {uri} HTTP/{version}
|
||||
---------------------------------------------------------*/
|
||||
if !r.first {
|
||||
|
||||
err := r.request.Parse(b)
|
||||
func (req *Request) parseHeader(b []byte) error {
|
||||
// first line -> GET {uri} HTTP/{version}
|
||||
if !req.first {
|
||||
|
||||
_, err := req.requestLine.Read(b)
|
||||
if err != nil {
|
||||
r.code = BadRequest
|
||||
return &InvalidRequest{"Request-Line", err.Error()}
|
||||
req.statusCode = StatusBadRequest
|
||||
return &ErrInvalidRequest{"Request-Line", err.Error()}
|
||||
}
|
||||
|
||||
r.first = true
|
||||
req.first = true
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
/* (2) Other lines -> Header-Name: Header-Value
|
||||
---------------------------------------------------------*/
|
||||
// 1. Try to parse header
|
||||
// other lines -> Header-Name: Header-Value
|
||||
head, err := ReadHeader(b)
|
||||
if err != nil {
|
||||
r.code = BadRequest
|
||||
req.statusCode = StatusBadRequest
|
||||
return fmt.Errorf("parse header: %w", err)
|
||||
}
|
||||
|
||||
// 2. Manage header
|
||||
switch head.Name {
|
||||
case Host:
|
||||
err = r.extractHostPort(head.Values)
|
||||
err = req.extractHostPort(head.Values)
|
||||
case Origin:
|
||||
err = r.extractOrigin(head.Values)
|
||||
err = req.extractOrigin(head.Values)
|
||||
case Upgrade:
|
||||
err = r.checkUpgrade(head.Values)
|
||||
err = req.checkUpgrade(head.Values)
|
||||
case Connection:
|
||||
err = r.checkConnection(head.Values)
|
||||
err = req.checkConnection(head.Values)
|
||||
case WSVersion:
|
||||
err = r.checkVersion(head.Values)
|
||||
err = req.checkVersion(head.Values)
|
||||
case WSKey:
|
||||
err = r.extractKey(head.Values)
|
||||
err = req.extractKey(head.Values)
|
||||
case WSProtocol:
|
||||
err = r.extractProtocols(head.Values)
|
||||
err = req.extractProtocols(head.Values)
|
||||
|
||||
default:
|
||||
return nil
|
||||
|
@ -169,48 +151,39 @@ func (r *Request) parseHeader(b []byte) error {
|
|||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// isComplete returns whether the Upgrade Request
|
||||
// is complete (no missing required item)
|
||||
func (r Request) isComplete() error {
|
||||
|
||||
// 1. Request-Line
|
||||
if !r.first {
|
||||
return &IncompleteRequest{"Request-Line"}
|
||||
// is complete (no required field missing)
|
||||
// returns nil on success
|
||||
func (req Request) isComplete() error {
|
||||
if !req.first {
|
||||
return ErrIncompleteRequest("Request-Line")
|
||||
}
|
||||
|
||||
// 2. Host
|
||||
if len(r.host) == 0 {
|
||||
return &IncompleteRequest{"Host"}
|
||||
if len(req.host) == 0 {
|
||||
return ErrIncompleteRequest("Host")
|
||||
}
|
||||
|
||||
// 3. Origin
|
||||
if !bypassOriginPolicy && len(r.origin) == 0 {
|
||||
return &IncompleteRequest{"Origin"}
|
||||
if !bypassOriginPolicy && len(req.origin) == 0 {
|
||||
return ErrIncompleteRequest("Origin")
|
||||
}
|
||||
|
||||
// 4. Connection
|
||||
if !r.hasConnection {
|
||||
return &IncompleteRequest{"Connection"}
|
||||
if !req.hasConnection {
|
||||
return ErrIncompleteRequest("Connection")
|
||||
}
|
||||
|
||||
// 5. Upgrade
|
||||
if !r.hasUpgrade {
|
||||
return &IncompleteRequest{"Upgrade"}
|
||||
if !req.hasUpgrade {
|
||||
return ErrIncompleteRequest("Upgrade")
|
||||
}
|
||||
|
||||
// 6. Sec-WebSocket-Version
|
||||
if !r.hasVersion {
|
||||
return &IncompleteRequest{"Sec-WebSocket-Version"}
|
||||
if !req.hasVersion {
|
||||
return ErrIncompleteRequest("Sec-WebSocket-Version")
|
||||
}
|
||||
|
||||
// 7. Sec-WebSocket-Key
|
||||
if len(r.key) < 1 {
|
||||
return &IncompleteRequest{"Sec-WebSocket-Key"}
|
||||
if len(req.key) < 1 {
|
||||
return ErrIncompleteRequest("Sec-WebSocket-Key")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// RequestLine represents the HTTP Request line
|
||||
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
||||
type RequestLine struct {
|
||||
uri string
|
||||
version byte
|
||||
}
|
||||
|
||||
// Read an HTTP request line from a byte array
|
||||
//
|
||||
// implements io.Reader
|
||||
func (rl *RequestLine) Read(b []byte) (int, error) {
|
||||
var read = len(b)
|
||||
|
||||
// split by spaces
|
||||
parts := bytes.Split(b, []byte(" "))
|
||||
|
||||
if len(parts) != 3 {
|
||||
return read, fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
||||
}
|
||||
|
||||
err := rl.extractHttpMethod(parts[0])
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
|
||||
err = rl.extractURI(parts[1])
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
|
||||
err = rl.extractHttpVersion(parts[2])
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
|
||||
return read, nil
|
||||
|
||||
}
|
||||
|
||||
// URI of the request line
|
||||
func (rl RequestLine) URI() string {
|
||||
return rl.uri
|
||||
}
|
||||
|
||||
// extractHttpMethod extracts the HTTP method from a []byte
|
||||
// and checks for errors
|
||||
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
||||
func (rl *RequestLine) extractHttpMethod(b []byte) error {
|
||||
if string(b) != "GET" {
|
||||
return fmt.Errorf("invalid HTTP method '%s', expected 'GET'", b)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractURI extracts the URI from a []byte and checks for errors
|
||||
// allowed format: /([^/]/)*/?
|
||||
func (rl *RequestLine) extractURI(b []byte) error {
|
||||
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
||||
if !checker.Match(b) {
|
||||
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
||||
}
|
||||
rl.uri = string(b)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// extractHttpVersion extracts the version and checks for errors
|
||||
// allowed format: [1-9] or [1.9].[0-9]
|
||||
func (rl *RequestLine) extractHttpVersion(b []byte) error {
|
||||
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
|
||||
|
||||
if !extractor.Match(b) {
|
||||
return fmt.Errorf("invalid HTTP version, expected INT or INT.INT, got '%s'", b)
|
||||
}
|
||||
matches := extractor.FindSubmatch(b)
|
||||
|
||||
var version byte = matches[1][0] - '0'
|
||||
|
||||
var subversion byte = 0
|
||||
if len(matches[2]) > 0 {
|
||||
subversion = matches[2][0] - '0'
|
||||
}
|
||||
|
||||
rl.version = version*10 + subversion
|
||||
return nil
|
||||
}
|
|
@ -7,17 +7,15 @@ import (
|
|||
)
|
||||
|
||||
func TestEOFSocket(t *testing.T) {
|
||||
var (
|
||||
socket = &bytes.Buffer{}
|
||||
req = &Request{}
|
||||
)
|
||||
|
||||
socket := &bytes.Buffer{}
|
||||
|
||||
_, err := Parse(socket)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Empty socket expected EOF, got no error")
|
||||
} else if err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("Empty socket expected EOF, got '%s'", err)
|
||||
_, err := req.ReadFrom(socket)
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("unexpected error <%v> expected <%v>", err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInvalidRequestLine(t *testing.T) {
|
||||
|
@ -59,15 +57,13 @@ func TestInvalidRequestLine(t *testing.T) {
|
|||
socket.Write([]byte(tc.Reqline))
|
||||
socket.Write([]byte("\r\n\r\n"))
|
||||
|
||||
_, err := Parse(socket)
|
||||
|
||||
var req = &Request{}
|
||||
_, err := req.ReadFrom(socket)
|
||||
if !tc.HasError {
|
||||
|
||||
// no error -> ok
|
||||
if err == nil {
|
||||
continue
|
||||
// error for the end of the request -> ok
|
||||
} else if _, ok := err.(*IncompleteRequest); ok {
|
||||
} else if _, ok := err.(ErrIncompleteRequest); ok {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -80,7 +76,7 @@ func TestInvalidRequestLine(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
|
||||
ir, ok := err.(*InvalidRequest)
|
||||
ir, ok := err.(*ErrInvalidRequest)
|
||||
|
||||
// not InvalidRequest err -> error
|
||||
if !ok || ir.Field != "Request-Line" {
|
||||
|
@ -131,15 +127,15 @@ func TestInvalidHost(t *testing.T) {
|
|||
socket.Write([]byte(tc.Host))
|
||||
socket.Write([]byte("\r\n\r\n"))
|
||||
|
||||
_, err := Parse(socket)
|
||||
|
||||
var req = &Request{}
|
||||
_, err := req.ReadFrom(socket)
|
||||
if !tc.HasError {
|
||||
|
||||
// no error -> ok
|
||||
if err == nil {
|
||||
continue
|
||||
// error for the end of the request -> ok
|
||||
} else if _, ok := err.(*IncompleteRequest); ok {
|
||||
} else if _, ok := err.(ErrIncompleteRequest); ok {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -153,7 +149,7 @@ func TestInvalidHost(t *testing.T) {
|
|||
}
|
||||
|
||||
// check if InvalidRequest
|
||||
ir, ok := err.(*InvalidRequest)
|
||||
ir, ok := err.(ErrInvalidRequest)
|
||||
|
||||
// not InvalidRequest err -> error
|
||||
if ok && ir.Field != "Host" {
|
||||
|
|
|
@ -7,88 +7,58 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
// HTTPVersion constant
|
||||
const HTTPVersion = "1.1"
|
||||
// constants
|
||||
const (
|
||||
httpVersion = "1.1"
|
||||
wsVersion = 13
|
||||
keySalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
)
|
||||
|
||||
// UsedWSVersion constant websocket version
|
||||
const UsedWSVersion = 13
|
||||
|
||||
// WSSalt constant websocket salt
|
||||
const WSSalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
// Response represents an HTTP Upgrade Response
|
||||
// Response is an HTTP Upgrade Response
|
||||
type Response struct {
|
||||
code StatusCode // status code
|
||||
accept []byte // processed from Sec-WebSocket-Key
|
||||
protocol []byte // set from Sec-WebSocket-Protocol or none if not received
|
||||
}
|
||||
|
||||
// SetStatusCode sets the status code
|
||||
func (r *Response) SetStatusCode(sc StatusCode) {
|
||||
r.code = sc
|
||||
}
|
||||
|
||||
// SetProtocol sets the protocols
|
||||
func (r *Response) SetProtocol(p []byte) {
|
||||
r.protocol = p
|
||||
StatusCode StatusCode
|
||||
// Sec-WebSocket-Protocol or nil if missing
|
||||
Protocol []byte
|
||||
// processed from Sec-WebSocket-Key
|
||||
key []byte
|
||||
}
|
||||
|
||||
// ProcessKey processes the accept token according
|
||||
// to the rfc from the Sec-WebSocket-Key
|
||||
func (r *Response) ProcessKey(k []byte) {
|
||||
|
||||
// do nothing for empty key
|
||||
if k == nil || len(k) == 0 {
|
||||
r.accept = nil
|
||||
// ignore empty key
|
||||
if k == nil || len(k) < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// 1. Concat with constant salt
|
||||
mix := append(k, []byte(WSSalt)...)
|
||||
|
||||
// 2. Hash with sha1 algorithm
|
||||
digest := sha1.Sum(mix)
|
||||
|
||||
// 3. Base64 encode it
|
||||
r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
|
||||
|
||||
// concat with constant salt
|
||||
salted := append(k, []byte(keySalt)...)
|
||||
// hash with sha1
|
||||
digest := sha1.Sum(salted)
|
||||
// base64 encode
|
||||
r.key = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
|
||||
}
|
||||
|
||||
// Send sends the response through an io.Writer
|
||||
// typically a socket
|
||||
func (r Response) Send(w io.Writer) (int, error) {
|
||||
// WriteTo writes the response; typically in a socket
|
||||
//
|
||||
// implements io.WriterTo
|
||||
func (r Response) WriteTo(w io.Writer) (int64, error) {
|
||||
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", httpVersion, r.StatusCode, r.StatusCode)
|
||||
|
||||
// 1. Build response line
|
||||
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HTTPVersion, r.code, r.code)
|
||||
|
||||
// 2. Build headers
|
||||
optionalProtocol := ""
|
||||
if len(r.protocol) > 0 {
|
||||
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol)
|
||||
if len(r.Protocol) > 0 {
|
||||
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.Protocol)
|
||||
}
|
||||
|
||||
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", UsedWSVersion, optionalProtocol)
|
||||
if r.accept != nil {
|
||||
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept)
|
||||
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", wsVersion, optionalProtocol)
|
||||
if r.key != nil {
|
||||
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.key)
|
||||
}
|
||||
headers = fmt.Sprintf("%s\r\n", headers)
|
||||
|
||||
// 3. Build all
|
||||
raw := []byte(fmt.Sprintf("%s%s", responseLine, headers))
|
||||
combined := []byte(fmt.Sprintf("%s%s", responseLine, headers))
|
||||
|
||||
// 4. Write
|
||||
written, err := w.Write(raw)
|
||||
|
||||
return written, err
|
||||
written, err := w.Write(combined)
|
||||
return int64(written), err
|
||||
|
||||
}
|
||||
|
||||
// GetProtocol returns the choosen protocol if set, else nil
|
||||
func (r Response) GetProtocol() []byte {
|
||||
return r.protocol
|
||||
}
|
||||
|
||||
// GetStatusCode returns the response status code
|
||||
func (r Response) GetStatusCode() StatusCode {
|
||||
return r.code
|
||||
}
|
||||
|
|
|
@ -1,37 +1,37 @@
|
|||
package upgrade
|
||||
|
||||
// StatusCode maps the status codes (and description)
|
||||
type StatusCode uint16
|
||||
// StatusCode maps HTTP status codes (and description)
|
||||
type StatusCode int
|
||||
|
||||
const (
|
||||
// SwitchingProtocols - handshake success
|
||||
SwitchingProtocols StatusCode = 101
|
||||
// BadRequest - missing/malformed headers
|
||||
BadRequest StatusCode = 400
|
||||
// Forbidden - invalid origin policy, TLS required
|
||||
Forbidden StatusCode = 403
|
||||
// UpgradeRequired - invalid WS version
|
||||
UpgradeRequired StatusCode = 426
|
||||
// NotFound - unserved or invalid URI
|
||||
NotFound StatusCode = 404
|
||||
// Internal - custom error
|
||||
Internal StatusCode = 500
|
||||
// StatusSwitchingProtocols - handshake success
|
||||
StatusSwitchingProtocols StatusCode = 101
|
||||
// StatusBadRequest - missing/malformed headers
|
||||
StatusBadRequest StatusCode = 400
|
||||
// StatusForbidden - invalid origin policy, TLS required
|
||||
StatusForbidden StatusCode = 403
|
||||
// StatusUpgradeRequired - invalid WS version
|
||||
StatusUpgradeRequired StatusCode = 426
|
||||
// StatusNotFound - unserved or invalid URI
|
||||
StatusNotFound StatusCode = 404
|
||||
// StatusInternal - custom error
|
||||
StatusInternal StatusCode = 500
|
||||
)
|
||||
|
||||
// String implements the Stringer interface
|
||||
func (sc StatusCode) String() string {
|
||||
switch sc {
|
||||
case SwitchingProtocols:
|
||||
case StatusSwitchingProtocols:
|
||||
return "Switching Protocols"
|
||||
case BadRequest:
|
||||
case StatusBadRequest:
|
||||
return "Bad Request"
|
||||
case Forbidden:
|
||||
case StatusForbidden:
|
||||
return "Forbidden"
|
||||
case UpgradeRequired:
|
||||
case StatusUpgradeRequired:
|
||||
return "Upgrade Required"
|
||||
case NotFound:
|
||||
case StatusNotFound:
|
||||
return "Not Found"
|
||||
case Internal:
|
||||
case StatusInternal:
|
||||
return "Internal Server Error"
|
||||
default:
|
||||
return "Unknown Status Code"
|
||||
|
|
|
@ -39,18 +39,16 @@ type matcher struct {
|
|||
// Scheme represents an URI scheme
|
||||
type Scheme []*matcher
|
||||
|
||||
// FromString builds an URI scheme from a pattern string
|
||||
// FromString builds an URI scheme from a string pattern
|
||||
func FromString(s string) (*Scheme, error) {
|
||||
|
||||
// 1. Manage '/' at the start
|
||||
// handle '/' at the start
|
||||
if len(s) < 1 || s[0] != '/' {
|
||||
return nil, fmt.Errorf("invalid URI; must start with '/'")
|
||||
}
|
||||
|
||||
// 2. Split by '/'
|
||||
parts := strings.Split(s, "/")
|
||||
|
||||
// 3. Max exceeded
|
||||
// check max match size
|
||||
if len(parts)-2 > maxMatch {
|
||||
for i, p := range parts {
|
||||
fmt.Printf("%d: '%s'\n", i, p)
|
||||
|
@ -58,13 +56,11 @@ func FromString(s string) (*Scheme, error) {
|
|||
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
|
||||
}
|
||||
|
||||
// 4. Build for each part
|
||||
sch, err := buildScheme(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. Optimise structure
|
||||
opti, err := sch.optimise()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -74,91 +70,67 @@ func FromString(s string) (*Scheme, error) {
|
|||
|
||||
}
|
||||
|
||||
// Match returns if the given URI is matched by the scheme
|
||||
func (s Scheme) Match(str string) bool {
|
||||
|
||||
// 1. Nothing -> match all
|
||||
// Match returns whether the given URI is matched by the scheme
|
||||
func (s Scheme) Match(uri string) bool {
|
||||
if len(s) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// 2. Check for string match
|
||||
clearURI, match := s.matchString(str)
|
||||
// check for string match
|
||||
clearURI, match := s.matchString(uri)
|
||||
if !match {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3. Check for non-string match (wildcards)
|
||||
match = s.matchWildcards(clearURI)
|
||||
if !match {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
// check for non-string match (wildcards)
|
||||
return s.matchWildcards(clearURI)
|
||||
}
|
||||
|
||||
// GetMatch returns the indexed match (excluding string matchers)
|
||||
func (s Scheme) GetMatch(n uint8) ([]string, error) {
|
||||
|
||||
// 1. Index out of range
|
||||
if n > uint8(len(s)) {
|
||||
return nil, fmt.Errorf("index out of range")
|
||||
}
|
||||
|
||||
// 2. Iterate to find index (exclude strings)
|
||||
ni := -1
|
||||
// iterate to find index (exclude strings)
|
||||
matches := -1
|
||||
for _, m := range s {
|
||||
|
||||
// ignore strings
|
||||
if len(m.pat) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// increment match counter : ni
|
||||
ni++
|
||||
matches++
|
||||
|
||||
// if expected index -> return matches
|
||||
if uint8(ni) == n {
|
||||
// expected index -> return matches
|
||||
if uint8(matches) == n {
|
||||
return m.buf, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 3. If nothing found -> return empty set
|
||||
return nil, fmt.Errorf("index out of range (max: %d)", ni)
|
||||
// nothing found -> return empty set
|
||||
return nil, fmt.Errorf("index out of range (max: %d)", matches)
|
||||
|
||||
}
|
||||
|
||||
// GetAllMatch returns all the indexed match (excluding string matchers)
|
||||
func (s Scheme) GetAllMatch() [][]string {
|
||||
|
||||
match := make([][]string, 0, len(s))
|
||||
|
||||
for _, m := range s {
|
||||
|
||||
// ignore strings
|
||||
if len(m.pat) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
match = append(match, m.buf)
|
||||
|
||||
}
|
||||
|
||||
return match
|
||||
|
||||
}
|
||||
|
||||
// buildScheme builds a 'basic' scheme
|
||||
// from a pattern string
|
||||
func buildScheme(ss []string) (Scheme, error) {
|
||||
|
||||
// 1. Build scheme
|
||||
sch := make(Scheme, 0, maxMatch)
|
||||
|
||||
for _, s := range ss {
|
||||
|
||||
// 2. ignore empty
|
||||
if len(s) == 0 {
|
||||
continue
|
||||
}
|
||||
|
@ -167,31 +139,31 @@ func buildScheme(ss []string) (Scheme, error) {
|
|||
|
||||
switch s {
|
||||
|
||||
// 3. Card: 0, N
|
||||
// card: 0, N
|
||||
case "**":
|
||||
m.req = false
|
||||
m.mul = true
|
||||
sch = append(sch, m)
|
||||
|
||||
// 4. Card: 1, N
|
||||
// card: 1, N
|
||||
case "..":
|
||||
m.req = true
|
||||
m.mul = true
|
||||
sch = append(sch, m)
|
||||
|
||||
// 5. Card: 0, 1
|
||||
// card: 0, 1
|
||||
case "*":
|
||||
m.req = false
|
||||
m.mul = false
|
||||
sch = append(sch, m)
|
||||
|
||||
// 6. Card: 1
|
||||
// card: 1
|
||||
case ".":
|
||||
m.req = true
|
||||
m.mul = false
|
||||
sch = append(sch, m)
|
||||
|
||||
// 7. Card: 1, literal string
|
||||
// card: 1, literal string
|
||||
default:
|
||||
m.req = true
|
||||
m.mul = false
|
||||
|
@ -207,30 +179,26 @@ func buildScheme(ss []string) (Scheme, error) {
|
|||
|
||||
// optimise optimised the scheme for further parsing
|
||||
func (s Scheme) optimise() (Scheme, error) {
|
||||
|
||||
// 1. Nothing to do if only 1 element
|
||||
if len(s) <= 1 {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// 2. Init reshifted scheme
|
||||
// init reshifted scheme
|
||||
rshift := make(Scheme, 0, maxMatch)
|
||||
rshift = append(rshift, s[0])
|
||||
|
||||
// 2. Iterate over matchers
|
||||
// iterate over matchers
|
||||
for p, i, l := 0, 1, len(s); i < l; i++ {
|
||||
|
||||
pre, cur := s[p], s[i]
|
||||
|
||||
/* Merge: 2 following literals */
|
||||
// merge: 2 following literals
|
||||
if len(pre.pat) > 0 && len(cur.pat) > 0 {
|
||||
|
||||
// merge strings into previous
|
||||
pre.pat = fmt.Sprintf("%s%s", pre.pat, cur.pat)
|
||||
|
||||
// delete current
|
||||
s[i] = nil
|
||||
|
||||
}
|
||||
|
||||
// increment previous (only if current is not nul)
|
||||
|
@ -242,67 +210,65 @@ func (s Scheme) optimise() (Scheme, error) {
|
|||
}
|
||||
|
||||
return rshift, nil
|
||||
|
||||
}
|
||||
|
||||
// matchString checks the STRING matchers from an URI
|
||||
// it returns a boolean : false when not matching, true eitherway
|
||||
// it returns a cleared uri, without STRING data
|
||||
// - returns a boolean : false when not matching, true eitherway
|
||||
// - returns a cleared uri, without STRING data
|
||||
func (s Scheme) matchString(uri string) (string, bool) {
|
||||
|
||||
// 1. Initialise variables
|
||||
clr := uri // contains cleared input string
|
||||
minOff := 0 // minimum offset
|
||||
var (
|
||||
clearedInput = uri
|
||||
minOffset = 0
|
||||
)
|
||||
|
||||
// 2. Iterate over strings
|
||||
for _, m := range s {
|
||||
|
||||
ls := len(m.pat)
|
||||
|
||||
// {1} If not STRING matcher -> ignore //
|
||||
// ignore no STRING match
|
||||
if ls == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// {2} Get offset in URI (else -1) //
|
||||
off := strings.Index(clr, m.pat)
|
||||
// get offset in URI (else -1)
|
||||
off := strings.Index(clearedInput, m.pat)
|
||||
if off < 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// {3} Fail on invalid offset range //
|
||||
if off < minOff {
|
||||
// fail on invalid offset range
|
||||
if off < minOffset {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// {4} Check for trailing '/' //
|
||||
// check for trailing '/'
|
||||
hasSlash := 0
|
||||
if off+ls < len(clr) && clr[off+ls] == '/' {
|
||||
if off+ls < len(clearedInput) && clearedInput[off+ls] == '/' {
|
||||
hasSlash = 1
|
||||
}
|
||||
|
||||
// {5} Remove the current string (+trailing slash) from the URI //
|
||||
beg, end := clr[:off], clr[off+ls+hasSlash:]
|
||||
clr = fmt.Sprintf("%s\a/%s", beg, end) // separate matches by '\a' character
|
||||
// remove the current string (+trailing slash) from the URI
|
||||
beg, end := clearedInput[:off], clearedInput[off+ls+hasSlash:]
|
||||
clearedInput = fmt.Sprintf("%s\a/%s", beg, end) // separate matches with a '\a' character
|
||||
|
||||
// {6} Update offset range //
|
||||
minOff = len(beg) + 2 - 1 // +2 slash separators
|
||||
// update offset range
|
||||
// +2 slash separators
|
||||
// -1 because strings begin with 1 slash already
|
||||
minOffset = len(beg) + 2 - 1
|
||||
|
||||
}
|
||||
|
||||
// 3. If exists, remove trailing '/'
|
||||
if clr[len(clr)-1] == '/' {
|
||||
clr = clr[:len(clr)-1]
|
||||
// if exists, remove trailing '/'
|
||||
if clearedInput[len(clearedInput)-1] == '/' {
|
||||
clearedInput = clearedInput[:len(clearedInput)-1]
|
||||
}
|
||||
|
||||
// 4. If exists, remove trailing '\a'
|
||||
if clr[len(clr)-1] == '\a' {
|
||||
clr = clr[:len(clr)-1]
|
||||
// if exists, remove trailing '\a'
|
||||
if clearedInput[len(clearedInput)-1] == '\a' {
|
||||
clearedInput = clearedInput[:len(clearedInput)-1]
|
||||
}
|
||||
|
||||
return clr, true
|
||||
|
||||
return clearedInput, true
|
||||
}
|
||||
|
||||
// matchWildcards check the WILCARDS (non-string) matchers from
|
||||
|
@ -310,7 +276,7 @@ func (s Scheme) matchString(uri string) (string, bool) {
|
|||
// + it sets the matchers buffers for later extraction
|
||||
func (s Scheme) matchWildcards(clear string) bool {
|
||||
|
||||
// 1. Extract wildcards (ref)
|
||||
// extract wildcards (ref)
|
||||
wildcards := make(Scheme, 0, maxMatch)
|
||||
|
||||
for _, m := range s {
|
||||
|
@ -320,41 +286,34 @@ func (s Scheme) matchWildcards(clear string) bool {
|
|||
}
|
||||
}
|
||||
|
||||
// 2. If no wildcards -> match
|
||||
if len(wildcards) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// 3. Break uri by '\a' characters
|
||||
// break uri by '\a' characters
|
||||
matches := strings.Split(clear, "\a")[1:]
|
||||
|
||||
// 4. Iterate over matches
|
||||
for n, match := range matches {
|
||||
|
||||
// {1} If no more matcher //
|
||||
// no more matcher
|
||||
if n >= len(wildcards) {
|
||||
return false
|
||||
}
|
||||
|
||||
// {2} Split by '/' //
|
||||
data := strings.Split(match, "/")[1:] // from index 1 because it begins with '/'
|
||||
// from index 1 because it begins with '/'
|
||||
data := strings.Split(match, "/")[1:]
|
||||
|
||||
// {3} If required and missing //
|
||||
// missing required
|
||||
if wildcards[n].req && len(data) < 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// {4} If not multi but got multi //
|
||||
// if not multi but got multi
|
||||
if !wildcards[n].mul && len(data) > 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// {5} Store data into matcher //
|
||||
wildcards[n].buf = data
|
||||
|
||||
}
|
||||
|
||||
// 5. Match
|
||||
return true
|
||||
|
||||
}
|
||||
|
|
114
message.go
114
message.go
|
@ -84,36 +84,40 @@ type Message struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
// receive reads a message form reader
|
||||
func readMessage(reader io.Reader) (*Message, error) {
|
||||
// ReadFrom reads a message from a reader
|
||||
//
|
||||
// implements io.ReaderFrom
|
||||
func (m *Message) ReadFrom(reader io.Reader) (int64, error) {
|
||||
var (
|
||||
read int64
|
||||
err error
|
||||
tmpBuf []byte
|
||||
mask []byte
|
||||
cursor int
|
||||
)
|
||||
|
||||
var err error
|
||||
var tmpBuf []byte
|
||||
var mask []byte
|
||||
var cursor int
|
||||
|
||||
m := &Message{}
|
||||
|
||||
// 2. Byte 1: FIN and OpCode
|
||||
// byte 1: FIN and OpCode
|
||||
tmpBuf = make([]byte, 1)
|
||||
read += int64(len(tmpBuf))
|
||||
err = readBytes(reader, tmpBuf)
|
||||
if err != nil {
|
||||
return m, err
|
||||
return read, err
|
||||
}
|
||||
|
||||
// check reserved bits
|
||||
if tmpBuf[0]&0x70 != 0 {
|
||||
return m, ErrReservedBits
|
||||
return read, ErrReservedBits
|
||||
}
|
||||
|
||||
m.Final = bool(tmpBuf[0]&0x80 == 0x80)
|
||||
m.Type = MessageType(tmpBuf[0] & 0x0f)
|
||||
|
||||
// 3. Byte 2: Mask and Length[0]
|
||||
// byte 2: mask and length[0]
|
||||
tmpBuf = make([]byte, 1)
|
||||
read += int64(len(tmpBuf))
|
||||
err = readBytes(reader, tmpBuf)
|
||||
if err != nil {
|
||||
return m, err
|
||||
return read, err
|
||||
}
|
||||
|
||||
// if mask, byte array not nil
|
||||
|
@ -124,71 +128,63 @@ func readMessage(reader io.Reader) (*Message, error) {
|
|||
// payload length
|
||||
m.Size = uint(tmpBuf[0] & 0x7f)
|
||||
|
||||
// 4. Extended payload
|
||||
// extended payload
|
||||
if m.Size == 127 {
|
||||
|
||||
tmpBuf = make([]byte, 8)
|
||||
read += int64(len(tmpBuf))
|
||||
err := readBytes(reader, tmpBuf)
|
||||
if err != nil {
|
||||
return m, err
|
||||
return read, err
|
||||
}
|
||||
|
||||
m.Size = uint(binary.BigEndian.Uint64(tmpBuf))
|
||||
|
||||
} else if m.Size == 126 {
|
||||
|
||||
tmpBuf = make([]byte, 2)
|
||||
read += int64(len(tmpBuf))
|
||||
err := readBytes(reader, tmpBuf)
|
||||
if err != nil {
|
||||
return m, err
|
||||
return read, err
|
||||
}
|
||||
|
||||
m.Size = uint(binary.BigEndian.Uint16(tmpBuf))
|
||||
|
||||
}
|
||||
|
||||
// 5. Masking key
|
||||
// masking key
|
||||
if mask != nil {
|
||||
|
||||
tmpBuf = make([]byte, 4)
|
||||
read += int64(len(tmpBuf))
|
||||
err := readBytes(reader, tmpBuf)
|
||||
if err != nil {
|
||||
return m, err
|
||||
return read, err
|
||||
}
|
||||
|
||||
mask = make([]byte, 4)
|
||||
copy(mask, tmpBuf)
|
||||
|
||||
}
|
||||
|
||||
// 6. Read payload by chunks
|
||||
// read payload by chunks
|
||||
m.Data = make([]byte, int(m.Size))
|
||||
|
||||
cursor = 0
|
||||
|
||||
// {1} While we have data to read //
|
||||
// while data to read
|
||||
for uint(cursor) < m.Size {
|
||||
|
||||
// {2} Try to read (at least 1 byte) //
|
||||
// try to read (at least 1 byte)
|
||||
nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1)
|
||||
if err != nil {
|
||||
return m, err
|
||||
return read + int64(cursor) + int64(nbread), err
|
||||
}
|
||||
|
||||
// {3} Unmask data //
|
||||
// unmask data //
|
||||
if mask != nil {
|
||||
for i, l := cursor, cursor+nbread; i < l; i++ {
|
||||
|
||||
mi := i % 4 // mask index
|
||||
m.Data[i] = m.Data[i] ^ mask[mi]
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// {4} Update cursor //
|
||||
cursor += nbread
|
||||
|
||||
}
|
||||
read += int64(cursor)
|
||||
|
||||
// return error if unmasked frame
|
||||
// we have to fully read it for read buffer to be clean
|
||||
|
@ -197,13 +193,14 @@ func readMessage(reader io.Reader) (*Message, error) {
|
|||
err = ErrUnmaskedFrame
|
||||
}
|
||||
|
||||
return m, err
|
||||
return read, err
|
||||
|
||||
}
|
||||
|
||||
// Send sends a frame over a socket
|
||||
func (m Message) Send(writer io.Writer) error {
|
||||
|
||||
// WriteTo writes a message frame over a socket
|
||||
//
|
||||
// implements io.WriterTo
|
||||
func (m Message) WriteTo(writer io.Writer) (int64, error) {
|
||||
header := make([]byte, 0, maximumHeaderSize)
|
||||
|
||||
// fix size
|
||||
|
@ -211,20 +208,18 @@ func (m Message) Send(writer io.Writer) error {
|
|||
m.Size = uint(len(m.Data))
|
||||
}
|
||||
|
||||
// 1. Byte 0 : FIN + opcode
|
||||
// byte 0 : FIN + opcode
|
||||
var final byte = 0x80
|
||||
if !m.Final {
|
||||
final = 0
|
||||
}
|
||||
header = append(header, final|byte(m.Type))
|
||||
|
||||
// 2. Get payload length
|
||||
// get payload length
|
||||
if m.Size < 126 { // simple
|
||||
|
||||
header = append(header, byte(m.Size))
|
||||
|
||||
} else if m.Size <= 0xffff { // extended: 16 bits
|
||||
|
||||
header = append(header, 126)
|
||||
|
||||
buf := make([]byte, 2)
|
||||
|
@ -232,7 +227,6 @@ func (m Message) Send(writer io.Writer) error {
|
|||
header = append(header, buf...)
|
||||
|
||||
} else if m.Size <= 0xffffffffffffffff { // extended: 64 bits
|
||||
|
||||
header = append(header, 127)
|
||||
|
||||
buf := make([]byte, 8)
|
||||
|
@ -241,16 +235,15 @@ func (m Message) Send(writer io.Writer) error {
|
|||
|
||||
}
|
||||
|
||||
// 3. Build write buffer
|
||||
// build write buffer
|
||||
writeBuf := make([]byte, 0, len(header)+int(m.Size))
|
||||
writeBuf = append(writeBuf, header...)
|
||||
writeBuf = append(writeBuf, m.Data[0:m.Size]...)
|
||||
|
||||
// 4. Send over socket by chunks
|
||||
// write by chunks
|
||||
toWrite := len(header) + int(m.Size)
|
||||
cursor := 0
|
||||
for cursor < toWrite {
|
||||
|
||||
maxBoundary := cursor + maxWriteChunk
|
||||
if maxBoundary > toWrite {
|
||||
maxBoundary = toWrite
|
||||
|
@ -259,34 +252,32 @@ func (m Message) Send(writer io.Writer) error {
|
|||
// Try to wrote (at max 1024 bytes) //
|
||||
nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary])
|
||||
if err != nil {
|
||||
return err
|
||||
return int64(nbwritten), err
|
||||
}
|
||||
|
||||
// Update cursor //
|
||||
cursor += nbwritten
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
return int64(cursor), nil
|
||||
}
|
||||
|
||||
// Check for message errors with:
|
||||
// check for message errors with:
|
||||
// (m) the current message
|
||||
// (fragment) whether there is a fragment in construction
|
||||
// returns the message error
|
||||
func (m *Message) check(fragment bool) error {
|
||||
|
||||
// 1. Invalid first fragment (not TEXT nor BINARY)
|
||||
// invalid first fragment (not TEXT nor BINARY)
|
||||
if !m.Final && !fragment && m.Type != Text && m.Type != Binary {
|
||||
return ErrInvalidFragment
|
||||
}
|
||||
|
||||
// 2. Waiting fragment but received standalone frame
|
||||
// waiting fragment but received standalone frame
|
||||
if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong {
|
||||
return ErrInvalidFragment
|
||||
}
|
||||
|
||||
// 3. Control frame too long
|
||||
// control frame too long
|
||||
if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) {
|
||||
return ErrTooLongControlFrame
|
||||
}
|
||||
|
@ -347,20 +338,19 @@ func (m *Message) check(fragment bool) error {
|
|||
//
|
||||
// It manages connections which chunks data
|
||||
func readBytes(reader io.Reader, buffer []byte) error {
|
||||
var (
|
||||
cur = 0
|
||||
len = len(buffer)
|
||||
)
|
||||
|
||||
var cur, len int = 0, len(buffer)
|
||||
|
||||
// try to read until the full size is read
|
||||
// read until the full size is read
|
||||
for cur < len {
|
||||
|
||||
nbread, err := reader.Read(buffer[cur:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cur += nbread
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
|
|
@ -67,11 +67,12 @@ func TestSimpleMessageReading(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
var (
|
||||
reader = bytes.NewBuffer(tc.ReadBuffer)
|
||||
msg = &Message{}
|
||||
)
|
||||
|
||||
reader := bytes.NewBuffer(tc.ReadBuffer)
|
||||
|
||||
got, err := readMessage(reader)
|
||||
|
||||
_, err := msg.ReadFrom(reader)
|
||||
if err != tc.Err {
|
||||
t.Errorf("Expected %v error, got %v", tc.Err, err)
|
||||
}
|
||||
|
@ -82,23 +83,23 @@ func TestSimpleMessageReading(t *testing.T) {
|
|||
}
|
||||
|
||||
// check FIN
|
||||
if got.Final != tc.Expected.Final {
|
||||
t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, got.Final)
|
||||
if msg.Final != tc.Expected.Final {
|
||||
t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, msg.Final)
|
||||
}
|
||||
|
||||
// check OpCode
|
||||
if got.Type != tc.Expected.Type {
|
||||
t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, got.Type)
|
||||
if msg.Type != tc.Expected.Type {
|
||||
t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, msg.Type)
|
||||
}
|
||||
|
||||
// check Size
|
||||
if got.Size != tc.Expected.Size {
|
||||
t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, got.Size)
|
||||
if msg.Size != tc.Expected.Size {
|
||||
t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, msg.Size)
|
||||
}
|
||||
|
||||
// check Data
|
||||
if string(got.Data) != string(tc.Expected.Data) {
|
||||
t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, got.Data)
|
||||
if string(msg.Data) != string(tc.Expected.Data) {
|
||||
t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, msg.Data)
|
||||
}
|
||||
|
||||
})
|
||||
|
@ -177,17 +178,15 @@ func TestReadEOF(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
|
||||
reader := bytes.NewBuffer(tc.ReadBuffer)
|
||||
|
||||
got, err := readMessage(reader)
|
||||
|
||||
var (
|
||||
reader = bytes.NewBuffer(tc.ReadBuffer)
|
||||
msg = &Message{}
|
||||
)
|
||||
_, err := msg.ReadFrom(reader)
|
||||
if tc.eof {
|
||||
|
||||
if err != io.EOF {
|
||||
t.Errorf("Expected EOF, got %v", err)
|
||||
t.Fatalf("Expected EOF, got %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -195,8 +194,8 @@ func TestReadEOF(t *testing.T) {
|
|||
t.Errorf("Expected UnmaskedFrameor, got %v", err)
|
||||
}
|
||||
|
||||
if got.Size != 0x00 {
|
||||
t.Errorf("Expected a size of 0, got %d", got.Size)
|
||||
if msg.Size != 0x00 {
|
||||
t.Errorf("Expected a size of 0, got %d", msg.Size)
|
||||
}
|
||||
|
||||
})
|
||||
|
@ -269,8 +268,7 @@ func TestSimpleMessageSending(t *testing.T) {
|
|||
|
||||
writer := &bytes.Buffer{}
|
||||
|
||||
err := tc.Base.Send(writer)
|
||||
|
||||
_, err := tc.Base.WriteTo(writer)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
return
|
||||
|
|
53
server.go
53
server.go
|
@ -27,9 +27,8 @@ type Server struct {
|
|||
ch serverChannelSet
|
||||
}
|
||||
|
||||
// CreateServer for a specific HOST and PORT
|
||||
func CreateServer(host string, port uint16) *Server {
|
||||
|
||||
// NewServer creates a server
|
||||
func NewServer(host string, port uint16) *Server {
|
||||
return &Server{
|
||||
addr: []byte(host),
|
||||
port: port,
|
||||
|
@ -47,116 +46,84 @@ func CreateServer(host string, port uint16) *Server {
|
|||
broadcast: make(chan Message, 1),
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// BindDefault binds a default controller
|
||||
// it will be called if the URI does not
|
||||
// match another controller
|
||||
func (s *Server) BindDefault(f ControllerFunc) {
|
||||
|
||||
s.ctl.Def = &Controller{
|
||||
URI: nil,
|
||||
Fun: f,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Bind a controller to an URI scheme
|
||||
func (s *Server) Bind(uriStr string, f ControllerFunc) error {
|
||||
|
||||
// 1. Build URI parser
|
||||
uriScheme, err := uri.FromString(uriStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot build URI: %w", err)
|
||||
}
|
||||
|
||||
// 2. Create controller
|
||||
s.ctl.URI = append(s.ctl.URI, &Controller{
|
||||
URI: uriScheme,
|
||||
Fun: f,
|
||||
})
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Launch the websocket server
|
||||
func (s *Server) Launch() error {
|
||||
var (
|
||||
err error
|
||||
url = fmt.Sprintf("%s:%d", s.addr, s.port)
|
||||
)
|
||||
|
||||
var err error
|
||||
|
||||
/* (1) Listen socket
|
||||
---------------------------------------------------------*/
|
||||
// 1. Build full url
|
||||
url := fmt.Sprintf("%s:%d", s.addr, s.port)
|
||||
|
||||
// 2. Bind socket to listen
|
||||
s.sock, err = net.Listen("tcp", url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
|
||||
defer s.sock.Close()
|
||||
|
||||
fmt.Printf("+ listening on %s\n", url)
|
||||
go s.schedule()
|
||||
|
||||
// 3. Launch scheduler
|
||||
go s.scheduler()
|
||||
|
||||
/* (2) For each incoming connection (client)
|
||||
---------------------------------------------------------*/
|
||||
for {
|
||||
|
||||
// 1. Wait for client
|
||||
sock, err := s.sock.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
||||
// 2. Try to create client
|
||||
cli, err := buildClient(sock, s.ctl, s.ch)
|
||||
cli, err := newClient(sock, s.ctl, s.ch)
|
||||
if err != nil {
|
||||
fmt.Printf(" - %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Register client
|
||||
s.ch.register <- cli
|
||||
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Scheduler schedules clients registration and broadcast
|
||||
func (s *Server) scheduler() {
|
||||
|
||||
// schedule client registration and broadcast
|
||||
func (s *Server) schedule() {
|
||||
for {
|
||||
|
||||
select {
|
||||
|
||||
// 1. Create client
|
||||
case client := <-s.ch.register:
|
||||
s.clients[client.io.sock] = client
|
||||
|
||||
// 2. Remove client
|
||||
case client := <-s.ch.unregister:
|
||||
delete(s.clients, client.io.sock)
|
||||
|
||||
// 3. Broadcast
|
||||
case msg := <-s.ch.broadcast:
|
||||
for _, c := range s.clients {
|
||||
c.ch.send <- msg
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue