ws/internal/http/upgrade/request.go

190 lines
3.5 KiB
Go

package upgrade
import (
"fmt"
"io"
"git.xdrm.io/go/ws/internal/http/reader"
)
// whether origin is required
const bypassOriginPolicy = true
// Request represents an HTTP Upgrade request
type Request struct {
// whether the first line has been read (GET uri HTTP/version)
first bool
statusCode StatusCode
requestLine RequestLine
// data to check origin (depends on reading order)
host string
port uint16 // 0 if not set
origin string
validPolicy bool
// websocket specific
key []byte
protocols [][]byte
// mandatory fields to check
hasConnection bool
hasUpgrade bool
hasVersion bool
}
// ReadFrom reads an upgrade HTTP request ; typically from bufio.NewRead of the
// socket
//
// implements io.ReaderFrom
func (req *Request) ReadFrom(r io.Reader) (int64, error) {
var read int64
// reset request
req.statusCode = 500
// parse header line by line
var cr = reader.NewReader(r)
for {
line, err := cr.Read()
read += int64(len(line))
if err == io.EOF {
break
}
if err != nil {
return read, err
}
err = req.parseHeader(line)
if err != nil {
return read, err
}
}
err := req.isComplete()
if err != nil {
req.statusCode = StatusBadRequest
return read, err
}
req.statusCode = StatusSwitchingProtocols
return read, nil
}
// StatusCode returns the status current
func (req Request) StatusCode() StatusCode {
return req.statusCode
}
// BuildResponse builds a response from the request
func (req *Request) BuildResponse() *Response {
res := &Response{
StatusCode: req.statusCode,
Protocol: nil,
}
if len(req.protocols) > 0 {
res.Protocol = req.protocols[0]
}
res.ProcessKey(req.key)
return res
}
// URI returns the actual URI
func (req Request) URI() string {
return req.requestLine.URI()
}
// parseHeader parses any http request line
// (header and request-line)
func (req *Request) parseHeader(b []byte) error {
// first line -> GET {uri} HTTP/{version}
if !req.first {
_, err := req.requestLine.Read(b)
if err != nil {
req.statusCode = StatusBadRequest
return &ErrInvalidRequest{"Request-Line", err.Error()}
}
req.first = true
return nil
}
// other lines -> Header-Name: Header-Value
head, err := ReadHeader(b)
if err != nil {
req.statusCode = StatusBadRequest
return fmt.Errorf("parse header: %w", err)
}
// 2. Manage header
switch head.Name {
case Host:
err = req.extractHostPort(head.Values)
case Origin:
err = req.extractOrigin(head.Values)
case Upgrade:
err = req.checkUpgrade(head.Values)
case Connection:
err = req.checkConnection(head.Values)
case WSVersion:
err = req.checkVersion(head.Values)
case WSKey:
err = req.extractKey(head.Values)
case WSProtocol:
err = req.extractProtocols(head.Values)
default:
return nil
}
// dispatch error
if err != nil {
return err
}
return nil
}
// isComplete returns whether the Upgrade Request
// is complete (no required field missing)
// returns nil on success
func (req Request) isComplete() error {
if !req.first {
return ErrIncompleteRequest("Request-Line")
}
if len(req.host) == 0 {
return ErrIncompleteRequest("Host")
}
if !bypassOriginPolicy && len(req.origin) == 0 {
return ErrIncompleteRequest("Origin")
}
if !req.hasConnection {
return ErrIncompleteRequest("Connection")
}
if !req.hasUpgrade {
return ErrIncompleteRequest("Upgrade")
}
if !req.hasVersion {
return ErrIncompleteRequest("Sec-WebSocket-Version")
}
if len(req.key) < 1 {
return ErrIncompleteRequest("Sec-WebSocket-Key")
}
return nil
}