package upgrade import ( "fmt" "io" "git.xdrm.io/go/ws/internal/http/reader" ) // If origin is required const bypassOriginPolicy = true // Request represents an HTTP Upgrade request type Request struct { first bool // whether the first line has been read (GET uri HTTP/version) // status code code StatusCode // request line request Line // data to check origin (depends of reading order) host string port uint16 // 0 if not set origin string validPolicy bool // ws data key []byte protocols [][]byte // required fields check hasConnection bool hasUpgrade bool hasVersion bool } // Parse builds an upgrade HTTP request // from a reader (typically bufio.NewRead of the socket) func Parse(r io.Reader) (request *Request, err error) { req := &Request{ code: 500, } /* (1) Parse request ---------------------------------------------------------*/ // 1. Get chunk reader cr := reader.NewReader(r) if err != nil { return req, fmt.Errorf("create chunk reader: %w", err) } // 2. Parse header line by line for { line, err := cr.Read() if err == io.EOF { break } if err != nil { return req, err } err = req.parseHeader(line) if err != nil { return req, err } } // 3. Check completion err = req.isComplete() if err != nil { req.code = BadRequest return req, err } req.code = SwitchingProtocols return req, nil } // StatusCode returns the status current func (r Request) StatusCode() StatusCode { return r.code } // BuildResponse builds a response from the request func (r *Request) BuildResponse() *Response { inst := &Response{} // 1. Copy code inst.SetStatusCode(r.code) // 2. Set Protocol if len(r.protocols) > 0 { inst.SetProtocol(r.protocols[0]) } // 4. Process key inst.ProcessKey(r.key) return inst } // GetURI returns the actual URI func (r Request) GetURI() string { return r.request.GetURI() } // parseHeader parses any http request line // (header and request-line) func (r *Request) parseHeader(b []byte) error { /* (1) First line -> GET {uri} HTTP/{version} ---------------------------------------------------------*/ if !r.first { err := r.request.Parse(b) if err != nil { r.code = BadRequest return &InvalidRequest{"Request-Line", err.Error()} } r.first = true return nil } /* (2) Other lines -> Header-Name: Header-Value ---------------------------------------------------------*/ // 1. Try to parse header head, err := ReadHeader(b) if err != nil { r.code = BadRequest return fmt.Errorf("parse header: %w", err) } // 2. Manage header switch head.Name { case Host: err = r.extractHostPort(head.Values) case Origin: err = r.extractOrigin(head.Values) case Upgrade: err = r.checkUpgrade(head.Values) case Connection: err = r.checkConnection(head.Values) case WSVersion: err = r.checkVersion(head.Values) case WSKey: err = r.extractKey(head.Values) case WSProtocol: err = r.extractProtocols(head.Values) default: return nil } // dispatch error if err != nil { return err } return nil } // isComplete returns whether the Upgrade Request // is complete (no missing required item) func (r Request) isComplete() error { // 1. Request-Line if !r.first { return &IncompleteRequest{"Request-Line"} } // 2. Host if len(r.host) == 0 { return &IncompleteRequest{"Host"} } // 3. Origin if !bypassOriginPolicy && len(r.origin) == 0 { return &IncompleteRequest{"Origin"} } // 4. Connection if !r.hasConnection { return &IncompleteRequest{"Connection"} } // 5. Upgrade if !r.hasUpgrade { return &IncompleteRequest{"Upgrade"} } // 6. Sec-WebSocket-Version if !r.hasVersion { return &IncompleteRequest{"Sec-WebSocket-Version"} } // 7. Sec-WebSocket-Key if len(r.key) < 1 { return &IncompleteRequest{"Sec-WebSocket-Key"} } return nil }