293 lines
5.6 KiB
Go
293 lines
5.6 KiB
Go
package websocket
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.xdrm.io/go/ws/internal/http/upgrade"
|
|
)
|
|
|
|
// Represents a client socket utility (reader, writer, ..)
|
|
type clientIO struct {
|
|
sock net.Conn
|
|
reader *bufio.Reader
|
|
kill chan<- *client // unregisters client
|
|
closing bool
|
|
closingMu sync.Mutex
|
|
reading sync.WaitGroup
|
|
writing bool
|
|
}
|
|
|
|
// Represents all channels that need a client
|
|
type clientChannelSet struct {
|
|
receive chan Message
|
|
send chan Message
|
|
}
|
|
|
|
// Represents a websocket client
|
|
type client struct {
|
|
io clientIO
|
|
iface *Client
|
|
ch clientChannelSet
|
|
status MessageError // close status ; 0 = nothing ; else -> must close
|
|
}
|
|
|
|
// newClient creates a new client
|
|
func newClient(s net.Conn, ctl ControllerSet, serverCh serverChannelSet) (*client, error) {
|
|
req := &upgrade.Request{}
|
|
_, err := req.ReadFrom(s)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request read: %w", err)
|
|
}
|
|
|
|
res := req.BuildResponse()
|
|
|
|
_, err = res.WriteTo(s)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("upgrade write: %w", err)
|
|
}
|
|
|
|
if res.StatusCode != 101 {
|
|
s.Close()
|
|
return nil, fmt.Errorf("upgrade failed (HTTP %d)", res.StatusCode)
|
|
}
|
|
|
|
var cli = &client{
|
|
io: clientIO{
|
|
sock: s,
|
|
reader: bufio.NewReader(s),
|
|
kill: serverCh.unregister,
|
|
},
|
|
|
|
iface: &Client{
|
|
Protocol: string(res.Protocol),
|
|
Arguments: [][]string{{req.URI()}},
|
|
},
|
|
|
|
ch: clientChannelSet{
|
|
receive: make(chan Message, 1),
|
|
send: make(chan Message, 1),
|
|
},
|
|
}
|
|
|
|
// find controller by URI
|
|
controller, arguments := ctl.Match(req.URI())
|
|
if controller == nil {
|
|
return nil, fmt.Errorf("no controller found, no default controller set")
|
|
}
|
|
|
|
// copy args
|
|
cli.iface.Arguments = arguments
|
|
|
|
go controller.Fun(
|
|
cli.iface, // pass the client
|
|
cli.ch.receive, // the receiver
|
|
cli.ch.send, // the sender
|
|
serverCh.broadcast, // broadcast sender
|
|
)
|
|
go clientReader(cli)
|
|
go clientWriter(cli)
|
|
return cli, nil
|
|
}
|
|
|
|
// clientReader reads and parses messages from the buffer
|
|
func clientReader(c *client) {
|
|
var (
|
|
frag *Message
|
|
closeStatus = Normal
|
|
clientAck = true
|
|
)
|
|
|
|
c.io.reading.Add(1)
|
|
|
|
for {
|
|
// currently closing -> exit
|
|
if c.io.closing {
|
|
fmt.Printf("[reader] killed because closing")
|
|
break
|
|
}
|
|
|
|
// Parse message
|
|
var msg = &Message{}
|
|
_, err := msg.ReadFrom(c.io.reader)
|
|
if err == ErrUnmaskedFrame || err == ErrReservedBits {
|
|
closeStatus = ProtocolError
|
|
}
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
// invalid message
|
|
msgErr := msg.check(frag != nil)
|
|
if msgErr != nil {
|
|
|
|
mustClose := false
|
|
|
|
switch msgErr {
|
|
|
|
// fail
|
|
case ErrUnexpectedContinuation:
|
|
closeStatus = None
|
|
clientAck = false
|
|
mustClose = true
|
|
|
|
// proper close
|
|
case ErrCloseFrame:
|
|
closeStatus = Normal
|
|
clientAck = true
|
|
mustClose = true
|
|
|
|
// invalid payload proper close
|
|
case ErrInvalidPayload:
|
|
closeStatus = InvalidPayload
|
|
clientAck = true
|
|
mustClose = true
|
|
|
|
// any other error -> protocol error
|
|
default:
|
|
closeStatus = ProtocolError
|
|
clientAck = true
|
|
mustClose = true
|
|
}
|
|
|
|
if mustClose {
|
|
break
|
|
}
|
|
|
|
}
|
|
|
|
// ping <-> Pong
|
|
if msg.Type == Ping && c.io.writing {
|
|
msg.Final = true
|
|
msg.Type = Pong
|
|
c.ch.send <- *msg
|
|
continue
|
|
}
|
|
|
|
// store first fragment
|
|
if frag == nil && !msg.Final {
|
|
frag = &Message{
|
|
Type: msg.Type,
|
|
Final: msg.Final,
|
|
Data: msg.Data,
|
|
Size: msg.Size,
|
|
}
|
|
continue
|
|
}
|
|
|
|
// store fragments
|
|
if frag != nil {
|
|
frag.Final = msg.Final
|
|
frag.Size += msg.Size
|
|
frag.Data = append(frag.Data, msg.Data...)
|
|
|
|
if !frag.Final { // continue if not last fragment
|
|
continue
|
|
}
|
|
|
|
// check message errors
|
|
fragErr := frag.check(false)
|
|
if fragErr == ErrInvalidPayload {
|
|
closeStatus = InvalidPayload
|
|
break
|
|
} else if fragErr != nil {
|
|
closeStatus = ProtocolError
|
|
break
|
|
}
|
|
|
|
msg = frag
|
|
frag = nil
|
|
|
|
}
|
|
|
|
// dispatch to receiver
|
|
if msg.Type == Text || msg.Type == Binary {
|
|
c.ch.receive <- *msg
|
|
}
|
|
|
|
}
|
|
|
|
close(c.ch.receive)
|
|
c.io.reading.Done()
|
|
|
|
// close channel (if not already done)
|
|
// fmt.Printf("[reader] end\n")
|
|
c.close(closeStatus, clientAck)
|
|
|
|
}
|
|
|
|
// clientWriter writes to the websocket connection and is triggered by
|
|
// client.ch.send channel
|
|
func clientWriter(c *client) {
|
|
c.io.writing = true // if channel still exists
|
|
|
|
for msg := range c.ch.send {
|
|
_, err := msg.WriteTo(c.io.sock)
|
|
if err != nil {
|
|
fmt.Printf(" [writer] %s\n", err)
|
|
c.io.writing = false
|
|
break
|
|
}
|
|
}
|
|
|
|
c.io.writing = false
|
|
|
|
// close channel (if not already done)
|
|
// fmt.Printf("[writer] end\n")
|
|
c.close(Normal, true)
|
|
|
|
}
|
|
|
|
// close the connection
|
|
// send CLOSE frame is 'status' is not NONE
|
|
// wait for the next message (CLOSE acknowledge) if 'clientACK'
|
|
// then delete client
|
|
func (c *client) close(status MessageError, clientACK bool) {
|
|
// fail if already closing
|
|
alreadyClosing := false
|
|
c.io.closingMu.Lock()
|
|
alreadyClosing = c.io.closing
|
|
c.io.closing = true
|
|
c.io.closingMu.Unlock()
|
|
if alreadyClosing {
|
|
return
|
|
}
|
|
|
|
// kill writer' if still running
|
|
if c.io.writing {
|
|
close(c.ch.send)
|
|
}
|
|
|
|
// kill reader if still running
|
|
c.io.sock.SetReadDeadline(time.Now().Add(time.Second * -1))
|
|
c.io.reading.Wait()
|
|
|
|
if status != None {
|
|
msg := &Message{
|
|
Final: true,
|
|
Type: Close,
|
|
Size: 2,
|
|
Data: make([]byte, 2),
|
|
}
|
|
binary.BigEndian.PutUint16(msg.Data, uint16(status))
|
|
|
|
msg.WriteTo(c.io.sock)
|
|
}
|
|
|
|
// wait for client CLOSE if needed
|
|
if clientACK {
|
|
c.io.sock.SetReadDeadline(time.Now().Add(time.Millisecond))
|
|
var tmpMsg = &Message{}
|
|
tmpMsg.ReadFrom(c.io.reader)
|
|
}
|
|
|
|
c.io.sock.Close()
|
|
// fmt.Printf("[close] socket closed\n")
|
|
|
|
c.io.kill <- c
|
|
}
|