refactor: http/upgrade normalise files
This commit is contained in:
parent
db52cfd28f
commit
6c47dbc38f
70
client.go
70
client.go
|
@ -8,7 +8,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.xdrm.io/go/ws/internal/http/upgrade/request"
|
"git.xdrm.io/go/ws/internal/http/upgrade"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Represents a client socket utility (reader, writer, ..)
|
// Represents a client socket utility (reader, writer, ..)
|
||||||
|
@ -41,13 +41,13 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
||||||
|
|
||||||
/* (1) Manage UPGRADE request
|
/* (1) Manage UPGRADE request
|
||||||
---------------------------------------------------------*/
|
---------------------------------------------------------*/
|
||||||
/* (1) Parse request */
|
// 1. Parse request
|
||||||
req, _ := request.Parse(s)
|
req, _ := upgrade.Parse(s)
|
||||||
|
|
||||||
/* (3) Build response */
|
// 3. Build response
|
||||||
res := req.BuildResponse()
|
res := req.BuildResponse()
|
||||||
|
|
||||||
/* (4) Write into socket */
|
// 4. Write into socket
|
||||||
_, err := res.Send(s)
|
_, err := res.Send(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Upgrade write error: %s", err)
|
return nil, fmt.Errorf("Upgrade write error: %s", err)
|
||||||
|
@ -55,16 +55,16 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
||||||
|
|
||||||
if res.GetStatusCode() != 101 {
|
if res.GetStatusCode() != 101 {
|
||||||
s.Close()
|
s.Close()
|
||||||
return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
|
return nil, fmt.Errorf("Upgrade error (HTTP %d)", res.GetStatusCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Initialise client
|
/* (2) Initialise client
|
||||||
---------------------------------------------------------*/
|
---------------------------------------------------------*/
|
||||||
/* (1) Get upgrade data */
|
// 1. Get upgrade data
|
||||||
clientURI := req.GetURI()
|
clientURI := req.GetURI()
|
||||||
clientProtocol := res.GetProtocol()
|
clientProtocol := res.GetProtocol()
|
||||||
|
|
||||||
/* (2) Initialise client */
|
// 2. Initialise client
|
||||||
cli := &client{
|
cli := &client{
|
||||||
io: clientIO{
|
io: clientIO{
|
||||||
sock: s,
|
sock: s,
|
||||||
|
@ -74,7 +74,7 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
||||||
|
|
||||||
iface: &Client{
|
iface: &Client{
|
||||||
Protocol: string(clientProtocol),
|
Protocol: string(clientProtocol),
|
||||||
Arguments: [][]string{[]string{clientURI}},
|
Arguments: [][]string{{clientURI}},
|
||||||
},
|
},
|
||||||
|
|
||||||
ch: clientChannelSet{
|
ch: clientChannelSet{
|
||||||
|
@ -85,20 +85,20 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
||||||
|
|
||||||
/* (3) Find controller by URI
|
/* (3) Find controller by URI
|
||||||
---------------------------------------------------------*/
|
---------------------------------------------------------*/
|
||||||
/* (1) Try to find one */
|
// 1. Try to find one
|
||||||
controller, arguments := ctl.Match(clientURI)
|
controller, arguments := ctl.Match(clientURI)
|
||||||
|
|
||||||
/* (2) If nothing found -> error */
|
// 2. If nothing found -> error
|
||||||
if controller == nil {
|
if controller == nil {
|
||||||
return nil, fmt.Errorf("No controller found, no default controller set\n")
|
return nil, fmt.Errorf("No controller found, no default controller set")
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Copy arguments */
|
// 3. Copy arguments
|
||||||
cli.iface.Arguments = arguments
|
cli.iface.Arguments = arguments
|
||||||
|
|
||||||
/* (4) Launch client routines
|
/* (4) Launch client routines
|
||||||
---------------------------------------------------------*/
|
---------------------------------------------------------*/
|
||||||
/* (1) Launch client controller */
|
// 1. Launch client controller
|
||||||
go controller.Fun(
|
go controller.Fun(
|
||||||
cli.iface, // pass the client
|
cli.iface, // pass the client
|
||||||
cli.ch.receive, // the receiver
|
cli.ch.receive, // the receiver
|
||||||
|
@ -106,10 +106,10 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
|
||||||
serverCh.broadcast, // broadcast sender
|
serverCh.broadcast, // broadcast sender
|
||||||
)
|
)
|
||||||
|
|
||||||
/* (2) Launch message reader */
|
// 2. Launch message reader
|
||||||
go clientReader(cli)
|
go clientReader(cli)
|
||||||
|
|
||||||
/* (3) Launc writer */
|
// 3. Launc writer
|
||||||
go clientWriter(cli)
|
go clientWriter(cli)
|
||||||
|
|
||||||
return cli, nil
|
return cli, nil
|
||||||
|
@ -127,13 +127,13 @@ func clientReader(c *client) {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
/* (1) if currently closing -> exit */
|
// 1. if currently closing -> exit
|
||||||
if c.io.closing {
|
if c.io.closing {
|
||||||
fmt.Printf("[reader] killed because closing")
|
fmt.Printf("[reader] killed because closing")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Parse message */
|
// 2. Parse message
|
||||||
msg, err := readMessage(c.io.reader)
|
msg, err := readMessage(c.io.reader)
|
||||||
|
|
||||||
if err == ErrUnmaskedFrame || err == ErrReservedBits {
|
if err == ErrUnmaskedFrame || err == ErrReservedBits {
|
||||||
|
@ -143,7 +143,7 @@ func clientReader(c *client) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Fail on invalid message */
|
// 3. Fail on invalid message
|
||||||
msgErr := msg.check(frag != nil)
|
msgErr := msg.check(frag != nil)
|
||||||
if msgErr != nil {
|
if msgErr != nil {
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ func clientReader(c *client) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) Ping <-> Pong */
|
// 4. Ping <-> Pong
|
||||||
if msg.Type == Ping && c.io.writing {
|
if msg.Type == Ping && c.io.writing {
|
||||||
msg.Final = true
|
msg.Final = true
|
||||||
msg.Type = Pong
|
msg.Type = Pong
|
||||||
|
@ -190,7 +190,7 @@ func clientReader(c *client) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (5) Store first fragment */
|
// 5. Store first fragment
|
||||||
if frag == nil && !msg.Final {
|
if frag == nil && !msg.Final {
|
||||||
frag = &Message{
|
frag = &Message{
|
||||||
Type: msg.Type,
|
Type: msg.Type,
|
||||||
|
@ -201,7 +201,7 @@ func clientReader(c *client) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (6) Store fragments */
|
// 6. Store fragments
|
||||||
if frag != nil {
|
if frag != nil {
|
||||||
frag.Final = msg.Final
|
frag.Final = msg.Final
|
||||||
frag.Size += msg.Size
|
frag.Size += msg.Size
|
||||||
|
@ -226,7 +226,7 @@ func clientReader(c *client) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (7) Dispatch to receiver */
|
// 7. Dispatch to receiver
|
||||||
if msg.Type == Text || msg.Type == Binary {
|
if msg.Type == Text || msg.Type == Binary {
|
||||||
c.ch.receive <- *msg
|
c.ch.receive <- *msg
|
||||||
}
|
}
|
||||||
|
@ -236,7 +236,7 @@ func clientReader(c *client) {
|
||||||
close(c.ch.receive)
|
close(c.ch.receive)
|
||||||
c.io.reading.Done()
|
c.io.reading.Done()
|
||||||
|
|
||||||
/* (8) close channel (if not already done) */
|
// 8. close channel (if not already done)
|
||||||
// fmt.Printf("[reader] end\n")
|
// fmt.Printf("[reader] end\n")
|
||||||
c.close(closeStatus, clientAck)
|
c.close(closeStatus, clientAck)
|
||||||
|
|
||||||
|
@ -250,10 +250,10 @@ func clientWriter(c *client) {
|
||||||
|
|
||||||
for msg := range c.ch.send {
|
for msg := range c.ch.send {
|
||||||
|
|
||||||
/* (2) Send message */
|
// 2. Send message
|
||||||
err := msg.Send(c.io.sock)
|
err := msg.Send(c.io.sock)
|
||||||
|
|
||||||
/* (3) Fail on error */
|
// 3. Fail on error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf(" [writer] %s\n", err)
|
fmt.Printf(" [writer] %s\n", err)
|
||||||
c.io.writing = false
|
c.io.writing = false
|
||||||
|
@ -264,7 +264,7 @@ func clientWriter(c *client) {
|
||||||
|
|
||||||
c.io.writing = false
|
c.io.writing = false
|
||||||
|
|
||||||
/* (4) close channel (if not already done) */
|
// 4. close channel (if not already done)
|
||||||
// fmt.Printf("[writer] end\n")
|
// fmt.Printf("[writer] end\n")
|
||||||
c.close(Normal, true)
|
c.close(Normal, true)
|
||||||
|
|
||||||
|
@ -276,7 +276,7 @@ func clientWriter(c *client) {
|
||||||
// then delete client
|
// then delete client
|
||||||
func (c *client) close(status MessageError, clientACK bool) {
|
func (c *client) close(status MessageError, clientACK bool) {
|
||||||
|
|
||||||
/* (1) Fail if already closing */
|
// 1. Fail if already closing
|
||||||
alreadyClosing := false
|
alreadyClosing := false
|
||||||
c.io.closingMu.Lock()
|
c.io.closingMu.Lock()
|
||||||
alreadyClosing = c.io.closing
|
alreadyClosing = c.io.closing
|
||||||
|
@ -287,18 +287,18 @@ func (c *client) close(status MessageError, clientACK bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) kill writer' if still running */
|
// 2. kill writer' if still running
|
||||||
if c.io.writing {
|
if c.io.writing {
|
||||||
close(c.ch.send)
|
close(c.ch.send)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) kill reader if still running */
|
// 3. kill reader if still running
|
||||||
c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
|
c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
|
||||||
c.io.reading.Wait()
|
c.io.reading.Wait()
|
||||||
|
|
||||||
if status != None {
|
if status != None {
|
||||||
|
|
||||||
/* (3) Build message */
|
// 3. Build message
|
||||||
msg := &Message{
|
msg := &Message{
|
||||||
Final: true,
|
Final: true,
|
||||||
Type: Close,
|
Type: Close,
|
||||||
|
@ -307,7 +307,7 @@ func (c *client) close(status MessageError, clientACK bool) {
|
||||||
}
|
}
|
||||||
binary.BigEndian.PutUint16(msg.Data, uint16(status))
|
binary.BigEndian.PutUint16(msg.Data, uint16(status))
|
||||||
|
|
||||||
/* (4) Send message */
|
// 4. Send message
|
||||||
msg.Send(c.io.sock)
|
msg.Send(c.io.sock)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// fmt.Printf("[close] send error (%s0\n", err)
|
// fmt.Printf("[close] send error (%s0\n", err)
|
||||||
|
@ -315,7 +315,7 @@ func (c *client) close(status MessageError, clientACK bool) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Wait for client CLOSE if needed */
|
// 2. Wait for client CLOSE if needed
|
||||||
if clientACK {
|
if clientACK {
|
||||||
|
|
||||||
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
|
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||||
|
@ -334,11 +334,11 @@ func (c *client) close(status MessageError, clientACK bool) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Close socket */
|
// 3. Close socket
|
||||||
c.io.sock.Close()
|
c.io.sock.Close()
|
||||||
// fmt.Printf("[close] socket closed\n")
|
// fmt.Printf("[close] socket closed\n")
|
||||||
|
|
||||||
/* (4) Unregister */
|
// 4. Unregister
|
||||||
c.io.kill <- c
|
c.io.kill <- c
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -11,10 +11,10 @@ func main() {
|
||||||
|
|
||||||
startTime := time.Now().UnixNano()
|
startTime := time.Now().UnixNano()
|
||||||
|
|
||||||
/* (1) Bind WebSocket server */
|
// 1. Bind WebSocket server
|
||||||
serv := ws.CreateServer("0.0.0.0", 4444)
|
serv := ws.CreateServer("0.0.0.0", 4444)
|
||||||
|
|
||||||
/* (2) Bind default controller */
|
// 2. Bind default controller
|
||||||
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -33,7 +33,7 @@ func main() {
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
/* (3) Bind to URI */
|
// 3. Bind to URI
|
||||||
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
|
||||||
|
|
||||||
fmt.Printf("[uri] connected\n")
|
fmt.Printf("[uri] connected\n")
|
||||||
|
@ -52,7 +52,7 @@ func main() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) Launch the server */
|
// 4. Launch the server
|
||||||
err = serv.Launch()
|
err = serv.Launch()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[ERROR] %s\n", err)
|
fmt.Printf("[ERROR] %s\n", err)
|
||||||
|
|
|
@ -30,10 +30,10 @@ type ControllerSet struct {
|
||||||
// also it returns the matching string patterns
|
// also it returns the matching string patterns
|
||||||
func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
|
func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
|
||||||
|
|
||||||
/* (1) Initialise argument list */
|
// 1. Initialise argument list
|
||||||
arguments := [][]string{{uri}}
|
arguments := [][]string{{uri}}
|
||||||
|
|
||||||
/* (2) Try each controller */
|
// 2. Try each controller
|
||||||
for _, c := range s.URI {
|
for _, c := range s.URI {
|
||||||
|
|
||||||
/* 1. If matches */
|
/* 1. If matches */
|
||||||
|
@ -52,12 +52,12 @@ func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) If no controller found -> set default controller */
|
// 3. If no controller found -> set default controller
|
||||||
if s.Def != nil {
|
if s.Def != nil {
|
||||||
return s.Def, arguments
|
return s.Def, arguments
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) If default is NIL, return empty controller */
|
// 4. If default is NIL, return empty controller
|
||||||
return nil, arguments
|
return nil, arguments
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,16 +35,16 @@ func NewReader(r io.Reader) *ChunkReader {
|
||||||
// Read reads a chunk, err is io.EOF when done
|
// Read reads a chunk, err is io.EOF when done
|
||||||
func (r *ChunkReader) Read() ([]byte, error) {
|
func (r *ChunkReader) Read() ([]byte, error) {
|
||||||
|
|
||||||
/* (1) If already ended */
|
// 1. If already ended
|
||||||
if r.isEnded {
|
if r.isEnded {
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Read line */
|
// 2. Read line
|
||||||
var line []byte
|
var line []byte
|
||||||
line, err := r.reader.ReadSlice('\n')
|
line, err := r.reader.ReadSlice('\n')
|
||||||
|
|
||||||
/* (3) manage errors */
|
// 3. manage errors
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -57,10 +57,10 @@ func (r *ChunkReader) Read() ([]byte, error) {
|
||||||
return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength)
|
return nil, fmt.Errorf("HTTP line %d exceeded buffer size %d", len(line), maxLineLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) Trim */
|
// 4. Trim
|
||||||
line = removeTrailingSpace(line)
|
line = removeTrailingSpace(line)
|
||||||
|
|
||||||
/* (5) Manage ending line */
|
// 5. Manage ending line
|
||||||
if len(line) == 0 {
|
if len(line) == 0 {
|
||||||
r.isEnded = true
|
r.isEnded = true
|
||||||
return line, io.EOF
|
return line, io.EOF
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package request
|
package upgrade
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
|
@ -0,0 +1,74 @@
|
||||||
|
package upgrade
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeaderType represents all 'valid' HTTP request headers
|
||||||
|
type HeaderType uint8
|
||||||
|
|
||||||
|
// header types
|
||||||
|
const (
|
||||||
|
Unknown HeaderType = iota
|
||||||
|
Host
|
||||||
|
Upgrade
|
||||||
|
Connection
|
||||||
|
Origin
|
||||||
|
WSKey
|
||||||
|
WSProtocol
|
||||||
|
WSExtensions
|
||||||
|
WSVersion
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeaderValue represents a unique or multiple header value(s)
|
||||||
|
type HeaderValue [][]byte
|
||||||
|
|
||||||
|
// Header represents the data of a HTTP request header
|
||||||
|
type Header struct {
|
||||||
|
Name HeaderType
|
||||||
|
Values HeaderValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadHeader tries to parse an HTTP header from a byte array
|
||||||
|
func ReadHeader(b []byte) (*Header, error) {
|
||||||
|
|
||||||
|
// 1. Split by ':'
|
||||||
|
parts := bytes.Split(b, []byte(": "))
|
||||||
|
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, fmt.Errorf("Invalid HTTP header format '%s'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Create instance
|
||||||
|
inst := &Header{}
|
||||||
|
|
||||||
|
// 3. Check for header name
|
||||||
|
switch strings.ToLower(string(parts[0])) {
|
||||||
|
case "host":
|
||||||
|
inst.Name = Host
|
||||||
|
case "upgrade":
|
||||||
|
inst.Name = Upgrade
|
||||||
|
case "connection":
|
||||||
|
inst.Name = Connection
|
||||||
|
case "origin":
|
||||||
|
inst.Name = Origin
|
||||||
|
case "sec-websocket-key":
|
||||||
|
inst.Name = WSKey
|
||||||
|
case "sec-websocket-protocol":
|
||||||
|
inst.Name = WSProtocol
|
||||||
|
case "sec-websocket-extensions":
|
||||||
|
inst.Name = WSExtensions
|
||||||
|
case "sec-websocket-version":
|
||||||
|
inst.Name = WSVersion
|
||||||
|
default:
|
||||||
|
inst.Name = Unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Split values
|
||||||
|
inst.Values = bytes.Split(parts[1], []byte(", "))
|
||||||
|
|
||||||
|
return inst, nil
|
||||||
|
|
||||||
|
}
|
|
@ -1,16 +1,13 @@
|
||||||
package request
|
package upgrade
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.xdrm.io/go/ws/internal/http/upgrade/request/parser/header"
|
|
||||||
"git.xdrm.io/go/ws/internal/http/upgrade/response"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// checkHost checks and extracts the Host header
|
// checkHost checks and extracts the Host header
|
||||||
func (r *T) extractHostPort(bb header.HeaderValue) error {
|
func (r *Request) extractHostPort(bb HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 {
|
if len(bb) != 1 {
|
||||||
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
|
@ -32,7 +29,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||||
// extract port
|
// extract port
|
||||||
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.code = response.BadRequest
|
r.code = BadRequest
|
||||||
return &InvalidRequest{"Host", "cannot read port"}
|
return &InvalidRequest{"Host", "cannot read port"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +39,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||||
if len(r.origin) > 0 {
|
if len(r.origin) > 0 {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = r.checkOriginPolicy()
|
err = r.checkOriginPolicy()
|
||||||
r.code = response.Forbidden
|
r.code = Forbidden
|
||||||
return &InvalidOriginPolicy{r.host, r.origin, err}
|
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -52,7 +49,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkOrigin checks the Origin Header
|
// checkOrigin checks the Origin Header
|
||||||
func (r *T) extractOrigin(bb header.HeaderValue) error {
|
func (r *Request) extractOrigin(bb HeaderValue) error {
|
||||||
|
|
||||||
// bypass
|
// bypass
|
||||||
if bypassOriginPolicy {
|
if bypassOriginPolicy {
|
||||||
|
@ -60,7 +57,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(bb) != 1 {
|
if len(bb) != 1 {
|
||||||
r.code = response.Forbidden
|
r.code = Forbidden
|
||||||
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +67,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
|
||||||
if len(r.host) > 0 {
|
if len(r.host) > 0 {
|
||||||
err := r.checkOriginPolicy()
|
err := r.checkOriginPolicy()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.code = response.Forbidden
|
r.code = Forbidden
|
||||||
return &InvalidOriginPolicy{r.host, r.origin, err}
|
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,7 +77,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkOriginPolicy origin policy based on 'host' value
|
// checkOriginPolicy origin policy based on 'host' value
|
||||||
func (r *T) checkOriginPolicy() error {
|
func (r *Request) checkOriginPolicy() error {
|
||||||
// TODO: Origin policy, for now BYPASS
|
// TODO: Origin policy, for now BYPASS
|
||||||
r.validPolicy = true
|
r.validPolicy = true
|
||||||
return nil
|
return nil
|
||||||
|
@ -88,7 +85,7 @@ func (r *T) checkOriginPolicy() error {
|
||||||
|
|
||||||
// checkConnection checks the 'Connection' header
|
// checkConnection checks the 'Connection' header
|
||||||
// it MUST contain 'Upgrade'
|
// it MUST contain 'Upgrade'
|
||||||
func (r *T) checkConnection(bb header.HeaderValue) error {
|
func (r *Request) checkConnection(bb HeaderValue) error {
|
||||||
|
|
||||||
for _, b := range bb {
|
for _, b := range bb {
|
||||||
|
|
||||||
|
@ -99,17 +96,17 @@ func (r *T) checkConnection(bb header.HeaderValue) error {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.code = response.BadRequest
|
r.code = BadRequest
|
||||||
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkUpgrade checks the 'Upgrade' header
|
// checkUpgrade checks the 'Upgrade' header
|
||||||
// it MUST be 'websocket'
|
// it MUST be 'websocket'
|
||||||
func (r *T) checkUpgrade(bb header.HeaderValue) error {
|
func (r *Request) checkUpgrade(bb HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 {
|
if len(bb) != 1 {
|
||||||
r.code = response.BadRequest
|
r.code = BadRequest
|
||||||
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,17 +115,17 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r.code = response.BadRequest
|
r.code = BadRequest
|
||||||
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkVersion checks the 'Sec-WebSocket-Version' header
|
// checkVersion checks the 'Sec-WebSocket-Version' header
|
||||||
// it MUST be '13'
|
// it MUST be '13'
|
||||||
func (r *T) checkVersion(bb header.HeaderValue) error {
|
func (r *Request) checkVersion(bb HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 || string(bb[0]) != "13" {
|
if len(bb) != 1 || string(bb[0]) != "13" {
|
||||||
r.code = response.UpgradeRequired
|
r.code = UpgradeRequired
|
||||||
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,10 +136,10 @@ func (r *T) checkVersion(bb header.HeaderValue) error {
|
||||||
|
|
||||||
// extractKey extracts the 'Sec-WebSocket-Key' header
|
// extractKey extracts the 'Sec-WebSocket-Key' header
|
||||||
// it MUST be 24 bytes (base64)
|
// it MUST be 24 bytes (base64)
|
||||||
func (r *T) extractKey(bb header.HeaderValue) error {
|
func (r *Request) extractKey(bb HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 || len(bb[0]) != 24 {
|
if len(bb) != 1 || len(bb[0]) != 24 {
|
||||||
r.code = response.BadRequest
|
r.code = BadRequest
|
||||||
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +151,7 @@ func (r *T) extractKey(bb header.HeaderValue) error {
|
||||||
|
|
||||||
// extractProtocols extracts the 'Sec-WebSocket-Protocol' header
|
// extractProtocols extracts the 'Sec-WebSocket-Protocol' header
|
||||||
// it can contain multiple values
|
// it can contain multiple values
|
||||||
func (r *T) extractProtocols(bb header.HeaderValue) error {
|
func (r *Request) extractProtocols(bb HeaderValue) error {
|
||||||
|
|
||||||
r.protocols = bb
|
r.protocols = bb
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package request
|
package upgrade
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -6,50 +6,38 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// httpMethod represents available http methods
|
// Line represents the HTTP Request line
|
||||||
type httpMethod byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
OPTIONS httpMethod = iota
|
|
||||||
GET
|
|
||||||
HEAD
|
|
||||||
POST
|
|
||||||
PUT
|
|
||||||
DELETE
|
|
||||||
)
|
|
||||||
|
|
||||||
// RequestLine represents the HTTP Request line
|
|
||||||
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
||||||
type RequestLine struct {
|
type Line struct {
|
||||||
method httpMethod
|
method Method
|
||||||
uri string
|
uri string
|
||||||
version byte
|
version byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseRequestLine parses the first HTTP request line
|
// Parse parses the first HTTP request line
|
||||||
func (r *RequestLine) Parse(b []byte) error {
|
func (r *Line) Parse(b []byte) error {
|
||||||
|
|
||||||
/* (1) Split by ' ' */
|
// 1. Split by ' '
|
||||||
parts := bytes.Split(b, []byte(" "))
|
parts := bytes.Split(b, []byte(" "))
|
||||||
|
|
||||||
/* (2) Fail when missing parts */
|
// 2. Fail when missing parts
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Extract HTTP method */
|
// 3. Extract HTTP method
|
||||||
err := r.extractHttpMethod(parts[0])
|
err := r.extractHttpMethod(parts[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) Extract URI */
|
// 4. Extract URI
|
||||||
err = r.extractURI(parts[1])
|
err = r.extractURI(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (5) Extract version */
|
// 5. Extract version
|
||||||
err = r.extractHttpVersion(parts[2])
|
err = r.extractHttpVersion(parts[2])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -60,19 +48,19 @@ func (r *RequestLine) Parse(b []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetURI returns the actual URI
|
// GetURI returns the actual URI
|
||||||
func (r RequestLine) GetURI() string {
|
func (r Line) GetURI() string {
|
||||||
return r.uri
|
return r.uri
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractHttpMethod extracts the HTTP method from a []byte
|
// extractHttpMethod extracts the HTTP method from a []byte
|
||||||
// and checks for errors
|
// and checks for errors
|
||||||
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
||||||
func (r *RequestLine) extractHttpMethod(b []byte) error {
|
func (r *Line) extractHttpMethod(b []byte) error {
|
||||||
|
|
||||||
switch string(b) {
|
switch string(b) {
|
||||||
// case "OPTIONS": r.method = OPTIONS
|
// case "OPTIONS": r.method = OPTIONS
|
||||||
case "GET":
|
case "GET":
|
||||||
r.method = GET
|
r.method = Get
|
||||||
// case "HEAD": r.method = HEAD
|
// case "HEAD": r.method = HEAD
|
||||||
// case "POST": r.method = POST
|
// case "POST": r.method = POST
|
||||||
// case "PUT": r.method = PUT
|
// case "PUT": r.method = PUT
|
||||||
|
@ -87,15 +75,15 @@ func (r *RequestLine) extractHttpMethod(b []byte) error {
|
||||||
|
|
||||||
// extractURI extracts the URI from a []byte and checks for errors
|
// extractURI extracts the URI from a []byte and checks for errors
|
||||||
// allowed format: /([^/]/)*/?
|
// allowed format: /([^/]/)*/?
|
||||||
func (r *RequestLine) extractURI(b []byte) error {
|
func (r *Line) extractURI(b []byte) error {
|
||||||
|
|
||||||
/* (1) Check format */
|
// 1. Check format
|
||||||
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
||||||
if !checker.Match(b) {
|
if !checker.Match(b) {
|
||||||
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Store */
|
// 2. Store
|
||||||
r.uri = string(b)
|
r.uri = string(b)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -104,26 +92,26 @@ func (r *RequestLine) extractURI(b []byte) error {
|
||||||
|
|
||||||
// extractHttpVersion extracts the version and checks for errors
|
// extractHttpVersion extracts the version and checks for errors
|
||||||
// allowed format: [1-9] or [1.9].[0-9]
|
// allowed format: [1-9] or [1.9].[0-9]
|
||||||
func (r *RequestLine) extractHttpVersion(b []byte) error {
|
func (r *Line) extractHttpVersion(b []byte) error {
|
||||||
|
|
||||||
/* (1) Extract version parts */
|
// 1. Extract version parts
|
||||||
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
|
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`)
|
||||||
|
|
||||||
if !extractor.Match(b) {
|
if !extractor.Match(b) {
|
||||||
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
|
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Extract version number */
|
// 2. Extract version number
|
||||||
matches := extractor.FindSubmatch(b)
|
matches := extractor.FindSubmatch(b)
|
||||||
var version byte = matches[1][0] - '0'
|
var version byte = matches[1][0] - '0'
|
||||||
|
|
||||||
/* (3) Extract subversion (if exists) */
|
// 3. Extract subversion (if exists)
|
||||||
var subVersion byte = 0
|
var subVersion byte = 0
|
||||||
if len(matches[2]) > 0 {
|
if len(matches[2]) > 0 {
|
||||||
subVersion = matches[2][0] - '0'
|
subVersion = matches[2][0] - '0'
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) Store version (x 10 to fit uint8) */
|
// 4. Store version (x 10 to fit uint8)
|
||||||
r.version = version*10 + subVersion
|
r.version = version*10 + subVersion
|
||||||
|
|
||||||
return nil
|
return nil
|
|
@ -0,0 +1,14 @@
|
||||||
|
package upgrade
|
||||||
|
|
||||||
|
// Method represents available http methods
|
||||||
|
type Method uint8
|
||||||
|
|
||||||
|
// http methods
|
||||||
|
const (
|
||||||
|
Options Method = iota
|
||||||
|
Get
|
||||||
|
Head
|
||||||
|
Post
|
||||||
|
Put
|
||||||
|
Delete
|
||||||
|
)
|
|
@ -0,0 +1,216 @@
|
||||||
|
package upgrade
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"git.xdrm.io/go/ws/internal/http/reader"
|
||||||
|
)
|
||||||
|
|
||||||
|
// If origin is required
|
||||||
|
const bypassOriginPolicy = true
|
||||||
|
|
||||||
|
// Request represents an HTTP Upgrade request
|
||||||
|
type Request struct {
|
||||||
|
first bool // whether the first line has been read (GET uri HTTP/version)
|
||||||
|
|
||||||
|
// status code
|
||||||
|
code StatusCode
|
||||||
|
|
||||||
|
// request line
|
||||||
|
request Line
|
||||||
|
|
||||||
|
// data to check origin (depends of reading order)
|
||||||
|
host string
|
||||||
|
port uint16 // 0 if not set
|
||||||
|
origin string
|
||||||
|
validPolicy bool
|
||||||
|
|
||||||
|
// ws data
|
||||||
|
key []byte
|
||||||
|
protocols [][]byte
|
||||||
|
|
||||||
|
// required fields check
|
||||||
|
hasConnection bool
|
||||||
|
hasUpgrade bool
|
||||||
|
hasVersion bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse builds an upgrade HTTP request
|
||||||
|
// from a reader (typically bufio.NewRead of the socket)
|
||||||
|
func Parse(r io.Reader) (request *Request, err error) {
|
||||||
|
|
||||||
|
req := &Request{
|
||||||
|
code: 500,
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (1) Parse request
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
// 1. Get chunk reader
|
||||||
|
cr := reader.NewReader(r)
|
||||||
|
if err != nil {
|
||||||
|
return req, fmt.Errorf("Error while creating chunk reader: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Parse header line by line
|
||||||
|
for {
|
||||||
|
|
||||||
|
line, err := cr.Read()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = req.parseHeader(line)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Check completion
|
||||||
|
err = req.isComplete()
|
||||||
|
if err != nil {
|
||||||
|
req.code = BadRequest
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.code = SwitchingProtocols
|
||||||
|
return req, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCode returns the status current
|
||||||
|
func (r Request) StatusCode() StatusCode {
|
||||||
|
return r.code
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildResponse builds a response from the request
|
||||||
|
func (r *Request) BuildResponse() *Response {
|
||||||
|
|
||||||
|
inst := &Response{}
|
||||||
|
|
||||||
|
// 1. Copy code
|
||||||
|
inst.SetStatusCode(r.code)
|
||||||
|
|
||||||
|
// 2. Set Protocol
|
||||||
|
if len(r.protocols) > 0 {
|
||||||
|
inst.SetProtocol(r.protocols[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Process key
|
||||||
|
inst.ProcessKey(r.key)
|
||||||
|
|
||||||
|
return inst
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetURI returns the actual URI
|
||||||
|
func (r Request) GetURI() string {
|
||||||
|
return r.request.GetURI()
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseHeader parses any http request line
|
||||||
|
// (header and request-line)
|
||||||
|
func (r *Request) parseHeader(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) First line -> GET {uri} HTTP/{version}
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
if !r.first {
|
||||||
|
|
||||||
|
err := r.request.Parse(b)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
r.code = BadRequest
|
||||||
|
return &InvalidRequest{"Request-Line", err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.first = true
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Other lines -> Header-Name: Header-Value
|
||||||
|
---------------------------------------------------------*/
|
||||||
|
// 1. Try to parse header
|
||||||
|
head, err := ReadHeader(b)
|
||||||
|
if err != nil {
|
||||||
|
r.code = BadRequest
|
||||||
|
return fmt.Errorf("Error parsing header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Manage header
|
||||||
|
switch head.Name {
|
||||||
|
case Host:
|
||||||
|
err = r.extractHostPort(head.Values)
|
||||||
|
case Origin:
|
||||||
|
err = r.extractOrigin(head.Values)
|
||||||
|
case Upgrade:
|
||||||
|
err = r.checkUpgrade(head.Values)
|
||||||
|
case Connection:
|
||||||
|
err = r.checkConnection(head.Values)
|
||||||
|
case WSVersion:
|
||||||
|
err = r.checkVersion(head.Values)
|
||||||
|
case WSKey:
|
||||||
|
err = r.extractKey(head.Values)
|
||||||
|
case WSProtocol:
|
||||||
|
err = r.extractProtocols(head.Values)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// isComplete returns whether the Upgrade Request
|
||||||
|
// is complete (no missing required item)
|
||||||
|
func (r Request) isComplete() error {
|
||||||
|
|
||||||
|
// 1. Request-Line
|
||||||
|
if !r.first {
|
||||||
|
return &IncompleteRequest{"Request-Line"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Host
|
||||||
|
if len(r.host) == 0 {
|
||||||
|
return &IncompleteRequest{"Host"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Origin
|
||||||
|
if !bypassOriginPolicy && len(r.origin) == 0 {
|
||||||
|
return &IncompleteRequest{"Origin"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Connection
|
||||||
|
if !r.hasConnection {
|
||||||
|
return &IncompleteRequest{"Connection"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Upgrade
|
||||||
|
if !r.hasUpgrade {
|
||||||
|
return &IncompleteRequest{"Upgrade"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Sec-WebSocket-Version
|
||||||
|
if !r.hasVersion {
|
||||||
|
return &IncompleteRequest{"Sec-WebSocket-Version"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Sec-WebSocket-Key
|
||||||
|
if len(r.key) < 1 {
|
||||||
|
return &IncompleteRequest{"Sec-WebSocket-Key"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
|
@ -1,41 +0,0 @@
|
||||||
package header
|
|
||||||
|
|
||||||
import (
|
|
||||||
// "regexp"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
// parse tries to return a 'T' (httpHeader) from a byte array
|
|
||||||
func Parse(b []byte) (*T, error) {
|
|
||||||
|
|
||||||
/* (1) Split by ':' */
|
|
||||||
parts := bytes.Split(b, []byte(": "))
|
|
||||||
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return nil, fmt.Errorf("Invalid HTTP header format '%s'", b)
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (2) Create instance */
|
|
||||||
inst := new(T)
|
|
||||||
|
|
||||||
/* (3) Check for header name */
|
|
||||||
switch strings.ToLower(string(parts[0])) {
|
|
||||||
case "host": inst.Name = HOST
|
|
||||||
case "upgrade": inst.Name = UPGRADE
|
|
||||||
case "connection": inst.Name = CONNECTION
|
|
||||||
case "origin": inst.Name = ORIGIN
|
|
||||||
case "sec-websocket-key": inst.Name = WSKEY
|
|
||||||
case "sec-websocket-protocol": inst.Name = WSPROTOCOL
|
|
||||||
case "sec-websocket-extensions": inst.Name = WSEXTENSIONS
|
|
||||||
case "sec-websocket-version": inst.Name = WSVERSION
|
|
||||||
default: inst.Name = UNKNOWN
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (4) Split values */
|
|
||||||
inst.Values = bytes.Split(parts[1], []byte(", "))
|
|
||||||
|
|
||||||
return inst, nil
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
package header
|
|
||||||
|
|
||||||
// HeaderType represents all 'valid' HTTP request headers
|
|
||||||
type HeaderType byte
|
|
||||||
const (
|
|
||||||
UNKNOWN HeaderType = iota
|
|
||||||
HOST
|
|
||||||
UPGRADE
|
|
||||||
CONNECTION
|
|
||||||
ORIGIN
|
|
||||||
WSKEY
|
|
||||||
WSPROTOCOL
|
|
||||||
WSEXTENSIONS
|
|
||||||
WSVERSION
|
|
||||||
)
|
|
||||||
|
|
||||||
// HeaderValue represents a unique or multiple header value(s)
|
|
||||||
type HeaderValue [][]byte
|
|
||||||
|
|
||||||
|
|
||||||
// T represents the data of a HTTP request header
|
|
||||||
type T struct{
|
|
||||||
Name HeaderType
|
|
||||||
Values HeaderValue
|
|
||||||
}
|
|
|
@ -1,110 +0,0 @@
|
||||||
package request
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"git.xdrm.io/go/ws/internal/http/upgrade/request/parser/header"
|
|
||||||
"git.xdrm.io/go/ws/internal/http/upgrade/response"
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseHeader parses any http request line
|
|
||||||
// (header and request-line)
|
|
||||||
func (r *T) parseHeader(b []byte) error {
|
|
||||||
|
|
||||||
/* (1) First line -> GET {uri} HTTP/{version}
|
|
||||||
---------------------------------------------------------*/
|
|
||||||
if !r.first {
|
|
||||||
|
|
||||||
err := r.request.Parse(b)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
r.code = response.BadRequest
|
|
||||||
return &InvalidRequest{"Request-Line", err.Error()}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.first = true
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (2) Other lines -> Header-Name: Header-Value
|
|
||||||
---------------------------------------------------------*/
|
|
||||||
/* (1) Try to parse header */
|
|
||||||
head, err := header.Parse(b)
|
|
||||||
if err != nil {
|
|
||||||
r.code = response.BadRequest
|
|
||||||
return fmt.Errorf("Error parsing header: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (2) Manage header */
|
|
||||||
switch head.Name {
|
|
||||||
case header.HOST:
|
|
||||||
err = r.extractHostPort(head.Values)
|
|
||||||
case header.ORIGIN:
|
|
||||||
err = r.extractOrigin(head.Values)
|
|
||||||
case header.UPGRADE:
|
|
||||||
err = r.checkUpgrade(head.Values)
|
|
||||||
case header.CONNECTION:
|
|
||||||
err = r.checkConnection(head.Values)
|
|
||||||
case header.WSVERSION:
|
|
||||||
err = r.checkVersion(head.Values)
|
|
||||||
case header.WSKEY:
|
|
||||||
err = r.extractKey(head.Values)
|
|
||||||
case header.WSPROTOCOL:
|
|
||||||
err = r.extractProtocols(head.Values)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatch error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// isComplete returns whether the Upgrade Request
|
|
||||||
// is complete (no missing required item)
|
|
||||||
func (r T) isComplete() error {
|
|
||||||
|
|
||||||
/* (1) Request-Line */
|
|
||||||
if !r.first {
|
|
||||||
return &IncompleteRequest{"Request-Line"}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (2) Host */
|
|
||||||
if len(r.host) == 0 {
|
|
||||||
return &IncompleteRequest{"Host"}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (3) Origin */
|
|
||||||
if !bypassOriginPolicy && len(r.origin) == 0 {
|
|
||||||
return &IncompleteRequest{"Origin"}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (4) Connection */
|
|
||||||
if !r.hasConnection {
|
|
||||||
return &IncompleteRequest{"Connection"}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (5) Upgrade */
|
|
||||||
if !r.hasUpgrade {
|
|
||||||
return &IncompleteRequest{"Upgrade"}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (6) Sec-WebSocket-Version */
|
|
||||||
if !r.hasVersion {
|
|
||||||
return &IncompleteRequest{"Sec-WebSocket-Version"}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (7) Sec-WebSocket-Key */
|
|
||||||
if len(r.key) < 1 {
|
|
||||||
return &IncompleteRequest{"Sec-WebSocket-Key"}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,85 +0,0 @@
|
||||||
package request
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"git.xdrm.io/go/ws/internal/http/reader"
|
|
||||||
"git.xdrm.io/go/ws/internal/http/upgrade/response"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Parse builds an upgrade HTTP request
|
|
||||||
// from a reader (typically bufio.NewRead of the socket)
|
|
||||||
func Parse(r io.Reader) (request *T, err error) {
|
|
||||||
|
|
||||||
req := new(T)
|
|
||||||
req.code = 500
|
|
||||||
|
|
||||||
/* (1) Parse request
|
|
||||||
---------------------------------------------------------*/
|
|
||||||
/* (1) Get chunk reader */
|
|
||||||
cr := reader.NewReader(r)
|
|
||||||
if err != nil {
|
|
||||||
return req, fmt.Errorf("Error while creating chunk reader: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (2) Parse header line by line */
|
|
||||||
for {
|
|
||||||
|
|
||||||
line, err := cr.Read()
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return req, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.parseHeader(line)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return req, err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (3) Check completion */
|
|
||||||
err = req.isComplete()
|
|
||||||
if err != nil {
|
|
||||||
req.code = response.BadRequest
|
|
||||||
return req, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.code = response.SwitchingProtocols
|
|
||||||
return req, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// StatusCode returns the status current
|
|
||||||
func (r T) StatusCode() response.StatusCode {
|
|
||||||
return r.code
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildResponse builds a response.T from the request
|
|
||||||
func (r *T) BuildResponse() *response.T {
|
|
||||||
|
|
||||||
inst := new(response.T)
|
|
||||||
|
|
||||||
/* (1) Copy code */
|
|
||||||
inst.SetStatusCode(r.code)
|
|
||||||
|
|
||||||
/* (2) Set Protocol */
|
|
||||||
if len(r.protocols) > 0 {
|
|
||||||
inst.SetProtocol(r.protocols[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (4) Process key */
|
|
||||||
inst.ProcessKey(r.key)
|
|
||||||
|
|
||||||
return inst
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetURI returns the actual URI
|
|
||||||
func (r T) GetURI() string {
|
|
||||||
return r.request.GetURI()
|
|
||||||
}
|
|
|
@ -1,32 +0,0 @@
|
||||||
package request
|
|
||||||
|
|
||||||
import "git.xdrm.io/go/ws/internal/http/upgrade/response"
|
|
||||||
|
|
||||||
// If origin is required
|
|
||||||
const bypassOriginPolicy = true
|
|
||||||
|
|
||||||
// T represents an HTTP Upgrade request
|
|
||||||
type T struct {
|
|
||||||
first bool // whether the first line has been read (GET uri HTTP/version)
|
|
||||||
|
|
||||||
// status code
|
|
||||||
code response.StatusCode
|
|
||||||
|
|
||||||
// request line
|
|
||||||
request RequestLine
|
|
||||||
|
|
||||||
// data to check origin (depends of reading order)
|
|
||||||
host string
|
|
||||||
port uint16 // 0 if not set
|
|
||||||
origin string
|
|
||||||
validPolicy bool
|
|
||||||
|
|
||||||
// ws data
|
|
||||||
key []byte
|
|
||||||
protocols [][]byte
|
|
||||||
|
|
||||||
// required fields check
|
|
||||||
hasConnection bool
|
|
||||||
hasUpgrade bool
|
|
||||||
hasVersion bool
|
|
||||||
}
|
|
|
@ -1,4 +1,4 @@
|
||||||
package request
|
package upgrade
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -6,13 +6,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// /* (1) Parse request */
|
// // 1. Parse request
|
||||||
// req, _ := request.Parse(s)
|
// req, _ := request.Parse(s)
|
||||||
|
|
||||||
// /* (3) Build response */
|
// // 3. Build response
|
||||||
// res := req.BuildResponse()
|
// res := req.BuildResponse()
|
||||||
|
|
||||||
// /* (4) Write into socket */
|
// // 4. Write into socket
|
||||||
// _, err := res.Send(s)
|
// _, err := res.Send(s)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// return nil, fmt.Errorf("Upgrade write error: %s", err)
|
// return nil, fmt.Errorf("Upgrade write error: %s", err)
|
||||||
|
@ -25,7 +25,7 @@ import (
|
||||||
|
|
||||||
func TestEOFSocket(t *testing.T) {
|
func TestEOFSocket(t *testing.T) {
|
||||||
|
|
||||||
socket := new(bytes.Buffer)
|
socket := &bytes.Buffer{}
|
||||||
|
|
||||||
_, err := Parse(socket)
|
_, err := Parse(socket)
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ func TestEOFSocket(t *testing.T) {
|
||||||
|
|
||||||
func TestInvalidRequestLine(t *testing.T) {
|
func TestInvalidRequestLine(t *testing.T) {
|
||||||
|
|
||||||
socket := new(bytes.Buffer)
|
socket := &bytes.Buffer{}
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Reqline string
|
Reqline string
|
||||||
HasError bool
|
HasError bool
|
||||||
|
@ -113,7 +113,7 @@ func TestInvalidHost(t *testing.T) {
|
||||||
|
|
||||||
requestLine := []byte("GET / HTTP/1.1\r\n")
|
requestLine := []byte("GET / HTTP/1.1\r\n")
|
||||||
|
|
||||||
socket := new(bytes.Buffer)
|
socket := &bytes.Buffer{}
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Host string
|
Host string
|
||||||
HasError bool
|
HasError bool
|
|
@ -1,4 +1,4 @@
|
||||||
package response
|
package upgrade
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
|
@ -7,19 +7,35 @@ import (
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// HTTPVersion constant
|
||||||
|
const HTTPVersion = "1.1"
|
||||||
|
|
||||||
|
// UsedWSVersion constant websocket version
|
||||||
|
const UsedWSVersion = 13
|
||||||
|
|
||||||
|
// WSSalt constant websocket salt
|
||||||
|
const WSSalt = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
|
||||||
|
// Response represents an HTTP Upgrade Response
|
||||||
|
type Response struct {
|
||||||
|
code StatusCode // status code
|
||||||
|
accept []byte // processed from Sec-WebSocket-Key
|
||||||
|
protocol []byte // set from Sec-WebSocket-Protocol or none if not received
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatusCode sets the status code
|
// SetStatusCode sets the status code
|
||||||
func (r *T) SetStatusCode(sc StatusCode) {
|
func (r *Response) SetStatusCode(sc StatusCode) {
|
||||||
r.code = sc
|
r.code = sc
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetProtocol sets the protocols
|
// SetProtocol sets the protocols
|
||||||
func (r *T) SetProtocol(p []byte) {
|
func (r *Response) SetProtocol(p []byte) {
|
||||||
r.protocol = p
|
r.protocol = p
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessKey processes the accept token according
|
// ProcessKey processes the accept token according
|
||||||
// to the rfc from the Sec-WebSocket-Key
|
// to the rfc from the Sec-WebSocket-Key
|
||||||
func (r *T) ProcessKey(k []byte) {
|
func (r *Response) ProcessKey(k []byte) {
|
||||||
|
|
||||||
// do nothing for empty key
|
// do nothing for empty key
|
||||||
if k == nil || len(k) == 0 {
|
if k == nil || len(k) == 0 {
|
||||||
|
@ -27,40 +43,40 @@ func (r *T) ProcessKey(k []byte) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (1) Concat with constant salt */
|
// 1. Concat with constant salt
|
||||||
mix := append(k, WSSalt...)
|
mix := append(k, []byte(WSSalt)...)
|
||||||
|
|
||||||
/* (2) Hash with sha1 algorithm */
|
// 2. Hash with sha1 algorithm
|
||||||
digest := sha1.Sum(mix)
|
digest := sha1.Sum(mix)
|
||||||
|
|
||||||
/* (3) Base64 encode it */
|
// 3. Base64 encode it
|
||||||
r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
|
r.accept = []byte(base64.StdEncoding.EncodeToString(digest[:sha1.Size]))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends the response through an io.Writer
|
// Send sends the response through an io.Writer
|
||||||
// typically a socket
|
// typically a socket
|
||||||
func (r T) Send(w io.Writer) (int, error) {
|
func (r Response) Send(w io.Writer) (int, error) {
|
||||||
|
|
||||||
/* (1) Build response line */
|
// 1. Build response line
|
||||||
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HttpVersion, r.code, r.code.String())
|
responseLine := fmt.Sprintf("HTTP/%s %d %s\r\n", HTTPVersion, r.code, r.code)
|
||||||
|
|
||||||
/* (2) Build headers */
|
// 2. Build headers
|
||||||
optionalProtocol := ""
|
optionalProtocol := ""
|
||||||
if len(r.protocol) > 0 {
|
if len(r.protocol) > 0 {
|
||||||
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol)
|
optionalProtocol = fmt.Sprintf("Sec-WebSocket-Protocol: %s\r\n", r.protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", WSVersion, optionalProtocol)
|
headers := fmt.Sprintf("Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: %d\r\n%s", UsedWSVersion, optionalProtocol)
|
||||||
if r.accept != nil {
|
if r.accept != nil {
|
||||||
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept)
|
headers = fmt.Sprintf("%sSec-WebSocket-Accept: %s\r\n", headers, r.accept)
|
||||||
}
|
}
|
||||||
headers = fmt.Sprintf("%s\r\n", headers)
|
headers = fmt.Sprintf("%s\r\n", headers)
|
||||||
|
|
||||||
/* (3) Build all */
|
// 3. Build all
|
||||||
raw := []byte(fmt.Sprintf("%s%s", responseLine, headers))
|
raw := []byte(fmt.Sprintf("%s%s", responseLine, headers))
|
||||||
|
|
||||||
/* (4) Write */
|
// 4. Write
|
||||||
written, err := w.Write(raw)
|
written, err := w.Write(raw)
|
||||||
|
|
||||||
return written, err
|
return written, err
|
||||||
|
@ -68,11 +84,11 @@ func (r T) Send(w io.Writer) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProtocol returns the choosen protocol if set, else nil
|
// GetProtocol returns the choosen protocol if set, else nil
|
||||||
func (r T) GetProtocol() []byte {
|
func (r Response) GetProtocol() []byte {
|
||||||
return r.protocol
|
return r.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatusCode returns the response status code
|
// GetStatusCode returns the response status code
|
||||||
func (r T) GetStatusCode() StatusCode {
|
func (r Response) GetStatusCode() StatusCode {
|
||||||
return r.code
|
return r.code
|
||||||
}
|
}
|
|
@ -1,15 +0,0 @@
|
||||||
package response
|
|
||||||
|
|
||||||
// Constant
|
|
||||||
const HttpVersion = "1.1"
|
|
||||||
const WSVersion = 13
|
|
||||||
|
|
||||||
var WSSalt []byte = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
|
||||||
|
|
||||||
// T represents an HTTP Upgrade Response
|
|
||||||
type T struct {
|
|
||||||
code StatusCode // status code
|
|
||||||
accept []byte // processed from Sec-WebSocket-Key
|
|
||||||
protocol []byte // set from Sec-WebSocket-Protocol or none if not received
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,9 +1,9 @@
|
||||||
package response
|
package upgrade
|
||||||
|
|
||||||
// StatusCode maps the status codes (and description)
|
// StatusCode maps the status codes (and description)
|
||||||
type StatusCode uint16
|
type StatusCode uint16
|
||||||
|
|
||||||
var (
|
const (
|
||||||
// SwitchingProtocols - handshake success
|
// SwitchingProtocols - handshake success
|
||||||
SwitchingProtocols StatusCode = 101
|
SwitchingProtocols StatusCode = 101
|
||||||
// BadRequest - missing/malformed headers
|
// BadRequest - missing/malformed headers
|
|
@ -9,45 +9,45 @@ import (
|
||||||
// from a pattern string
|
// from a pattern string
|
||||||
func buildScheme(ss []string) (Scheme, error) {
|
func buildScheme(ss []string) (Scheme, error) {
|
||||||
|
|
||||||
/* (1) Build scheme */
|
// 1. Build scheme
|
||||||
sch := make(Scheme, 0, maxMatch)
|
sch := make(Scheme, 0, maxMatch)
|
||||||
|
|
||||||
for _, s := range ss {
|
for _, s := range ss {
|
||||||
|
|
||||||
/* (2) ignore empty */
|
// 2. ignore empty
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
m := new(matcher)
|
m := &matcher{}
|
||||||
|
|
||||||
switch s {
|
switch s {
|
||||||
|
|
||||||
/* (3) Card: 0, N */
|
// 3. Card: 0, N
|
||||||
case "**":
|
case "**":
|
||||||
m.req = false
|
m.req = false
|
||||||
m.mul = true
|
m.mul = true
|
||||||
sch = append(sch, m)
|
sch = append(sch, m)
|
||||||
|
|
||||||
/* (4) Card: 1, N */
|
// 4. Card: 1, N
|
||||||
case "..":
|
case "..":
|
||||||
m.req = true
|
m.req = true
|
||||||
m.mul = true
|
m.mul = true
|
||||||
sch = append(sch, m)
|
sch = append(sch, m)
|
||||||
|
|
||||||
/* (5) Card: 0, 1 */
|
// 5. Card: 0, 1
|
||||||
case "*":
|
case "*":
|
||||||
m.req = false
|
m.req = false
|
||||||
m.mul = false
|
m.mul = false
|
||||||
sch = append(sch, m)
|
sch = append(sch, m)
|
||||||
|
|
||||||
/* (6) Card: 1 */
|
// 6. Card: 1
|
||||||
case ".":
|
case ".":
|
||||||
m.req = true
|
m.req = true
|
||||||
m.mul = false
|
m.mul = false
|
||||||
sch = append(sch, m)
|
sch = append(sch, m)
|
||||||
|
|
||||||
/* (7) Card: 1, literal string */
|
// 7. Card: 1, literal string
|
||||||
default:
|
default:
|
||||||
m.req = true
|
m.req = true
|
||||||
m.mul = false
|
m.mul = false
|
||||||
|
@ -64,16 +64,16 @@ func buildScheme(ss []string) (Scheme, error) {
|
||||||
// optimise optimised the scheme for further parsing
|
// optimise optimised the scheme for further parsing
|
||||||
func (s Scheme) optimise() (Scheme, error) {
|
func (s Scheme) optimise() (Scheme, error) {
|
||||||
|
|
||||||
/* (1) Nothing to do if only 1 element */
|
// 1. Nothing to do if only 1 element
|
||||||
if len(s) <= 1 {
|
if len(s) <= 1 {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Init reshifted scheme */
|
// 2. Init reshifted scheme
|
||||||
rshift := make(Scheme, 0, maxMatch)
|
rshift := make(Scheme, 0, maxMatch)
|
||||||
rshift = append(rshift, s[0])
|
rshift = append(rshift, s[0])
|
||||||
|
|
||||||
/* (2) Iterate over matchers */
|
// 2. Iterate over matchers
|
||||||
for p, i, l := 0, 1, len(s); i < l; i++ {
|
for p, i, l := 0, 1, len(s); i < l; i++ {
|
||||||
|
|
||||||
pre, cur := s[p], s[i]
|
pre, cur := s[p], s[i]
|
||||||
|
@ -106,11 +106,11 @@ func (s Scheme) optimise() (Scheme, error) {
|
||||||
// it returns a cleared uri, without STRING data
|
// it returns a cleared uri, without STRING data
|
||||||
func (s Scheme) matchString(uri string) (string, bool) {
|
func (s Scheme) matchString(uri string) (string, bool) {
|
||||||
|
|
||||||
/* (1) Initialise variables */
|
// 1. Initialise variables
|
||||||
clr := uri // contains cleared input string
|
clr := uri // contains cleared input string
|
||||||
minOff := 0 // minimum offset
|
minOff := 0 // minimum offset
|
||||||
|
|
||||||
/* (2) Iterate over strings */
|
// 2. Iterate over strings
|
||||||
for _, m := range s {
|
for _, m := range s {
|
||||||
|
|
||||||
ls := len(m.pat)
|
ls := len(m.pat)
|
||||||
|
@ -147,12 +147,12 @@ func (s Scheme) matchString(uri string) (string, bool) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) If exists, remove trailing '/' */
|
// 3. If exists, remove trailing '/'
|
||||||
if clr[len(clr)-1] == '/' {
|
if clr[len(clr)-1] == '/' {
|
||||||
clr = clr[:len(clr)-1]
|
clr = clr[:len(clr)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) If exists, remove trailing '\a' */
|
// 4. If exists, remove trailing '\a'
|
||||||
if clr[len(clr)-1] == '\a' {
|
if clr[len(clr)-1] == '\a' {
|
||||||
clr = clr[:len(clr)-1]
|
clr = clr[:len(clr)-1]
|
||||||
}
|
}
|
||||||
|
@ -166,7 +166,7 @@ func (s Scheme) matchString(uri string) (string, bool) {
|
||||||
// + it sets the matchers buffers for later extraction
|
// + it sets the matchers buffers for later extraction
|
||||||
func (s Scheme) matchWildcards(clear string) bool {
|
func (s Scheme) matchWildcards(clear string) bool {
|
||||||
|
|
||||||
/* (1) Extract wildcards (ref) */
|
// 1. Extract wildcards (ref)
|
||||||
wildcards := make(Scheme, 0, maxMatch)
|
wildcards := make(Scheme, 0, maxMatch)
|
||||||
|
|
||||||
for _, m := range s {
|
for _, m := range s {
|
||||||
|
@ -176,15 +176,15 @@ func (s Scheme) matchWildcards(clear string) bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) If no wildcards -> match */
|
// 2. If no wildcards -> match
|
||||||
if len(wildcards) == 0 {
|
if len(wildcards) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Break uri by '\a' characters */
|
// 3. Break uri by '\a' characters
|
||||||
matches := strings.Split(clear, "\a")[1:]
|
matches := strings.Split(clear, "\a")[1:]
|
||||||
|
|
||||||
/* (4) Iterate over matches */
|
// 4. Iterate over matches
|
||||||
for n, match := range matches {
|
for n, match := range matches {
|
||||||
|
|
||||||
// {1} If no more matcher //
|
// {1} If no more matcher //
|
||||||
|
@ -210,7 +210,7 @@ func (s Scheme) matchWildcards(clear string) bool {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (5) Match */
|
// 5. Match
|
||||||
return true
|
return true
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,15 +8,15 @@ import (
|
||||||
// Build builds an URI scheme from a pattern string
|
// Build builds an URI scheme from a pattern string
|
||||||
func Build(s string) (*Scheme, error) {
|
func Build(s string) (*Scheme, error) {
|
||||||
|
|
||||||
/* (1) Manage '/' at the start */
|
// 1. Manage '/' at the start
|
||||||
if len(s) < 1 || s[0] != '/' {
|
if len(s) < 1 || s[0] != '/' {
|
||||||
return nil, fmt.Errorf("URI must begin with '/'")
|
return nil, fmt.Errorf("URI must begin with '/'")
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Split by '/' */
|
// 2. Split by '/'
|
||||||
parts := strings.Split(s, "/")
|
parts := strings.Split(s, "/")
|
||||||
|
|
||||||
/* (3) Max exceeded */
|
// 3. Max exceeded
|
||||||
if len(parts)-2 > maxMatch {
|
if len(parts)-2 > maxMatch {
|
||||||
for i, p := range parts {
|
for i, p := range parts {
|
||||||
fmt.Printf("%d: '%s'\n", i, p)
|
fmt.Printf("%d: '%s'\n", i, p)
|
||||||
|
@ -24,13 +24,13 @@ func Build(s string) (*Scheme, error) {
|
||||||
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
|
return nil, fmt.Errorf("URI must not exceed %d slash-separated components, got %d", maxMatch, len(parts))
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) Build for each part */
|
// 4. Build for each part
|
||||||
sch, err := buildScheme(parts)
|
sch, err := buildScheme(parts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (5) Optimise structure */
|
// 5. Optimise structure
|
||||||
opti, err := sch.optimise()
|
opti, err := sch.optimise()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -43,18 +43,18 @@ func Build(s string) (*Scheme, error) {
|
||||||
// Match returns if the given URI is matched by the scheme
|
// Match returns if the given URI is matched by the scheme
|
||||||
func (s Scheme) Match(str string) bool {
|
func (s Scheme) Match(str string) bool {
|
||||||
|
|
||||||
/* (1) Nothing -> match all */
|
// 1. Nothing -> match all
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Check for string match */
|
// 2. Check for string match
|
||||||
clearURI, match := s.matchString(str)
|
clearURI, match := s.matchString(str)
|
||||||
if !match {
|
if !match {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Check for non-string match (wildcards) */
|
// 3. Check for non-string match (wildcards)
|
||||||
match = s.matchWildcards(clearURI)
|
match = s.matchWildcards(clearURI)
|
||||||
if !match {
|
if !match {
|
||||||
return false
|
return false
|
||||||
|
@ -66,12 +66,12 @@ func (s Scheme) Match(str string) bool {
|
||||||
// GetMatch returns the indexed match (excluding string matchers)
|
// GetMatch returns the indexed match (excluding string matchers)
|
||||||
func (s Scheme) GetMatch(n uint8) ([]string, error) {
|
func (s Scheme) GetMatch(n uint8) ([]string, error) {
|
||||||
|
|
||||||
/* (1) Index out of range */
|
// 1. Index out of range
|
||||||
if n > uint8(len(s)) {
|
if n > uint8(len(s)) {
|
||||||
return nil, fmt.Errorf("Index out of range")
|
return nil, fmt.Errorf("Index out of range")
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Iterate to find index (exclude strings) */
|
// 2. Iterate to find index (exclude strings)
|
||||||
ni := -1
|
ni := -1
|
||||||
for _, m := range s {
|
for _, m := range s {
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ func (s Scheme) GetMatch(n uint8) ([]string, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) If nothing found -> return empty set */
|
// 3. If nothing found -> return empty set
|
||||||
return nil, fmt.Errorf("Index out of range (max: %d)", ni)
|
return nil, fmt.Errorf("Index out of range (max: %d)", ni)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
56
message.go
56
message.go
|
@ -2,32 +2,36 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// constant error
|
||||||
|
type constErr string
|
||||||
|
|
||||||
|
func (c constErr) Error() string { return string(c) }
|
||||||
|
|
||||||
|
const (
|
||||||
// ErrUnmaskedFrame error
|
// ErrUnmaskedFrame error
|
||||||
ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame")
|
ErrUnmaskedFrame = constErr("Received unmasked frame")
|
||||||
// ErrTooLongControlFrame error
|
// ErrTooLongControlFrame error
|
||||||
ErrTooLongControlFrame = fmt.Errorf("Received a control frame that is fragmented or too long")
|
ErrTooLongControlFrame = constErr("Received a control frame that is fragmented or too long")
|
||||||
// ErrInvalidFragment error
|
// ErrInvalidFragment error
|
||||||
ErrInvalidFragment = fmt.Errorf("Received invalid fragmentation")
|
ErrInvalidFragment = constErr("Received invalid fragmentation")
|
||||||
// ErrUnexpectedContinuation error
|
// ErrUnexpectedContinuation error
|
||||||
ErrUnexpectedContinuation = fmt.Errorf("Received unexpected continuation frame")
|
ErrUnexpectedContinuation = constErr("Received unexpected continuation frame")
|
||||||
// ErrInvalidSize error
|
// ErrInvalidSize error
|
||||||
ErrInvalidSize = fmt.Errorf("Received invalid payload size")
|
ErrInvalidSize = constErr("Received invalid payload size")
|
||||||
// ErrInvalidPayload error
|
// ErrInvalidPayload error
|
||||||
ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload")
|
ErrInvalidPayload = constErr("Received invalid utf8 payload")
|
||||||
// ErrInvalidCloseStatus error
|
// ErrInvalidCloseStatus error
|
||||||
ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status")
|
ErrInvalidCloseStatus = constErr("Received invalid close status")
|
||||||
// ErrInvalidOpCode error
|
// ErrInvalidOpCode error
|
||||||
ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode")
|
ErrInvalidOpCode = constErr("Received invalid OpCode")
|
||||||
// ErrReservedBits error
|
// ErrReservedBits error
|
||||||
ErrReservedBits = fmt.Errorf("Received reserved bits")
|
ErrReservedBits = constErr("Received reserved bits")
|
||||||
// ErrCloseFrame error
|
// ErrCloseFrame error
|
||||||
ErrCloseFrame = fmt.Errorf("Received close Frame")
|
ErrCloseFrame = constErr("Received close Frame")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
|
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
|
||||||
|
@ -88,9 +92,9 @@ func readMessage(reader io.Reader) (*Message, error) {
|
||||||
var mask []byte
|
var mask []byte
|
||||||
var cursor int
|
var cursor int
|
||||||
|
|
||||||
m := new(Message)
|
m := &Message{}
|
||||||
|
|
||||||
/* (2) Byte 1: FIN and OpCode */
|
// 2. Byte 1: FIN and OpCode
|
||||||
tmpBuf = make([]byte, 1)
|
tmpBuf = make([]byte, 1)
|
||||||
err = readBytes(reader, tmpBuf)
|
err = readBytes(reader, tmpBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -105,7 +109,7 @@ func readMessage(reader io.Reader) (*Message, error) {
|
||||||
m.Final = bool(tmpBuf[0]&0x80 == 0x80)
|
m.Final = bool(tmpBuf[0]&0x80 == 0x80)
|
||||||
m.Type = MessageType(tmpBuf[0] & 0x0f)
|
m.Type = MessageType(tmpBuf[0] & 0x0f)
|
||||||
|
|
||||||
/* (3) Byte 2: Mask and Length[0] */
|
// 3. Byte 2: Mask and Length[0]
|
||||||
tmpBuf = make([]byte, 1)
|
tmpBuf = make([]byte, 1)
|
||||||
err = readBytes(reader, tmpBuf)
|
err = readBytes(reader, tmpBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -120,7 +124,7 @@ func readMessage(reader io.Reader) (*Message, error) {
|
||||||
// payload length
|
// payload length
|
||||||
m.Size = uint(tmpBuf[0] & 0x7f)
|
m.Size = uint(tmpBuf[0] & 0x7f)
|
||||||
|
|
||||||
/* (4) Extended payload */
|
// 4. Extended payload
|
||||||
if m.Size == 127 {
|
if m.Size == 127 {
|
||||||
|
|
||||||
tmpBuf = make([]byte, 8)
|
tmpBuf = make([]byte, 8)
|
||||||
|
@ -143,7 +147,7 @@ func readMessage(reader io.Reader) (*Message, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (5) Masking key */
|
// 5. Masking key
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
|
|
||||||
tmpBuf = make([]byte, 4)
|
tmpBuf = make([]byte, 4)
|
||||||
|
@ -157,7 +161,7 @@ func readMessage(reader io.Reader) (*Message, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (6) Read payload by chunks */
|
// 6. Read payload by chunks
|
||||||
m.Data = make([]byte, int(m.Size))
|
m.Data = make([]byte, int(m.Size))
|
||||||
|
|
||||||
cursor = 0
|
cursor = 0
|
||||||
|
@ -207,14 +211,14 @@ func (m Message) Send(writer io.Writer) error {
|
||||||
m.Size = uint(len(m.Data))
|
m.Size = uint(len(m.Data))
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (1) Byte 0 : FIN + opcode */
|
// 1. Byte 0 : FIN + opcode
|
||||||
var final byte = 0x80
|
var final byte = 0x80
|
||||||
if !m.Final {
|
if !m.Final {
|
||||||
final = 0
|
final = 0
|
||||||
}
|
}
|
||||||
header = append(header, final|byte(m.Type))
|
header = append(header, final|byte(m.Type))
|
||||||
|
|
||||||
/* (2) Get payload length */
|
// 2. Get payload length
|
||||||
if m.Size < 126 { // simple
|
if m.Size < 126 { // simple
|
||||||
|
|
||||||
header = append(header, byte(m.Size))
|
header = append(header, byte(m.Size))
|
||||||
|
@ -237,12 +241,12 @@ func (m Message) Send(writer io.Writer) error {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Build write buffer */
|
// 3. Build write buffer
|
||||||
writeBuf := make([]byte, 0, len(header)+int(m.Size))
|
writeBuf := make([]byte, 0, len(header)+int(m.Size))
|
||||||
writeBuf = append(writeBuf, header...)
|
writeBuf = append(writeBuf, header...)
|
||||||
writeBuf = append(writeBuf, m.Data[0:m.Size]...)
|
writeBuf = append(writeBuf, m.Data[0:m.Size]...)
|
||||||
|
|
||||||
/* (4) Send over socket by chunks */
|
// 4. Send over socket by chunks
|
||||||
toWrite := len(header) + int(m.Size)
|
toWrite := len(header) + int(m.Size)
|
||||||
cursor := 0
|
cursor := 0
|
||||||
for cursor < toWrite {
|
for cursor < toWrite {
|
||||||
|
@ -272,17 +276,17 @@ func (m Message) Send(writer io.Writer) error {
|
||||||
// returns the message error
|
// returns the message error
|
||||||
func (m *Message) check(fragment bool) error {
|
func (m *Message) check(fragment bool) error {
|
||||||
|
|
||||||
/* (1) Invalid first fragment (not TEXT nor BINARY) */
|
// 1. Invalid first fragment (not TEXT nor BINARY)
|
||||||
if !m.Final && !fragment && m.Type != Text && m.Type != Binary {
|
if !m.Final && !fragment && m.Type != Text && m.Type != Binary {
|
||||||
return ErrInvalidFragment
|
return ErrInvalidFragment
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Waiting fragment but received standalone frame */
|
// 2. Waiting fragment but received standalone frame
|
||||||
if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong {
|
if fragment && m.Type != Continuation && m.Type != Close && m.Type != Ping && m.Type != Pong {
|
||||||
return ErrInvalidFragment
|
return ErrInvalidFragment
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Control frame too long */
|
// 3. Control frame too long
|
||||||
if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) {
|
if (m.Type == Close || m.Type == Ping || m.Type == Pong) && (m.Size > 125 || !m.Final) {
|
||||||
return ErrTooLongControlFrame
|
return ErrTooLongControlFrame
|
||||||
}
|
}
|
||||||
|
@ -335,8 +339,6 @@ func (m *Message) check(fragment bool) error {
|
||||||
return ErrInvalidOpCode
|
return ErrInvalidOpCode
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// readBytes reads from a reader into a byte array
|
// readBytes reads from a reader into a byte array
|
||||||
|
|
|
@ -267,7 +267,7 @@ func TestSimpleMessageSending(t *testing.T) {
|
||||||
|
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
|
||||||
writer := new(bytes.Buffer)
|
writer := &bytes.Buffer{}
|
||||||
|
|
||||||
err := tc.Base.Send(writer)
|
err := tc.Base.Send(writer)
|
||||||
|
|
||||||
|
|
22
server.go
22
server.go
|
@ -65,13 +65,13 @@ func (s *Server) BindDefault(f ControllerFunc) {
|
||||||
// Bind a controller to an URI scheme
|
// Bind a controller to an URI scheme
|
||||||
func (s *Server) Bind(uri string, f ControllerFunc) error {
|
func (s *Server) Bind(uri string, f ControllerFunc) error {
|
||||||
|
|
||||||
/* (1) Build URI parser */
|
// 1. Build URI parser
|
||||||
uriScheme, err := parser.Build(uri)
|
uriScheme, err := parser.Build(uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Cannot build URI: %s", err)
|
return fmt.Errorf("Cannot build URI: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Create controller */
|
// 2. Create controller
|
||||||
s.ctl.URI = append(s.ctl.URI, &Controller{
|
s.ctl.URI = append(s.ctl.URI, &Controller{
|
||||||
URI: uriScheme,
|
URI: uriScheme,
|
||||||
Fun: f,
|
Fun: f,
|
||||||
|
@ -88,10 +88,10 @@ func (s *Server) Launch() error {
|
||||||
|
|
||||||
/* (1) Listen socket
|
/* (1) Listen socket
|
||||||
---------------------------------------------------------*/
|
---------------------------------------------------------*/
|
||||||
/* (1) Build full url */
|
// 1. Build full url
|
||||||
url := fmt.Sprintf("%s:%d", s.addr, s.port)
|
url := fmt.Sprintf("%s:%d", s.addr, s.port)
|
||||||
|
|
||||||
/* (2) Bind socket to listen */
|
// 2. Bind socket to listen
|
||||||
s.sock, err = net.Listen("tcp", url)
|
s.sock, err = net.Listen("tcp", url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Listen socket: %s", err)
|
return fmt.Errorf("Listen socket: %s", err)
|
||||||
|
@ -101,14 +101,14 @@ func (s *Server) Launch() error {
|
||||||
|
|
||||||
fmt.Printf("+ listening on %s\n", url)
|
fmt.Printf("+ listening on %s\n", url)
|
||||||
|
|
||||||
/* (3) Launch scheduler */
|
// 3. Launch scheduler
|
||||||
go s.scheduler()
|
go s.scheduler()
|
||||||
|
|
||||||
/* (2) For each incoming connection (client)
|
/* (2) For each incoming connection (client)
|
||||||
---------------------------------------------------------*/
|
---------------------------------------------------------*/
|
||||||
for {
|
for {
|
||||||
|
|
||||||
/* (1) Wait for client */
|
// 1. Wait for client
|
||||||
sock, err := s.sock.Accept()
|
sock, err := s.sock.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
|
@ -116,14 +116,14 @@ func (s *Server) Launch() error {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|
||||||
/* (2) Try to create client */
|
// 2. Try to create client
|
||||||
cli, err := buildClient(sock, s.ctl, s.ch)
|
cli, err := buildClient(sock, s.ctl, s.ch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf(" - %s\n", err)
|
fmt.Printf(" - %s\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Register client */
|
// 3. Register client
|
||||||
s.ch.register <- cli
|
s.ch.register <- cli
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
@ -141,15 +141,15 @@ func (s *Server) scheduler() {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|
||||||
/* (1) Create client */
|
// 1. Create client
|
||||||
case client := <-s.ch.register:
|
case client := <-s.ch.register:
|
||||||
s.clients[client.io.sock] = client
|
s.clients[client.io.sock] = client
|
||||||
|
|
||||||
/* (2) Remove client */
|
// 2. Remove client
|
||||||
case client := <-s.ch.unregister:
|
case client := <-s.ch.unregister:
|
||||||
delete(s.clients, client.io.sock)
|
delete(s.clients, client.io.sock)
|
||||||
|
|
||||||
/* (3) Broadcast */
|
// 3. Broadcast
|
||||||
case msg := <-s.ch.broadcast:
|
case msg := <-s.ch.broadcast:
|
||||||
for _, c := range s.clients {
|
for _, c := range s.clients {
|
||||||
c.ch.send <- msg
|
c.ch.send <- msg
|
||||||
|
|
Loading…
Reference in New Issue