Compare commits
No commits in common. "master" and "ft.buffer.optimization.1" have entirely different histories.
master
...
ft.buffer.
292
client.go
292
client.go
|
@ -1,292 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.xdrm.io/go/ws/internal/http/upgrade"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Represents a client socket utility (reader, writer, ..)
|
|
||||||
type clientIO struct {
|
|
||||||
sock net.Conn
|
|
||||||
reader *bufio.Reader
|
|
||||||
kill chan<- *client // unregisters client
|
|
||||||
closing bool
|
|
||||||
closingMu sync.Mutex
|
|
||||||
reading sync.WaitGroup
|
|
||||||
writing bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Represents all channels that need a client
|
|
||||||
type clientChannelSet struct {
|
|
||||||
receive chan Message
|
|
||||||
send chan Message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Represents a websocket client
|
|
||||||
type client struct {
|
|
||||||
io clientIO
|
|
||||||
iface *Client
|
|
||||||
ch clientChannelSet
|
|
||||||
status MessageError // close status ; 0 = nothing ; else -> must close
|
|
||||||
}
|
|
||||||
|
|
||||||
// newClient creates a new client
|
|
||||||
func newClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) {
|
|
||||||
req := &upgrade.Request{}
|
|
||||||
_, err := req.ReadFrom(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("request read: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
res := req.BuildResponse()
|
|
||||||
|
|
||||||
_, err = res.WriteTo(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("upgrade write: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode != 101 {
|
|
||||||
s.Close()
|
|
||||||
return nil, fmt.Errorf("upgrade failed (HTTP %d)", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cli = &client{
|
|
||||||
io: clientIO{
|
|
||||||
sock: s,
|
|
||||||
reader: bufio.NewReader(s),
|
|
||||||
kill: serverCh.unregister,
|
|
||||||
},
|
|
||||||
|
|
||||||
iface: &Client{
|
|
||||||
Protocol: string(res.Protocol),
|
|
||||||
Arguments: [][]string{{req.URI()}},
|
|
||||||
},
|
|
||||||
|
|
||||||
ch: clientChannelSet{
|
|
||||||
receive: make(chan Message, 1),
|
|
||||||
send: make(chan Message, 1),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// find controller by URI
|
|
||||||
controller, arguments := ctl.Match(req.URI())
|
|
||||||
if controller == nil {
|
|
||||||
return nil, fmt.Errorf("no controller found, no default controller set")
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy args
|
|
||||||
cli.iface.Arguments = arguments
|
|
||||||
|
|
||||||
go controller.Fun(
|
|
||||||
cli.iface, // pass the client
|
|
||||||
cli.ch.receive, // the receiver
|
|
||||||
cli.ch.send, // the sender
|
|
||||||
serverCh.broadcast, // broadcast sender
|
|
||||||
)
|
|
||||||
go clientReader(cli)
|
|
||||||
go clientWriter(cli)
|
|
||||||
return cli, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientReader reads and parses messages from the buffer
|
|
||||||
func clientReader(c *client) {
|
|
||||||
var (
|
|
||||||
frag *Message
|
|
||||||
closeStatus = Normal
|
|
||||||
clientAck = true
|
|
||||||
)
|
|
||||||
|
|
||||||
c.io.reading.Add(1)
|
|
||||||
|
|
||||||
for {
|
|
||||||
// currently closing -> exit
|
|
||||||
if c.io.closing {
|
|
||||||
fmt.Printf("[reader] killed because closing")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse message
|
|
||||||
var msg = &Message{}
|
|
||||||
_, err := msg.ReadFrom(c.io.reader)
|
|
||||||
if err == ErrUnmaskedFrame || err == ErrReservedBits {
|
|
||||||
closeStatus = ProtocolError
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// invalid message
|
|
||||||
msgErr := msg.check(frag != nil)
|
|
||||||
if msgErr != nil {
|
|
||||||
|
|
||||||
mustClose := false
|
|
||||||
|
|
||||||
switch msgErr {
|
|
||||||
|
|
||||||
// fail
|
|
||||||
case ErrUnexpectedContinuation:
|
|
||||||
closeStatus = None
|
|
||||||
clientAck = false
|
|
||||||
mustClose = true
|
|
||||||
|
|
||||||
// proper close
|
|
||||||
case ErrCloseFrame:
|
|
||||||
closeStatus = Normal
|
|
||||||
clientAck = true
|
|
||||||
mustClose = true
|
|
||||||
|
|
||||||
// invalid payload proper close
|
|
||||||
case ErrInvalidPayload:
|
|
||||||
closeStatus = InvalidPayload
|
|
||||||
clientAck = true
|
|
||||||
mustClose = true
|
|
||||||
|
|
||||||
// any other error -> protocol error
|
|
||||||
default:
|
|
||||||
closeStatus = ProtocolError
|
|
||||||
clientAck = true
|
|
||||||
mustClose = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if mustClose {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// ping <-> Pong
|
|
||||||
if msg.Type == Ping && c.io.writing {
|
|
||||||
msg.Final = true
|
|
||||||
msg.Type = Pong
|
|
||||||
c.ch.send <- *msg
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// store first fragment
|
|
||||||
if frag == nil && !msg.Final {
|
|
||||||
frag = &Message{
|
|
||||||
Type: msg.Type,
|
|
||||||
Final: msg.Final,
|
|
||||||
Data: msg.Data,
|
|
||||||
Size: msg.Size,
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// store fragments
|
|
||||||
if frag != nil {
|
|
||||||
frag.Final = msg.Final
|
|
||||||
frag.Size += msg.Size
|
|
||||||
frag.Data = append(frag.Data, msg.Data...)
|
|
||||||
|
|
||||||
if !frag.Final { // continue if not last fragment
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// check message errors
|
|
||||||
fragErr := frag.check(false)
|
|
||||||
if fragErr == ErrInvalidPayload {
|
|
||||||
closeStatus = InvalidPayload
|
|
||||||
break
|
|
||||||
} else if fragErr != nil {
|
|
||||||
closeStatus = ProtocolError
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
msg = frag
|
|
||||||
frag = nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatch to receiver
|
|
||||||
if msg.Type == Text || msg.Type == Binary {
|
|
||||||
c.ch.receive <- *msg
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
close(c.ch.receive)
|
|
||||||
c.io.reading.Done()
|
|
||||||
|
|
||||||
// close channel (if not already done)
|
|
||||||
// fmt.Printf("[reader] end\n")
|
|
||||||
c.close(closeStatus, clientAck)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientWriter writes to the websocket connection and is triggered by
|
|
||||||
// client.ch.send channel
|
|
||||||
func clientWriter(c *client) {
|
|
||||||
c.io.writing = true // if channel still exists
|
|
||||||
|
|
||||||
for msg := range c.ch.send {
|
|
||||||
_, err := msg.WriteTo(c.io.sock)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf(" [writer] %s\n", err)
|
|
||||||
c.io.writing = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.io.writing = false
|
|
||||||
|
|
||||||
// close channel (if not already done)
|
|
||||||
// fmt.Printf("[writer] end\n")
|
|
||||||
c.close(Normal, true)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// close the connection
|
|
||||||
// send CLOSE frame is 'status' is not NONE
|
|
||||||
// wait for the next message (CLOSE acknowledge) if 'clientACK'
|
|
||||||
// then delete client
|
|
||||||
func (c *client) close(status MessageError, clientACK bool) {
|
|
||||||
// fail if already closing
|
|
||||||
alreadyClosing := false
|
|
||||||
c.io.closingMu.Lock()
|
|
||||||
alreadyClosing = c.io.closing
|
|
||||||
c.io.closing = true
|
|
||||||
c.io.closingMu.Unlock()
|
|
||||||
if alreadyClosing {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// kill writer' if still running
|
|
||||||
if c.io.writing {
|
|
||||||
close(c.ch.send)
|
|
||||||
}
|
|
||||||
|
|
||||||
// kill reader if still running
|
|
||||||
c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
|
|
||||||
c.io.reading.Wait()
|
|
||||||
|
|
||||||
if status != None {
|
|
||||||
msg := &Message{
|
|
||||||
Final: true,
|
|
||||||
Type: Close,
|
|
||||||
Size: 2,
|
|
||||||
Data: make([]byte, 2),
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint16(msg.Data, uint16(status))
|
|
||||||
|
|
||||||
msg.WriteTo(c.io.sock)
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for client CLOSE if needed
|
|
||||||
if clientACK {
|
|
||||||
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
|
|
||||||
var tmpMsg = &Message{}
|
|
||||||
tmpMsg.ReadFrom(c.io.reader)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.io.sock.Close()
|
|
||||||
// fmt.Printf("[close] socket closed\n")
|
|
||||||
|
|
||||||
c.io.kill <- c
|
|
||||||
}
|
|
|
@ -1,55 +1,57 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"git.xdrm.io/gws/ws"
|
||||||
"time"
|
"time"
|
||||||
|
"fmt"
|
||||||
ws "git.xdrm.io/go/ws"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
func main(){
|
func main(){
|
||||||
|
|
||||||
startTime := time.Now().UnixNano()
|
startTime := time.Now().UnixNano()
|
||||||
|
|
||||||
// creqte WebSocket server
|
/* (1) Bind WebSocket server */
|
||||||
serv := ws.NewServer("0.0.0.0", 4444)
|
serv := ws.CreateServer("0.0.0.0", 4444)
|
||||||
|
|
||||||
// bind default controller
|
/* (2) Bind default controller */
|
||||||
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- *ws.Message, bc chan<- *ws.Message){
|
||||||
defer func() {
|
|
||||||
if recover() != nil {
|
|
||||||
fmt.Printf("*** PANIC\n")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for msg := range receiver {
|
for msg := range receiver {
|
||||||
|
|
||||||
// if receive message -> send it back
|
// if receive message -> send it back
|
||||||
sender <- msg
|
sender <- &msg
|
||||||
// close(sender)
|
sender <- nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// bnd to URI
|
/* (3) Bind to URI */
|
||||||
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- *ws.Message, bc chan<- *ws.Message){
|
||||||
|
|
||||||
fmt.Printf("[uri] connected\n")
|
fmt.Printf("[uri] connected\n")
|
||||||
|
|
||||||
for msg := range receiver{
|
for msg := range receiver{
|
||||||
|
|
||||||
fmt.Printf("[uri] received '%s'\n", msg.Data)
|
fmt.Printf("[uri] received '%s'\n", msg.Data)
|
||||||
sender <- msg
|
sender <- &msg
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("[uri] unexpectedly closed\n")
|
fmt.Printf("[uri] unexpectedly closed\n")
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// launch the server
|
})
|
||||||
|
if err != nil { panic(err) }
|
||||||
|
|
||||||
|
/* (4) Launch the server */
|
||||||
err = serv.Launch()
|
err = serv.Launch()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[ERROR] %s\n", err)
|
fmt.Printf("[ERROR] %s\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3)
|
fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3)
|
||||||
|
|
||||||
}
|
}
|
|
@ -1,48 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import "git.xdrm.io/go/ws/internal/uri"
|
|
||||||
|
|
||||||
// Client contains available information about a client
|
|
||||||
type Client struct {
|
|
||||||
Protocol string // choosen protocol (Sec-WebSocket-Protocol)
|
|
||||||
Arguments [][]string // URI parameters, index 0 is full URI, then matching groups
|
|
||||||
Store interface{} // store (for client implementation-specific data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ControllerFunc is a websocket controller callback function
|
|
||||||
type ControllerFunc func(*Client, <-chan Message, chan<- Message, chan<- Message)
|
|
||||||
|
|
||||||
// Controller is a websocket controller
|
|
||||||
type Controller struct {
|
|
||||||
URI *uri.Scheme // uri scheme
|
|
||||||
Fun ControllerFunc // controller function
|
|
||||||
}
|
|
||||||
|
|
||||||
// ControllerSet contains a set of controllers
|
|
||||||
type ControllerSet struct {
|
|
||||||
Def *Controller // default controller
|
|
||||||
URI []*Controller // uri controllers
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match finds a controller for a given URI
|
|
||||||
// also it returns the matching string patterns
|
|
||||||
func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
|
|
||||||
arguments := [][]string{{uri}}
|
|
||||||
|
|
||||||
for _, c := range s.URI {
|
|
||||||
if c.URI.Match(uri) {
|
|
||||||
match := c.URI.GetAllMatch()
|
|
||||||
arguments = append(arguments, match...)
|
|
||||||
return c, arguments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fallback to default
|
|
||||||
if s.Def != nil {
|
|
||||||
return s.Def, arguments
|
|
||||||
}
|
|
||||||
|
|
||||||
// no default
|
|
||||||
return nil, arguments
|
|
||||||
|
|
||||||
}
|
|
|
@ -6,43 +6,48 @@ package reader
|
||||||
// the golang standard library
|
// the golang standard library
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"bufio"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Maximum line length
|
// Maximum line length
|
||||||
const maxLineLength = 4096
|
var maxLineLength = 4096
|
||||||
|
|
||||||
// ChunkReader struct
|
// Chunk reader
|
||||||
type ChunkReader struct {
|
type chunkReader struct {
|
||||||
reader *bufio.Reader // the reader
|
reader *bufio.Reader // the reader
|
||||||
isEnded bool // If we are done (2 consecutive CRLF)
|
isEnded bool // If we are done (2 consecutive CRLF)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReader creates a new reader
|
|
||||||
func NewReader(r io.Reader) *ChunkReader {
|
// New creates a new reader
|
||||||
|
func NewReader(r io.Reader) (reader *chunkReader) {
|
||||||
|
|
||||||
br, ok := r.(*bufio.Reader)
|
br, ok := r.(*bufio.Reader)
|
||||||
if !ok {
|
if !ok {
|
||||||
br = bufio.NewReader(r)
|
br = bufio.NewReader(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ChunkReader{reader: br}
|
return &chunkReader{reader: br}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads a chunk, io.EOF when done
|
|
||||||
func (r *ChunkReader) Read() ([]byte, error) {
|
|
||||||
// already ended
|
// Read reads a chunk, err is io.EOF when done
|
||||||
|
func (r *chunkReader) Read() ([]byte, error){
|
||||||
|
|
||||||
|
/* (1) If already ended */
|
||||||
if r.isEnded {
|
if r.isEnded {
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
// read line
|
/* (2) Read line */
|
||||||
var line []byte
|
var line []byte
|
||||||
line, err := r.reader.ReadSlice('\n')
|
line, err := r.reader.ReadSlice('\n')
|
||||||
|
|
||||||
|
/* (3) manage errors */
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -55,8 +60,10 @@ func (r *ChunkReader) Read() ([]byte, error) {
|
||||||
return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength)
|
return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
line = trimSpaces(line)
|
/* (4) Trim */
|
||||||
|
line = removeTrailingSpace(line)
|
||||||
|
|
||||||
|
/* (5) Manage ending line */
|
||||||
if len(line) == 0 {
|
if len(line) == 0 {
|
||||||
r.isEnded = true
|
r.isEnded = true
|
||||||
return line, io.EOF
|
return line, io.EOF
|
||||||
|
@ -66,13 +73,15 @@ func (r *ChunkReader) Read() ([]byte, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func trimSpaces(b []byte) []byte {
|
|
||||||
for len(b) > 0 && isSpaceChar(b[len(b)-1]) {
|
|
||||||
|
func removeTrailingSpace(b []byte) []byte{
|
||||||
|
for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
|
||||||
b = b[:len(b)-1]
|
b = b[:len(b)-1]
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSpaceChar(b byte) bool {
|
func isASCIISpace(b byte) bool {
|
||||||
return b == ' ' || b == '\t' || b == '\r' || b =='\n'
|
return b == ' ' || b == '\t' || b == '\r' || b =='\n'
|
||||||
}
|
}
|
|
@ -1,35 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrInvalidRequest for invalid requests
|
|
||||||
// - multiple-value if only 1 expected
|
|
||||||
type ErrInvalidRequest struct {
|
|
||||||
Field string
|
|
||||||
Reason string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (err ErrInvalidRequest) Error() string {
|
|
||||||
return fmt.Sprintf("invalid field '%s': %s", err.Field, err.Reason)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrIncompleteRequest when mandatory request fields are missing (request-line or headers)
|
|
||||||
// it contains the missing field as a string
|
|
||||||
type ErrIncompleteRequest string
|
|
||||||
|
|
||||||
func (err ErrIncompleteRequest) Error() string {
|
|
||||||
return fmt.Sprintf("incomplete request, '%s' is invalid or missing", string(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrInvalidOriginPolicy when a request has a violated origin policy
|
|
||||||
type ErrInvalidOriginPolicy struct {
|
|
||||||
Host string
|
|
||||||
Origin string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (err ErrInvalidOriginPolicy) Error() string {
|
|
||||||
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HeaderType represents all 'valid' HTTP request headers
|
|
||||||
type HeaderType uint8
|
|
||||||
|
|
||||||
// header types
|
|
||||||
const (
|
|
||||||
Unknown HeaderType = iota
|
|
||||||
Host
|
|
||||||
Upgrade
|
|
||||||
Connection
|
|
||||||
Origin
|
|
||||||
WSKey
|
|
||||||
WSProtocol
|
|
||||||
WSExtensions
|
|
||||||
WSVersion
|
|
||||||
)
|
|
||||||
|
|
||||||
// HeaderValue represents a unique or multiple header value(s)
|
|
||||||
type HeaderValue [][]byte
|
|
||||||
|
|
||||||
// Header represents the data of a HTTP request header
|
|
||||||
type Header struct {
|
|
||||||
Name HeaderType
|
|
||||||
Values HeaderValue
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadHeader tries to parse an HTTP header from a byte array
|
|
||||||
func ReadHeader(b []byte) (*Header, error) {
|
|
||||||
|
|
||||||
// 1. Split by ':'
|
|
||||||
parts := bytes.Split(b, []byte(": "))
|
|
||||||
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return nil, fmt.Errorf("invalid HTTP header format '%s'", b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Create instance
|
|
||||||
inst := &Header{}
|
|
||||||
|
|
||||||
// 3. Check for header name
|
|
||||||
switch strings.ToLower(string(parts[0])) {
|
|
||||||
case "host":
|
|
||||||
inst.Name = Host
|
|
||||||
case "upgrade":
|
|
||||||
inst.Name = Upgrade
|
|
||||||
case "connection":
|
|
||||||
inst.Name = Connection
|
|
||||||
case "origin":
|
|
||||||
inst.Name = Origin
|
|
||||||
case "sec-websocket-key":
|
|
||||||
inst.Name = WSKey
|
|
||||||
case "sec-websocket-protocol":
|
|
||||||
inst.Name = WSProtocol
|
|
||||||
case "sec-websocket-extensions":
|
|
||||||
inst.Name = WSExtensions
|
|
||||||
case "sec-websocket-version":
|
|
||||||
inst.Name = WSVersion
|
|
||||||
default:
|
|
||||||
inst.Name = Unknown
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. Split values
|
|
||||||
inst.Values = bytes.Split(parts[1], []byte(", "))
|
|
||||||
|
|
||||||
return inst, nil
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,159 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// checkHost checks and extracts the Host header
|
|
||||||
func (r *Request) extractHostPort(bb HeaderValue) error {
|
|
||||||
|
|
||||||
if len(bb) != 1 {
|
|
||||||
return &ErrInvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(bb[0]) <= 3 {
|
|
||||||
return &ErrInvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))}
|
|
||||||
}
|
|
||||||
|
|
||||||
split := strings.Split(string(bb[0]), ":")
|
|
||||||
|
|
||||||
r.host = split[0]
|
|
||||||
|
|
||||||
// no port
|
|
||||||
if len(split) < 2 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract port
|
|
||||||
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
r.statusCode = StatusBadRequest
|
|
||||||
return &ErrInvalidRequest{"Host", "cannot read port"}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.port = uint16(readPort)
|
|
||||||
|
|
||||||
// if 'Origin' header is already read, check it
|
|
||||||
if len(r.origin) > 0 {
|
|
||||||
if err != nil {
|
|
||||||
err = r.checkOriginPolicy()
|
|
||||||
r.statusCode = StatusForbidden
|
|
||||||
return &ErrInvalidOriginPolicy{r.host, r.origin, err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkOrigin checks the Origin Header
|
|
||||||
func (r *Request) extractOrigin(bb HeaderValue) error {
|
|
||||||
|
|
||||||
// bypass
|
|
||||||
if bypassOriginPolicy {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(bb) != 1 {
|
|
||||||
r.statusCode = StatusForbidden
|
|
||||||
return &ErrInvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.origin = string(bb[0])
|
|
||||||
|
|
||||||
// if host already stored, check origin policy
|
|
||||||
if len(r.host) > 0 {
|
|
||||||
err := r.checkOriginPolicy()
|
|
||||||
if err != nil {
|
|
||||||
r.statusCode = StatusForbidden
|
|
||||||
return &ErrInvalidOriginPolicy{r.host, r.origin, err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkOriginPolicy origin policy based on 'host' value
|
|
||||||
func (r *Request) checkOriginPolicy() error {
|
|
||||||
// TODO: Origin policy, for now BYPASS
|
|
||||||
r.validPolicy = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkConnection checks the 'Connection' header
|
|
||||||
// it MUST contain 'Upgrade'
|
|
||||||
func (r *Request) checkConnection(bb HeaderValue) error {
|
|
||||||
|
|
||||||
for _, b := range bb {
|
|
||||||
|
|
||||||
if strings.ToLower(string(b)) == "upgrade" {
|
|
||||||
r.hasConnection = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
r.statusCode = StatusBadRequest
|
|
||||||
return &ErrInvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkUpgrade checks the 'Upgrade' header
|
|
||||||
// it MUST be 'websocket'
|
|
||||||
func (r *Request) checkUpgrade(bb HeaderValue) error {
|
|
||||||
|
|
||||||
if len(bb) != 1 {
|
|
||||||
r.statusCode = StatusBadRequest
|
|
||||||
return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.ToLower(string(bb[0])) == "websocket" {
|
|
||||||
r.hasUpgrade = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.statusCode = StatusBadRequest
|
|
||||||
return &ErrInvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkVersion checks the 'Sec-WebSocket-Version' header
|
|
||||||
// it MUST be '13'
|
|
||||||
func (r *Request) checkVersion(bb HeaderValue) error {
|
|
||||||
|
|
||||||
if len(bb) != 1 || string(bb[0]) != "13" {
|
|
||||||
r.statusCode = StatusUpgradeRequired
|
|
||||||
return &ErrInvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.hasVersion = true
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractKey extracts the 'Sec-WebSocket-Key' header
|
|
||||||
// it MUST be 24 bytes (base64)
|
|
||||||
func (r *Request) extractKey(bb HeaderValue) error {
|
|
||||||
|
|
||||||
if len(bb) != 1 || len(bb[0]) != 24 {
|
|
||||||
r.statusCode = StatusBadRequest
|
|
||||||
return &ErrInvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.key = bb[0]
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractProtocols extracts the 'Sec-WebSocket-Protocol' header
|
|
||||||
// it can contain multiple values
|
|
||||||
func (r *Request) extractProtocols(bb HeaderValue) error {
|
|
||||||
|
|
||||||
r.protocols = bb
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,189 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"git.xdrm.io/go/ws/internal/http/reader"
|
|
||||||
)
|
|
||||||
|
|
||||||
// whether origin is required
|
|
||||||
const bypassOriginPolicy = true
|
|
||||||
|
|
||||||
// Request represents an HTTP Upgrade request
|
|
||||||
type Request struct {
|
|
||||||
// whether the first line has been read (GET uri HTTP/version)
|
|
||||||
first bool
|
|
||||||
statusCode StatusCode
|
|
||||||
requestLine RequestLine
|
|
||||||
|
|
||||||
// data to check origin (depends on reading order)
|
|
||||||
host string
|
|
||||||
port uint16 // 0 if not set
|
|
||||||
origin string
|
|
||||||
validPolicy bool
|
|
||||||
|
|
||||||
// websocket specific
|
|
||||||
key []byte
|
|
||||||
protocols [][]byte
|
|
||||||
|
|
||||||
// mandatory fields to check
|
|
||||||
hasConnection bool
|
|
||||||
hasUpgrade bool
|
|
||||||
hasVersion bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadFrom reads an upgrade HTTP request ; typically from bufio.NewRead of the
|
|
||||||
// socket
|
|
||||||
//
|
|
||||||
// implements io.ReaderFrom
|
|
||||||
func (req *Request) ReadFrom(r io.Reader) (int64, error) {
|
|
||||||
var read int64
|
|
||||||
|
|
||||||
// reset request
|
|
||||||
req.statusCode = 500
|
|
||||||
|
|
||||||
// parse header line by line
|
|
||||||
var cr = reader.NewReader(r)
|
|
||||||
for {
|
|
||||||
line, err := cr.Read()
|
|
||||||
read += int64(len(line))
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.parseHeader(line)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := req.isComplete()
|
|
||||||
if err != nil {
|
|
||||||
req.statusCode = StatusBadRequest
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.statusCode = StatusSwitchingProtocols
|
|
||||||
return read, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StatusCode returns the status current
|
|
||||||
func (req Request) StatusCode() StatusCode {
|
|
||||||
return req.statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildResponse builds a response from the request
|
|
||||||
func (req *Request) BuildResponse() *Response {
|
|
||||||
|
|
||||||
res := &Response{
|
|
||||||
StatusCode: req.statusCode,
|
|
||||||
Protocol: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.protocols) > 0 {
|
|
||||||
res.Protocol = req.protocols[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
res.ProcessKey(req.key)
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// URI returns the actual URI
|
|
||||||
func (req Request) URI() string {
|
|
||||||
return req.requestLine.URI()
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseHeader parses any http request line
|
|
||||||
// (header and request-line)
|
|
||||||
func (req *Request) parseHeader(b []byte) error {
|
|
||||||
// first line -> GET {uri} HTTP/{version}
|
|
||||||
if !req.first {
|
|
||||||
|
|
||||||
_, err := req.requestLine.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
req.statusCode = StatusBadRequest
|
|
||||||
return &ErrInvalidRequest{"Request-Line", err.Error()}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.first = true
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// other lines -> Header-Name: Header-Value
|
|
||||||
head, err := ReadHeader(b)
|
|
||||||
if err != nil {
|
|
||||||
req.statusCode = StatusBadRequest
|
|
||||||
return fmt.Errorf("parse header: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Manage header
|
|
||||||
switch head.Name {
|
|
||||||
case Host:
|
|
||||||
err = req.extractHostPort(head.Values)
|
|
||||||
case Origin:
|
|
||||||
err = req.extractOrigin(head.Values)
|
|
||||||
case Upgrade:
|
|
||||||
err = req.checkUpgrade(head.Values)
|
|
||||||
case Connection:
|
|
||||||
err = req.checkConnection(head.Values)
|
|
||||||
case WSVersion:
|
|
||||||
err = req.checkVersion(head.Values)
|
|
||||||
case WSKey:
|
|
||||||
err = req.extractKey(head.Values)
|
|
||||||
case WSProtocol:
|
|
||||||
err = req.extractProtocols(head.Values)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatch error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isComplete returns whether the Upgrade Request
|
|
||||||
// is complete (no required field missing)
|
|
||||||
// returns nil on success
|
|
||||||
func (req Request) isComplete() error {
|
|
||||||
if !req.first {
|
|
||||||
return ErrIncompleteRequest("Request-Line")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.host) == 0 {
|
|
||||||
return ErrIncompleteRequest("Host")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bypassOriginPolicy && len(req.origin) == 0 {
|
|
||||||
return ErrIncompleteRequest("Origin")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !req.hasConnection {
|
|
||||||
return ErrIncompleteRequest("Connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !req.hasUpgrade {
|
|
||||||
return ErrIncompleteRequest("Upgrade")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !req.hasVersion {
|
|
||||||
return ErrIncompleteRequest("Sec-WebSocket-Version")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.key) < 1 {
|
|
||||||
return ErrIncompleteRequest("Sec-WebSocket-Key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
// invalid request
|
||||||
|
// - multiple-value if only 1 expected
|
||||||
|
type InvalidRequest struct {
|
||||||
|
Field string
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err InvalidRequest) Error() string {
|
||||||
|
return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request misses fields (request-line or headers)
|
||||||
|
type IncompleteRequest struct {
|
||||||
|
MissingField string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err IncompleteRequest) Error() string {
|
||||||
|
return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Request has a violated origin policy
|
||||||
|
type InvalidOriginPolicy struct {
|
||||||
|
Host string
|
||||||
|
Origin string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
func (err InvalidOriginPolicy) Error() string {
|
||||||
|
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
|
||||||
|
}
|
|
@ -0,0 +1,163 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.xdrm.io/gws/internal/http/upgrade/response"
|
||||||
|
"git.xdrm.io/gws/internal/http/upgrade/request/parser/header"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// checkHost checks and extracts the Host header
|
||||||
|
func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||||
|
|
||||||
|
if len(bb) != 1 {
|
||||||
|
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(bb[0]) <= 3 {
|
||||||
|
return &InvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))}
|
||||||
|
}
|
||||||
|
|
||||||
|
split := strings.Split(string(bb[0]), ":")
|
||||||
|
|
||||||
|
r.host = split[0]
|
||||||
|
|
||||||
|
// no port
|
||||||
|
if len(split) < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract port
|
||||||
|
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
r.code = response.BAD_REQUEST
|
||||||
|
return &InvalidRequest{"Host", "cannot read port"}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.port = uint16(readPort)
|
||||||
|
|
||||||
|
// if 'Origin' header is already read, check it
|
||||||
|
if len(r.origin) > 0 {
|
||||||
|
if err != nil {
|
||||||
|
err = r.checkOriginPolicy()
|
||||||
|
r.code = response.FORBIDDEN
|
||||||
|
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// checkOrigin checks the Origin Header
|
||||||
|
func (r *T) extractOrigin(bb header.HeaderValue) error {
|
||||||
|
|
||||||
|
// bypass
|
||||||
|
if bypassOriginPolicy { return nil }
|
||||||
|
|
||||||
|
if len(bb) != 1 {
|
||||||
|
r.code = response.FORBIDDEN
|
||||||
|
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.origin = string(bb[0])
|
||||||
|
|
||||||
|
// if host already stored, check origin policy
|
||||||
|
if len(r.host) > 0 {
|
||||||
|
err := r.checkOriginPolicy()
|
||||||
|
if err != nil {
|
||||||
|
r.code = response.FORBIDDEN
|
||||||
|
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkOriginPolicy origin policy based on 'host' value
|
||||||
|
func (r *T) checkOriginPolicy() error {
|
||||||
|
// TODO: Origin policy, for now BYPASS
|
||||||
|
r.validPolicy = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkConnection checks the 'Connection' header
|
||||||
|
// it MUST contain 'Upgrade'
|
||||||
|
func (r *T) checkConnection(bb header.HeaderValue) error {
|
||||||
|
|
||||||
|
for _, b := range bb {
|
||||||
|
|
||||||
|
if strings.ToLower( string(b) ) == "upgrade" {
|
||||||
|
r.hasConnection = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
r.code = response.BAD_REQUEST
|
||||||
|
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkUpgrade checks the 'Upgrade' header
|
||||||
|
// it MUST be 'websocket'
|
||||||
|
func (r *T) checkUpgrade(bb header.HeaderValue) error {
|
||||||
|
|
||||||
|
if len(bb) != 1 {
|
||||||
|
r.code = response.BAD_REQUEST
|
||||||
|
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.ToLower( string(bb[0]) ) == "websocket" {
|
||||||
|
r.hasUpgrade = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.code = response.BAD_REQUEST
|
||||||
|
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkVersion checks the 'Sec-WebSocket-Version' header
|
||||||
|
// it MUST be '13'
|
||||||
|
func (r *T) checkVersion(bb header.HeaderValue) error {
|
||||||
|
|
||||||
|
if len(bb) != 1 || string(bb[0]) != "13" {
|
||||||
|
r.code = response.UPGRADE_REQUIRED
|
||||||
|
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.hasVersion = true
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractKey extracts the 'Sec-WebSocket-Key' header
|
||||||
|
// it MUST be 24 bytes (base64)
|
||||||
|
func (r *T) extractKey(bb header.HeaderValue) error {
|
||||||
|
|
||||||
|
if len(bb) != 1 || len(bb[0]) != 24 {
|
||||||
|
r.code = response.BAD_REQUEST
|
||||||
|
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.key = bb[0]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extractProtocols extracts the 'Sec-WebSocket-Protocol' header
|
||||||
|
// it can contain multiple values
|
||||||
|
func (r *T) extractProtocols(bb header.HeaderValue) error {
|
||||||
|
|
||||||
|
r.protocols = bb
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package header
|
||||||
|
|
||||||
|
import (
|
||||||
|
// "regexp"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parse tries to return a 'T' (httpHeader) from a byte array
|
||||||
|
func Parse(b []byte) (*T, error) {
|
||||||
|
|
||||||
|
/* (1) Split by ':' */
|
||||||
|
parts := bytes.Split(b, []byte(": "))
|
||||||
|
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, fmt.Errorf("Invalid HTTP header format '%s'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Create instance */
|
||||||
|
inst := new(T)
|
||||||
|
|
||||||
|
/* (3) Check for header name */
|
||||||
|
switch strings.ToLower(string(parts[0])) {
|
||||||
|
case "host": inst.Name = HOST
|
||||||
|
case "upgrade": inst.Name = UPGRADE
|
||||||
|
case "connection": inst.Name = CONNECTION
|
||||||
|
case "origin": inst.Name = ORIGIN
|
||||||
|
case "sec-websocket-key": inst.Name = WSKEY
|
||||||
|
case "sec-websocket-protocol": inst.Name = WSPROTOCOL
|
||||||
|
case "sec-websocket-extensions": inst.Name = WSEXTENSIONS
|
||||||
|
case "sec-websocket-version": inst.Name = WSVERSION
|
||||||
|
default: inst.Name = UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Split values */
|
||||||
|
inst.Values = bytes.Split(parts[1], []byte(", "))
|
||||||
|
|
||||||
|
return inst, nil
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package header
|
||||||
|
|
||||||
|
// HeaderType represents all 'valid' HTTP request headers
|
||||||
|
type HeaderType byte
|
||||||
|
const (
|
||||||
|
UNKNOWN HeaderType = iota
|
||||||
|
HOST
|
||||||
|
UPGRADE
|
||||||
|
CONNECTION
|
||||||
|
ORIGIN
|
||||||
|
WSKEY
|
||||||
|
WSPROTOCOL
|
||||||
|
WSEXTENSIONS
|
||||||
|
WSVERSION
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeaderValue represents a unique or multiple header value(s)
|
||||||
|
type HeaderValue [][]byte
|
||||||
|
|
||||||
|
|
||||||
|
// T represents the data of a HTTP request header
|
||||||
|
type T struct{
|
||||||
|
Name HeaderType
|
||||||
|
Values HeaderValue
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.xdrm.io/gws/internal/http/upgrade/response"
|
||||||
|
"fmt"
|
||||||
|
"git.xdrm.io/gws/internal/http/upgrade/request/parser/header"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
// parseHeader parses any http request line
|
||||||
|
// (header and request-line)
|
||||||
|
func (r *T) parseHeader(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) First line -> GET {uri} HTTP/{version}
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
if !r.first {
|
||||||
|
|
||||||
|
err := r.request.Parse(b)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
r.code = response.BAD_REQUEST
|
||||||
|
return &InvalidRequest{"Request-Line", err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.first = true
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* (2) Other lines -> Header-Name: Header-Value
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
/* (1) Try to parse header */
|
||||||
|
head, err := header.Parse(b)
|
||||||
|
if err != nil {
|
||||||
|
r.code = response.BAD_REQUEST
|
||||||
|
return fmt.Errorf("Error parsing header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Manage header */
|
||||||
|
switch head.Name {
|
||||||
|
case header.HOST: err = r.extractHostPort(head.Values)
|
||||||
|
case header.ORIGIN: err = r.extractOrigin(head.Values)
|
||||||
|
case header.UPGRADE: err = r.checkUpgrade(head.Values)
|
||||||
|
case header.CONNECTION: err = r.checkConnection(head.Values)
|
||||||
|
case header.WSVERSION: err = r.checkVersion(head.Values)
|
||||||
|
case header.WSKEY: err = r.extractKey(head.Values)
|
||||||
|
case header.WSPROTOCOL: err = r.extractProtocols(head.Values)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// dispatch error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// isComplete returns whether the Upgrade Request
|
||||||
|
// is complete (no missing required item)
|
||||||
|
func (r T) isComplete() error {
|
||||||
|
|
||||||
|
/* (1) Request-Line */
|
||||||
|
if !r.first {
|
||||||
|
return &IncompleteRequest{"Request-Line"}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Host */
|
||||||
|
if len(r.host) == 0 {
|
||||||
|
return &IncompleteRequest{"Host"}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Origin */
|
||||||
|
if !bypassOriginPolicy && len(r.origin) == 0 {
|
||||||
|
return &IncompleteRequest{"Origin"}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Connection */
|
||||||
|
if !r.hasConnection {
|
||||||
|
return &IncompleteRequest{"Connection"}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (5) Upgrade */
|
||||||
|
if !r.hasUpgrade {
|
||||||
|
return &IncompleteRequest{"Upgrade"}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (6) Sec-WebSocket-Version */
|
||||||
|
if !r.hasVersion {
|
||||||
|
return &IncompleteRequest{"Sec-WebSocket-Version"}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (7) Sec-WebSocket-Key */
|
||||||
|
if len(r.key) < 1 {
|
||||||
|
return &IncompleteRequest{"Sec-WebSocket-Key"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.xdrm.io/gws/internal/http/upgrade/response"
|
||||||
|
"git.xdrm.io/gws/internal/http/reader"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse builds an upgrade HTTP request
|
||||||
|
// from a reader (typically bufio.NewRead of the socket)
|
||||||
|
func Parse(r io.Reader) (request *T, err error) {
|
||||||
|
|
||||||
|
req := new(T)
|
||||||
|
req.code = 500
|
||||||
|
|
||||||
|
|
||||||
|
/* (1) Parse request
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
/* (1) Get chunk reader */
|
||||||
|
cr := reader.NewReader(r)
|
||||||
|
if err != nil {
|
||||||
|
return req, fmt.Errorf("Error while creating chunk reader: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Parse header line by line */
|
||||||
|
for {
|
||||||
|
|
||||||
|
line, err := cr.Read()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = req.parseHeader(line)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Check completion */
|
||||||
|
err = req.isComplete()
|
||||||
|
if err != nil {
|
||||||
|
req.code = response.BAD_REQUEST
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
req.code = response.SWITCHING_PROTOCOLS
|
||||||
|
return req, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// StatusCode returns the status current
|
||||||
|
func (r T) StatusCode() response.StatusCode {
|
||||||
|
return r.code
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// BuildResponse builds a response.T from the request
|
||||||
|
func (r *T) BuildResponse() *response.T{
|
||||||
|
|
||||||
|
inst := new(response.T)
|
||||||
|
|
||||||
|
/* (1) Copy code */
|
||||||
|
inst.SetStatusCode(r.code)
|
||||||
|
|
||||||
|
/* (2) Set Protocol */
|
||||||
|
if len(r.protocols) > 0 {
|
||||||
|
inst.SetProtocol(r.protocols[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Process key */
|
||||||
|
inst.ProcessKey(r.key)
|
||||||
|
|
||||||
|
return inst
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// GetURI returns the actual URI
|
||||||
|
func (r T) GetURI() string{
|
||||||
|
return r.request.GetURI()
|
||||||
|
}
|
|
@ -0,0 +1,135 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"bytes"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpMethod represents available http methods
|
||||||
|
type httpMethod byte
|
||||||
|
const (
|
||||||
|
OPTIONS httpMethod = iota
|
||||||
|
GET
|
||||||
|
HEAD
|
||||||
|
POST
|
||||||
|
PUT
|
||||||
|
DELETE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
// RequestLine represents the HTTP Request line
|
||||||
|
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
||||||
|
type RequestLine struct {
|
||||||
|
method httpMethod
|
||||||
|
uri string
|
||||||
|
version byte
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// parseRequestLine parses the first HTTP request line
|
||||||
|
func (r *RequestLine) Parse(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) Split by ' ' */
|
||||||
|
parts := bytes.Split(b, []byte(" "))
|
||||||
|
|
||||||
|
/* (2) Fail when missing parts */
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Extract HTTP method */
|
||||||
|
err := r.extractHttpMethod(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Extract URI */
|
||||||
|
err = r.extractURI(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (5) Extract version */
|
||||||
|
err = r.extractHttpVersion(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// GetURI returns the actual URI
|
||||||
|
func (r RequestLine) GetURI() string {
|
||||||
|
return r.uri
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// extractHttpMethod extracts the HTTP method from a []byte
|
||||||
|
// and checks for errors
|
||||||
|
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
||||||
|
func (r *RequestLine) extractHttpMethod(b []byte) error {
|
||||||
|
|
||||||
|
switch string(b) {
|
||||||
|
// case "OPTIONS": r.method = OPTIONS
|
||||||
|
case "GET": r.method = GET
|
||||||
|
// case "HEAD": r.method = HEAD
|
||||||
|
// case "POST": r.method = POST
|
||||||
|
// case "PUT": r.method = PUT
|
||||||
|
// case "DELETE": r.method = DELETE
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Invalid HTTP method '%s', expected 'GET'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extractURI extracts the URI from a []byte and checks for errors
|
||||||
|
// allowed format: /([^/]/)*/?
|
||||||
|
func (r *RequestLine) extractURI(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) Check format */
|
||||||
|
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
||||||
|
if !checker.Match(b) {
|
||||||
|
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Store */
|
||||||
|
r.uri = string(b)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extractHttpVersion extracts the version and checks for errors
|
||||||
|
// allowed format: [1-9] or [1.9].[0-9]
|
||||||
|
func (r *RequestLine) extractHttpVersion(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) Extract version parts */
|
||||||
|
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`);
|
||||||
|
|
||||||
|
if !extractor.Match(b) {
|
||||||
|
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Extract version number */
|
||||||
|
matches := extractor.FindSubmatch(b)
|
||||||
|
var version byte = matches[1][0] - '0'
|
||||||
|
|
||||||
|
/* (3) Extract subversion (if exists) */
|
||||||
|
var subVersion byte = 0
|
||||||
|
if len(matches[2]) > 0 {
|
||||||
|
subVersion = matches[2][0] - '0'
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Store version (x 10 to fit uint8) */
|
||||||
|
r.version = version * 10 + subVersion
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,183 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
// /* (1) Parse request */
|
||||||
|
// req, _ := request.Parse(s)
|
||||||
|
|
||||||
|
// /* (3) Build response */
|
||||||
|
// res := req.BuildResponse()
|
||||||
|
|
||||||
|
// /* (4) Write into socket */
|
||||||
|
// _, err := res.Send(s)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("Upgrade write error: %s", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if res.GetStatusCode() != 101 {
|
||||||
|
// s.Close()
|
||||||
|
// return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
|
||||||
|
// }
|
||||||
|
|
||||||
|
func TestEOFSocket(t *testing.T){
|
||||||
|
|
||||||
|
socket := new(bytes.Buffer)
|
||||||
|
|
||||||
|
_, err := Parse(socket)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Empty socket expected EOF, got no error")
|
||||||
|
} else if err != io.ErrUnexpectedEOF {
|
||||||
|
t.Fatalf("Empty socket expected EOF, got '%s'", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidRequestLine(t *testing.T){
|
||||||
|
|
||||||
|
socket := new(bytes.Buffer)
|
||||||
|
cases := []struct{
|
||||||
|
Reqline string
|
||||||
|
HasError bool
|
||||||
|
}{
|
||||||
|
{ "abc", true },
|
||||||
|
{ "a c", true },
|
||||||
|
{ "a c", true },
|
||||||
|
{ "a c", true },
|
||||||
|
{ "a b c", true },
|
||||||
|
|
||||||
|
{ "GET invaliduri HTTP/1.1", true },
|
||||||
|
{ "GET /validuri HTTP/1.1", false },
|
||||||
|
|
||||||
|
{ "POST /validuri HTTP/1.1", true },
|
||||||
|
{ "PUT /validuri HTTP/1.1", true },
|
||||||
|
{ "DELETE /validuri HTTP/1.1", true },
|
||||||
|
{ "OPTIONS /validuri HTTP/1.1", true },
|
||||||
|
{ "UNKNOWN /validuri HTTP/1.1", true },
|
||||||
|
|
||||||
|
{ "GET / HTTP", true },
|
||||||
|
{ "GET / HTTP/", true },
|
||||||
|
{ "GET / 1.1", true },
|
||||||
|
{ "GET / 1", true },
|
||||||
|
{ "GET / HTTP/52", true },
|
||||||
|
{ "GET / HTTP/1.", true },
|
||||||
|
{ "GET / HTTP/.1", true },
|
||||||
|
{ "GET / HTTP/1.1", false },
|
||||||
|
{ "GET / HTTP/2", false },
|
||||||
|
}
|
||||||
|
|
||||||
|
for ti, tc := range cases {
|
||||||
|
|
||||||
|
socket.Reset()
|
||||||
|
socket.Write( []byte(tc.Reqline) )
|
||||||
|
socket.Write( []byte("\r\n\r\n") )
|
||||||
|
|
||||||
|
_, err := Parse(socket)
|
||||||
|
|
||||||
|
if !tc.HasError {
|
||||||
|
|
||||||
|
// no error -> ok
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
// error for the end of the request -> ok
|
||||||
|
} else if _, ok := err.(*IncompleteRequest); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Errorf("[%d] Expected no error", ti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing required error -> error
|
||||||
|
if tc.HasError && err == nil {
|
||||||
|
t.Errorf("[%d] Expected error", ti)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ir, ok := err.(*InvalidRequest);
|
||||||
|
|
||||||
|
// not InvalidRequest err -> error
|
||||||
|
if !ok || ir.Field != "Request-Line" {
|
||||||
|
t.Errorf("[%d] expected InvalidRequest", ti)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidHost(t *testing.T){
|
||||||
|
|
||||||
|
requestLine := []byte( "GET / HTTP/1.1\r\n" )
|
||||||
|
|
||||||
|
socket := new(bytes.Buffer)
|
||||||
|
cases := []struct{
|
||||||
|
Host string
|
||||||
|
HasError bool
|
||||||
|
}{
|
||||||
|
{ "1", true },
|
||||||
|
{ "12", true },
|
||||||
|
{ "123", true },
|
||||||
|
{ "1234", false },
|
||||||
|
|
||||||
|
{ "singlevalue", false },
|
||||||
|
{ "multi value", true },
|
||||||
|
|
||||||
|
{ "singlevalue:1", false },
|
||||||
|
{ "singlevalue:", true },
|
||||||
|
{ "singlevalue:x", true },
|
||||||
|
{ "xx:x", true },
|
||||||
|
{ ":xxx", true },
|
||||||
|
{ "xxx:", true },
|
||||||
|
{ "a:12", false },
|
||||||
|
|
||||||
|
{ "google.com", false },
|
||||||
|
{ "8.8.8.8", false },
|
||||||
|
{ "google.com:8080", false },
|
||||||
|
{ "8.8.8.8:8080", false },
|
||||||
|
}
|
||||||
|
|
||||||
|
for ti, tc := range cases {
|
||||||
|
|
||||||
|
socket.Reset()
|
||||||
|
socket.Write(requestLine)
|
||||||
|
socket.Write( []byte("Host: ") )
|
||||||
|
socket.Write( []byte(tc.Host) )
|
||||||
|
socket.Write( []byte("\r\n\r\n") )
|
||||||
|
|
||||||
|
_, err := Parse(socket)
|
||||||
|
|
||||||
|
if !tc.HasError {
|
||||||
|
|
||||||
|
|
||||||
|
// no error -> ok
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
// error for the end of the request -> ok
|
||||||
|
} else if _, ok := err.(*IncompleteRequest); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Errorf("[%d] Expected no error; %s", ti, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing required error -> error
|
||||||
|
if tc.HasError && err == nil {
|
||||||
|
t.Errorf("[%d] Expected error", ti)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if InvalidRequest
|
||||||
|
ir, ok := err.(*InvalidRequest);
|
||||||
|
|
||||||
|
// not InvalidRequest err -> error
|
||||||
|
if ok && ir.Field != "Host" {
|
||||||
|
t.Errorf("[%d] expected InvalidRequest", ti)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import "git.xdrm.io/gws/internal/http/upgrade/response"
|
||||||
|
|
||||||
|
// If origin is required
|
||||||
|
const bypassOriginPolicy = true
|
||||||
|
|
||||||
|
// T represents an HTTP Upgrade request
|
||||||
|
type T struct {
|
||||||
|
first bool // whether the first line has been read (GET uri HTTP/version)
|
||||||
|
|
||||||
|
// status code
|
||||||
|
code response.StatusCode
|
||||||
|
|
||||||
|
// request line
|
||||||
|
request RequestLine
|
||||||
|
|
||||||
|
// data to check origin (depends of reading order)
|
||||||
|
host string
|
||||||
|
port uint16 // 0 if not set
|
||||||
|
origin string
|
||||||
|
validPolicy bool
|
||||||
|
|
||||||
|
// ws data
|
||||||
|
key []byte
|
||||||
|
protocols [][]byte
|
||||||
|
|
||||||
|
// required fields check
|
||||||
|
hasConnection bool
|
||||||
|
hasUpgrade bool
|
||||||
|
hasVersion bool
|
||||||
|
}
|
|
@ -1,94 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RequestLine represents the HTTP Request line
|
|
||||||
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
|
||||||
type RequestLine struct {
|
|
||||||
uri string
|
|
||||||
version byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read an HTTP request line from a byte array
|
|
||||||
//
|
|
||||||
// implements io.Reader
|
|
||||||
func (rl *RequestLine) Read(b []byte) (int, error) {
|
|
||||||
var read = len(b)
|
|
||||||
|
|
||||||
// split by spaces
|
|
||||||
parts := bytes.Split(b, []byte(" "))
|
|
||||||
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return read, fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
|
||||||
}
|
|
||||||
|
|
||||||
err := rl.extractHttpMethod(parts[0])
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rl.extractURI(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rl.extractHttpVersion(parts[2])
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return read, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// URI of the request line
|
|
||||||
func (rl RequestLine) URI() string {
|
|
||||||
return rl.uri
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractHttpMethod extracts the HTTP method from a []byte
|
|
||||||
// and checks for errors
|
|
||||||
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
|
||||||
func (rl *RequestLine) extractHttpMethod(b []byte) error {
|
|
||||||
if string(b) != "GET" {
|
|
||||||
return fmt.Errorf("invalid HTTP method '%s', expected 'GET'", b)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractURI extracts the URI from a []byte and checks for errors
|
|
||||||
// allowed format: /([^/]/)*/?
|
|
||||||
func (rl *RequestLine) extractURI(b []byte) error {
|
|
||||||
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
|
||||||
if !checker.Match(b) {
|
|
||||||
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
|
||||||
}
|
|
||||||
rl.uri = string(b)
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractHttpVersion extracts the version and checks for errors
|
|
||||||
// allowed format: [1-9] or [1.9].[0-9]
|
|
||||||
func (rl *RequestLine) extractHttpVersion(b []byte) error {
|
|
||||||
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
|
|
||||||
|
|
||||||
if !extractor.Match(b) {
|
|
||||||
return fmt.Errorf("invalid HTTP version, expected INT or INT.INT, got '%s'", b)
|
|
||||||
}
|
|
||||||
matches := extractor.FindSubmatch(b)
|
|
||||||
|
|
||||||
var version byte = matches[1][0] - '0'
|
|
||||||
|
|
||||||
var subversion byte = 0
|
|
||||||
if len(matches[2]) > 0 {
|
|
||||||
subversion = matches[2][0] - '0'
|
|
||||||
}
|
|
||||||
|
|
||||||
rl.version = version*10 + subversion
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,162 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEOFSocket(t *testing.T) {
|
|
||||||
var (
|
|
||||||
socket = &bytes.Buffer{}
|
|
||||||
req = &Request{}
|
|
||||||
)
|
|
||||||
|
|
||||||
_, err := req.ReadFrom(socket)
|
|
||||||
if err != io.ErrUnexpectedEOF {
|
|
||||||
t.Fatalf("unexpected error <%v> expected <%v>", err, io.ErrUnexpectedEOF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvalidRequestLine(t *testing.T) {
|
|
||||||
|
|
||||||
socket := &bytes.Buffer{}
|
|
||||||
cases := []struct {
|
|
||||||
Reqline string
|
|
||||||
HasError bool
|
|
||||||
}{
|
|
||||||
{"abc", true},
|
|
||||||
{"a c", true},
|
|
||||||
{"a c", true},
|
|
||||||
{"a c", true},
|
|
||||||
{"a b c", true},
|
|
||||||
|
|
||||||
{"GET invaliduri HTTP/1.1", true},
|
|
||||||
{"GET /validuri HTTP/1.1", false},
|
|
||||||
|
|
||||||
{"POST /validuri HTTP/1.1", true},
|
|
||||||
{"PUT /validuri HTTP/1.1", true},
|
|
||||||
{"DELETE /validuri HTTP/1.1", true},
|
|
||||||
{"OPTIONS /validuri HTTP/1.1", true},
|
|
||||||
{"UNKNOWN /validuri HTTP/1.1", true},
|
|
||||||
|
|
||||||
{"GET / HTTP", true},
|
|
||||||
{"GET / HTTP/", true},
|
|
||||||
{"GET / 1.1", true},
|
|
||||||
{"GET / 1", true},
|
|
||||||
{"GET / HTTP/52", true},
|
|
||||||
{"GET / HTTP/1.", true},
|
|
||||||
{"GET / HTTP/.1", true},
|
|
||||||
{"GET / HTTP/1.1", false},
|
|
||||||
{"GET / HTTP/2", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for ti, tc := range cases {
|
|
||||||
|
|
||||||
socket.Reset()
|
|
||||||
socket.Write([]byte(tc.Reqline))
|
|
||||||
socket.Write([]byte("\r\n\r\n"))
|
|
||||||
|
|
||||||
var req = &Request{}
|
|
||||||
_, err := req.ReadFrom(socket)
|
|
||||||
if !tc.HasError {
|
|
||||||
if err == nil {
|
|
||||||
continue
|
|
||||||
// error for the end of the request -> ok
|
|
||||||
} else if _, ok := err.(ErrIncompleteRequest); ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Errorf("[%d] Expected no error", ti)
|
|
||||||
}
|
|
||||||
|
|
||||||
// missing required error -> error
|
|
||||||
if tc.HasError && err == nil {
|
|
||||||
t.Errorf("[%d] Expected error", ti)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ir, ok := err.(*ErrInvalidRequest)
|
|
||||||
|
|
||||||
// not InvalidRequest err -> error
|
|
||||||
if !ok || ir.Field != "Request-Line" {
|
|
||||||
t.Errorf("[%d] expected InvalidRequest", ti)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvalidHost(t *testing.T) {
|
|
||||||
|
|
||||||
requestLine := []byte("GET / HTTP/1.1\r\n")
|
|
||||||
|
|
||||||
socket := &bytes.Buffer{}
|
|
||||||
cases := []struct {
|
|
||||||
Host string
|
|
||||||
HasError bool
|
|
||||||
}{
|
|
||||||
{"1", true},
|
|
||||||
{"12", true},
|
|
||||||
{"123", true},
|
|
||||||
{"1234", false},
|
|
||||||
|
|
||||||
{"singlevalue", false},
|
|
||||||
{"multi value", true},
|
|
||||||
|
|
||||||
{"singlevalue:1", false},
|
|
||||||
{"singlevalue:", true},
|
|
||||||
{"singlevalue:x", true},
|
|
||||||
{"xx:x", true},
|
|
||||||
{":xxx", true},
|
|
||||||
{"xxx:", true},
|
|
||||||
{"a:12", false},
|
|
||||||
|
|
||||||
{"google.com", false},
|
|
||||||
{"8.8.8.8", false},
|
|
||||||
{"google.com:8080", false},
|
|
||||||
{"8.8.8.8:8080", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for ti, tc := range cases {
|
|
||||||
|
|
||||||
socket.Reset()
|
|
||||||
socket.Write(requestLine)
|
|
||||||
socket.Write([]byte("Host: "))
|
|
||||||
socket.Write([]byte(tc.Host))
|
|
||||||
socket.Write([]byte("\r\n\r\n"))
|
|
||||||
|
|
||||||
var req = &Request{}
|
|
||||||
_, err := req.ReadFrom(socket)
|
|
||||||
if !tc.HasError {
|
|
||||||
|
|
||||||
// no error -> ok
|
|
||||||
if err == nil {
|
|
||||||
continue
|
|
||||||
// error for the end of the request -> ok
|
|
||||||
} else if _, ok := err.(ErrIncompleteRequest); ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Errorf("[%d] Expected no error; %s", ti, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// missing required error -> error
|
|
||||||
if tc.HasError && err == nil {
|
|
||||||
t.Errorf("[%d] Expected error", ti)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if InvalidRequest
|
|
||||||
ir, ok := err.(ErrInvalidRequest)
|
|
||||||
|
|
||||||
// not InvalidRequest err -> error
|
|
||||||
if ok && ir.Field != "Host" {
|
|
||||||
t.Errorf("[%d] expected InvalidRequest", ti)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// constants
|
|
||||||
const (
|
|
||||||
httpVersion = "1.1"
|
|
||||||
wsVersion = 13
|
|
||||||
keySalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Response is an HTTP Upgrade Response
|
|
||||||
type Response struct {
|
|
||||||
StatusCode StatusCode
|
|
||||||
// Sec-WebSocket-Protocol or nil if missing
|
|
||||||
Protocol []byte
|
|
||||||
// processed from Sec-WebSocket-Key
|
|
||||||
key []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessKey processes the accept token according
|
|
||||||
// to the rfc from the Sec-WebSocket-Key
|
|
||||||
func (r *Response) ProcessKey(k []byte) {
|
|
||||||
// ignore empty key
|
|
||||||
if k == nil || len(k) < 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// concat with constant salt
|
|
||||||
salted := append(k, []byte(keySalt)...)
|
|
||||||
// hash with sha1
|
|
||||||
digest := sha1.Sum(salted)
|
|
||||||
// base64 encode
|
|
||||||
r.key = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteTo writes the response; typically in a socket
|
|
||||||
//
|
|
||||||
// implements io.WriterTo
|
|
||||||
func (r Response) WriteTo(w io.Writer) (int64, error) {
|
|
||||||
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", httpVersion, r.StatusCode, r.StatusCode)
|
|
||||||
|
|
||||||
optionalProtocol := ""
|
|
||||||
if len(r.Protocol) > 0 {
|
|
||||||
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.Protocol)
|
|
||||||
}
|
|
||||||
|
|
||||||
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", wsVersion, optionalProtocol)
|
|
||||||
if r.key != nil {
|
|
||||||
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.key)
|
|
||||||
}
|
|
||||||
headers = fmt.Sprintf("%s\r\n", headers)
|
|
||||||
|
|
||||||
combined := []byte(fmt.Sprintf("%s%s", responseLine, headers))
|
|
||||||
|
|
||||||
written, err := w.Write(combined)
|
|
||||||
return int64(written), err
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
package response
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"encoding/base64"
|
||||||
|
"crypto/sha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// SetStatusCode sets the status code
|
||||||
|
func (r *T) SetStatusCode(sc StatusCode) {
|
||||||
|
r.code = sc
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetProtocols sets the protocols
|
||||||
|
func (r *T) SetProtocol(p []byte) {
|
||||||
|
r.protocol = p
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessKey processes the accept token according
|
||||||
|
// to the rfc from the Sec-WebSocket-Key
|
||||||
|
func (r *T) ProcessKey(k []byte) {
|
||||||
|
|
||||||
|
// do nothing for empty key
|
||||||
|
if k == nil || len(k) == 0 {
|
||||||
|
r.accept = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (1) Concat with constant salt */
|
||||||
|
mix := append(k, WSSalt...)
|
||||||
|
|
||||||
|
/* (2) Hash with sha1 algorithm */
|
||||||
|
digest := sha1.Sum(mix)
|
||||||
|
|
||||||
|
/* (3) Base64 encode it */
|
||||||
|
r.accept = []byte( base64.StdEncoding.EncodeToString( digest[:sha1.Size] ) )
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Send sends the response through an io.Writer
|
||||||
|
// typically a socket
|
||||||
|
func (r T) Send(w io.Writer) (int, error) {
|
||||||
|
|
||||||
|
/* (1) Build response line */
|
||||||
|
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HttpVersion, r.code, r.code.Message())
|
||||||
|
|
||||||
|
/* (2) Build headers */
|
||||||
|
optionalProtocol := ""
|
||||||
|
if len(r.protocol) > 0 {
|
||||||
|
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", WSVersion, optionalProtocol)
|
||||||
|
if r.accept != nil {
|
||||||
|
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept)
|
||||||
|
}
|
||||||
|
headers = fmt.Sprintf("%s\r\n", headers)
|
||||||
|
|
||||||
|
/* (3) Build all */
|
||||||
|
raw := []byte(fmt.Sprintf("%s%s", responseLine, headers))
|
||||||
|
|
||||||
|
/* (4) Write */
|
||||||
|
written, err := w.Write(raw)
|
||||||
|
|
||||||
|
return written, err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProtocol returns the choosen protocol if set, else nil
|
||||||
|
func (r T) GetProtocol() []byte {
|
||||||
|
return r.protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// GetStatusCode returns the response status code
|
||||||
|
func (r T) GetStatusCode() StatusCode {
|
||||||
|
return r.code
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package response
|
||||||
|
|
||||||
|
// StatusCode maps the status codes (and description)
|
||||||
|
type StatusCode uint16
|
||||||
|
|
||||||
|
var SWITCHING_PROTOCOLS StatusCode = 101 // handshake success
|
||||||
|
var BAD_REQUEST StatusCode = 400 // missing/malformed headers
|
||||||
|
var FORBIDDEN StatusCode = 403 // invalid origin policy, TLS required
|
||||||
|
var UPGRADE_REQUIRED StatusCode = 426 // invalid WS version
|
||||||
|
var NOT_FOUND StatusCode = 404 // unserved or invalid URI
|
||||||
|
var INTERNAL StatusCode = 500 // custom error
|
||||||
|
|
||||||
|
func (sc StatusCode) Message() string {
|
||||||
|
|
||||||
|
switch sc {
|
||||||
|
case SWITCHING_PROTOCOLS: return "Switching Protocols"
|
||||||
|
case BAD_REQUEST: return "Bad Request"
|
||||||
|
case FORBIDDEN: return "Forbidden"
|
||||||
|
case UPGRADE_REQUIRED: return "Upgrade Required"
|
||||||
|
case NOT_FOUND: return "Not Found"
|
||||||
|
case INTERNAL: return "Internal Server Error"
|
||||||
|
default:
|
||||||
|
return "Unknown Status Code"
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package response
|
||||||
|
|
||||||
|
// Constant
|
||||||
|
const HttpVersion = "1.1"
|
||||||
|
const WSVersion = 13
|
||||||
|
var WSSalt []byte = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
|
||||||
|
// T represents an HTTP Upgrade Response
|
||||||
|
type T struct {
|
||||||
|
|
||||||
|
code StatusCode // status code
|
||||||
|
accept []byte // processed from Sec-WebSocket-Key
|
||||||
|
protocol []byte // set from Sec-WebSocket-Protocol or none if not received
|
||||||
|
|
||||||
|
}
|
|
@ -1,39 +0,0 @@
|
||||||
package upgrade
|
|
||||||
|
|
||||||
// StatusCode maps HTTP status codes (and description)
|
|
||||||
type StatusCode int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// StatusSwitchingProtocols - handshake success
|
|
||||||
StatusSwitchingProtocols StatusCode = 101
|
|
||||||
// StatusBadRequest - missing/malformed headers
|
|
||||||
StatusBadRequest StatusCode = 400
|
|
||||||
// StatusForbidden - invalid origin policy, TLS required
|
|
||||||
StatusForbidden StatusCode = 403
|
|
||||||
// StatusUpgradeRequired - invalid WS version
|
|
||||||
StatusUpgradeRequired StatusCode = 426
|
|
||||||
// StatusNotFound - unserved or invalid URI
|
|
||||||
StatusNotFound StatusCode = 404
|
|
||||||
// StatusInternal - custom error
|
|
||||||
StatusInternal StatusCode = 500
|
|
||||||
)
|
|
||||||
|
|
||||||
// String implements the Stringer interface
|
|
||||||
func (sc StatusCode) String() string {
|
|
||||||
switch sc {
|
|
||||||
case StatusSwitchingProtocols:
|
|
||||||
return "Switching Protocols"
|
|
||||||
case StatusBadRequest:
|
|
||||||
return "Bad Request"
|
|
||||||
case StatusForbidden:
|
|
||||||
return "Forbidden"
|
|
||||||
case StatusUpgradeRequired:
|
|
||||||
return "Upgrade Required"
|
|
||||||
case StatusNotFound:
|
|
||||||
return "Not Found"
|
|
||||||
case StatusInternal:
|
|
||||||
return "Internal Server Error"
|
|
||||||
default:
|
|
||||||
return "Unknown Status Code"
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,319 +0,0 @@
|
||||||
package uri
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// === WILDCARDS ===
|
|
||||||
//
|
|
||||||
// The star '*' -> matches 0 or 1 slash-bounded string
|
|
||||||
// The multi star '**' -> matches 0 or more slash-separated strings
|
|
||||||
// The dot '.' -> matches 1 slash-bounded string
|
|
||||||
// The multi dot '..' -> matches 1 or more slash-separated strings
|
|
||||||
//
|
|
||||||
// === SCHEME POLICY ===
|
|
||||||
//
|
|
||||||
// - The last '/' is optional
|
|
||||||
// - Any '**' at the very end will match anything that starts with the given prefix
|
|
||||||
//
|
|
||||||
// === LIMITATIONS ==
|
|
||||||
//
|
|
||||||
// - A scheme must begin with '/'
|
|
||||||
// - A scheme cannot contain something else than a STRING or WILDCARD between 2 '/' separators
|
|
||||||
// - A scheme STRING cannot contain the symbols '/' as a character
|
|
||||||
// - A scheme STRING containing '*' or '.' characters will be treating as STRING only
|
|
||||||
// - A maximum of 16 slash-separated matchers (STRING or WILDCARD) are allowed
|
|
||||||
|
|
||||||
const maxMatch = 16
|
|
||||||
|
|
||||||
// Represents an URI matcher
|
|
||||||
type matcher struct {
|
|
||||||
pat string // pattern to match (empty if wildcard)
|
|
||||||
req bool // whether it is required
|
|
||||||
mul bool // whether multiple matches are allowed
|
|
||||||
|
|
||||||
buf []string // matched content (when matching)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scheme represents an URI scheme
|
|
||||||
type Scheme []*matcher
|
|
||||||
|
|
||||||
// FromString builds an URI scheme from a string pattern
|
|
||||||
func FromString(s string) (*Scheme, error) {
|
|
||||||
// handle '/' at the start
|
|
||||||
if len(s) < 1 || s[0] != '/' {
|
|
||||||
return nil, fmt.Errorf("invalid URI; must start with '/'")
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(s, "/")
|
|
||||||
|
|
||||||
// check max match size
|
|
||||||
if len(parts)-2 > maxMatch {
|
|
||||||
for i, p := range parts {
|
|
||||||
fmt.Printf("%d: '%s'\n", i, p)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
|
|
||||||
}
|
|
||||||
|
|
||||||
sch, err := buildScheme(parts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
opti, err := sch.optimise()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &opti, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match returns whether the given URI is matched by the scheme
|
|
||||||
func (s Scheme) Match(uri string) bool {
|
|
||||||
if len(s) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for string match
|
|
||||||
clearURI, match := s.matchString(uri)
|
|
||||||
if !match {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for non-string match (wildcards)
|
|
||||||
return s.matchWildcards(clearURI)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMatch returns the indexed match (excluding string matchers)
|
|
||||||
func (s Scheme) GetMatch(n uint8) ([]string, error) {
|
|
||||||
if n > uint8(len(s)) {
|
|
||||||
return nil, fmt.Errorf("index out of range")
|
|
||||||
}
|
|
||||||
|
|
||||||
// iterate to find index (exclude strings)
|
|
||||||
matches := -1
|
|
||||||
for _, m := range s {
|
|
||||||
if len(m.pat) > 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
matches++
|
|
||||||
|
|
||||||
// expected index -> return matches
|
|
||||||
if uint8(matches) == n {
|
|
||||||
return m.buf, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// nothing found -> return empty set
|
|
||||||
return nil, fmt.Errorf("index out of range (max: %d)", matches)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllMatch returns all the indexed match (excluding string matchers)
|
|
||||||
func (s Scheme) GetAllMatch() [][]string {
|
|
||||||
match := make([][]string, 0, len(s))
|
|
||||||
|
|
||||||
for _, m := range s {
|
|
||||||
if len(m.pat) > 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
match = append(match, m.buf)
|
|
||||||
}
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildScheme builds a 'basic' scheme
|
|
||||||
// from a pattern string
|
|
||||||
func buildScheme(ss []string) (Scheme, error) {
|
|
||||||
sch := make(Scheme, 0, maxMatch)
|
|
||||||
|
|
||||||
for _, s := range ss {
|
|
||||||
if len(s) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &matcher{}
|
|
||||||
|
|
||||||
switch s {
|
|
||||||
|
|
||||||
// card: 0, N
|
|
||||||
case "**":
|
|
||||||
m.req = false
|
|
||||||
m.mul = true
|
|
||||||
sch = append(sch, m)
|
|
||||||
|
|
||||||
// card: 1, N
|
|
||||||
case "..":
|
|
||||||
m.req = true
|
|
||||||
m.mul = true
|
|
||||||
sch = append(sch, m)
|
|
||||||
|
|
||||||
// card: 0, 1
|
|
||||||
case "*":
|
|
||||||
m.req = false
|
|
||||||
m.mul = false
|
|
||||||
sch = append(sch, m)
|
|
||||||
|
|
||||||
// card: 1
|
|
||||||
case ".":
|
|
||||||
m.req = true
|
|
||||||
m.mul = false
|
|
||||||
sch = append(sch, m)
|
|
||||||
|
|
||||||
// card: 1, literal string
|
|
||||||
default:
|
|
||||||
m.req = true
|
|
||||||
m.mul = false
|
|
||||||
m.pat = fmt.Sprintf("/%s", s)
|
|
||||||
sch = append(sch, m)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return sch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// optimise optimised the scheme for further parsing
|
|
||||||
func (s Scheme) optimise() (Scheme, error) {
|
|
||||||
if len(s) <= 1 {
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// init reshifted scheme
|
|
||||||
rshift := make(Scheme, 0, maxMatch)
|
|
||||||
rshift = append(rshift, s[0])
|
|
||||||
|
|
||||||
// iterate over matchers
|
|
||||||
for p, i, l := 0, 1, len(s); i < l; i++ {
|
|
||||||
|
|
||||||
pre, cur := s[p], s[i]
|
|
||||||
|
|
||||||
// merge: 2 following literals
|
|
||||||
if len(pre.pat) > 0 && len(cur.pat) > 0 {
|
|
||||||
// merge strings into previous
|
|
||||||
pre.pat = fmt.Sprintf("%s%s", pre.pat, cur.pat)
|
|
||||||
|
|
||||||
// delete current
|
|
||||||
s[i] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// increment previous (only if current is not nul)
|
|
||||||
if s[i] != nil {
|
|
||||||
rshift = append(rshift, s[i])
|
|
||||||
p = i
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return rshift, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// matchString checks the STRING matchers from an URI
|
|
||||||
// - returns a boolean : false when not matching, true eitherway
|
|
||||||
// - returns a cleared uri, without STRING data
|
|
||||||
func (s Scheme) matchString(uri string) (string, bool) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
clearedInput = uri
|
|
||||||
minOffset = 0
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, m := range s {
|
|
||||||
ls := len(m.pat)
|
|
||||||
|
|
||||||
// ignore no STRING match
|
|
||||||
if ls == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// get offset in URI (else -1)
|
|
||||||
off := strings.Index(clearedInput, m.pat)
|
|
||||||
if off < 0 {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
// fail on invalid offset range
|
|
||||||
if off < minOffset {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for trailing '/'
|
|
||||||
hasSlash := 0
|
|
||||||
if off+ls < len(clearedInput) && clearedInput[off+ls] == '/' {
|
|
||||||
hasSlash = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the current string (+trailing slash) from the URI
|
|
||||||
beg, end := clearedInput[:off], clearedInput[off+ls+hasSlash:]
|
|
||||||
clearedInput = fmt.Sprintf("%s\a/%s", beg, end) // separate matches with a '\a' character
|
|
||||||
|
|
||||||
// update offset range
|
|
||||||
// +2 slash separators
|
|
||||||
// -1 because strings begin with 1 slash already
|
|
||||||
minOffset = len(beg) + 2 - 1
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// if exists, remove trailing '/'
|
|
||||||
if clearedInput[len(clearedInput)-1] == '/' {
|
|
||||||
clearedInput = clearedInput[:len(clearedInput)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// if exists, remove trailing '\a'
|
|
||||||
if clearedInput[len(clearedInput)-1] == '\a' {
|
|
||||||
clearedInput = clearedInput[:len(clearedInput)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
return clearedInput, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// matchWildcards check the WILCARDS (non-string) matchers from
|
|
||||||
// a cleared URI. it returns if the string matches
|
|
||||||
// + it sets the matchers buffers for later extraction
|
|
||||||
func (s Scheme) matchWildcards(clear string) bool {
|
|
||||||
|
|
||||||
// extract wildcards (ref)
|
|
||||||
wildcards := make(Scheme, 0, maxMatch)
|
|
||||||
|
|
||||||
for _, m := range s {
|
|
||||||
if len(m.pat) == 0 {
|
|
||||||
m.buf = nil // flush buffers
|
|
||||||
wildcards = append(wildcards, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wildcards) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// break uri by '\a' characters
|
|
||||||
matches := strings.Split(clear, "\a")[1:]
|
|
||||||
|
|
||||||
for n, match := range matches {
|
|
||||||
// no more matcher
|
|
||||||
if n >= len(wildcards) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// from index 1 because it begins with '/'
|
|
||||||
data := strings.Split(match, "/")[1:]
|
|
||||||
|
|
||||||
// missing required
|
|
||||||
if wildcards[n].req && len(data) < 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// if not multi but got multi
|
|
||||||
if !wildcards[n].mul && len(data) > 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
wildcards[n].buf = data
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
|
@ -0,0 +1,215 @@
|
||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
// buildScheme builds a 'basic' scheme
|
||||||
|
// from a pattern string
|
||||||
|
func buildScheme(ss []string) (Scheme, error) {
|
||||||
|
|
||||||
|
/* (1) Build scheme */
|
||||||
|
sch := make(Scheme, 0, maxMatch)
|
||||||
|
|
||||||
|
for _, s := range ss {
|
||||||
|
|
||||||
|
/* (2) ignore empty */
|
||||||
|
if len(s) == 0 { continue }
|
||||||
|
|
||||||
|
m := new(matcher)
|
||||||
|
|
||||||
|
switch s {
|
||||||
|
|
||||||
|
/* (3) Card: 0, N */
|
||||||
|
case "**":
|
||||||
|
m.req = false
|
||||||
|
m.mul = true
|
||||||
|
sch = append(sch, m)
|
||||||
|
|
||||||
|
/* (4) Card: 1, N */
|
||||||
|
case "..":
|
||||||
|
m.req = true
|
||||||
|
m.mul = true
|
||||||
|
sch = append(sch, m)
|
||||||
|
|
||||||
|
/* (5) Card: 0, 1 */
|
||||||
|
case "*":
|
||||||
|
m.req = false
|
||||||
|
m.mul = false
|
||||||
|
sch = append(sch, m)
|
||||||
|
|
||||||
|
/* (6) Card: 1 */
|
||||||
|
case ".":
|
||||||
|
m.req = true
|
||||||
|
m.mul = false
|
||||||
|
sch = append(sch, m)
|
||||||
|
|
||||||
|
/* (7) Card: 1, literal string */
|
||||||
|
default:
|
||||||
|
m.req = true
|
||||||
|
m.mul = false
|
||||||
|
m.pat = fmt.Sprintf("/%s", s)
|
||||||
|
sch = append(sch, m)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return sch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// optimise optimised the scheme for further parsing
|
||||||
|
func (s Scheme) optimise() (Scheme, error) {
|
||||||
|
|
||||||
|
/* (1) Nothing to do if only 1 element */
|
||||||
|
if len(s) <= 1 {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Init reshifted scheme */
|
||||||
|
rshift := make(Scheme, 0, maxMatch)
|
||||||
|
rshift = append(rshift, s[0])
|
||||||
|
|
||||||
|
|
||||||
|
/* (2) Iterate over matchers */
|
||||||
|
for p, i, l := 0, 1, len(s) ; i < l ; i++ {
|
||||||
|
|
||||||
|
pre, cur := s[p], s[i]
|
||||||
|
|
||||||
|
/* Merge: 2 following literals */
|
||||||
|
if len(pre.pat) > 0 && len(cur.pat) > 0 {
|
||||||
|
|
||||||
|
// merge strings into previous
|
||||||
|
pre.pat = fmt.Sprintf("%s%s", pre.pat, cur.pat)
|
||||||
|
|
||||||
|
// delete current
|
||||||
|
s[i] = nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// increment previous (only if current is not nul)
|
||||||
|
if s[i] != nil {
|
||||||
|
rshift = append(rshift, s[i])
|
||||||
|
p = i
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return rshift, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// matchString checks the STRING matchers from an URI
|
||||||
|
// it returns a boolean : false when not matching, true eitherway
|
||||||
|
// it returns a cleared uri, without STRING data
|
||||||
|
func (s Scheme) matchString(uri string) (string, bool) {
|
||||||
|
|
||||||
|
/* (1) Initialise variables */
|
||||||
|
clr := uri // contains cleared input string
|
||||||
|
minOff := 0 // minimum offset
|
||||||
|
|
||||||
|
/* (2) Iterate over strings */
|
||||||
|
for _, m := range s {
|
||||||
|
|
||||||
|
|
||||||
|
ls := len(m.pat)
|
||||||
|
|
||||||
|
// {1} If not STRING matcher -> ignore //
|
||||||
|
if ls == 0 { continue }
|
||||||
|
|
||||||
|
// {2} Get offset in URI (else -1) //
|
||||||
|
off := strings.Index(clr, m.pat)
|
||||||
|
if off < 0 { return "", false }
|
||||||
|
|
||||||
|
// {3} Fail on invalid offset range //
|
||||||
|
if off < minOff { return "", false }
|
||||||
|
|
||||||
|
// {4} Check for trailing '/' //
|
||||||
|
hasSlash := 0
|
||||||
|
if off+ls < len(clr) && clr[off+ls] == '/' {
|
||||||
|
hasSlash = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// {5} Remove the current string (+trailing slash) from the URI //
|
||||||
|
beg, end := clr[:off], clr[off+ls+hasSlash:]
|
||||||
|
clr = fmt.Sprintf("%s\a/%s", beg, end) // separate matches by '\a' character
|
||||||
|
|
||||||
|
// {6} Update offset range //
|
||||||
|
minOff = len(beg) + 2 - 1 // +2 slash separators
|
||||||
|
// -1 because strings begin with 1 slash already
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) If exists, remove trailing '/' */
|
||||||
|
if clr[len(clr)-1] == '/' {
|
||||||
|
clr = clr[:len(clr)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) If exists, remove trailing '\a' */
|
||||||
|
if clr[len(clr)-1] == '\a' {
|
||||||
|
clr = clr[:len(clr)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return clr, true
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// matchWildcards check the WILCARDS (non-string) matchers from
|
||||||
|
// a cleared URI. it returns if the string matches
|
||||||
|
// + it sets the matchers buffers for later extraction
|
||||||
|
func (s Scheme) matchWildcards(clear string) bool {
|
||||||
|
|
||||||
|
/* (1) Extract wildcards (ref) */
|
||||||
|
wildcards := make(Scheme, 0, maxMatch)
|
||||||
|
|
||||||
|
for _, m := range s {
|
||||||
|
if len(m.pat) == 0 {
|
||||||
|
m.buf = nil // flush buffers
|
||||||
|
wildcards = append(wildcards, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) If no wildcards -> match */
|
||||||
|
if len(wildcards) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Break uri by '\a' characters */
|
||||||
|
matches := strings.Split(clear, "\a")[1:]
|
||||||
|
|
||||||
|
/* (4) Iterate over matches */
|
||||||
|
for n, match := range matches {
|
||||||
|
|
||||||
|
// {1} If no more matcher //
|
||||||
|
if n >= len(wildcards) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// {2} Split by '/' //
|
||||||
|
data := strings.Split(match, "/")[1:] // from index 1 because it begins with '/'
|
||||||
|
|
||||||
|
// {3} If required and missing //
|
||||||
|
if wildcards[n].req && len(data) < 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// {4} If not multi but got multi //
|
||||||
|
if !wildcards[n].mul && len(data) > 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// {5} Store data into matcher //
|
||||||
|
wildcards[n].buf = data
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (5) Match */
|
||||||
|
return true
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Build builds an URI scheme from a pattern string
|
||||||
|
func Build(s string) (*Scheme, error){
|
||||||
|
|
||||||
|
/* (1) Manage '/' at the start */
|
||||||
|
if len(s) < 1 || s[0] != '/' {
|
||||||
|
return nil, fmt.Errorf("URI must begin with '/'")
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Split by '/' */
|
||||||
|
parts := strings.Split(s, "/")
|
||||||
|
|
||||||
|
/* (3) Max exceeded */
|
||||||
|
if len(parts)-2 > maxMatch {
|
||||||
|
for i, p := range parts {
|
||||||
|
fmt.Printf("%d: '%s'\n", i, p);
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Build for each part */
|
||||||
|
sch, err := buildScheme(parts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (5) Optimise structure */
|
||||||
|
opti, err := sch.optimise()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &opti, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Match returns if the given URI is matched by the scheme
|
||||||
|
func (s Scheme) Match(str string) bool {
|
||||||
|
|
||||||
|
/* (1) Nothing -> match all */
|
||||||
|
if len(s) == 0 { return true }
|
||||||
|
|
||||||
|
/* (2) Check for string match */
|
||||||
|
clearURI, match := s.matchString(str)
|
||||||
|
if !match {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Check for non-string match (wildcards) */
|
||||||
|
match = s.matchWildcards(clearURI)
|
||||||
|
if !match {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// GetMatch returns the indexed match (excluding string matchers)
|
||||||
|
func (s Scheme) GetMatch(n uint8) ([]string, error) {
|
||||||
|
|
||||||
|
/* (1) Index out of range */
|
||||||
|
if n > uint8(len(s)) {
|
||||||
|
return nil, fmt.Errorf("Index out of range")
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Iterate to find index (exclude strings) */
|
||||||
|
ni := -1
|
||||||
|
for _, m := range s {
|
||||||
|
|
||||||
|
// ignore strings
|
||||||
|
if len(m.pat) > 0 { continue }
|
||||||
|
|
||||||
|
// increment match counter : ni
|
||||||
|
ni++
|
||||||
|
|
||||||
|
// if expected index -> return matches
|
||||||
|
if uint8(ni) == n {
|
||||||
|
return m.buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) If nothing found -> return empty set */
|
||||||
|
return nil, fmt.Errorf("Index out of range (max: %d)", ni)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// GetAllMatch returns all the indexed match (excluding string matchers)
|
||||||
|
func (s Scheme) GetAllMatch() [][]string {
|
||||||
|
|
||||||
|
match := make([][]string, 0, len(s))
|
||||||
|
|
||||||
|
for _, m := range s {
|
||||||
|
|
||||||
|
// ignore strings
|
||||||
|
if len(m.pat) > 0 { continue }
|
||||||
|
|
||||||
|
match = append(match, m.buf)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return match
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
package parser
|
||||||
|
|
||||||
|
// === WILDCARDS ===
|
||||||
|
//
|
||||||
|
// The star '*' -> matches 0 or 1 slash-bounded string
|
||||||
|
// The multi star '**' -> matches 0 or more slash-separated strings
|
||||||
|
// The dot '.' -> matches 1 slash-bounded string
|
||||||
|
// The multi dot '..' -> matches 1 or more slash-separated strings
|
||||||
|
//
|
||||||
|
// === SCHEME POLICY ===
|
||||||
|
//
|
||||||
|
// - The last '/' is optional
|
||||||
|
// - Any '**' at the very end will match anything that starts with the given prefix
|
||||||
|
//
|
||||||
|
// === LIMITATIONS ==
|
||||||
|
//
|
||||||
|
// - A scheme must begin with '/'
|
||||||
|
// - A scheme cannot contain something else than a STRING or WILDCARD between 2 '/' separators
|
||||||
|
// - A scheme STRING cannot contain the symbols '/' as a character
|
||||||
|
// - A scheme STRING containing '*' or '.' characters will be treating as STRING only
|
||||||
|
// - A maximum of 16 slash-separated matchers (STRING or WILDCARD) are allowed
|
||||||
|
|
||||||
|
const maxMatch = 16
|
||||||
|
|
||||||
|
// Represents an URI matcher
|
||||||
|
type matcher struct {
|
||||||
|
pat string // pattern to match (empty if wildcard)
|
||||||
|
req bool // whether it is required
|
||||||
|
mul bool // whether multiple matches are allowed
|
||||||
|
|
||||||
|
buf []string // matched content (when matching)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Represents an URI scheme
|
||||||
|
type Scheme []*matcher
|
356
message.go
356
message.go
|
@ -1,356 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
"unicode/utf8"
|
|
||||||
)
|
|
||||||
|
|
||||||
// constant error
|
|
||||||
type constErr string
|
|
||||||
|
|
||||||
func (c constErr) Error() string { return string(c) }
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ErrUnmaskedFrame error
|
|
||||||
ErrUnmaskedFrame constErr = "Received unmasked frame"
|
|
||||||
// ErrTooLongControlFrame error
|
|
||||||
ErrTooLongControlFrame constErr = "Received a control frame that is fragmented or too long"
|
|
||||||
// ErrInvalidFragment error
|
|
||||||
ErrInvalidFragment constErr = "Received invalid fragmentation"
|
|
||||||
// ErrUnexpectedContinuation error
|
|
||||||
ErrUnexpectedContinuation constErr = "Received unexpected continuation frame"
|
|
||||||
// ErrInvalidSize error
|
|
||||||
ErrInvalidSize constErr = "Received invalid payload size"
|
|
||||||
// ErrInvalidPayload error
|
|
||||||
ErrInvalidPayload constErr = "Received invalid utf8 payload"
|
|
||||||
// ErrInvalidCloseStatus error
|
|
||||||
ErrInvalidCloseStatus constErr = "Received invalid close status"
|
|
||||||
// ErrInvalidOpCode error
|
|
||||||
ErrInvalidOpCode constErr = "Received invalid OpCode"
|
|
||||||
// ErrReservedBits error
|
|
||||||
ErrReservedBits constErr = "Received reserved bits"
|
|
||||||
// ErrCloseFrame error
|
|
||||||
ErrCloseFrame constErr = "Received close Frame"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
|
|
||||||
const maximumHeaderSize = 1 + 1 + 8 + 4
|
|
||||||
const maxWriteChunk = 0x7fff
|
|
||||||
|
|
||||||
// MessageError lists websocket close statuses
|
|
||||||
type MessageError uint16
|
|
||||||
|
|
||||||
const (
|
|
||||||
// None used when there is no error
|
|
||||||
None MessageError = 0
|
|
||||||
// Normal error
|
|
||||||
Normal MessageError = 1000
|
|
||||||
// GoingAway error
|
|
||||||
GoingAway MessageError = 1001
|
|
||||||
// ProtocolError error
|
|
||||||
ProtocolError MessageError = 1002
|
|
||||||
// UnacceptableOpCode error
|
|
||||||
UnacceptableOpCode MessageError = 1003
|
|
||||||
// InvalidPayload error
|
|
||||||
InvalidPayload MessageError = 1007 // utf8
|
|
||||||
// MessageTooLarge error
|
|
||||||
MessageTooLarge MessageError = 1009
|
|
||||||
)
|
|
||||||
|
|
||||||
// MessageType lists websocket message types
|
|
||||||
type MessageType byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Continuation message type
|
|
||||||
Continuation MessageType = 0x00
|
|
||||||
// Text message type
|
|
||||||
Text MessageType = 0x01
|
|
||||||
// Binary message type
|
|
||||||
Binary MessageType = 0x02
|
|
||||||
// Close message type
|
|
||||||
Close MessageType = 0x08
|
|
||||||
// Ping message type
|
|
||||||
Ping MessageType = 0x09
|
|
||||||
// Pong message type
|
|
||||||
Pong MessageType = 0x0a
|
|
||||||
)
|
|
||||||
|
|
||||||
// Message is a websocket message
|
|
||||||
type Message struct {
|
|
||||||
Final bool
|
|
||||||
Type MessageType
|
|
||||||
Size uint
|
|
||||||
Data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadFrom reads a message from a reader
|
|
||||||
//
|
|
||||||
// implements io.ReaderFrom
|
|
||||||
func (m *Message) ReadFrom(reader io.Reader) (int64, error) {
|
|
||||||
var (
|
|
||||||
read int64
|
|
||||||
err error
|
|
||||||
tmpBuf []byte
|
|
||||||
mask []byte
|
|
||||||
cursor int
|
|
||||||
)
|
|
||||||
|
|
||||||
// byte 1: FIN and OpCode
|
|
||||||
tmpBuf = make([]byte, 1)
|
|
||||||
read += int64(len(tmpBuf))
|
|
||||||
err = readBytes(reader, tmpBuf)
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// check reserved bits
|
|
||||||
if tmpBuf[0]&0x70 != 0 {
|
|
||||||
return read, ErrReservedBits
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Final = bool(tmpBuf[0]&0x80 == 0x80)
|
|
||||||
m.Type = MessageType(tmpBuf[0] & 0x0f)
|
|
||||||
|
|
||||||
// byte 2: mask and length[0]
|
|
||||||
tmpBuf = make([]byte, 1)
|
|
||||||
read += int64(len(tmpBuf))
|
|
||||||
err = readBytes(reader, tmpBuf)
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if mask, byte array not nil
|
|
||||||
if tmpBuf[0]&0x80 == 0x80 {
|
|
||||||
mask = make([]byte, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// payload length
|
|
||||||
m.Size = uint(tmpBuf[0] & 0x7f)
|
|
||||||
|
|
||||||
// extended payload
|
|
||||||
if m.Size == 127 {
|
|
||||||
tmpBuf = make([]byte, 8)
|
|
||||||
read += int64(len(tmpBuf))
|
|
||||||
err := readBytes(reader, tmpBuf)
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
m.Size = uint(binary.BigEndian.Uint64(tmpBuf))
|
|
||||||
|
|
||||||
} else if m.Size == 126 {
|
|
||||||
tmpBuf = make([]byte, 2)
|
|
||||||
read += int64(len(tmpBuf))
|
|
||||||
err := readBytes(reader, tmpBuf)
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
m.Size = uint(binary.BigEndian.Uint16(tmpBuf))
|
|
||||||
}
|
|
||||||
|
|
||||||
// masking key
|
|
||||||
if mask != nil {
|
|
||||||
tmpBuf = make([]byte, 4)
|
|
||||||
read += int64(len(tmpBuf))
|
|
||||||
err := readBytes(reader, tmpBuf)
|
|
||||||
if err != nil {
|
|
||||||
return read, err
|
|
||||||
}
|
|
||||||
mask = make([]byte, 4)
|
|
||||||
copy(mask, tmpBuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// read payload by chunks
|
|
||||||
m.Data = make([]byte, int(m.Size))
|
|
||||||
|
|
||||||
cursor = 0
|
|
||||||
|
|
||||||
// while data to read
|
|
||||||
for uint(cursor) < m.Size {
|
|
||||||
|
|
||||||
// try to read (at least 1 byte)
|
|
||||||
nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1)
|
|
||||||
if err != nil {
|
|
||||||
return read + int64(cursor) + int64(nbread), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmask data //
|
|
||||||
if mask != nil {
|
|
||||||
for i, l := cursor, cursor+nbread; i < l; i++ {
|
|
||||||
mi := i % 4 // mask index
|
|
||||||
m.Data[i] = m.Data[i] ^ mask[mi]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cursor += nbread
|
|
||||||
}
|
|
||||||
read += int64(cursor)
|
|
||||||
|
|
||||||
// return error if unmasked frame
|
|
||||||
// we have to fully read it for read buffer to be clean
|
|
||||||
err = nil
|
|
||||||
if mask == nil {
|
|
||||||
err = ErrUnmaskedFrame
|
|
||||||
}
|
|
||||||
|
|
||||||
return read, err
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteTo writes a message frame over a socket
|
|
||||||
//
|
|
||||||
// implements io.WriterTo
|
|
||||||
func (m Message) WriteTo(writer io.Writer) (int64, error) {
|
|
||||||
header := make([]byte, 0, maximumHeaderSize)
|
|
||||||
|
|
||||||
// fix size
|
|
||||||
if uint(len(m.Data)) <= m.Size {
|
|
||||||
m.Size = uint(len(m.Data))
|
|
||||||
}
|
|
||||||
|
|
||||||
// byte 0 : FIN + opcode
|
|
||||||
var final byte = 0x80
|
|
||||||
if !m.Final {
|
|
||||||
final = 0
|
|
||||||
}
|
|
||||||
header = append(header, final|byte(m.Type))
|
|
||||||
|
|
||||||
// get payload length
|
|
||||||
if m.Size < 126 { // simple
|
|
||||||
header = append(header, byte(m.Size))
|
|
||||||
|
|
||||||
} else if m.Size <= 0xffff { // extended: 16 bits
|
|
||||||
header = append(header, 126)
|
|
||||||
|
|
||||||
buf := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(buf, uint16(m.Size))
|
|
||||||
header = append(header, buf...)
|
|
||||||
|
|
||||||
} else if m.Size <= 0xffffffffffffffff { // extended: 64 bits
|
|
||||||
header = append(header, 127)
|
|
||||||
|
|
||||||
buf := make([]byte, 8)
|
|
||||||
binary.BigEndian.PutUint64(buf, uint64(m.Size))
|
|
||||||
header = append(header, buf...)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// build write buffer
|
|
||||||
writeBuf := make([]byte, 0, len(header)+int(m.Size))
|
|
||||||
writeBuf = append(writeBuf, header...)
|
|
||||||
writeBuf = append(writeBuf, m.Data[0:m.Size]...)
|
|
||||||
|
|
||||||
// write by chunks
|
|
||||||
toWrite := len(header) + int(m.Size)
|
|
||||||
cursor := 0
|
|
||||||
for cursor < toWrite {
|
|
||||||
maxBoundary := cursor + maxWriteChunk
|
|
||||||
if maxBoundary > toWrite {
|
|
||||||
maxBoundary = toWrite
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to wrote (at max 1024 bytes) //
|
|
||||||
nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary])
|
|
||||||
if err != nil {
|
|
||||||
return int64(nbwritten), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update cursor //
|
|
||||||
cursor += nbwritten
|
|
||||||
}
|
|
||||||
return int64(cursor), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for message errors with:
|
|
||||||
// (m) the current message
|
|
||||||
// (fragment) whether there is a fragment in construction
|
|
||||||
// returns the message error
|
|
||||||
func (m *Message) check(fragment bool) error {
|
|
||||||
|
|
||||||
// invalid first fragment (not TEXT nor BINARY)
|
|
||||||
if !m.Final && !fragment && m.Type != Text && m.Type != Binary {
|
|
||||||
return ErrInvalidFragment
|
|
||||||
}
|
|
||||||
|
|
||||||
// waiting fragment but received standalone frame
|
|
||||||
if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong {
|
|
||||||
return ErrInvalidFragment
|
|
||||||
}
|
|
||||||
|
|
||||||
// control frame too long
|
|
||||||
if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) {
|
|
||||||
return ErrTooLongControlFrame
|
|
||||||
}
|
|
||||||
|
|
||||||
switch m.Type {
|
|
||||||
case Continuation:
|
|
||||||
// unexpected continuation
|
|
||||||
if !fragment {
|
|
||||||
return ErrUnexpectedContinuation
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case Text:
|
|
||||||
if m.Final && !utf8.Valid(m.Data) {
|
|
||||||
return ErrInvalidPayload
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case Binary:
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case Close:
|
|
||||||
// incomplete code
|
|
||||||
if m.Size == 1 {
|
|
||||||
return ErrInvalidCloseStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
// invalid utf-8 reason
|
|
||||||
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
|
|
||||||
return ErrInvalidPayload
|
|
||||||
}
|
|
||||||
|
|
||||||
// invalid code
|
|
||||||
if m.Size >= 2 {
|
|
||||||
c := binary.BigEndian.Uint16(m.Data[0:2])
|
|
||||||
|
|
||||||
if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 {
|
|
||||||
return ErrInvalidCloseStatus
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ErrCloseFrame
|
|
||||||
|
|
||||||
case Ping:
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case Pong:
|
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return ErrInvalidOpCode
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// readBytes reads from a reader into a byte array
|
|
||||||
// until the byte length is fully filled with data
|
|
||||||
// loops while there is no error
|
|
||||||
//
|
|
||||||
// It manages connections which chunks data
|
|
||||||
func readBytes(reader io.Reader, buffer []byte) error {
|
|
||||||
var (
|
|
||||||
cur = 0
|
|
||||||
len = len(buffer)
|
|
||||||
)
|
|
||||||
|
|
||||||
// read until the full size is read
|
|
||||||
for cur < len {
|
|
||||||
nbread, err := reader.Read(buffer[cur:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cur += nbread
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
552
message_test.go
552
message_test.go
|
@ -1,552 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSimpleMessageReading(t *testing.T) {
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
Name string
|
|
||||||
ReadBuffer []byte
|
|
||||||
Expected Message
|
|
||||||
Err error
|
|
||||||
}{
|
|
||||||
{ // FIN ; TEXT ; Unmasked -> error
|
|
||||||
"must fail on unmasked frame",
|
|
||||||
[]byte{0x81, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
|
|
||||||
Message{},
|
|
||||||
ErrUnmaskedFrame,
|
|
||||||
},
|
|
||||||
{ // FIN ; TEXT ; Unmasked -> error
|
|
||||||
"must fail because of RSV bit 1 set",
|
|
||||||
[]byte{0x81 | 0x40, 0x10, 0x00, 0x00, 0x00, 0x00},
|
|
||||||
Message{},
|
|
||||||
ErrReservedBits,
|
|
||||||
},
|
|
||||||
{ // FIN ; TEXT ; Unmasked -> error
|
|
||||||
"must fail because of RSV bit 2 set",
|
|
||||||
[]byte{0x81 | 0x20, 0x10, 0x00, 0x00, 0x00, 0x00},
|
|
||||||
Message{},
|
|
||||||
ErrReservedBits,
|
|
||||||
},
|
|
||||||
{ // FIN ; TEXT ; Unmasked -> error
|
|
||||||
"must fail because of RSV bit 3 set",
|
|
||||||
[]byte{0x81 | 0x10, 0x10, 0x00, 0x00, 0x00, 0x00},
|
|
||||||
Message{},
|
|
||||||
ErrReservedBits,
|
|
||||||
},
|
|
||||||
{ // FIN ; TEXT ; hello
|
|
||||||
"simple hello text message",
|
|
||||||
[]byte{0x81, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
|
|
||||||
Message{true, Text, 5, []byte("hello")},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{ // FIN ; BINARY ; hello
|
|
||||||
"simple hello binary message",
|
|
||||||
[]byte{0x82, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
|
|
||||||
Message{true, Binary, 5, []byte("hello")},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{ // FIN ; BINARY ; test unmasking
|
|
||||||
"unmasking test",
|
|
||||||
[]byte{0x82, 0x88, 0x01, 0x02, 0x03, 0x04, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80},
|
|
||||||
Message{true, Binary, 8, []byte{0x11, 0x22, 0x33, 0x44, 0x51, 0x62, 0x73, 0x84}},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
{ // FIN=0 ; TEXT ;
|
|
||||||
"non final frame",
|
|
||||||
[]byte{0x01, 0x82, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02},
|
|
||||||
Message{false, Text, 2, []byte{0x01, 0x02}},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
|
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
|
||||||
var (
|
|
||||||
reader = bytes.NewBuffer(tc.ReadBuffer)
|
|
||||||
msg = &Message{}
|
|
||||||
)
|
|
||||||
|
|
||||||
_, err := msg.ReadFrom(reader)
|
|
||||||
if err != tc.Err {
|
|
||||||
t.Errorf("Expected %v error, got %v", tc.Err, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// do not check message if error expected
|
|
||||||
if tc.Err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check FIN
|
|
||||||
if msg.Final != tc.Expected.Final {
|
|
||||||
t.Errorf("Expected FIN=%t, got %t", tc.Expected.Final, msg.Final)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check OpCode
|
|
||||||
if msg.Type != tc.Expected.Type {
|
|
||||||
t.Errorf("Expected TYPE=%x, got %x", tc.Expected.Type, msg.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check Size
|
|
||||||
if msg.Size != tc.Expected.Size {
|
|
||||||
t.Errorf("Expected SIZE=%d, got %d", tc.Expected.Size, msg.Size)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check Data
|
|
||||||
if string(msg.Data) != string(tc.Expected.Data) {
|
|
||||||
t.Errorf("Expected Data='%s', got '%d'", tc.Expected.Data, msg.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadEOF(t *testing.T) {
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
Name string
|
|
||||||
ReadBuffer []byte
|
|
||||||
eof bool
|
|
||||||
unmaskedError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"no byte",
|
|
||||||
[]byte{},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"only opcode",
|
|
||||||
[]byte{0x82},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"only opcode and 0 length",
|
|
||||||
[]byte{0x82, 0x00},
|
|
||||||
false, true,
|
|
||||||
}, {
|
|
||||||
"missing extended 16 bits length",
|
|
||||||
[]byte{0x82, 126},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"incomplete extended 16 bits length",
|
|
||||||
[]byte{0x82, 126, 0x00},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"complete extended 16 bits length",
|
|
||||||
[]byte{0x82, 126, 0x00, 0x00},
|
|
||||||
false, true,
|
|
||||||
}, {
|
|
||||||
"missing extended 64 bits length",
|
|
||||||
[]byte{0x82, 127},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"incomplete extended 64 bits length",
|
|
||||||
[]byte{0x82, 127, 0x00},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"incomplete extended 64 bits length",
|
|
||||||
[]byte{0x82, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"complete extended 64 bits length",
|
|
||||||
[]byte{0x82, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
|
||||||
false, true,
|
|
||||||
}, {
|
|
||||||
"missing mask",
|
|
||||||
[]byte{0x82, 0x80},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"incomplete mask 1",
|
|
||||||
[]byte{0x82, 0x80, 0x00},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"incomplete mask 2",
|
|
||||||
[]byte{0x82, 0x80, 0x00, 0x00, 0x00},
|
|
||||||
true, false,
|
|
||||||
}, {
|
|
||||||
"complete mask",
|
|
||||||
[]byte{0x82, 0x80, 0x00, 0x00, 0x00, 0x00},
|
|
||||||
false, false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
|
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
|
||||||
var (
|
|
||||||
reader = bytes.NewBuffer(tc.ReadBuffer)
|
|
||||||
msg = &Message{}
|
|
||||||
)
|
|
||||||
_, err := msg.ReadFrom(reader)
|
|
||||||
if tc.eof {
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Fatalf("Expected EOF, got %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.unmaskedError && err != ErrUnmaskedFrame {
|
|
||||||
t.Errorf("Expected UnmaskedFrameor, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.Size != 0x00 {
|
|
||||||
t.Errorf("Expected a size of 0, got %d", msg.Size)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSimpleMessageSending(t *testing.T) {
|
|
||||||
|
|
||||||
m4b1 := make([]byte, 0x7e-1)
|
|
||||||
m4b2 := make([]byte, 0x7e)
|
|
||||||
m4b3 := make([]byte, 0x7e+1)
|
|
||||||
|
|
||||||
m16b1 := make([]byte, 0xffff-1)
|
|
||||||
m16b2 := make([]byte, 0xffff)
|
|
||||||
m16b3 := make([]byte, 0xffff+1)
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
Name string
|
|
||||||
Base Message
|
|
||||||
Expected []byte
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"simple hello text message",
|
|
||||||
Message{true, Text, 5, []byte("hello")},
|
|
||||||
[]byte{0x81, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
|
|
||||||
}, {
|
|
||||||
"simple hello binary message",
|
|
||||||
Message{true, Binary, 5, []byte("hello")},
|
|
||||||
[]byte{0x82, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
|
|
||||||
}, {
|
|
||||||
"other simple binary message",
|
|
||||||
Message{true, Binary, 8, []byte{0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80}},
|
|
||||||
[]byte{0x82, 0x08, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80},
|
|
||||||
}, {
|
|
||||||
"non final frame",
|
|
||||||
Message{false, Text, 2, []byte{0x01, 0x02}},
|
|
||||||
[]byte{0x01, 0x02, 0x01, 0x02},
|
|
||||||
}, {
|
|
||||||
"125 > normal length",
|
|
||||||
Message{true, Text, uint(len(m4b1)), m4b1},
|
|
||||||
append([]byte{0x81, 0x7e - 1}, m4b1...),
|
|
||||||
}, {
|
|
||||||
"126 > extended 16 bits length",
|
|
||||||
Message{true, Text, uint(len(m4b2)), m4b2},
|
|
||||||
append([]byte{0x81, 126, 0x00, 0x7e}, m4b2...),
|
|
||||||
}, {
|
|
||||||
"127 > extended 16 bits length",
|
|
||||||
Message{true, Text, uint(len(m4b3)), m4b3},
|
|
||||||
append([]byte{0x81, 126, 0x00, 0x7e + 1}, m4b3...),
|
|
||||||
}, {
|
|
||||||
"fffe > extended 16 bits length",
|
|
||||||
Message{true, Text, uint(len(m16b1)), m16b1},
|
|
||||||
append([]byte{0x81, 126, 0xff, 0xfe}, m16b1...),
|
|
||||||
}, {
|
|
||||||
"ffff > extended 16 bits length",
|
|
||||||
Message{true, Text, uint(len(m16b2)), m16b2},
|
|
||||||
append([]byte{0x81, 126, 0xff, 0xff}, m16b2...),
|
|
||||||
}, {
|
|
||||||
"10000 > extended 64 bits length",
|
|
||||||
Message{true, Text, uint(len(m16b3)), m16b3},
|
|
||||||
append([]byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, m16b3...),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
|
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
|
||||||
|
|
||||||
writer := &bytes.Buffer{}
|
|
||||||
|
|
||||||
_, err := tc.Base.WriteTo(writer)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("expected no error, got %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check buffer
|
|
||||||
if writer.String() != string(tc.Expected) {
|
|
||||||
t.Errorf("expected '%.20x', got '%.20x'", tc.Expected, writer.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageCheck(t *testing.T) {
|
|
||||||
|
|
||||||
type Case struct {
|
|
||||||
Name string
|
|
||||||
Msg Message
|
|
||||||
WaitingFragment bool
|
|
||||||
Expected error
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
Name string
|
|
||||||
Cases []Case
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
Name: "first fragment type",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"CONTINUATION must fail",
|
|
||||||
Message{false, Continuation, 0, []byte{}}, false, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"TEXT must not fail",
|
|
||||||
Message{false, Text, 0, []byte{}}, false, nil,
|
|
||||||
}, {
|
|
||||||
"BINARY must not fail",
|
|
||||||
Message{false, Binary, 0, []byte{}}, false, nil,
|
|
||||||
}, {
|
|
||||||
"CLOSE must fail",
|
|
||||||
Message{false, Close, 0, []byte{}}, false, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"PING must fail",
|
|
||||||
Message{false, Ping, 0, []byte{}}, false, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"PONG must fail",
|
|
||||||
Message{false, Pong, 0, []byte{}}, false, ErrInvalidFragment,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "frame during fragment",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"CONTINUATION must not fail",
|
|
||||||
Message{true, Continuation, 0, []byte{}}, true, nil,
|
|
||||||
}, {
|
|
||||||
"TEXT must fail",
|
|
||||||
Message{true, Text, 0, []byte{}}, true, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"BINARY must fail",
|
|
||||||
Message{true, Binary, 0, []byte{}}, true, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"CLOSE must not fail",
|
|
||||||
Message{true, Close, 0, []byte{}}, true, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"PING must not fail",
|
|
||||||
Message{true, Ping, 0, []byte{}}, true, nil,
|
|
||||||
}, {
|
|
||||||
"PONG must not fail",
|
|
||||||
Message{true, Pong, 0, []byte{}}, true, nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "125-length control frame",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"CLOSE must not fail",
|
|
||||||
Message{true, Close, 125, []byte{0x03, 0xe8, 0}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"PING must not fail",
|
|
||||||
Message{true, Ping, 125, []byte{}}, false, nil,
|
|
||||||
}, {
|
|
||||||
"PONG must not fail",
|
|
||||||
Message{true, Pong, 125, []byte{}}, false, nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "126-length control frame",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"CLOSE must fail",
|
|
||||||
Message{true, Close, 126, []byte{0x03, 0xe8, 0}}, false, ErrTooLongControlFrame,
|
|
||||||
}, {
|
|
||||||
"PING must fail",
|
|
||||||
Message{true, Ping, 126, []byte{}}, false, ErrTooLongControlFrame,
|
|
||||||
}, {
|
|
||||||
"PONG must fail",
|
|
||||||
Message{true, Pong, 126, []byte{}}, false, ErrTooLongControlFrame,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "fragmented control frame",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"CLOSE must fail",
|
|
||||||
Message{false, Close, 126, []byte{0x03, 0xe8, 0}}, false, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"PING must fail",
|
|
||||||
Message{false, Ping, 126, []byte{}}, false, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"PONG must fail",
|
|
||||||
Message{false, Pong, 126, []byte{}}, false, ErrInvalidFragment,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "unexpected continuation frame",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"no waiting fragment final",
|
|
||||||
Message{false, Continuation, 126, nil}, false, ErrInvalidFragment,
|
|
||||||
}, {
|
|
||||||
"no waiting fragment non-final",
|
|
||||||
Message{true, Continuation, 126, nil}, false, ErrUnexpectedContinuation,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "utf8 check",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"CLOSE valid reason",
|
|
||||||
Message{true, Close, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0xa1}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"CLOSE invalid reason byte 2",
|
|
||||||
Message{true, Close, 5, []byte{0x03, 0xe8, 0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload,
|
|
||||||
}, {
|
|
||||||
"CLOSE invalid reason byte 3",
|
|
||||||
Message{true, Close, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0x28}}, false, ErrInvalidPayload,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"TEXT valid reason",
|
|
||||||
Message{true, Text, 3, []byte{0xe2, 0x82, 0xa1}}, false, nil,
|
|
||||||
}, {
|
|
||||||
"TEXT invalid reason byte 2",
|
|
||||||
Message{true, Text, 3, []byte{0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload,
|
|
||||||
}, {
|
|
||||||
"TEXT invalid reason byte 3",
|
|
||||||
Message{true, Text, 3, []byte{0xe2, 0x82, 0x28}}, false, ErrInvalidPayload,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "CLOSE status",
|
|
||||||
Cases: []Case{
|
|
||||||
{
|
|
||||||
"CLOSE only 1 byte",
|
|
||||||
Message{true, Close, 1, []byte{0x03}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1000",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xe8}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 999 under 1000",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xe7}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1001",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xe9}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1002",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xea}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1003",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xeb}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1004",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xec}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1005",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xed}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1006",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xee}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1007",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xef}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1011",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xf3}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1012",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xf4}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1013",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xf5}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1014",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xf6}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1015",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xf7}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1016",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xf8}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1017",
|
|
||||||
Message{true, Close, 2, []byte{0x03, 0xf9}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1099",
|
|
||||||
Message{true, Close, 2, []byte{0x04, 0x4b}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 1100",
|
|
||||||
Message{true, Close, 2, []byte{0x04, 0x4c}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1101",
|
|
||||||
Message{true, Close, 2, []byte{0x04, 0x4d}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 1999",
|
|
||||||
Message{true, Close, 2, []byte{0x07, 0xcf}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 2000",
|
|
||||||
Message{true, Close, 2, []byte{0x07, 0xd0}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 2001",
|
|
||||||
Message{true, Close, 2, []byte{0x07, 0xd1}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 2998",
|
|
||||||
Message{true, Close, 2, []byte{0x0b, 0xb6}}, false, ErrCloseFrame,
|
|
||||||
}, {
|
|
||||||
"invalid CLOSE status 2999",
|
|
||||||
Message{true, Close, 2, []byte{0x0b, 0xb7}}, false, ErrInvalidCloseStatus,
|
|
||||||
}, {
|
|
||||||
"valid CLOSE status 3000",
|
|
||||||
Message{true, Close, 2, []byte{0x0b, 0xb8}}, false, ErrCloseFrame,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Name: "OpCode check",
|
|
||||||
Cases: []Case{
|
|
||||||
{"0", Message{true, 0, 0, []byte{}}, false, ErrUnexpectedContinuation},
|
|
||||||
{"1", Message{true, 1, 0, []byte{}}, false, nil},
|
|
||||||
{"2", Message{true, 2, 0, []byte{}}, false, nil},
|
|
||||||
{"3", Message{true, 3, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"4", Message{true, 4, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"5", Message{true, 5, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"6", Message{true, 6, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"7", Message{true, 7, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"8", Message{true, 8, 0, []byte{}}, false, ErrCloseFrame},
|
|
||||||
{"9", Message{true, 9, 0, []byte{}}, false, nil},
|
|
||||||
{"10", Message{true, 10, 0, []byte{}}, false, nil},
|
|
||||||
{"11", Message{true, 11, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"12", Message{true, 12, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"13", Message{true, 13, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"14", Message{true, 14, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
{"15", Message{true, 15, 0, []byte{}}, false, ErrInvalidOpCode},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tcc := range cases {
|
|
||||||
|
|
||||||
t.Run(tcc.Name, func(t *testing.T) {
|
|
||||||
|
|
||||||
for _, tc := range tcc.Cases {
|
|
||||||
|
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
|
||||||
|
|
||||||
actual := tc.Msg.check(tc.WaitingFragment)
|
|
||||||
|
|
||||||
if actual != tc.Expected {
|
|
||||||
t.Errorf("expected '%v', got '%v'", tc.Expected, actual)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
129
server.go
129
server.go
|
@ -1,129 +0,0 @@
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"git.xdrm.io/go/ws/internal/uri"
|
|
||||||
)
|
|
||||||
|
|
||||||
// All channels that a server features
|
|
||||||
type serverChannelSet struct {
|
|
||||||
register chan *client
|
|
||||||
unregister chan *client
|
|
||||||
broadcast chan Message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server is a websocket server
|
|
||||||
type Server struct {
|
|
||||||
sock net.Listener // listen socket
|
|
||||||
addr []byte // server listening ip/host
|
|
||||||
port uint16 // server listening port
|
|
||||||
|
|
||||||
clients map[net.Conn]*client
|
|
||||||
|
|
||||||
ctl ControllerSet // controllers
|
|
||||||
|
|
||||||
ch serverChannelSet
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServer creates a server
|
|
||||||
func NewServer(host string, port uint16) *Server {
|
|
||||||
return &Server{
|
|
||||||
addr: []byte(host),
|
|
||||||
port: port,
|
|
||||||
|
|
||||||
clients: make(map[net.Conn]*client, 0),
|
|
||||||
|
|
||||||
ctl: ControllerSet{
|
|
||||||
Def: nil,
|
|
||||||
URI: make([]*Controller, 0),
|
|
||||||
},
|
|
||||||
|
|
||||||
ch: serverChannelSet{
|
|
||||||
register: make(chan *client, 1),
|
|
||||||
unregister: make(chan *client, 1),
|
|
||||||
broadcast: make(chan Message, 1),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BindDefault binds a default controller
|
|
||||||
// it will be called if the URI does not
|
|
||||||
// match another controller
|
|
||||||
func (s *Server) BindDefault(f ControllerFunc) {
|
|
||||||
s.ctl.Def = &Controller{
|
|
||||||
URI: nil,
|
|
||||||
Fun: f,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind a controller to an URI scheme
|
|
||||||
func (s *Server) Bind(uriStr string, f ControllerFunc) error {
|
|
||||||
uriScheme, err := uri.FromString(uriStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot build URI: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.ctl.URI = append(s.ctl.URI, &Controller{
|
|
||||||
URI: uriScheme,
|
|
||||||
Fun: f,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch the websocket server
|
|
||||||
func (s *Server) Launch() error {
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
url = fmt.Sprintf("%s:%d", s.addr, s.port)
|
|
||||||
)
|
|
||||||
|
|
||||||
s.sock, err = net.Listen("tcp", url)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("listen: %w", err)
|
|
||||||
}
|
|
||||||
defer s.sock.Close()
|
|
||||||
|
|
||||||
fmt.Printf("+ listening on %s\n", url)
|
|
||||||
go s.schedule()
|
|
||||||
|
|
||||||
for {
|
|
||||||
sock, err := s.sock.Accept()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
cli, err := newClient(sock, s.ctl, s.ch)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf(" - %s\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.ch.register <- cli
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// schedule client registration and broadcast
|
|
||||||
func (s *Server) schedule() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
|
|
||||||
case client := <-s.ch.register:
|
|
||||||
s.clients[client.io.sock] = client
|
|
||||||
|
|
||||||
case client := <-s.ch.unregister:
|
|
||||||
delete(s.clients, client.io.sock)
|
|
||||||
|
|
||||||
case msg := <-s.ch.broadcast:
|
|
||||||
for _, c := range s.clients {
|
|
||||||
c.ch.send <- msg
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,325 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unicode/utf8"
|
||||||
|
"time"
|
||||||
|
"sync"
|
||||||
|
"bufio"
|
||||||
|
"encoding/binary"
|
||||||
|
"git.xdrm.io/gws/internal/http/upgrade/request"
|
||||||
|
"net"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Represents a client socket utility (reader, writer, ..)
|
||||||
|
type clientIO struct {
|
||||||
|
sock net.Conn
|
||||||
|
reader *bufio.Reader
|
||||||
|
kill chan<- *client // unregisters client
|
||||||
|
closing bool
|
||||||
|
closingMu sync.Mutex
|
||||||
|
reading sync.WaitGroup
|
||||||
|
writing sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents all channels that need a client
|
||||||
|
type clientChannelSet struct{
|
||||||
|
receive chan Message
|
||||||
|
send chan *Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents a websocket client
|
||||||
|
type client struct {
|
||||||
|
io clientIO
|
||||||
|
iface *Client
|
||||||
|
ch clientChannelSet
|
||||||
|
status MessageError // close status ; 0 = nothing ; else -> must close
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Create creates a new client
|
||||||
|
func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error){
|
||||||
|
|
||||||
|
/* (1) Manage UPGRADE request
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
/* (1) Parse request */
|
||||||
|
req, _ := request.Parse(s)
|
||||||
|
|
||||||
|
/* (3) Build response */
|
||||||
|
res := req.BuildResponse()
|
||||||
|
|
||||||
|
/* (4) Write into socket */
|
||||||
|
_, err := res.Send(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Upgrade write error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.GetStatusCode() != 101 {
|
||||||
|
s.Close()
|
||||||
|
return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* (2) Initialise client
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
/* (1) Get upgrade data */
|
||||||
|
clientURI := req.GetURI()
|
||||||
|
clientProtocol := res.GetProtocol()
|
||||||
|
|
||||||
|
/* (2) Initialise client */
|
||||||
|
cli := &client{
|
||||||
|
io: clientIO{
|
||||||
|
sock: s,
|
||||||
|
reader: bufio.NewReader(s),
|
||||||
|
kill: serverCh.unregister,
|
||||||
|
},
|
||||||
|
|
||||||
|
iface: &Client{
|
||||||
|
Protocol: string(clientProtocol),
|
||||||
|
Arguments: [][]string{ []string{ clientURI } },
|
||||||
|
},
|
||||||
|
|
||||||
|
ch: clientChannelSet{
|
||||||
|
receive: make(chan Message, 1),
|
||||||
|
send: make(chan *Message, 2),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* (3) Find controller by URI
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
/* (1) Try to find one */
|
||||||
|
controller, arguments := ctl.Match(clientURI);
|
||||||
|
|
||||||
|
/* (2) If nothing found -> error */
|
||||||
|
if controller == nil {
|
||||||
|
return nil, fmt.Errorf("No controller found, no default controller set\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Copy arguments */
|
||||||
|
cli.iface.Arguments = arguments
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* (4) Launch client routines
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
/* (1) Launch client controller */
|
||||||
|
go controller.Fun(
|
||||||
|
cli.iface, // pass the client
|
||||||
|
cli.ch.receive, // the receiver
|
||||||
|
cli.ch.send, // the sender
|
||||||
|
serverCh.broadcast, // broadcast sender
|
||||||
|
)
|
||||||
|
|
||||||
|
/* (2) Launch message reader */
|
||||||
|
go clientReader(cli)
|
||||||
|
|
||||||
|
/* (3) Launc writer */
|
||||||
|
go clientWriter(cli)
|
||||||
|
|
||||||
|
return cli, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// reader reads and parses messages from the buffer
|
||||||
|
func clientReader(c *client){
|
||||||
|
|
||||||
|
errorCode := NORMAL
|
||||||
|
clientAck := true
|
||||||
|
c.io.reading.Add(1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
|
||||||
|
/* if currently closing -> exit */
|
||||||
|
if c.io.closing {
|
||||||
|
fmt.Printf("[reader] killed because closing")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
/*** Parse message ***/
|
||||||
|
msg, err := readMessage(c.io.reader)
|
||||||
|
if err != nil {
|
||||||
|
// fmt.Printf(" [reader] %s\n", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) CLOSE */
|
||||||
|
if msg.Type == CLOSE {
|
||||||
|
// fmt.Printf(" [reader] CLOSE ; size %d\n", msg.Size)
|
||||||
|
// if msg.Size >= 2 {
|
||||||
|
// errCode := binary.BigEndian.Uint16(msg.Data[0:2])
|
||||||
|
// fmt.Printf(" ; status %d\n", errCode)
|
||||||
|
// fmt.Printf(" ; msg '%s'\n", msg.Data[2:])
|
||||||
|
// }
|
||||||
|
clientAck = false
|
||||||
|
break
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (5) PING size error */
|
||||||
|
if msg.Type == PING && msg.Size > 125 {
|
||||||
|
fmt.Printf(" [reader] PING payload too big\n")
|
||||||
|
// fmt.Printf("[reader] PING err\n")
|
||||||
|
errorCode = PROTOCOL_ERR
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (6) Send PONG */
|
||||||
|
if msg.Type == PING {
|
||||||
|
// fmt.Printf("[reader] PING -> PONG\n")
|
||||||
|
msg.Final = true
|
||||||
|
msg.Type = PONG
|
||||||
|
c.ch.send <- msg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (7) Invalid UTF8 */
|
||||||
|
if msg.Type == TEXT && !utf8.Valid(msg.Data) {
|
||||||
|
fmt.Printf(" [reader] invalid utf-8\n")
|
||||||
|
errorCode = INVALID_PAYLOAD
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (8) Unknown opcode */
|
||||||
|
if msg.Type != TEXT && msg.Type != BINARY {
|
||||||
|
fmt.Printf(" [reader] unknown OpCode %d\n", msg.Type)
|
||||||
|
errorCode = PROTOCOL_ERR
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (9) Dispatch to receiver */
|
||||||
|
c.ch.receive <- *msg
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
c.io.reading.Done()
|
||||||
|
|
||||||
|
/* (8) close channel (if not already done) */
|
||||||
|
// fmt.Printf("[reader] end\n")
|
||||||
|
c.close(errorCode, clientAck)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// writer writes into websocket
|
||||||
|
// and is triggered by client.ch.send channel
|
||||||
|
func clientWriter(c *client){
|
||||||
|
|
||||||
|
c.io.writing.Add(1)
|
||||||
|
|
||||||
|
for msg := range c.ch.send {
|
||||||
|
|
||||||
|
/* (1) If empty message -> close properly */
|
||||||
|
if msg == nil {
|
||||||
|
fmt.Printf(" [writer] nil\n")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Send message */
|
||||||
|
err := msg.Send(c.io.sock)
|
||||||
|
|
||||||
|
/* (3) Fail on error */
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf(" [writer] %s\n", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
c.io.writing.Done()
|
||||||
|
|
||||||
|
/* (4) close channel (if not already done) */
|
||||||
|
// fmt.Printf("[writer] end\n")
|
||||||
|
c.close(NORMAL, true)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// close writes the error message (if needed)
|
||||||
|
// and it closes the socket
|
||||||
|
// if 'clientACK' is true, reads the next message (CLOSE acknowledge)
|
||||||
|
// before closing the socket
|
||||||
|
func (c *client) close(status MessageError, clientACK bool){
|
||||||
|
|
||||||
|
/* (1) Fail if already closing */
|
||||||
|
alreadyClosing := false
|
||||||
|
c.io.closingMu.Lock()
|
||||||
|
alreadyClosing = c.io.closing
|
||||||
|
c.io.closing = true
|
||||||
|
c.io.closingMu.Unlock()
|
||||||
|
|
||||||
|
if alreadyClosing {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* (2) kill 'c.reader()' if already running */
|
||||||
|
c.io.sock.SetReadDeadline(time.Now().Add(time.Second*-1))
|
||||||
|
// fmt.Printf("[close] wait read stop\n")
|
||||||
|
c.io.reading.Wait()
|
||||||
|
close(c.ch.receive)
|
||||||
|
// close(c.ch.send)
|
||||||
|
|
||||||
|
|
||||||
|
if status == NONE {
|
||||||
|
status = NORMAL
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Build message */
|
||||||
|
msg := &Message{
|
||||||
|
Final: true,
|
||||||
|
Type: CLOSE,
|
||||||
|
Size: 2,
|
||||||
|
Data: make([]byte, 2),
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(msg.Data, uint16(status))
|
||||||
|
// msg.Data = append(msg.Data, []byte("(closing)")...)
|
||||||
|
msg.Size = uint( len(msg.Data) )
|
||||||
|
|
||||||
|
/* (4) Send message */
|
||||||
|
err := msg.Send(c.io.sock)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[close] send error (%s0\n", err)
|
||||||
|
}
|
||||||
|
// fmt.Printf("[close] frame sent\n")
|
||||||
|
|
||||||
|
|
||||||
|
/* (2) Wait for client CLOSE if needed */
|
||||||
|
if clientACK {
|
||||||
|
|
||||||
|
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||||
|
|
||||||
|
/* Wait for message */
|
||||||
|
msg, err := readMessage(c.io.reader)
|
||||||
|
if err != nil || msg.Type != CLOSE {
|
||||||
|
if err == nil {
|
||||||
|
fmt.Printf("[close] received OpCode = %d\n", msg.Type)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[close] read error (%v)\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Printf("[close] received ACK\n")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Close socket */
|
||||||
|
c.io.sock.Close()
|
||||||
|
// fmt.Printf("[close] socket closed\n")
|
||||||
|
|
||||||
|
/* (4) Unregister */
|
||||||
|
c.io.kill <- c
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.xdrm.io/gws/internal/uri/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Represents available information about a client
|
||||||
|
type Client struct {
|
||||||
|
Protocol string // choosen protocol (Sec-WebSocket-Protocol)
|
||||||
|
Arguments [][]string // URI parameters, index 0 is full URI, then matching groups
|
||||||
|
Store interface{} // store (for client implementation-specific data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents a websocket controller callback function
|
||||||
|
type ControllerFunc func(*Client, <-chan Message, chan<- *Message, chan<- *Message)
|
||||||
|
|
||||||
|
// Represents a websocket controller
|
||||||
|
type Controller struct {
|
||||||
|
URI *parser.Scheme // uri scheme
|
||||||
|
Fun ControllerFunc // controller function
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents a controller set
|
||||||
|
type ControllerSet struct {
|
||||||
|
Def *Controller // default controller
|
||||||
|
Uri []*Controller // uri controllers
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Match finds a controller for a given URI
|
||||||
|
// also it returns the matching string patterns
|
||||||
|
func (s *ControllerSet) Match(uri string) (*Controller, [][]string){
|
||||||
|
|
||||||
|
/* (1) Initialise argument list */
|
||||||
|
arguments := [][]string{ []string{ uri } }
|
||||||
|
|
||||||
|
|
||||||
|
/* (2) Try each controller */
|
||||||
|
for _, c := range s.Uri {
|
||||||
|
|
||||||
|
/* 1. If matches */
|
||||||
|
if c.URI.Match(uri) {
|
||||||
|
|
||||||
|
/* Extract matches */
|
||||||
|
match := c.URI.GetAllMatch()
|
||||||
|
|
||||||
|
/* Add them to the 'arg' attribute */
|
||||||
|
arguments = append(arguments, match...)
|
||||||
|
|
||||||
|
/* Mark that we have a controller */
|
||||||
|
return c, arguments
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) If no controller found -> set default controller */
|
||||||
|
if s.Def != nil {
|
||||||
|
return s.Def, arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) If default is NIL, return empty controller */
|
||||||
|
return nil, arguments
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,192 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
|
||||||
|
const maximumHeaderSize = 1 + 1 + 8 + 4
|
||||||
|
|
||||||
|
// Lists websocket close status
|
||||||
|
type MessageError uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
NONE MessageError = 0
|
||||||
|
NORMAL MessageError = 1000
|
||||||
|
GOING_AWAY MessageError = 1001
|
||||||
|
PROTOCOL_ERR MessageError = 1002
|
||||||
|
UNACCEPTABLE_OPCODE MessageError = 1003
|
||||||
|
INVALID_PAYLOAD MessageError = 1007 // utf8
|
||||||
|
MESSAGE_TOO_LARGE MessageError = 1009
|
||||||
|
)
|
||||||
|
|
||||||
|
// Lists websocket message types
|
||||||
|
type MessageType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
CONTINUATION MessageType = 0x0
|
||||||
|
TEXT MessageType = 0x1
|
||||||
|
BINARY MessageType = 0x2
|
||||||
|
CLOSE MessageType = 0x8
|
||||||
|
PING MessageType = 0x9
|
||||||
|
PONG MessageType = 0xa
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
// Represents a websocket message
|
||||||
|
type Message struct {
|
||||||
|
Type MessageType
|
||||||
|
Data []byte
|
||||||
|
Size uint
|
||||||
|
Final bool
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// receive reads a message form reader
|
||||||
|
func readMessage(reader io.Reader) (*Message, error){
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var tmpBuf []byte
|
||||||
|
var mask []byte
|
||||||
|
var cursor int
|
||||||
|
|
||||||
|
m := new(Message)
|
||||||
|
|
||||||
|
/* (2) Byte 1: FIN and OpCode */
|
||||||
|
tmpBuf = make([]byte, 1)
|
||||||
|
_, err = reader.Read(tmpBuf)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
|
||||||
|
|
||||||
|
m.Final = bool( tmpBuf[0] & 0x80 == 0x80 )
|
||||||
|
m.Type = MessageType( tmpBuf[0] & 0x0f )
|
||||||
|
|
||||||
|
/* (3) Byte 2: Mask and Length[0] */
|
||||||
|
tmpBuf = make([]byte, 1)
|
||||||
|
_, err = reader.Read(tmpBuf)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
|
||||||
|
// if mask, byte array not nil
|
||||||
|
if tmpBuf[0] & 0x80 == 0x80 {
|
||||||
|
mask = make([]byte, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// payload length
|
||||||
|
m.Size = uint( tmpBuf[0] & 0x7f )
|
||||||
|
|
||||||
|
/* (4) Extended payload */
|
||||||
|
if m.Size == 127 {
|
||||||
|
|
||||||
|
tmpBuf = make([]byte, 8)
|
||||||
|
_, err := reader.Read(tmpBuf)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
|
||||||
|
m.Size = uint( binary.BigEndian.Uint64(tmpBuf) )
|
||||||
|
|
||||||
|
} else if m.Size == 126 {
|
||||||
|
|
||||||
|
tmpBuf = make([]byte, 2)
|
||||||
|
_, err := reader.Read(tmpBuf)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
|
||||||
|
m.Size = uint( binary.BigEndian.Uint16(tmpBuf) )
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (5) Masking key */
|
||||||
|
if mask != nil {
|
||||||
|
|
||||||
|
tmpBuf = make([]byte, 4)
|
||||||
|
_, err := reader.Read(tmpBuf)
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
|
||||||
|
mask = make([]byte, 4)
|
||||||
|
copy(mask, tmpBuf)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (6) Read payload by chunks */
|
||||||
|
m.Data = make([]byte, int(m.Size))
|
||||||
|
|
||||||
|
// If empty payload
|
||||||
|
if m.Size <= 0 {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = 0
|
||||||
|
|
||||||
|
// {1} While we have data to read //
|
||||||
|
for uint(cursor) < m.Size {
|
||||||
|
|
||||||
|
// {2} Try to read (at least 1 byte) //
|
||||||
|
nbread, err := io.ReadAtLeast(reader, m.Data[cursor:m.Size], 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// {3} Unmask data //
|
||||||
|
if mask != nil {
|
||||||
|
for i, l := cursor, cursor+nbread ; i < l ; i++ {
|
||||||
|
|
||||||
|
mi := i % 4 // mask index
|
||||||
|
m.Data[i] = m.Data[i] ^ mask[mi]
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// {4} Update cursor //
|
||||||
|
cursor += nbread
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Send sends a frame over a socket
|
||||||
|
func (m Message) Send(socket net.Conn) error {
|
||||||
|
|
||||||
|
header := make([]byte, 0, maximumHeaderSize)
|
||||||
|
|
||||||
|
/* (1) Byte 0 : FIN + opcode */
|
||||||
|
header = append(header, 0x80 | byte(TEXT) )
|
||||||
|
|
||||||
|
/* (2) Get payload length */
|
||||||
|
if m.Size < 126 { // simple
|
||||||
|
|
||||||
|
header = append(header, byte(m.Size) )
|
||||||
|
|
||||||
|
} else if m.Size < 0xffff { // extended: 16 bits
|
||||||
|
|
||||||
|
header = append(header, 126)
|
||||||
|
|
||||||
|
buf := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(m.Size))
|
||||||
|
header = append(header, buf...)
|
||||||
|
|
||||||
|
} else if m.Size < 0xffffffffffffffff { // extended: 64 bits
|
||||||
|
|
||||||
|
header = append(header, 127)
|
||||||
|
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint64(buf, uint64(m.Size))
|
||||||
|
header = append(header, buf...)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Build write buffer */
|
||||||
|
writeBuf := make([]byte, 0, len(header) + int(m.Size))
|
||||||
|
writeBuf = append(writeBuf, header...)
|
||||||
|
writeBuf = append(writeBuf, m.Data...)
|
||||||
|
|
||||||
|
/* (4) Send over socket */
|
||||||
|
_, err := socket.Write(writeBuf)
|
||||||
|
if err != nil { return err }
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,178 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"fmt"
|
||||||
|
"git.xdrm.io/gws/internal/uri/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Represents all channels that need a server
|
||||||
|
type serverChannelSet struct{
|
||||||
|
register chan *client
|
||||||
|
unregister chan *client
|
||||||
|
broadcast chan *Message
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Represents a websocket server
|
||||||
|
type Server struct {
|
||||||
|
sock net.Listener // listen socket
|
||||||
|
addr []byte // server listening ip/host
|
||||||
|
port uint16 // server listening port
|
||||||
|
|
||||||
|
clients map[net.Conn]*client
|
||||||
|
|
||||||
|
ctl ControllerSet // controllers
|
||||||
|
|
||||||
|
ch serverChannelSet
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// CreateServer creates a server for a specific HOST and PORT
|
||||||
|
func CreateServer(host string, port uint16) *Server{
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
addr: []byte(host),
|
||||||
|
port: port,
|
||||||
|
|
||||||
|
clients: make(map[net.Conn]*client, 0),
|
||||||
|
|
||||||
|
ctl: ControllerSet{
|
||||||
|
Def: nil,
|
||||||
|
Uri: make([]*Controller, 0),
|
||||||
|
},
|
||||||
|
|
||||||
|
ch: serverChannelSet{
|
||||||
|
register: make(chan *client, 1),
|
||||||
|
unregister: make(chan *client, 1),
|
||||||
|
broadcast: make(chan *Message, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// BindDefault binds a default controller
|
||||||
|
// it will be called if the URI does not
|
||||||
|
// match another controller
|
||||||
|
func (s *Server) BindDefault(f ControllerFunc){
|
||||||
|
|
||||||
|
s.ctl.Def = &Controller{
|
||||||
|
URI: nil,
|
||||||
|
Fun: f,
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Bind binds a controller to an URI scheme
|
||||||
|
func (s *Server) Bind(uri string, f ControllerFunc) error {
|
||||||
|
|
||||||
|
/* (1) Build URI parser */
|
||||||
|
uriScheme, err := parser.Build(uri)
|
||||||
|
if err != nil { return fmt.Errorf("Cannot build URI: %s", err) }
|
||||||
|
|
||||||
|
/* (2) Create controller */
|
||||||
|
s.ctl.Uri = append(s.ctl.Uri, &Controller{
|
||||||
|
URI: uriScheme,
|
||||||
|
Fun: f,
|
||||||
|
} )
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Launch launches the websocket server
|
||||||
|
func (s *Server) Launch() error {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
/* (1) Listen socket
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
/* (1) Build full url */
|
||||||
|
url := fmt.Sprintf("%s:%d", s.addr, s.port)
|
||||||
|
|
||||||
|
/* (2) Bind socket to listen */
|
||||||
|
s.sock, err = net.Listen("tcp", url)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Listen socket: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer s.sock.Close()
|
||||||
|
|
||||||
|
fmt.Printf("+ listening on %s\n", url)
|
||||||
|
|
||||||
|
/* (3) Launch scheduler */
|
||||||
|
go s.scheduler()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* (2) For each incoming connection (client)
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
for {
|
||||||
|
|
||||||
|
/* (1) Wait for client */
|
||||||
|
sock, err := s.sock.Accept()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(){
|
||||||
|
|
||||||
|
/* (2) Try to create client */
|
||||||
|
cli, err := buildClient(sock, s.ctl, s.ch)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf(" - %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Register client */
|
||||||
|
s.ch.register <- cli
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Scheduler schedules clients registration and broadcast
|
||||||
|
func (s *Server) scheduler(){
|
||||||
|
|
||||||
|
for {
|
||||||
|
|
||||||
|
select {
|
||||||
|
|
||||||
|
/* (1) New client */
|
||||||
|
case client := <- s.ch.register:
|
||||||
|
|
||||||
|
// fmt.Printf(" + client\n")
|
||||||
|
s.clients[client.io.sock] = client
|
||||||
|
|
||||||
|
/* (2) New client */
|
||||||
|
case client := <- s.ch.unregister:
|
||||||
|
|
||||||
|
// fmt.Printf(" - client\n")
|
||||||
|
delete(s.clients, client.io.sock)
|
||||||
|
|
||||||
|
/* (3) Broadcast */
|
||||||
|
case msg := <- s.ch.broadcast:
|
||||||
|
|
||||||
|
fmt.Printf(" + broadcast\n")
|
||||||
|
|
||||||
|
for _, c := range s.clients{
|
||||||
|
c.ch.send <- msg
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("+ server stopped\n")
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue