update structure according to proper go project

This commit is contained in:
Adrien Marquès 2018-09-29 14:36:47 +02:00
parent 53de261a43
commit a9986c3123
10 changed files with 356 additions and 399 deletions

View File

@ -1,45 +1,42 @@
package ws package websocket
import ( import (
"time"
"sync"
"bufio" "bufio"
"encoding/binary" "encoding/binary"
"git.xdrm.io/gws/internal/http/upgrade/request"
"net"
"fmt" "fmt"
"git.xdrm.io/go/websocket/internal/http/upgrade/request"
"net"
"sync"
"time"
) )
// Represents a client socket utility (reader, writer, ..) // Represents a client socket utility (reader, writer, ..)
type clientIO struct { type clientIO struct {
sock net.Conn sock net.Conn
reader *bufio.Reader reader *bufio.Reader
kill chan<- *client // unregisters client kill chan<- *client // unregisters client
closing bool closing bool
closingMu sync.Mutex closingMu sync.Mutex
reading sync.WaitGroup reading sync.WaitGroup
writing bool writing bool
} }
// Represents all channels that need a client // Represents all channels that need a client
type clientChannelSet struct{ type clientChannelSet struct {
receive chan Message receive chan Message
send chan Message send chan Message
} }
// Represents a websocket client // Represents a websocket client
type client struct { type client struct {
io clientIO io clientIO
iface *Client iface *Client
ch clientChannelSet ch clientChannelSet
status MessageError // close status ; 0 = nothing ; else -> must close status MessageError // close status ; 0 = nothing ; else -> must close
} }
// Create creates a new client // Create creates a new client
func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error){ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) {
/* (1) Manage UPGRADE request /* (1) Manage UPGRADE request
---------------------------------------------------------*/ ---------------------------------------------------------*/
@ -60,7 +57,6 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode()) return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
} }
/* (2) Initialise client /* (2) Initialise client
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Get upgrade data */ /* (1) Get upgrade data */
@ -70,14 +66,14 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
/* (2) Initialise client */ /* (2) Initialise client */
cli := &client{ cli := &client{
io: clientIO{ io: clientIO{
sock: s, sock: s,
reader: bufio.NewReader(s), reader: bufio.NewReader(s),
kill: serverCh.unregister, kill: serverCh.unregister,
}, },
iface: &Client{ iface: &Client{
Protocol: string(clientProtocol), Protocol: string(clientProtocol),
Arguments: [][]string{ []string{ clientURI } }, Arguments: [][]string{[]string{clientURI}},
}, },
ch: clientChannelSet{ ch: clientChannelSet{
@ -86,12 +82,10 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
}, },
} }
/* (3) Find controller by URI /* (3) Find controller by URI
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Try to find one */ /* (1) Try to find one */
controller, arguments := ctl.Match(clientURI); controller, arguments := ctl.Match(clientURI)
/* (2) If nothing found -> error */ /* (2) If nothing found -> error */
if controller == nil { if controller == nil {
@ -99,18 +93,16 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
} }
/* (3) Copy arguments */ /* (3) Copy arguments */
cli.iface.Arguments = arguments cli.iface.Arguments = arguments
/* (4) Launch client routines /* (4) Launch client routines
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Launch client controller */ /* (1) Launch client controller */
go controller.Fun( go controller.Fun(
cli.iface, // pass the client cli.iface, // pass the client
cli.ch.receive, // the receiver cli.ch.receive, // the receiver
cli.ch.send, // the sender cli.ch.send, // the sender
serverCh.broadcast, // broadcast sender serverCh.broadcast, // broadcast sender
) )
/* (2) Launch message reader */ /* (2) Launch message reader */
@ -123,11 +115,8 @@ func buildClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*cli
} }
// reader reads and parses messages from the buffer // reader reads and parses messages from the buffer
func clientReader(c *client){ func clientReader(c *client) {
var frag *Message var frag *Message
closeStatus := NORMAL closeStatus := NORMAL
@ -149,7 +138,9 @@ func clientReader(c *client){
if err == ErrUnmaskedFrame || err == ErrReservedBits { if err == ErrUnmaskedFrame || err == ErrReservedBits {
closeStatus = PROTOCOL_ERR closeStatus = PROTOCOL_ERR
} }
if err != nil { break } if err != nil {
break
}
/* (3) Fail on invalid message */ /* (3) Fail on invalid message */
msgErr := msg.check(frag != nil) msgErr := msg.check(frag != nil)
@ -159,39 +150,41 @@ func clientReader(c *client){
switch msgErr { switch msgErr {
// Fail // Fail
case ErrUnexpectedContinuation: case ErrUnexpectedContinuation:
closeStatus = NONE closeStatus = NONE
clientAck = false clientAck = false
mustClose = true mustClose = true
// proper close // proper close
case CloseFrame: case CloseFrame:
closeStatus = NORMAL closeStatus = NORMAL
clientAck = true clientAck = true
mustClose = true mustClose = true
// invalid payload proper close // invalid payload proper close
case ErrInvalidPayload: case ErrInvalidPayload:
closeStatus = INVALID_PAYLOAD closeStatus = INVALID_PAYLOAD
clientAck = true clientAck = true
mustClose = true mustClose = true
// any other error -> protocol error // any other error -> protocol error
default: default:
closeStatus = PROTOCOL_ERR closeStatus = PROTOCOL_ERR
clientAck = true clientAck = true
mustClose = true mustClose = true
} }
if mustClose { break } if mustClose {
break
}
} }
/* (4) Ping <-> Pong */ /* (4) Ping <-> Pong */
if msg.Type == PING && c.io.writing { if msg.Type == PING && c.io.writing {
msg.Final = true msg.Final = true
msg.Type = PONG msg.Type = PONG
c.ch.send <- *msg c.ch.send <- *msg
continue continue
} }
@ -199,10 +192,10 @@ func clientReader(c *client){
/* (5) Store first fragment */ /* (5) Store first fragment */
if frag == nil && !msg.Final { if frag == nil && !msg.Final {
frag = &Message{ frag = &Message{
Type: msg.Type, Type: msg.Type,
Final: msg.Final, Final: msg.Final,
Data: msg.Data, Data: msg.Data,
Size: msg.Size, Size: msg.Size,
} }
continue continue
} }
@ -248,11 +241,9 @@ func clientReader(c *client){
} }
// writer writes into websocket // writer writes into websocket
// and is triggered by client.ch.send channel // and is triggered by client.ch.send channel
func clientWriter(c *client){ func clientWriter(c *client) {
c.io.writing = true // if channel still exists c.io.writing = true // if channel still exists
@ -278,20 +269,17 @@ func clientWriter(c *client){
} }
// closes the connection // closes the connection
// send CLOSE frame is 'status' is not NONE // send CLOSE frame is 'status' is not NONE
// wait for the next message (CLOSE acknowledge) if 'clientACK' // wait for the next message (CLOSE acknowledge) if 'clientACK'
// then delete client // then delete client
func (c *client) close(status MessageError, clientACK bool){ func (c *client) close(status MessageError, clientACK bool) {
/* (1) Fail if already closing */ /* (1) Fail if already closing */
alreadyClosing := false alreadyClosing := false
c.io.closingMu.Lock() c.io.closingMu.Lock()
alreadyClosing = c.io.closing alreadyClosing = c.io.closing
c.io.closing = true c.io.closing = true
c.io.closingMu.Unlock() c.io.closingMu.Unlock()
if alreadyClosing { if alreadyClosing {
@ -304,19 +292,17 @@ func (c *client) close(status MessageError, clientACK bool){
} }
/* (3) kill reader if still running */ /* (3) kill reader if still running */
c.io.sock.SetReadDeadline(time.Now().Add(time.Second*-1)) c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
c.io.reading.Wait() c.io.reading.Wait()
if status != NONE { if status != NONE {
/* (3) Build message */ /* (3) Build message */
msg := &Message{ msg := &Message{
Final: true, Final: true,
Type: CLOSE, Type: CLOSE,
Size: 2, Size: 2,
Data: make([]byte, 2), Data: make([]byte, 2),
} }
binary.BigEndian.PutUint16(msg.Data, uint16(status)) binary.BigEndian.PutUint16(msg.Data, uint16(status))
@ -328,7 +314,6 @@ func (c *client) close(status MessageError, clientACK bool){
} }
/* (2) Wait for client CLOSE if needed */ /* (2) Wait for client CLOSE if needed */
if clientACK { if clientACK {
@ -358,4 +343,3 @@ func (c *client) close(status MessageError, clientACK bool){
return return
} }

View File

@ -1,13 +1,12 @@
package main package iface
import ( import (
"git.xdrm.io/gws/ws"
"time"
"fmt" "fmt"
ws "git.xdrm.io/go/websocket"
"time"
) )
func main() {
func main(){
startTime := time.Now().UnixNano() startTime := time.Now().UnixNano()
@ -15,10 +14,10 @@ func main(){
serv := ws.CreateServer("0.0.0.0", 4444) serv := ws.CreateServer("0.0.0.0", 4444)
/* (2) Bind default controller */ /* (2) Bind default controller */
serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message){ serv.BindDefault(func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
defer func(){ defer func() {
if (recover() != nil) { if recover() != nil {
fmt.Printf("*** PANIC\n") fmt.Printf("*** PANIC\n")
} }
}() }()
@ -34,11 +33,11 @@ func main(){
}) })
/* (3) Bind to URI */ /* (3) Bind to URI */
err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message){ err := serv.Bind("/channel/./room/./", func(cli *ws.Client, receiver <-chan ws.Message, sender chan<- ws.Message, bc chan<- ws.Message) {
fmt.Printf("[uri] connected\n") fmt.Printf("[uri] connected\n")
for msg := range receiver{ for msg := range receiver {
fmt.Printf("[uri] received '%s'\n", msg.Data) fmt.Printf("[uri] received '%s'\n", msg.Data)
sender <- msg sender <- msg
@ -48,7 +47,9 @@ func main(){
fmt.Printf("[uri] unexpectedly closed\n") fmt.Printf("[uri] unexpectedly closed\n")
}) })
if err != nil { panic(err) } if err != nil {
panic(err)
}
/* (4) Launch the server */ /* (4) Launch the server */
err = serv.Launch() err = serv.Launch()
@ -57,7 +58,6 @@ func main(){
return return
} }
fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3) fmt.Printf("+ elapsed: %1.1f us\n", float32(time.Now().UnixNano()-startTime)/1e3)
} }

View File

@ -1,14 +1,14 @@
package ws package websocket
import ( import (
"git.xdrm.io/gws/internal/uri/parser" "git.xdrm.io/go/websocket/internal/uri/parser"
) )
// Represents available information about a client // Represents available information about a client
type Client struct { type Client struct {
Protocol string // choosen protocol (Sec-WebSocket-Protocol) Protocol string // choosen protocol (Sec-WebSocket-Protocol)
Arguments [][]string // URI parameters, index 0 is full URI, then matching groups Arguments [][]string // URI parameters, index 0 is full URI, then matching groups
Store interface{} // store (for client implementation-specific data) Store interface{} // store (for client implementation-specific data)
} }
// Represents a websocket controller callback function // Represents a websocket controller callback function
@ -16,8 +16,8 @@ type ControllerFunc func(*Client, <-chan Message, chan<- Message, chan<- Message
// Represents a websocket controller // Represents a websocket controller
type Controller struct { type Controller struct {
URI *parser.Scheme // uri scheme URI *parser.Scheme // uri scheme
Fun ControllerFunc // controller function Fun ControllerFunc // controller function
} }
// Represents a controller set // Represents a controller set
@ -26,14 +26,12 @@ type ControllerSet struct {
Uri []*Controller // uri controllers Uri []*Controller // uri controllers
} }
// Match finds a controller for a given URI // Match finds a controller for a given URI
// also it returns the matching string patterns // also it returns the matching string patterns
func (s *ControllerSet) Match(uri string) (*Controller, [][]string){ func (s *ControllerSet) Match(uri string) (*Controller, [][]string) {
/* (1) Initialise argument list */ /* (1) Initialise argument list */
arguments := [][]string{ []string{ uri } } arguments := [][]string{[]string{uri}}
/* (2) Try each controller */ /* (2) Try each controller */
for _, c := range s.Uri { for _, c := range s.Uri {
@ -62,4 +60,4 @@ func (s *ControllerSet) Match(uri string) (*Controller, [][]string){
/* (4) If default is NIL, return empty controller */ /* (4) If default is NIL, return empty controller */
return nil, arguments return nil, arguments
} }

View File

@ -1,15 +1,13 @@
package request package request
import ( import (
"git.xdrm.io/gws/internal/http/upgrade/response"
"git.xdrm.io/gws/internal/http/upgrade/request/parser/header"
"fmt" "fmt"
"git.xdrm.io/go/websocket/internal/http/upgrade/request/parser/header"
"git.xdrm.io/go/websocket/internal/http/upgrade/response"
"strconv" "strconv"
"strings" "strings"
) )
// checkHost checks and extracts the Host header // checkHost checks and extracts the Host header
func (r *T) extractHostPort(bb header.HeaderValue) error { func (r *T) extractHostPort(bb header.HeaderValue) error {
@ -42,7 +40,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
// if 'Origin' header is already read, check it // if 'Origin' header is already read, check it
if len(r.origin) > 0 { if len(r.origin) > 0 {
if err != nil { if err != nil {
err = r.checkOriginPolicy() err = r.checkOriginPolicy()
r.code = response.FORBIDDEN r.code = response.FORBIDDEN
return &InvalidOriginPolicy{r.host, r.origin, err} return &InvalidOriginPolicy{r.host, r.origin, err}
} }
@ -52,12 +50,13 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
} }
// checkOrigin checks the Origin Header // checkOrigin checks the Origin Header
func (r *T) extractOrigin(bb header.HeaderValue) error { func (r *T) extractOrigin(bb header.HeaderValue) error {
// bypass // bypass
if bypassOriginPolicy { return nil } if bypassOriginPolicy {
return nil
}
if len(bb) != 1 { if len(bb) != 1 {
r.code = response.FORBIDDEN r.code = response.FORBIDDEN
@ -92,7 +91,7 @@ func (r *T) checkConnection(bb header.HeaderValue) error {
for _, b := range bb { for _, b := range bb {
if strings.ToLower( string(b) ) == "upgrade" { if strings.ToLower(string(b)) == "upgrade" {
r.hasConnection = true r.hasConnection = true
return nil return nil
} }
@ -113,7 +112,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
if strings.ToLower( string(bb[0]) ) == "websocket" { if strings.ToLower(string(bb[0])) == "websocket" {
r.hasUpgrade = true r.hasUpgrade = true
return nil return nil
} }
@ -152,7 +151,6 @@ func (r *T) extractKey(bb header.HeaderValue) error {
} }
// extractProtocols extracts the 'Sec-WebSocket-Protocol' header // extractProtocols extracts the 'Sec-WebSocket-Protocol' header
// it can contain multiple values // it can contain multiple values
func (r *T) extractProtocols(bb header.HeaderValue) error { func (r *T) extractProtocols(bb header.HeaderValue) error {
@ -160,4 +158,4 @@ func (r *T) extractProtocols(bb header.HeaderValue) error {
r.protocols = bb r.protocols = bb
return nil return nil
} }

View File

@ -1,12 +1,11 @@
package request package request
import ( import (
"git.xdrm.io/gws/internal/http/upgrade/response"
"fmt" "fmt"
"git.xdrm.io/gws/internal/http/upgrade/request/parser/header" "git.xdrm.io/go/websocket/internal/http/upgrade/request/parser/header"
"git.xdrm.io/go/websocket/internal/http/upgrade/response"
) )
// parseHeader parses any http request line // parseHeader parses any http request line
// (header and request-line) // (header and request-line)
func (r *T) parseHeader(b []byte) error { func (r *T) parseHeader(b []byte) error {
@ -27,8 +26,6 @@ func (r *T) parseHeader(b []byte) error {
} }
/* (2) Other lines -> Header-Name: Header-Value /* (2) Other lines -> Header-Name: Header-Value
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Try to parse header */ /* (1) Try to parse header */
@ -40,19 +37,25 @@ func (r *T) parseHeader(b []byte) error {
/* (2) Manage header */ /* (2) Manage header */
switch head.Name { switch head.Name {
case header.HOST: err = r.extractHostPort(head.Values) case header.HOST:
case header.ORIGIN: err = r.extractOrigin(head.Values) err = r.extractHostPort(head.Values)
case header.UPGRADE: err = r.checkUpgrade(head.Values) case header.ORIGIN:
case header.CONNECTION: err = r.checkConnection(head.Values) err = r.extractOrigin(head.Values)
case header.WSVERSION: err = r.checkVersion(head.Values) case header.UPGRADE:
case header.WSKEY: err = r.extractKey(head.Values) err = r.checkUpgrade(head.Values)
case header.WSPROTOCOL: err = r.extractProtocols(head.Values) case header.CONNECTION:
err = r.checkConnection(head.Values)
case header.WSVERSION:
err = r.checkVersion(head.Values)
case header.WSKEY:
err = r.extractKey(head.Values)
case header.WSPROTOCOL:
err = r.extractProtocols(head.Values)
default: default:
return nil return nil
} }
// dispatch error // dispatch error
if err != nil { if err != nil {
return err return err
@ -62,8 +65,6 @@ func (r *T) parseHeader(b []byte) error {
} }
// isComplete returns whether the Upgrade Request // isComplete returns whether the Upgrade Request
// is complete (no missing required item) // is complete (no missing required item)
func (r T) isComplete() error { func (r T) isComplete() error {
@ -105,4 +106,4 @@ func (r T) isComplete() error {
return nil return nil
} }

View File

@ -1,9 +1,9 @@
package request package request
import ( import (
"git.xdrm.io/gws/internal/http/upgrade/response"
"git.xdrm.io/gws/internal/http/reader"
"fmt" "fmt"
"git.xdrm.io/go/websocket/internal/http/reader"
"git.xdrm.io/go/websocket/internal/http/upgrade/response"
"io" "io"
) )
@ -14,7 +14,6 @@ func Parse(r io.Reader) (request *T, err error) {
req := new(T) req := new(T)
req.code = 500 req.code = 500
/* (1) Parse request /* (1) Parse request
---------------------------------------------------------*/ ---------------------------------------------------------*/
/* (1) Get chunk reader */ /* (1) Get chunk reader */
@ -50,23 +49,18 @@ func Parse(r io.Reader) (request *T, err error) {
return req, err return req, err
} }
req.code = response.SWITCHING_PROTOCOLS req.code = response.SWITCHING_PROTOCOLS
return req, nil return req, nil
} }
// StatusCode returns the status current // StatusCode returns the status current
func (r T) StatusCode() response.StatusCode { func (r T) StatusCode() response.StatusCode {
return r.code return r.code
} }
// BuildResponse builds a response.T from the request // BuildResponse builds a response.T from the request
func (r *T) BuildResponse() *response.T{ func (r *T) BuildResponse() *response.T {
inst := new(response.T) inst := new(response.T)
@ -84,9 +78,7 @@ func (r *T) BuildResponse() *response.T{
return inst return inst
} }
// GetURI returns the actual URI // GetURI returns the actual URI
func (r T) GetURI() string{ func (r T) GetURI() string {
return r.request.GetURI() return r.request.GetURI()
} }

View File

@ -1,19 +1,19 @@
package request package request
import "git.xdrm.io/gws/internal/http/upgrade/response" import "git.xdrm.io/go/websocket/internal/http/upgrade/response"
// If origin is required // If origin is required
const bypassOriginPolicy = true const bypassOriginPolicy = true
// T represents an HTTP Upgrade request // T represents an HTTP Upgrade request
type T struct { type T struct {
first bool // whether the first line has been read (GET uri HTTP/version) first bool // whether the first line has been read (GET uri HTTP/version)
// status code // status code
code response.StatusCode code response.StatusCode
// request line // request line
request RequestLine request RequestLine
// data to check origin (depends of reading order) // data to check origin (depends of reading order)
host string host string
@ -22,8 +22,8 @@ type T struct {
validPolicy bool validPolicy bool
// ws data // ws data
key []byte key []byte
protocols [][]byte protocols [][]byte
// required fields check // required fields check
hasConnection bool hasConnection bool

View File

@ -1,10 +1,10 @@
package ws package websocket
import ( import (
"unicode/utf8" "encoding/binary"
"fmt" "fmt"
"io" "io"
"encoding/binary" "unicode/utf8"
) )
var ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame") var ErrUnmaskedFrame = fmt.Errorf("Received unmasked frame")
@ -15,10 +15,9 @@ var ErrInvalidSize = fmt.Errorf("Received invalid payload size")
var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload") var ErrInvalidPayload = fmt.Errorf("Received invalid utf8 payload")
var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status") var ErrInvalidCloseStatus = fmt.Errorf("Received invalid close status")
var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode") var ErrInvalidOpCode = fmt.Errorf("Received invalid OpCode")
var ErrReservedBits = fmt.Errorf("Received reserved bits") var ErrReservedBits = fmt.Errorf("Received reserved bits")
var CloseFrame = fmt.Errorf("Received close Frame") var CloseFrame = fmt.Errorf("Received close Frame")
// Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask // Maximum Header Size = Final/OpCode + isMask/Length + Length + Mask
const maximumHeaderSize = 1 + 1 + 8 + 4 const maximumHeaderSize = 1 + 1 + 8 + 4
const maxWriteChunk = 0x7fff const maxWriteChunk = 0x7fff
@ -46,8 +45,7 @@ const (
CLOSE MessageType = 0x08 CLOSE MessageType = 0x08
PING MessageType = 0x09 PING MessageType = 0x09
PONG MessageType = 0x0a PONG MessageType = 0x0a
); )
// Represents a websocket message // Represents a websocket message
type Message struct { type Message struct {
@ -57,13 +55,12 @@ type Message struct {
Data []byte Data []byte
} }
// receive reads a message form reader // receive reads a message form reader
func readMessage(reader io.Reader) (*Message, error){ func readMessage(reader io.Reader) (*Message, error) {
var err error var err error
var tmpBuf []byte var tmpBuf []byte
var mask []byte var mask []byte
var cursor int var cursor int
m := new(Message) m := new(Message)
@ -71,46 +68,53 @@ func readMessage(reader io.Reader) (*Message, error){
/* (2) Byte 1: FIN and OpCode */ /* (2) Byte 1: FIN and OpCode */
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
err = readBytes(reader, tmpBuf) err = readBytes(reader, tmpBuf)
if err != nil { return m, err } if err != nil {
return m, err
}
// check reserved bits // check reserved bits
if tmpBuf[0] & 0x70 != 0 { if tmpBuf[0]&0x70 != 0 {
return m, ErrReservedBits return m, ErrReservedBits
} }
m.Final = bool( tmpBuf[0] & 0x80 == 0x80 ) m.Final = bool(tmpBuf[0]&0x80 == 0x80)
m.Type = MessageType( tmpBuf[0] & 0x0f ) m.Type = MessageType(tmpBuf[0] & 0x0f)
/* (3) Byte 2: Mask and Length[0] */ /* (3) Byte 2: Mask and Length[0] */
tmpBuf = make([]byte, 1) tmpBuf = make([]byte, 1)
err = readBytes(reader, tmpBuf) err = readBytes(reader, tmpBuf)
if err != nil { return m, err } if err != nil {
return m, err
}
// if mask, byte array not nil // if mask, byte array not nil
if tmpBuf[0] & 0x80 == 0x80 { if tmpBuf[0]&0x80 == 0x80 {
mask = make([]byte, 0) mask = make([]byte, 0)
} }
// payload length // payload length
m.Size = uint( tmpBuf[0] & 0x7f ) m.Size = uint(tmpBuf[0] & 0x7f)
/* (4) Extended payload */ /* (4) Extended payload */
if m.Size == 127 { if m.Size == 127 {
tmpBuf = make([]byte, 8) tmpBuf = make([]byte, 8)
err := readBytes(reader, tmpBuf) err := readBytes(reader, tmpBuf)
if err != nil { return m, err } if err != nil {
return m, err
}
m.Size = uint( binary.BigEndian.Uint64(tmpBuf) ) m.Size = uint(binary.BigEndian.Uint64(tmpBuf))
} else if m.Size == 126 { } else if m.Size == 126 {
tmpBuf = make([]byte, 2) tmpBuf = make([]byte, 2)
err := readBytes(reader, tmpBuf) err := readBytes(reader, tmpBuf)
if err != nil { return m, err } if err != nil {
return m, err
}
m.Size = uint( binary.BigEndian.Uint16(tmpBuf) ) m.Size = uint(binary.BigEndian.Uint16(tmpBuf))
} }
@ -119,7 +123,9 @@ func readMessage(reader io.Reader) (*Message, error){
tmpBuf = make([]byte, 4) tmpBuf = make([]byte, 4)
err := readBytes(reader, tmpBuf) err := readBytes(reader, tmpBuf)
if err != nil { return m, err } if err != nil {
return m, err
}
mask = make([]byte, 4) mask = make([]byte, 4)
copy(mask, tmpBuf) copy(mask, tmpBuf)
@ -142,7 +148,7 @@ func readMessage(reader io.Reader) (*Message, error){
// {3} Unmask data // // {3} Unmask data //
if mask != nil { if mask != nil {
for i, l := cursor, cursor+nbread ; i < l ; i++ { for i, l := cursor, cursor+nbread; i < l; i++ {
mi := i % 4 // mask index mi := i % 4 // mask index
m.Data[i] = m.Data[i] ^ mask[mi] m.Data[i] = m.Data[i] ^ mask[mi]
@ -166,9 +172,6 @@ func readMessage(reader io.Reader) (*Message, error){
} }
// Send sends a frame over a socket // Send sends a frame over a socket
func (m Message) Send(writer io.Writer) error { func (m Message) Send(writer io.Writer) error {
@ -176,7 +179,7 @@ func (m Message) Send(writer io.Writer) error {
// fix size // fix size
if uint(len(m.Data)) <= m.Size { if uint(len(m.Data)) <= m.Size {
m.Size = uint( len(m.Data) ) m.Size = uint(len(m.Data))
} }
/* (1) Byte 0 : FIN + opcode */ /* (1) Byte 0 : FIN + opcode */
@ -184,12 +187,12 @@ func (m Message) Send(writer io.Writer) error {
if !m.Final { if !m.Final {
final = 0 final = 0
} }
header = append(header, final | byte(m.Type) ) header = append(header, final|byte(m.Type))
/* (2) Get payload length */ /* (2) Get payload length */
if m.Size < 126 { // simple if m.Size < 126 { // simple
header = append(header, byte(m.Size) ) header = append(header, byte(m.Size))
} else if m.Size <= 0xffff { // extended: 16 bits } else if m.Size <= 0xffff { // extended: 16 bits
@ -210,7 +213,7 @@ func (m Message) Send(writer io.Writer) error {
} }
/* (3) Build write buffer */ /* (3) Build write buffer */
writeBuf := make([]byte, 0, len(header) + int(m.Size)) writeBuf := make([]byte, 0, len(header)+int(m.Size))
writeBuf = append(writeBuf, header...) writeBuf = append(writeBuf, header...)
writeBuf = append(writeBuf, m.Data[0:m.Size]...) writeBuf = append(writeBuf, m.Data[0:m.Size]...)
@ -219,14 +222,16 @@ func (m Message) Send(writer io.Writer) error {
cursor := 0 cursor := 0
for cursor < toWrite { for cursor < toWrite {
maxBoundary := cursor+maxWriteChunk maxBoundary := cursor + maxWriteChunk
if maxBoundary > toWrite { if maxBoundary > toWrite {
maxBoundary = toWrite maxBoundary = toWrite
} }
// Try to wrote (at max 1024 bytes) // // Try to wrote (at max 1024 bytes) //
nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary]) nbwritten, err := writer.Write(writeBuf[cursor:maxBoundary])
if err != nil { return err } if err != nil {
return err
}
// Update cursor // // Update cursor //
cursor += nbwritten cursor += nbwritten
@ -236,14 +241,11 @@ func (m Message) Send(writer io.Writer) error {
return nil return nil
} }
// Check for message errors with: // Check for message errors with:
// (m) the current message // (m) the current message
// (fragment) whether there is a fragment in construction // (fragment) whether there is a fragment in construction
// returns the message error // returns the message error
func (m *Message) check(fragment bool) error{ func (m *Message) check(fragment bool) error {
/* (1) Invalid first fragment (not TEXT nor BINARY) */ /* (1) Invalid first fragment (not TEXT nor BINARY) */
if !m.Final && !fragment && m.Type != TEXT && m.Type != BINARY { if !m.Final && !fragment && m.Type != TEXT && m.Type != BINARY {
@ -261,59 +263,57 @@ func (m *Message) check(fragment bool) error{
} }
switch m.Type { switch m.Type {
case CONTINUATION: case CONTINUATION:
// unexpected continuation // unexpected continuation
if !fragment { if !fragment {
return ErrUnexpectedContinuation return ErrUnexpectedContinuation
} }
return nil return nil
case TEXT: case TEXT:
if m.Final && !utf8.Valid(m.Data) { if m.Final && !utf8.Valid(m.Data) {
return ErrInvalidPayload return ErrInvalidPayload
} }
return nil return nil
case BINARY: case BINARY:
return nil return nil
case CLOSE: case CLOSE:
// incomplete code // incomplete code
if m.Size == 1 { if m.Size == 1 {
return ErrInvalidCloseStatus
}
// invalid utf-8 reason
if m.Size > 2 && !utf8.Valid(m.Data[2:]) {
return ErrInvalidPayload
}
// invalid code
if m.Size >= 2 {
c := binary.BigEndian.Uint16(m.Data[0:2])
if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 {
return ErrInvalidCloseStatus return ErrInvalidCloseStatus
} }
}
return CloseFrame
// invalid utf-8 reason case PING:
if m.Size > 2 && !utf8.Valid(m.Data[2:]) { return nil
return ErrInvalidPayload
}
// invalid code case PONG:
if m.Size >= 2 { return nil
c := binary.BigEndian.Uint16(m.Data[0:2])
if c < 1000 || c >= 1004 && c <= 1006 || c >= 1012 && c <= 1016 || c == 1100 || c == 2000 || c == 2999 { default:
return ErrInvalidCloseStatus return ErrInvalidOpCode
}
}
return CloseFrame
case PING:
return nil
case PONG:
return nil
default:
return ErrInvalidOpCode
} }
return nil return nil
} }
// readBytes reads from a reader into a byte array // readBytes reads from a reader into a byte array
// until the byte length is fully filled with data // until the byte length is fully filled with data
// loops while there is no error // loops while there is no error
@ -336,4 +336,4 @@ func readBytes(reader io.Reader, buffer []byte) error {
return nil return nil
} }

View File

@ -1,15 +1,14 @@
package ws package websocket
import ( import (
"io"
"bytes" "bytes"
"io"
"testing" "testing"
) )
func TestSimpleMessageReading(t *testing.T) { func TestSimpleMessageReading(t *testing.T) {
cases := []struct{ cases := []struct {
Name string Name string
ReadBuffer []byte ReadBuffer []byte
Expected Message Expected Message
@ -17,57 +16,57 @@ func TestSimpleMessageReading(t *testing.T) {
}{ }{
{ // FIN ; TEXT ; Unmasked -> error { // FIN ; TEXT ; Unmasked -> error
"must fail on unmasked frame", "must fail on unmasked frame",
[]byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f}, []byte{0x81, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
Message{}, Message{},
ErrUnmaskedFrame, ErrUnmaskedFrame,
}, },
{ // FIN ; TEXT ; Unmasked -> error { // FIN ; TEXT ; Unmasked -> error
"must fail because of RSV bit 1 set", "must fail because of RSV bit 1 set",
[]byte{0x81 | 0x40,0x10,0x00,0x00,0x00,0x00}, []byte{0x81 | 0x40, 0x10, 0x00, 0x00, 0x00, 0x00},
Message{}, Message{},
ErrReservedBits, ErrReservedBits,
}, },
{ // FIN ; TEXT ; Unmasked -> error { // FIN ; TEXT ; Unmasked -> error
"must fail because of RSV bit 2 set", "must fail because of RSV bit 2 set",
[]byte{0x81 | 0x20,0x10,0x00,0x00,0x00,0x00}, []byte{0x81 | 0x20, 0x10, 0x00, 0x00, 0x00, 0x00},
Message{}, Message{},
ErrReservedBits, ErrReservedBits,
}, },
{ // FIN ; TEXT ; Unmasked -> error { // FIN ; TEXT ; Unmasked -> error
"must fail because of RSV bit 3 set", "must fail because of RSV bit 3 set",
[]byte{0x81 | 0x10,0x10,0x00,0x00,0x00,0x00}, []byte{0x81 | 0x10, 0x10, 0x00, 0x00, 0x00, 0x00},
Message{}, Message{},
ErrReservedBits, ErrReservedBits,
}, },
{ // FIN ; TEXT ; hello { // FIN ; TEXT ; hello
"simple hello text message", "simple hello text message",
[]byte{0x81,0x85,0x00,0x00,0x00,0x00,0x68,0x65,0x6c,0x6c,0x6f}, []byte{0x81, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
Message{ true, TEXT, 5, []byte("hello") }, Message{true, TEXT, 5, []byte("hello")},
nil, nil,
}, },
{ // FIN ; BINARY ; hello { // FIN ; BINARY ; hello
"simple hello binary message", "simple hello binary message",
[]byte{0x82,0x85,0x00,0x00,0x00,0x00,0x68,0x65,0x6c,0x6c,0x6f}, []byte{0x82, 0x85, 0x00, 0x00, 0x00, 0x00, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
Message{ true, BINARY, 5, []byte("hello") }, Message{true, BINARY, 5, []byte("hello")},
nil, nil,
}, },
{ // FIN ; BINARY ; test unmasking { // FIN ; BINARY ; test unmasking
"unmasking test", "unmasking test",
[]byte{0x82,0x88,0x01,0x02,0x03,0x04,0x10,0x20,0x30,0x40,0x50,0x60,0x70,0x80}, []byte{0x82, 0x88, 0x01, 0x02, 0x03, 0x04, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80},
Message{ true, BINARY, 8, []byte{0x11,0x22,0x33,0x44,0x51,0x62,0x73,0x84} }, Message{true, BINARY, 8, []byte{0x11, 0x22, 0x33, 0x44, 0x51, 0x62, 0x73, 0x84}},
nil, nil,
}, },
{ // FIN=0 ; TEXT ; { // FIN=0 ; TEXT ;
"non final frame", "non final frame",
[]byte{0x01,0x82,0x00,0x00,0x00,0x00,0x01,0x02}, []byte{0x01, 0x82, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02},
Message{ false, TEXT, 2, []byte{0x01,0x02} }, Message{false, TEXT, 2, []byte{0x01, 0x02}},
nil, nil,
}, },
} }
for _, tc := range cases{ for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T){ t.Run(tc.Name, func(t *testing.T) {
reader := bytes.NewBuffer(tc.ReadBuffer) reader := bytes.NewBuffer(tc.ReadBuffer)
@ -108,10 +107,9 @@ func TestSimpleMessageReading(t *testing.T) {
} }
func TestReadEOF(t *testing.T) { func TestReadEOF(t *testing.T) {
cases := []struct{ cases := []struct {
Name string Name string
ReadBuffer []byte ReadBuffer []byte
eof bool eof bool
@ -127,58 +125,58 @@ func TestReadEOF(t *testing.T) {
true, false, true, false,
}, { }, {
"only opcode and 0 length", "only opcode and 0 length",
[]byte{0x82,0x00}, []byte{0x82, 0x00},
false, true, false, true,
}, { }, {
"missing extended 16 bits length", "missing extended 16 bits length",
[]byte{0x82,126}, []byte{0x82, 126},
true, false, true, false,
}, { }, {
"incomplete extended 16 bits length", "incomplete extended 16 bits length",
[]byte{0x82,126, 0x00}, []byte{0x82, 126, 0x00},
true, false, true, false,
}, { }, {
"complete extended 16 bits length", "complete extended 16 bits length",
[]byte{0x82,126, 0x00, 0x00}, []byte{0x82, 126, 0x00, 0x00},
false, true, false, true,
}, { }, {
"missing extended 64 bits length", "missing extended 64 bits length",
[]byte{0x82,127}, []byte{0x82, 127},
true, false, true, false,
}, { }, {
"incomplete extended 64 bits length", "incomplete extended 64 bits length",
[]byte{0x82,127, 0x00}, []byte{0x82, 127, 0x00},
true, false, true, false,
}, { }, {
"incomplete extended 64 bits length", "incomplete extended 64 bits length",
[]byte{0x82,127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, []byte{0x82, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
true, false, true, false,
}, { }, {
"complete extended 64 bits length", "complete extended 64 bits length",
[]byte{0x82,127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, []byte{0x82, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
false, true, false, true,
}, { }, {
"missing mask", "missing mask",
[]byte{0x82,0x80}, []byte{0x82, 0x80},
true, false, true, false,
}, { }, {
"incomplete mask 1", "incomplete mask 1",
[]byte{0x82,0x80, 0x00}, []byte{0x82, 0x80, 0x00},
true, false, true, false,
}, { }, {
"incomplete mask 2", "incomplete mask 2",
[]byte{0x82,0x80, 0x00, 0x00, 0x00}, []byte{0x82, 0x80, 0x00, 0x00, 0x00},
true, false, true, false,
},{ }, {
"complete mask", "complete mask",
[]byte{0x82,0x80, 0x00, 0x00, 0x00, 0x00}, []byte{0x82, 0x80, 0x00, 0x00, 0x00, 0x00},
false, false, false, false,
}, },
} }
for _, tc := range cases{ for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T){ t.Run(tc.Name, func(t *testing.T) {
reader := bytes.NewBuffer(tc.ReadBuffer) reader := bytes.NewBuffer(tc.ReadBuffer)
@ -186,7 +184,7 @@ func TestReadEOF(t *testing.T) {
if tc.eof { if tc.eof {
if err != io.EOF{ if err != io.EOF {
t.Errorf("Expected EOF, got %v", err) t.Errorf("Expected EOF, got %v", err)
} }
@ -207,69 +205,67 @@ func TestReadEOF(t *testing.T) {
} }
func TestSimpleMessageSending(t *testing.T) { func TestSimpleMessageSending(t *testing.T) {
m4b1 := make([]byte, 0x7e - 1) m4b1 := make([]byte, 0x7e-1)
m4b2 := make([]byte, 0x7e) m4b2 := make([]byte, 0x7e)
m4b3 := make([]byte, 0x7e + 1) m4b3 := make([]byte, 0x7e+1)
m16b1 := make([]byte, 0xffff - 1) m16b1 := make([]byte, 0xffff-1)
m16b2 := make([]byte, 0xffff) m16b2 := make([]byte, 0xffff)
m16b3 := make([]byte, 0xffff + 1) m16b3 := make([]byte, 0xffff+1)
cases := []struct{ cases := []struct {
Name string Name string
Base Message Base Message
Expected []byte Expected []byte
}{ }{
{ {
"simple hello text message", "simple hello text message",
Message{ true, TEXT, 5, []byte("hello") }, Message{true, TEXT, 5, []byte("hello")},
[]byte{0x81,0x05,0x68,0x65,0x6c,0x6c,0x6f}, []byte{0x81, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
}, { }, {
"simple hello binary message", "simple hello binary message",
Message{ true, BINARY, 5, []byte("hello") }, Message{true, BINARY, 5, []byte("hello")},
[]byte{0x82,0x05,0x68,0x65,0x6c,0x6c,0x6f}, []byte{0x82, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f},
}, { }, {
"other simple binary message", "other simple binary message",
Message{ true, BINARY, 8, []byte{0x10,0x20,0x30,0x40,0x50,0x60,0x70,0x80} }, Message{true, BINARY, 8, []byte{0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80}},
[]byte{0x82,0x08,0x10,0x20,0x30,0x40,0x50,0x60,0x70,0x80}, []byte{0x82, 0x08, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80},
}, { }, {
"non final frame", "non final frame",
Message{ false, TEXT, 2, []byte{0x01,0x02} }, Message{false, TEXT, 2, []byte{0x01, 0x02}},
[]byte{0x01,0x02,0x01,0x02}, []byte{0x01, 0x02, 0x01, 0x02},
}, { }, {
"125 > normal length", "125 > normal length",
Message{ true, TEXT, uint(len(m4b1)), m4b1 }, Message{true, TEXT, uint(len(m4b1)), m4b1},
append([]byte{0x81,0x7e-1}, m4b1...), append([]byte{0x81, 0x7e - 1}, m4b1...),
}, { }, {
"126 > extended 16 bits length", "126 > extended 16 bits length",
Message{ true, TEXT, uint(len(m4b2)), m4b2 }, Message{true, TEXT, uint(len(m4b2)), m4b2},
append([]byte{0x81,126,0x00,0x7e}, m4b2...), append([]byte{0x81, 126, 0x00, 0x7e}, m4b2...),
}, { }, {
"127 > extended 16 bits length", "127 > extended 16 bits length",
Message{ true, TEXT, uint(len(m4b3)), m4b3 }, Message{true, TEXT, uint(len(m4b3)), m4b3},
append([]byte{0x81,126,0x00,0x7e+1}, m4b3...), append([]byte{0x81, 126, 0x00, 0x7e + 1}, m4b3...),
}, { }, {
"fffe > extended 16 bits length", "fffe > extended 16 bits length",
Message{ true, TEXT, uint(len(m16b1)), m16b1 }, Message{true, TEXT, uint(len(m16b1)), m16b1},
append([]byte{0x81,126, 0xff, 0xfe}, m16b1...), append([]byte{0x81, 126, 0xff, 0xfe}, m16b1...),
}, { }, {
"ffff > extended 16 bits length", "ffff > extended 16 bits length",
Message{ true, TEXT, uint(len(m16b2)), m16b2 }, Message{true, TEXT, uint(len(m16b2)), m16b2},
append([]byte{0x81,126,0xff,0xff}, m16b2...), append([]byte{0x81, 126, 0xff, 0xff}, m16b2...),
}, { }, {
"10000 > extended 64 bits length", "10000 > extended 64 bits length",
Message{ true, TEXT, uint(len(m16b3)), m16b3 }, Message{true, TEXT, uint(len(m16b3)), m16b3},
append([]byte{0x81,127, 0x00,0x00,0x00,0x00,0x00,0x01,0x00,0x00,}, m16b3...), append([]byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, m16b3...),
}, },
} }
for _, tc := range cases{ for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T){ t.Run(tc.Name, func(t *testing.T) {
writer := new(bytes.Buffer) writer := new(bytes.Buffer)
@ -291,9 +287,6 @@ func TestSimpleMessageSending(t *testing.T) {
} }
func TestMessageCheck(t *testing.T) { func TestMessageCheck(t *testing.T) {
type Case struct { type Case struct {
@ -303,8 +296,8 @@ func TestMessageCheck(t *testing.T) {
Expected error Expected error
} }
cases := []struct{ cases := []struct {
Name string Name string
Cases []Case Cases []Case
}{ }{
{ {
@ -358,7 +351,7 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE must not fail", "CLOSE must not fail",
Message{true, CLOSE, 125, []byte{0x03,0xe8,0}}, false, CloseFrame, Message{true, CLOSE, 125, []byte{0x03, 0xe8, 0}}, false, CloseFrame,
}, { }, {
"PING must not fail", "PING must not fail",
Message{true, PING, 125, []byte{}}, false, nil, Message{true, PING, 125, []byte{}}, false, nil,
@ -372,7 +365,7 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE must fail", "CLOSE must fail",
Message{true, CLOSE, 126, []byte{0x03,0xe8,0}}, false, ErrTooLongControlFrame, Message{true, CLOSE, 126, []byte{0x03, 0xe8, 0}}, false, ErrTooLongControlFrame,
}, { }, {
"PING must fail", "PING must fail",
Message{true, PING, 126, []byte{}}, false, ErrTooLongControlFrame, Message{true, PING, 126, []byte{}}, false, ErrTooLongControlFrame,
@ -386,7 +379,7 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE must fail", "CLOSE must fail",
Message{false, CLOSE, 126, []byte{0x03,0xe8,0}}, false, ErrInvalidFragment, Message{false, CLOSE, 126, []byte{0x03, 0xe8, 0}}, false, ErrInvalidFragment,
}, { }, {
"PING must fail", "PING must fail",
Message{false, PING, 126, []byte{}}, false, ErrInvalidFragment, Message{false, PING, 126, []byte{}}, false, ErrInvalidFragment,
@ -411,23 +404,23 @@ func TestMessageCheck(t *testing.T) {
Cases: []Case{ Cases: []Case{
{ {
"CLOSE valid reason", "CLOSE valid reason",
Message{true, CLOSE, 5, []byte{0x03,0xe8, 0xe2,0x82,0xa1}}, false, CloseFrame, Message{true, CLOSE, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0xa1}}, false, CloseFrame,
}, { }, {
"CLOSE invalid reason byte 2", "CLOSE invalid reason byte 2",
Message{true, CLOSE, 5, []byte{0x03,0xe8, 0xe2,0x28,0xa1}}, false, ErrInvalidPayload, Message{true, CLOSE, 5, []byte{0x03, 0xe8, 0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload,
}, { }, {
"CLOSE invalid reason byte 3", "CLOSE invalid reason byte 3",
Message{true, CLOSE, 5, []byte{0x03,0xe8, 0xe2,0x82,0x28}}, false, ErrInvalidPayload, Message{true, CLOSE, 5, []byte{0x03, 0xe8, 0xe2, 0x82, 0x28}}, false, ErrInvalidPayload,
}, },
{ {
"TEXT valid reason", "TEXT valid reason",
Message{true, TEXT, 3, []byte{0xe2,0x82,0xa1}}, false, nil, Message{true, TEXT, 3, []byte{0xe2, 0x82, 0xa1}}, false, nil,
}, { }, {
"TEXT invalid reason byte 2", "TEXT invalid reason byte 2",
Message{true, TEXT, 3, []byte{0xe2,0x28,0xa1}}, false, ErrInvalidPayload, Message{true, TEXT, 3, []byte{0xe2, 0x28, 0xa1}}, false, ErrInvalidPayload,
}, { }, {
"TEXT invalid reason byte 3", "TEXT invalid reason byte 3",
Message{true, TEXT, 3, []byte{0xe2,0x82,0x28}}, false, ErrInvalidPayload, Message{true, TEXT, 3, []byte{0xe2, 0x82, 0x28}}, false, ErrInvalidPayload,
}, },
}, },
}, { }, {
@ -438,112 +431,111 @@ func TestMessageCheck(t *testing.T) {
Message{true, CLOSE, 1, []byte{0x03}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 1, []byte{0x03}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1000", "valid CLOSE status 1000",
Message{true, CLOSE, 2, []byte{0x03,0xe8}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x03, 0xe8}}, false, CloseFrame,
}, { }, {
"invalid CLOSE status 999 under 1000", "invalid CLOSE status 999 under 1000",
Message{true, CLOSE, 2, []byte{0x03,0xe7}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xe7}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1001", "valid CLOSE status 1001",
Message{true, CLOSE, 2, []byte{0x03,0xe9}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x03, 0xe9}}, false, CloseFrame,
}, { }, {
"valid CLOSE status 1002", "valid CLOSE status 1002",
Message{true, CLOSE, 2, []byte{0x03,0xea}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x03, 0xea}}, false, CloseFrame,
}, { }, {
"valid CLOSE status 1003", "valid CLOSE status 1003",
Message{true, CLOSE, 2, []byte{0x03,0xeb}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x03, 0xeb}}, false, CloseFrame,
}, { }, {
"invalid CLOSE status 1004", "invalid CLOSE status 1004",
Message{true, CLOSE, 2, []byte{0x03,0xec}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xec}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1005", "invalid CLOSE status 1005",
Message{true, CLOSE, 2, []byte{0x03,0xed}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xed}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1006", "invalid CLOSE status 1006",
Message{true, CLOSE, 2, []byte{0x03,0xee}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xee}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1007", "valid CLOSE status 1007",
Message{true, CLOSE, 2, []byte{0x03,0xef}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x03, 0xef}}, false, CloseFrame,
}, { }, {
"valid CLOSE status 1011", "valid CLOSE status 1011",
Message{true, CLOSE, 2, []byte{0x03,0xf3}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x03, 0xf3}}, false, CloseFrame,
}, { }, {
"invalid CLOSE status 1012", "invalid CLOSE status 1012",
Message{true, CLOSE, 2, []byte{0x03,0xf4}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xf4}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1013", "invalid CLOSE status 1013",
Message{true, CLOSE, 2, []byte{0x03,0xf5}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xf5}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1014", "invalid CLOSE status 1014",
Message{true, CLOSE, 2, []byte{0x03,0xf6}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xf6}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1015", "invalid CLOSE status 1015",
Message{true, CLOSE, 2, []byte{0x03,0xf7}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xf7}}, false, ErrInvalidCloseStatus,
}, { }, {
"invalid CLOSE status 1016", "invalid CLOSE status 1016",
Message{true, CLOSE, 2, []byte{0x03,0xf8}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x03, 0xf8}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1017", "valid CLOSE status 1017",
Message{true, CLOSE, 2, []byte{0x03,0xf9}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x03, 0xf9}}, false, CloseFrame,
}, { }, {
"valid CLOSE status 1099", "valid CLOSE status 1099",
Message{true, CLOSE, 2, []byte{0x04,0x4b}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x04, 0x4b}}, false, CloseFrame,
}, { }, {
"invalid CLOSE status 1100", "invalid CLOSE status 1100",
Message{true, CLOSE, 2, []byte{0x04,0x4c}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x04, 0x4c}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 1101", "valid CLOSE status 1101",
Message{true, CLOSE, 2, []byte{0x04,0x4d}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x04, 0x4d}}, false, CloseFrame,
}, { }, {
"valid CLOSE status 1999", "valid CLOSE status 1999",
Message{true, CLOSE, 2, []byte{0x07,0xcf}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x07, 0xcf}}, false, CloseFrame,
}, { }, {
"invalid CLOSE status 2000", "invalid CLOSE status 2000",
Message{true, CLOSE, 2, []byte{0x07,0xd0}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x07, 0xd0}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 2001", "valid CLOSE status 2001",
Message{true, CLOSE, 2, []byte{0x07,0xd1}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x07, 0xd1}}, false, CloseFrame,
}, { }, {
"valid CLOSE status 2998", "valid CLOSE status 2998",
Message{true, CLOSE, 2, []byte{0x0b,0xb6}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x0b, 0xb6}}, false, CloseFrame,
}, { }, {
"invalid CLOSE status 2999", "invalid CLOSE status 2999",
Message{true, CLOSE, 2, []byte{0x0b,0xb7}}, false, ErrInvalidCloseStatus, Message{true, CLOSE, 2, []byte{0x0b, 0xb7}}, false, ErrInvalidCloseStatus,
}, { }, {
"valid CLOSE status 3000", "valid CLOSE status 3000",
Message{true, CLOSE, 2, []byte{0x0b,0xb8}}, false, CloseFrame, Message{true, CLOSE, 2, []byte{0x0b, 0xb8}}, false, CloseFrame,
}, },
}, },
}, { }, {
Name: "OpCode check", Name: "OpCode check",
Cases: []Case{ Cases: []Case{
{ "0", Message{true, 0, 0, []byte{}}, false, ErrUnexpectedContinuation, }, {"0", Message{true, 0, 0, []byte{}}, false, ErrUnexpectedContinuation},
{ "1", Message{true, 1, 0, []byte{}}, false, nil, }, {"1", Message{true, 1, 0, []byte{}}, false, nil},
{ "2", Message{true, 2, 0, []byte{}}, false, nil, }, {"2", Message{true, 2, 0, []byte{}}, false, nil},
{ "3", Message{true, 3, 0, []byte{}}, false, ErrInvalidOpCode, }, {"3", Message{true, 3, 0, []byte{}}, false, ErrInvalidOpCode},
{ "4", Message{true, 4, 0, []byte{}}, false, ErrInvalidOpCode, }, {"4", Message{true, 4, 0, []byte{}}, false, ErrInvalidOpCode},
{ "5", Message{true, 5, 0, []byte{}}, false, ErrInvalidOpCode, }, {"5", Message{true, 5, 0, []byte{}}, false, ErrInvalidOpCode},
{ "6", Message{true, 6, 0, []byte{}}, false, ErrInvalidOpCode, }, {"6", Message{true, 6, 0, []byte{}}, false, ErrInvalidOpCode},
{ "7", Message{true, 7, 0, []byte{}}, false, ErrInvalidOpCode, }, {"7", Message{true, 7, 0, []byte{}}, false, ErrInvalidOpCode},
{ "8", Message{true, 8, 0, []byte{}}, false, CloseFrame, }, {"8", Message{true, 8, 0, []byte{}}, false, CloseFrame},
{ "9", Message{true, 9, 0, []byte{}}, false, nil, }, {"9", Message{true, 9, 0, []byte{}}, false, nil},
{ "10", Message{true, 10, 0, []byte{}}, false, nil, }, {"10", Message{true, 10, 0, []byte{}}, false, nil},
{ "11", Message{true, 11, 0, []byte{}}, false, ErrInvalidOpCode, }, {"11", Message{true, 11, 0, []byte{}}, false, ErrInvalidOpCode},
{ "12", Message{true, 12, 0, []byte{}}, false, ErrInvalidOpCode, }, {"12", Message{true, 12, 0, []byte{}}, false, ErrInvalidOpCode},
{ "13", Message{true, 13, 0, []byte{}}, false, ErrInvalidOpCode, }, {"13", Message{true, 13, 0, []byte{}}, false, ErrInvalidOpCode},
{ "14", Message{true, 14, 0, []byte{}}, false, ErrInvalidOpCode, }, {"14", Message{true, 14, 0, []byte{}}, false, ErrInvalidOpCode},
{ "15", Message{true, 15, 0, []byte{}}, false, ErrInvalidOpCode, }, {"15", Message{true, 15, 0, []byte{}}, false, ErrInvalidOpCode},
}, },
}, },
} }
for _, tcc := range cases {
for _, tcc := range cases{ t.Run(tcc.Name, func(t *testing.T) {
t.Run(tcc.Name, func(t *testing.T){ for _, tc := range tcc.Cases {
for _, tc := range tcc.Cases{ t.Run(tc.Name, func(t *testing.T) {
t.Run(tc.Name, func(t *testing.T){
actual := tc.Msg.check(tc.WaitingFragment) actual := tc.Msg.check(tc.WaitingFragment)
@ -559,4 +551,4 @@ func TestMessageCheck(t *testing.T) {
} }
} }

View File

@ -1,43 +1,39 @@
package ws package websocket
import ( import (
"net"
"fmt" "fmt"
"git.xdrm.io/gws/internal/uri/parser" "git.xdrm.io/go/websocket/internal/uri/parser"
"net"
) )
// Represents all channels that need a server // Represents all channels that need a server
type serverChannelSet struct{ type serverChannelSet struct {
register chan *client register chan *client
unregister chan *client unregister chan *client
broadcast chan Message broadcast chan Message
} }
// Represents a websocket server // Represents a websocket server
type Server struct { type Server struct {
sock net.Listener // listen socket sock net.Listener // listen socket
addr []byte // server listening ip/host addr []byte // server listening ip/host
port uint16 // server listening port port uint16 // server listening port
clients map[net.Conn]*client clients map[net.Conn]*client
ctl ControllerSet // controllers ctl ControllerSet // controllers
ch serverChannelSet ch serverChannelSet
} }
// CreateServer creates a server for a specific HOST and PORT // CreateServer creates a server for a specific HOST and PORT
func CreateServer(host string, port uint16) *Server{ func CreateServer(host string, port uint16) *Server {
return &Server{ return &Server{
addr: []byte(host), addr: []byte(host),
port: port, port: port,
clients: make(map[net.Conn]*client, 0), clients: make(map[net.Conn]*client, 0),
ctl: ControllerSet{ ctl: ControllerSet{
Def: nil, Def: nil,
@ -53,11 +49,10 @@ func CreateServer(host string, port uint16) *Server{
} }
// BindDefault binds a default controller // BindDefault binds a default controller
// it will be called if the URI does not // it will be called if the URI does not
// match another controller // match another controller
func (s *Server) BindDefault(f ControllerFunc){ func (s *Server) BindDefault(f ControllerFunc) {
s.ctl.Def = &Controller{ s.ctl.Def = &Controller{
URI: nil, URI: nil,
@ -66,25 +61,25 @@ func (s *Server) BindDefault(f ControllerFunc){
} }
// Bind binds a controller to an URI scheme // Bind binds a controller to an URI scheme
func (s *Server) Bind(uri string, f ControllerFunc) error { func (s *Server) Bind(uri string, f ControllerFunc) error {
/* (1) Build URI parser */ /* (1) Build URI parser */
uriScheme, err := parser.Build(uri) uriScheme, err := parser.Build(uri)
if err != nil { return fmt.Errorf("Cannot build URI: %s", err) } if err != nil {
return fmt.Errorf("Cannot build URI: %s", err)
}
/* (2) Create controller */ /* (2) Create controller */
s.ctl.Uri = append(s.ctl.Uri, &Controller{ s.ctl.Uri = append(s.ctl.Uri, &Controller{
URI: uriScheme, URI: uriScheme,
Fun: f, Fun: f,
} ) })
return nil return nil
} }
// Launch launches the websocket server // Launch launches the websocket server
func (s *Server) Launch() error { func (s *Server) Launch() error {
@ -108,8 +103,6 @@ func (s *Server) Launch() error {
/* (3) Launch scheduler */ /* (3) Launch scheduler */
go s.scheduler() go s.scheduler()
/* (2) For each incoming connection (client) /* (2) For each incoming connection (client)
---------------------------------------------------------*/ ---------------------------------------------------------*/
for { for {
@ -120,7 +113,7 @@ func (s *Server) Launch() error {
break break
} }
go func(){ go func() {
/* (2) Try to create client */ /* (2) Try to create client */
cli, err := buildClient(sock, s.ctl, s.ch) cli, err := buildClient(sock, s.ctl, s.ch)
@ -140,34 +133,33 @@ func (s *Server) Launch() error {
} }
// Scheduler schedules clients registration and broadcast // Scheduler schedules clients registration and broadcast
func (s *Server) scheduler(){ func (s *Server) scheduler() {
for { for {
select { select {
/* (1) New client */ /* (1) New client */
case client := <- s.ch.register: case client := <-s.ch.register:
// fmt.Printf(" + client\n") // fmt.Printf(" + client\n")
s.clients[client.io.sock] = client s.clients[client.io.sock] = client
/* (2) New client */ /* (2) New client */
case client := <- s.ch.unregister: case client := <-s.ch.unregister:
// fmt.Printf(" - client\n") // fmt.Printf(" - client\n")
delete(s.clients, client.io.sock) delete(s.clients, client.io.sock)
/* (3) Broadcast */ /* (3) Broadcast */
case msg := <- s.ch.broadcast: case msg := <-s.ch.broadcast:
fmt.Printf(" + broadcast\n") fmt.Printf(" + broadcast\n")
for _, c := range s.clients{ for _, c := range s.clients {
c.ch.send <- msg c.ch.send <- msg
} }
} }