diff --git a/cmd/iface/main.go b/cmd/iface/main.go index e6b4b1f..2dcbf2e 100644 --- a/cmd/iface/main.go +++ b/cmd/iface/main.go @@ -24,6 +24,7 @@ func main(){ case receivedFrame := <- receiver: fmt.Printf("[default] received '%s'\n", receivedFrame.Payload.Buffer) sender <- receivedFrame.Payload.Buffer + fmt.Printf("[default] sent\n") case closeFunc := <- closer: fmt.Printf("[default] client with protocol '%s' exited\n", client.Protocol) closeFunc() diff --git a/internal/http/upgrade/request/header_check.go b/internal/http/upgrade/request/header_check.go index b6deaec..e56678d 100644 --- a/internal/http/upgrade/request/header_check.go +++ b/internal/http/upgrade/request/header_check.go @@ -52,6 +52,9 @@ func (r *T) extractHostPort(bb header.HeaderValue) error { // checkOrigin checks the Origin Header func (r *T) extractOrigin(bb header.HeaderValue) error { + // bypass + if bypassOriginPolicy { return nil } + if len(bb) != 1 { r.code = response.FORBIDDEN return fmt.Errorf("Origin header must have a unique value") diff --git a/internal/http/upgrade/request/private.go b/internal/http/upgrade/request/private.go index 2a47343..d3dd204 100644 --- a/internal/http/upgrade/request/private.go +++ b/internal/http/upgrade/request/private.go @@ -53,7 +53,10 @@ func (r *T) parseHeader(b []byte) error { } - if err != nil { return err } + if err != nil { + fmt.Printf("ERR: %s\n", err) + return err + } return nil @@ -76,7 +79,7 @@ func (r T) isComplete() error { } /* (3) Origin */ - if len(r.origin) == 0 { + if !bypassOriginPolicy && len(r.origin) == 0 { return fmt.Errorf("Missing 'Origin' header") } diff --git a/internal/http/upgrade/request/public.go b/internal/http/upgrade/request/public.go index 46fe126..eeb834b 100644 --- a/internal/http/upgrade/request/public.go +++ b/internal/http/upgrade/request/public.go @@ -46,6 +46,7 @@ func Parse(r io.Reader) (request *T, err error) { /* (3) Check completion */ err = req.isComplete() if err != nil { + fmt.Printf("not complete: %s\b", err) req.code = response.BAD_REQUEST return req, err } diff --git a/internal/http/upgrade/request/types.go b/internal/http/upgrade/request/types.go index 388ba42..752da3b 100644 --- a/internal/http/upgrade/request/types.go +++ b/internal/http/upgrade/request/types.go @@ -3,6 +3,9 @@ package request import "git.xdrm.io/gws/internal/http/upgrade/request/parser/reqline" import "git.xdrm.io/gws/internal/http/upgrade/response" +// If origin is required +const bypassOriginPolicy = true + // T represents an HTTP Upgrade request type T struct { first bool // whether the first line has been read (GET uri HTTP/version) diff --git a/internal/ws/reader/reader.go b/internal/ws/reader/reader.go index 62472df..e7bf0e1 100644 --- a/internal/ws/reader/reader.go +++ b/internal/ws/reader/reader.go @@ -1,8 +1,8 @@ package reader import ( - "fmt" "io" + "bufio" ) // Maximum chunk size @@ -11,27 +11,16 @@ const MaxChunkSize = 4096 // Read reads a chunk of n bytes // err is io.EOF when done -func ReadBytes(r io.Reader, n uint) ([]byte, error){ +func ReadBytes(br *bufio.Reader, n int) ([]byte, error){ - res := make([]byte, 0, MaxChunkSize) - - totalRead := uint(0) - - // socket -> tmp( buffer) - for totalRead < n { - - tmp := make([]byte, 1) - - read, err := r.Read(tmp) - if err != nil { return nil, err } - if read == 0 { return nil, fmt.Errorf("Cannot read") } - - totalRead += uint(1) - - res = append(res, tmp[0]) + buf, err := br.Peek(n) + if err == io.EOF && len(buf) < n && n > 0 { + err = io.ErrUnexpectedEOF } - return res, nil + br.Discard(len(buf)) + + return buf, err } \ No newline at end of file diff --git a/ws/client.go b/ws/client.go index d81e559..e1e2639 100644 --- a/ws/client.go +++ b/ws/client.go @@ -3,7 +3,6 @@ package ws import ( "fmt" "time" - "bytes" ) @@ -15,19 +14,15 @@ import ( // for a given client func (c *Client) asyncReader(s *Server) { - // Get buffer - buf := new(bytes.Buffer) - - for { var startTime int64 = time.Now().UnixNano() // Try to read frame header - frame, err := ReadFrame(buf, c.sock) + frame, err := ReadFrame(c) if err != nil { - fmt.Printf("%s\n", err) + fmt.Printf("[read.err] %s\n", err) break } @@ -46,10 +41,12 @@ func (c *Client) asyncReader(s *Server) { c.closec <- func(){ // Remove client from server - delete(s.clients, c.sock) + s.clientsMutex.Lock() + delete(s.clients, c.conn.sock) + s.clientsMutex.Unlock() // Close socket - c.sock.Close() + c.conn.sock.Close() } @@ -64,17 +61,16 @@ func (c *Client) asyncWriter(s *Server){ for payload := range c.sendc { - fmt.Printf("Writing '%s'\n", payload) - // Build Frame f := buildFrame(payload) // Send over socket - senderr := f.Send(c.sock) + senderr := f.Send(&c.conn) if senderr != nil { fmt.Printf("Writing error: %s\n", senderr) } + } diff --git a/ws/frame.go b/ws/frame.go index 2e17326..172e733 100644 --- a/ws/frame.go +++ b/ws/frame.go @@ -1,36 +1,36 @@ package ws import ( + "bufio" "git.xdrm.io/gws/ws/frame/opcode" "encoding/binary" "git.xdrm.io/gws/ws/frame" "fmt" "git.xdrm.io/gws/internal/ws/reader" "net" - "bytes" ) // ReadFrame reads the frame from a socket -func ReadFrame(b *bytes.Buffer, s net.Conn) (*Frame, error) { +func ReadFrame(c *Client) (*Frame, error) { f := new(Frame) /* (1) Read header ---------------------------------------------------------*/ - err := f.readHeader(b, s) + err := f.readHeader(c.conn.br) if err != nil { - return nil, fmt.Errorf("Header read: %s\n", err) + return nil, err } /* (2) Read payload ---------------------------------------------------------*/ - err = f.readPayload(b, s) + err = f.readPayload(c.conn.br) if err != nil { - return nil, fmt.Errorf("Payload read: %s\n", err) + return nil, err } @@ -50,7 +50,7 @@ func ReadFrame(b *bytes.Buffer, s net.Conn) (*Frame, error) { case opcode.PING: fmt.Printf("Opcode: PING\n") - err = buildPong().Send(s) + err = buildPong().Send(&c.conn) if err != nil { return nil, fmt.Errorf("Pong frame: %s\n", err) } @@ -58,7 +58,7 @@ func ReadFrame(b *bytes.Buffer, s net.Conn) (*Frame, error) { default: fmt.Printf("Opcode: CLOSE\n") - buildClose().Send(s) + buildClose().Send(&c.conn) return nil, fmt.Errorf("Unknown Opcode %x\n", f.Header.Opc) } @@ -69,19 +69,19 @@ func ReadFrame(b *bytes.Buffer, s net.Conn) (*Frame, error) { // readHeader reads the frame header -func (f *Frame) readHeader(buf *bytes.Buffer, s net.Conn) error{ +func (f *Frame) readHeader(br *bufio.Reader) error{ var err error /* (2) Byte 1: FIN and OpCode */ - b, err := reader.ReadBytes(s, 1) - if err != nil { return fmt.Errorf("Cannot read byte Fin nor OpCode (%s)", err) } + b, err := reader.ReadBytes(br, 1) + if err != nil { return err } f.Header.Fin = b[0] & 0x80 == 0x80 f.Header.Opc = frame.OpCode( b[0] & 0x0f ) /* (3) Byte 2: Mask and Length[0] */ - b, err = reader.ReadBytes(s, 1) - if err != nil { return fmt.Errorf("Cannot read byte if has Mask nor Length (%s)", err) } + b, err = reader.ReadBytes(br, 1) + if err != nil { return err } // if mask, byte array not nil if b[0] & 0x80 == 0x80 { @@ -94,15 +94,15 @@ func (f *Frame) readHeader(buf *bytes.Buffer, s net.Conn) error{ /* (4) Extended payload */ if f.Payload.Length == 127 { - bx, err := reader.ReadBytes(s, 8) - if err != nil { return fmt.Errorf("Cannot read payload extended length of 64 bytes (%s)", err) } + bx, err := reader.ReadBytes(br, 8) + if err != nil { return err } f.Payload.Length = binary.BigEndian.Uint64(bx) } else if f.Payload.Length == 126 { - bx, err := reader.ReadBytes(s, 2) - if err != nil { return fmt.Errorf("Cannot read payload extended length of 16 bytes (%s)", err) } + bx, err := reader.ReadBytes(br, 2) + if err != nil { return err } f.Payload.Length = uint64( binary.BigEndian.Uint16(bx) ) @@ -111,8 +111,8 @@ func (f *Frame) readHeader(buf *bytes.Buffer, s net.Conn) error{ /* (5) Masking key */ if f.Header.Msk != nil { - bx, err := reader.ReadBytes(s, 4) - if err != nil { return fmt.Errorf("Cannot read mask or 32 bytes (%s)", err) } + bx, err := reader.ReadBytes(br, 4) + if err != nil { return err } f.Header.Msk = make([]byte, 4) copy(f.Header.Msk, bx) @@ -134,10 +134,10 @@ func (f *Frame) readHeader(buf *bytes.Buffer, s net.Conn) error{ // readPayload reads the frame payload -func (f *Frame) readPayload(buf *bytes.Buffer, s net.Conn) error{ +func (f *Frame) readPayload(br *bufio.Reader) error{ /* (1) Read payload */ - b, err := reader.ReadBytes(s, uint(f.Payload.Length) ) + b, err := reader.ReadBytes(br, int(f.Payload.Length) ) if err != nil { return fmt.Errorf("Cannot read payload (%s)", err) } f.Payload.Buffer = make([]byte, 0, f.Payload.Length) @@ -229,66 +229,50 @@ func buildPong() *Frame{ // Send sends a frame over a socket -func (f Frame) Send(s net.Conn) error { +func (f Frame) Send(c *Conn) error { - written := 0 + frameHeader := make([]byte, 0, maxHeaderLength) /* (1) Byte 0 : FIN + opcode */ - buf := []byte{ 0x80 + byte(opcode.TEXT) } - w, err := s.Write(buf) - if err != nil { return err } - written += w + frameHeader = append(frameHeader, 0x80 | byte(opcode.TEXT) ) /* (2) Get payload length */ if f.Payload.Length < 126 { // simple - buf := []byte{ byte(f.Payload.Length) } - w, err = s.Write(buf) - if err != nil { return err } - written += w + frameHeader = append(frameHeader, byte(f.Payload.Length) ) } else if f.Payload.Length < 0xffff { // extended: 16 bits - w, err = s.Write( []byte{126} ) - if err != nil { return err } - written += w + + frameHeader = append(frameHeader, 126) buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(f.Payload.Length)) - w, err = s.Write(buf) - if err != nil { return err } - written += w + frameHeader = append(frameHeader, buf...) } else if f.Payload.Length < 0xffffffffffffffff { // extended: 64 bits - w, err = s.Write( []byte{127} ) - if err != nil { return err } - written += w + frameHeader = append(frameHeader, 127) buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, f.Payload.Length) - w, err = s.Write(buf) - if err != nil { return err } - written += w + frameHeader = append(frameHeader, buf...) } - /* (3) Payload */ - w, err = s.Write(f.Payload.Buffer) + /* (3) Add payload */ + writeBuffer := net.Buffers{frameHeader, f.Payload.Buffer[:f.Payload.Length]} + fmt.Printf("[send]\n") + _, err := writeBuffer.WriteTo(c.sock) + fmt.Printf("[/send] ") if err != nil { return err } - fmt.Printf("[2] written %d bytes\n", written) - - fmt.Printf(" + Header\n") - fmt.Printf(" + FIN: %t\n", f.Header.Fin) - fmt.Printf(" + OPC: %x\n", f.Header.Opc) - fmt.Printf(" + MASK?: %t\n", f.Header.Msk != nil) - if f.Header.Msk != nil { - fmt.Printf(" + MASK: %x\n", f.Header.Msk) - } - fmt.Printf(" + LEN: %d\n", f.Payload.Length) - fmt.Printf("Total written: %d bytes (%d + %d)\n", written+w, written, w) + // fmt.Printf(" + Header\n") + // fmt.Printf(" + FIN: %t\n", f.Header.Fin) + // fmt.Printf(" + OPC: %x\n", f.Header.Opc) + // fmt.Printf(" + MASK?: %t\n", f.Header.Msk != nil) + // fmt.Printf(" + LEN: %d\n", f.Payload.Length) return nil } \ No newline at end of file diff --git a/ws/private.go b/ws/private.go index 867b886..2ad5d28 100644 --- a/ws/private.go +++ b/ws/private.go @@ -1,6 +1,7 @@ package ws import ( + "bufio" "git.xdrm.io/gws/upgrader" "net" ) @@ -15,7 +16,10 @@ func (s *Server) dispatch(sock net.Conn, u *upgrader.T){ client := &Client{ - sock: sock, + conn: Conn{ + sock: sock, + br: bufio.NewReader(sock), + }, Arguments: [][]string{ []string{ uri } }, Protocol: string(u.Response.GetProtocol()), recvc: make(chan Frame, maxChannelBufferLength), @@ -66,7 +70,9 @@ func (s *Server) dispatch(sock net.Conn, u *upgrader.T){ client.Controller = controller /* (2) Add client to server */ + s.clientsMutex.Lock() s.clients[sock] = client + s.clientsMutex.Unlock() /* (3) Bind controller */ go controller.fun(client, client.recvc, client.sendc, client.closec) @@ -75,7 +81,7 @@ func (s *Server) dispatch(sock net.Conn, u *upgrader.T){ go client.asyncWriter(s) /* (5) Run asynchronous frame reader */ - client.asyncReader(s) + go client.asyncReader(s) } diff --git a/ws/server.go b/ws/server.go index a224a47..69fd935 100644 --- a/ws/server.go +++ b/ws/server.go @@ -77,7 +77,7 @@ func (s *Server) Launch() error { continue } if upgrader.Response.GetStatusCode() != 101 { - fmt.Printf(" - upgrade bad request\n") + fmt.Printf(" - upgrade bad request (status code %d)\n", upgrader.Response.GetStatusCode()) sock.Close() continue } diff --git a/ws/types.go b/ws/types.go index 4d3707d..ebb4bd6 100644 --- a/ws/types.go +++ b/ws/types.go @@ -1,14 +1,23 @@ package ws import ( + "sync" + "bufio" "net" "git.xdrm.io/gws/internal/uri/parser" "git.xdrm.io/gws/ws/frame" ) const maxBufferLength = 4096 +const maxHeaderLength = 2 + 8 + 4 const maxChannelBufferLength = 1 +// Represents a websocket connection (socket + reader) +type Conn struct { + sock net.Conn + br *bufio.Reader +} + // Represents a websocket controller callback function type ControllerFunc func(*Client, <-chan Frame, chan<- []byte, <-chan func()) @@ -21,7 +30,7 @@ type Controller struct { // Represents a websocket client type Client struct { - sock net.Conn // communication socket + conn Conn // connection (socket + reader) Protocol string // choosen protocol (Sec-WebSocket-Protocol) Arguments [][]string // URI parameters, index 0 is full URI, then matching groups @@ -40,6 +49,7 @@ type Server struct { port uint16 // server listening port clients map[net.Conn]*Client // clients + clientsMutex sync.Mutex defaultController *Controller // default controller controllers []*Controller // URI-bound controllers