refactor upgrade.request + merge 'upgrade.reqline' + began tests
This commit is contained in:
parent
b713011e7b
commit
dc510ad5d9
|
@ -0,0 +1,35 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
// invalid request
|
||||||
|
// - multiple-value if only 1 expected
|
||||||
|
type InvalidRequest struct {
|
||||||
|
Field string
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err InvalidRequest) Error() string {
|
||||||
|
return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request misses fields (request-line or headers)
|
||||||
|
type IncompleteRequest struct {
|
||||||
|
MissingField string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err IncompleteRequest) Error() string {
|
||||||
|
return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Request has a violated origin policy
|
||||||
|
type InvalidOriginPolicy struct {
|
||||||
|
Host string
|
||||||
|
Origin string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
func (err InvalidOriginPolicy) Error() string {
|
||||||
|
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
|
||||||
|
}
|
|
@ -14,7 +14,7 @@ import (
|
||||||
func (r *T) extractHostPort(bb header.HeaderValue) error {
|
func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 {
|
if len(bb) != 1 {
|
||||||
return fmt.Errorf("Host header must have a unique value")
|
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
}
|
}
|
||||||
|
|
||||||
split := strings.Split(string(bb[0]), ":")
|
split := strings.Split(string(bb[0]), ":")
|
||||||
|
@ -30,7 +30,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||||
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.code = response.BAD_REQUEST
|
r.code = response.BAD_REQUEST
|
||||||
return fmt.Errorf("Cannot read port number '%s'", split[1])
|
return &InvalidRequest{"Host", "cannot read port"}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.port = uint16(readPort)
|
r.port = uint16(readPort)
|
||||||
|
@ -40,7 +40,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = r.checkOriginPolicy()
|
err = r.checkOriginPolicy()
|
||||||
r.code = response.FORBIDDEN
|
r.code = response.FORBIDDEN
|
||||||
return err
|
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 {
|
if len(bb) != 1 {
|
||||||
r.code = response.FORBIDDEN
|
r.code = response.FORBIDDEN
|
||||||
return fmt.Errorf("Origin header must have a unique value")
|
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.origin = string(bb[0])
|
r.origin = string(bb[0])
|
||||||
|
@ -67,7 +67,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
|
||||||
err := r.checkOriginPolicy()
|
err := r.checkOriginPolicy()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.code = response.FORBIDDEN
|
r.code = response.FORBIDDEN
|
||||||
return err
|
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ func (r *T) checkConnection(bb header.HeaderValue) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.code = response.BAD_REQUEST
|
r.code = response.BAD_REQUEST
|
||||||
return fmt.Errorf("Connection header must be 'Upgrade'")
|
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 {
|
if len(bb) != 1 {
|
||||||
r.code = response.BAD_REQUEST
|
r.code = response.BAD_REQUEST
|
||||||
return fmt.Errorf("Upgrade header must have only 1 element")
|
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToLower( string(bb[0]) ) == "websocket" {
|
if strings.ToLower( string(bb[0]) ) == "websocket" {
|
||||||
|
@ -115,7 +115,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.code = response.BAD_REQUEST
|
r.code = response.BAD_REQUEST
|
||||||
return fmt.Errorf("Upgrade header must be 'websocket', got '%s'", bb[0])
|
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ func (r *T) checkVersion(bb header.HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 || string(bb[0]) != "13" {
|
if len(bb) != 1 || string(bb[0]) != "13" {
|
||||||
r.code = response.UPGRADE_REQUIRED
|
r.code = response.UPGRADE_REQUIRED
|
||||||
return fmt.Errorf("Sec-WebSocket-Version header must be '13'")
|
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.hasVersion = true
|
r.hasVersion = true
|
||||||
|
@ -139,7 +139,7 @@ func (r *T) extractKey(bb header.HeaderValue) error {
|
||||||
|
|
||||||
if len(bb) != 1 || len(bb[0]) != 24 {
|
if len(bb) != 1 || len(bb[0]) != 24 {
|
||||||
r.code = response.BAD_REQUEST
|
r.code = response.BAD_REQUEST
|
||||||
return fmt.Errorf("Sec-WebSocket-Key header must be a unique 24 bytes base64 value, got %d bytes", len(bb[0]))
|
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.key = bb[0]
|
r.key = bb[0]
|
||||||
|
|
|
@ -1,73 +0,0 @@
|
||||||
package reqline
|
|
||||||
|
|
||||||
import (
|
|
||||||
"regexp"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
// extractHttpMethod extracts the HTTP method from a []byte
|
|
||||||
// and checks for errors
|
|
||||||
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
|
||||||
func (r *T) extractHttpMethod(b []byte) error {
|
|
||||||
|
|
||||||
switch string(b) {
|
|
||||||
case "OPTIONS": r.method = OPTIONS
|
|
||||||
case "GET": r.method = GET
|
|
||||||
case "HEAD": r.method = HEAD
|
|
||||||
case "POST": r.method = POST
|
|
||||||
case "PUT": r.method = PUT
|
|
||||||
case "DELETE": r.method = DELETE
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("Unknown HTTP method '%s'", b)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// extractURI extracts the URI from a []byte and checks for errors
|
|
||||||
// allowed format: /([^/]/)*/?
|
|
||||||
func (r *T) extractURI(b []byte) error {
|
|
||||||
|
|
||||||
/* (1) Check format */
|
|
||||||
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
|
||||||
if !checker.Match(b) {
|
|
||||||
return fmt.Errorf("Invalid URI format, expected an absolute path (starts with /), got '%s'", b)
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (2) Store */
|
|
||||||
r.uri = string(b)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// extractHttpVersion extracts the version and checks for errors
|
|
||||||
// allowed format: [1-9] or [1.9].[0-9]
|
|
||||||
func (r *T) extractHttpVersion(b []byte) error {
|
|
||||||
|
|
||||||
/* (1) Extract version parts */
|
|
||||||
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`);
|
|
||||||
|
|
||||||
if !extractor.Match(b) {
|
|
||||||
return fmt.Errorf("Cannot parse HTTP version, expected INT or INT.INT, got '%s'", b)
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (2) Extract version number */
|
|
||||||
matches := extractor.FindSubmatch(b)
|
|
||||||
var version byte = matches[1][0] - '0'
|
|
||||||
|
|
||||||
/* (3) Extract subversion (if exists) */
|
|
||||||
var subVersion byte = 0
|
|
||||||
if len(matches[2]) > 0 {
|
|
||||||
subVersion = matches[2][0] - '0'
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (4) Store version (x 10 to fit uint8) */
|
|
||||||
r.version = version * 10 + subVersion
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
package reqline
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseRequestLine parses the first HTTP request line
|
|
||||||
func (r *T) Parse(b []byte) error {
|
|
||||||
|
|
||||||
/* (1) Split by ' ' */
|
|
||||||
parts := bytes.Split(b, []byte(" "))
|
|
||||||
|
|
||||||
/* (2) Fail when missing parts */
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return fmt.Errorf("Malformed Request-Line must have 3 space-separated elements, got %d", len(parts))
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (3) Extract HTTP method */
|
|
||||||
err := r.extractHttpMethod(parts[0])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (4) Extract URI */
|
|
||||||
err = r.extractURI(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (5) Extract version */
|
|
||||||
err = r.extractHttpVersion(parts[2])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// GetURI returns the actual URI
|
|
||||||
func (r T) GetURI() string {
|
|
||||||
return r.uri
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
package reqline
|
|
||||||
|
|
||||||
// httpMethod represents available http methods
|
|
||||||
type httpMethod byte
|
|
||||||
const (
|
|
||||||
OPTIONS httpMethod = iota
|
|
||||||
GET
|
|
||||||
HEAD
|
|
||||||
POST
|
|
||||||
PUT
|
|
||||||
DELETE
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
// httpRequestLine represents the HTTP Request line
|
|
||||||
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
|
||||||
type T struct {
|
|
||||||
method httpMethod
|
|
||||||
uri string
|
|
||||||
version byte
|
|
||||||
}
|
|
|
@ -19,7 +19,7 @@ func (r *T) parseHeader(b []byte) error {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.code = response.BAD_REQUEST
|
r.code = response.BAD_REQUEST
|
||||||
return fmt.Errorf("Error while parsing first line: %s", err)
|
return &InvalidRequest{"Request-Line", err.Error()}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.first = true
|
r.first = true
|
||||||
|
@ -53,8 +53,8 @@ func (r *T) parseHeader(b []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// dispatch error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("ERR: %s\n", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,37 +70,37 @@ func (r T) isComplete() error {
|
||||||
|
|
||||||
/* (1) Request-Line */
|
/* (1) Request-Line */
|
||||||
if !r.first {
|
if !r.first {
|
||||||
return fmt.Errorf("Missing HTTP Request-Line");
|
return &IncompleteRequest{"Request-Line"}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (2) Host */
|
/* (2) Host */
|
||||||
if len(r.host) == 0 {
|
if len(r.host) == 0 {
|
||||||
return fmt.Errorf("Missing 'Host' header")
|
return &IncompleteRequest{"Host"}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (3) Origin */
|
/* (3) Origin */
|
||||||
if !bypassOriginPolicy && len(r.origin) == 0 {
|
if !bypassOriginPolicy && len(r.origin) == 0 {
|
||||||
return fmt.Errorf("Missing 'Origin' header")
|
return &IncompleteRequest{"Origin"}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (4) Connection */
|
/* (4) Connection */
|
||||||
if !r.hasConnection {
|
if !r.hasConnection {
|
||||||
return fmt.Errorf("Missing 'Connection' header");
|
return &IncompleteRequest{"Connection"}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (5) Upgrade */
|
/* (5) Upgrade */
|
||||||
if !r.hasUpgrade {
|
if !r.hasUpgrade {
|
||||||
return fmt.Errorf("Missing 'Upgrade' header");
|
return &IncompleteRequest{"Upgrade"}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (6) Sec-WebSocket-Version */
|
/* (6) Sec-WebSocket-Version */
|
||||||
if !r.hasVersion {
|
if !r.hasVersion {
|
||||||
return fmt.Errorf("Missing 'Sec-WebSocket-Version' header");
|
return &IncompleteRequest{"Sec-WebSocket-Version"}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* (7) Sec-WebSocket-Key */
|
/* (7) Sec-WebSocket-Key */
|
||||||
if len(r.key) < 1 {
|
if len(r.key) < 1 {
|
||||||
return fmt.Errorf("Missing 'Sec-WebSocket-Key' header");
|
return &IncompleteRequest{"Sec-WebSocket-Key"}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -32,13 +32,13 @@ func Parse(r io.Reader) (request *T, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return req, fmt.Errorf("Cannot read from reader: %s", err)
|
return req, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = req.parseHeader(line)
|
err = req.parseHeader(line)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return req, fmt.Errorf("Parsing error: %s\n", err);
|
return req, err
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,6 @@ func Parse(r io.Reader) (request *T, err error) {
|
||||||
/* (3) Check completion */
|
/* (3) Check completion */
|
||||||
err = req.isComplete()
|
err = req.isComplete()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("not complete: %s\b", err)
|
|
||||||
req.code = response.BAD_REQUEST
|
req.code = response.BAD_REQUEST
|
||||||
return req, err
|
return req, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"bytes"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpMethod represents available http methods
|
||||||
|
type httpMethod byte
|
||||||
|
const (
|
||||||
|
OPTIONS httpMethod = iota
|
||||||
|
GET
|
||||||
|
HEAD
|
||||||
|
POST
|
||||||
|
PUT
|
||||||
|
DELETE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
// RequestLine represents the HTTP Request line
|
||||||
|
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
||||||
|
type RequestLine struct {
|
||||||
|
method httpMethod
|
||||||
|
uri string
|
||||||
|
version byte
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// parseRequestLine parses the first HTTP request line
|
||||||
|
func (r *RequestLine) Parse(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) Split by ' ' */
|
||||||
|
parts := bytes.Split(b, []byte(" "))
|
||||||
|
|
||||||
|
/* (2) Fail when missing parts */
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (3) Extract HTTP method */
|
||||||
|
err := r.extractHttpMethod(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Extract URI */
|
||||||
|
err = r.extractURI(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (5) Extract version */
|
||||||
|
err = r.extractHttpVersion(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// GetURI returns the actual URI
|
||||||
|
func (r RequestLine) GetURI() string {
|
||||||
|
return r.uri
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// extractHttpMethod extracts the HTTP method from a []byte
|
||||||
|
// and checks for errors
|
||||||
|
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
||||||
|
func (r *RequestLine) extractHttpMethod(b []byte) error {
|
||||||
|
|
||||||
|
switch string(b) {
|
||||||
|
// case "OPTIONS": r.method = OPTIONS
|
||||||
|
case "GET": r.method = GET
|
||||||
|
// case "HEAD": r.method = HEAD
|
||||||
|
// case "POST": r.method = POST
|
||||||
|
// case "PUT": r.method = PUT
|
||||||
|
// case "DELETE": r.method = DELETE
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Invalid HTTP method '%s', expected 'GET'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extractURI extracts the URI from a []byte and checks for errors
|
||||||
|
// allowed format: /([^/]/)*/?
|
||||||
|
func (r *RequestLine) extractURI(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) Check format */
|
||||||
|
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
||||||
|
if !checker.Match(b) {
|
||||||
|
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Store */
|
||||||
|
r.uri = string(b)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extractHttpVersion extracts the version and checks for errors
|
||||||
|
// allowed format: [1-9] or [1.9].[0-9]
|
||||||
|
func (r *RequestLine) extractHttpVersion(b []byte) error {
|
||||||
|
|
||||||
|
/* (1) Extract version parts */
|
||||||
|
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`);
|
||||||
|
|
||||||
|
if !extractor.Match(b) {
|
||||||
|
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (2) Extract version number */
|
||||||
|
matches := extractor.FindSubmatch(b)
|
||||||
|
var version byte = matches[1][0] - '0'
|
||||||
|
|
||||||
|
/* (3) Extract subversion (if exists) */
|
||||||
|
var subVersion byte = 0
|
||||||
|
if len(matches[2]) > 0 {
|
||||||
|
subVersion = matches[2][0] - '0'
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (4) Store version (x 10 to fit uint8) */
|
||||||
|
r.version = version * 10 + subVersion
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
// /* (1) Parse request */
|
||||||
|
// req, _ := request.Parse(s)
|
||||||
|
|
||||||
|
// /* (3) Build response */
|
||||||
|
// res := req.BuildResponse()
|
||||||
|
|
||||||
|
// /* (4) Write into socket */
|
||||||
|
// _, err := res.Send(s)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("Upgrade write error: %s", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if res.GetStatusCode() != 101 {
|
||||||
|
// s.Close()
|
||||||
|
// return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
|
||||||
|
// }
|
||||||
|
|
||||||
|
func TestEOFSocket(t *testing.T){
|
||||||
|
|
||||||
|
socket := new(bytes.Buffer)
|
||||||
|
|
||||||
|
_, err := Parse(socket)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Empty socket expected EOF, got no error")
|
||||||
|
} else if err != io.ErrUnexpectedEOF {
|
||||||
|
t.Fatalf("Empty socket expected EOF, got '%s'", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidRequestLine(t *testing.T){
|
||||||
|
|
||||||
|
socket := new(bytes.Buffer)
|
||||||
|
cases := []struct{
|
||||||
|
Reqline string
|
||||||
|
HasError bool
|
||||||
|
}{
|
||||||
|
{ "abc", true },
|
||||||
|
{ "a c", true },
|
||||||
|
{ "a c", true },
|
||||||
|
{ "a c", true },
|
||||||
|
{ "a b c", true },
|
||||||
|
|
||||||
|
{ "GET invaliduri HTTP/1.1", true },
|
||||||
|
{ "GET /validuri HTTP/1.1", false },
|
||||||
|
|
||||||
|
{ "POST /validuri HTTP/1.1", true },
|
||||||
|
{ "PUT /validuri HTTP/1.1", true },
|
||||||
|
{ "DELETE /validuri HTTP/1.1", true },
|
||||||
|
{ "OPTIONS /validuri HTTP/1.1", true },
|
||||||
|
{ "UNKNOWN /validuri HTTP/1.1", true },
|
||||||
|
|
||||||
|
{ "GET / HTTP/52", true },
|
||||||
|
{ "GET / HTTP/1.", true },
|
||||||
|
{ "GET / HTTP/.1", true },
|
||||||
|
{ "GET / HTTP/1.1", false },
|
||||||
|
{ "GET / HTTP/2", false },
|
||||||
|
}
|
||||||
|
|
||||||
|
for ti, tc := range cases {
|
||||||
|
|
||||||
|
socket.Reset()
|
||||||
|
socket.Write( []byte(tc.Reqline) )
|
||||||
|
socket.Write( []byte("\r\n\r\n") )
|
||||||
|
|
||||||
|
_, err := Parse(socket)
|
||||||
|
|
||||||
|
if !tc.HasError {
|
||||||
|
|
||||||
|
// no error -> ok
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
// error for the end of the request -> ok
|
||||||
|
} else if _, ok := err.(*IncompleteRequest); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Errorf("[%d] Expected no error", ti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing required error -> error
|
||||||
|
if tc.HasError && err == nil {
|
||||||
|
t.Errorf("[%d] Expected error", ti)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ir, ok := err.(*InvalidRequest);
|
||||||
|
|
||||||
|
// not InvalidRequest err -> error
|
||||||
|
if !ok || ir.Field != "Request-Line" {
|
||||||
|
t.Errorf("[%d] expected InvalidRequest", ti)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -1,6 +1,5 @@
|
||||||
package request
|
package request
|
||||||
|
|
||||||
import "git.xdrm.io/gws/internal/http/upgrade/request/parser/reqline"
|
|
||||||
import "git.xdrm.io/gws/internal/http/upgrade/response"
|
import "git.xdrm.io/gws/internal/http/upgrade/response"
|
||||||
|
|
||||||
// If origin is required
|
// If origin is required
|
||||||
|
@ -14,7 +13,7 @@ type T struct {
|
||||||
code response.StatusCode
|
code response.StatusCode
|
||||||
|
|
||||||
// request line
|
// request line
|
||||||
request reqline.T
|
request RequestLine
|
||||||
|
|
||||||
// data to check origin (depends of reading order)
|
// data to check origin (depends of reading order)
|
||||||
host string
|
host string
|
||||||
|
|
Loading…
Reference in New Issue