ws/internal/http/upgrade/request.go

217 lines
3.8 KiB
Go

package upgrade
import (
"fmt"
"io"
"git.xdrm.io/go/ws/internal/http/reader"
)
// If origin is required
const bypassOriginPolicy = true
// Request represents an HTTP Upgrade request
type Request struct {
first bool // whether the first line has been read (GET uri HTTP/version)
// status code
code StatusCode
// request line
request Line
// data to check origin (depends of reading order)
host string
port uint16 // 0 if not set
origin string
validPolicy bool
// ws data
key []byte
protocols [][]byte
// required fields check
hasConnection bool
hasUpgrade bool
hasVersion bool
}
// Parse builds an upgrade HTTP request
// from a reader (typically bufio.NewRead of the socket)
func Parse(r io.Reader) (request *Request, err error) {
req := &Request{
code: 500,
}
/* (1) Parse request
---------------------------------------------------------*/
// 1. Get chunk reader
cr := reader.NewReader(r)
if err != nil {
return req, fmt.Errorf("create chunk reader: %w", err)
}
// 2. Parse header line by line
for {
line, err := cr.Read()
if err == io.EOF {
break
}
if err != nil {
return req, err
}
err = req.parseHeader(line)
if err != nil {
return req, err
}
}
// 3. Check completion
err = req.isComplete()
if err != nil {
req.code = BadRequest
return req, err
}
req.code = SwitchingProtocols
return req, nil
}
// StatusCode returns the status current
func (r Request) StatusCode() StatusCode {
return r.code
}
// BuildResponse builds a response from the request
func (r *Request) BuildResponse() *Response {
inst := &Response{}
// 1. Copy code
inst.SetStatusCode(r.code)
// 2. Set Protocol
if len(r.protocols) > 0 {
inst.SetProtocol(r.protocols[0])
}
// 4. Process key
inst.ProcessKey(r.key)
return inst
}
// GetURI returns the actual URI
func (r Request) GetURI() string {
return r.request.GetURI()
}
// parseHeader parses any http request line
// (header and request-line)
func (r *Request) parseHeader(b []byte) error {
/* (1) First line -> GET {uri} HTTP/{version}
---------------------------------------------------------*/
if !r.first {
err := r.request.Parse(b)
if err != nil {
r.code = BadRequest
return &InvalidRequest{"Request-Line", err.Error()}
}
r.first = true
return nil
}
/* (2) Other lines -> Header-Name: Header-Value
---------------------------------------------------------*/
// 1. Try to parse header
head, err := ReadHeader(b)
if err != nil {
r.code = BadRequest
return fmt.Errorf("parse header: %w", err)
}
// 2. Manage header
switch head.Name {
case Host:
err = r.extractHostPort(head.Values)
case Origin:
err = r.extractOrigin(head.Values)
case Upgrade:
err = r.checkUpgrade(head.Values)
case Connection:
err = r.checkConnection(head.Values)
case WSVersion:
err = r.checkVersion(head.Values)
case WSKey:
err = r.extractKey(head.Values)
case WSProtocol:
err = r.extractProtocols(head.Values)
default:
return nil
}
// dispatch error
if err != nil {
return err
}
return nil
}
// isComplete returns whether the Upgrade Request
// is complete (no missing required item)
func (r Request) isComplete() error {
// 1. Request-Line
if !r.first {
return &IncompleteRequest{"Request-Line"}
}
// 2. Host
if len(r.host) == 0 {
return &IncompleteRequest{"Host"}
}
// 3. Origin
if !bypassOriginPolicy && len(r.origin) == 0 {
return &IncompleteRequest{"Origin"}
}
// 4. Connection
if !r.hasConnection {
return &IncompleteRequest{"Connection"}
}
// 5. Upgrade
if !r.hasUpgrade {
return &IncompleteRequest{"Upgrade"}
}
// 6. Sec-WebSocket-Version
if !r.hasVersion {
return &IncompleteRequest{"Sec-WebSocket-Version"}
}
// 7. Sec-WebSocket-Key
if len(r.key) < 1 {
return &IncompleteRequest{"Sec-WebSocket-Key"}
}
return nil
}