refactor upgrade.request + merge 'upgrade.reqline' + began tests
This commit is contained in:
parent
b713011e7b
commit
dc510ad5d9
|
@ -0,0 +1,35 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
// invalid request
|
||||
// - multiple-value if only 1 expected
|
||||
type InvalidRequest struct {
|
||||
Field string
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (err InvalidRequest) Error() string {
|
||||
return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason)
|
||||
}
|
||||
|
||||
// Request misses fields (request-line or headers)
|
||||
type IncompleteRequest struct {
|
||||
MissingField string
|
||||
}
|
||||
|
||||
func (err IncompleteRequest) Error() string {
|
||||
return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField)
|
||||
}
|
||||
|
||||
|
||||
// Request has a violated origin policy
|
||||
type InvalidOriginPolicy struct {
|
||||
Host string
|
||||
Origin string
|
||||
err error
|
||||
}
|
||||
func (err InvalidOriginPolicy) Error() string {
|
||||
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
|
||||
}
|
|
@ -14,7 +14,7 @@ import (
|
|||
func (r *T) extractHostPort(bb header.HeaderValue) error {
|
||||
|
||||
if len(bb) != 1 {
|
||||
return fmt.Errorf("Host header must have a unique value")
|
||||
return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
}
|
||||
|
||||
split := strings.Split(string(bb[0]), ":")
|
||||
|
@ -30,7 +30,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
|
|||
readPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||
if err != nil {
|
||||
r.code = response.BAD_REQUEST
|
||||
return fmt.Errorf("Cannot read port number '%s'", split[1])
|
||||
return &InvalidRequest{"Host", "cannot read port"}
|
||||
}
|
||||
|
||||
r.port = uint16(readPort)
|
||||
|
@ -40,7 +40,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
|
|||
if err != nil {
|
||||
err = r.checkOriginPolicy()
|
||||
r.code = response.FORBIDDEN
|
||||
return err
|
||||
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
|
|||
|
||||
if len(bb) != 1 {
|
||||
r.code = response.FORBIDDEN
|
||||
return fmt.Errorf("Origin header must have a unique value")
|
||||
return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
}
|
||||
|
||||
r.origin = string(bb[0])
|
||||
|
@ -67,7 +67,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
|
|||
err := r.checkOriginPolicy()
|
||||
if err != nil {
|
||||
r.code = response.FORBIDDEN
|
||||
return err
|
||||
return &InvalidOriginPolicy{r.host, r.origin, err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,7 +96,7 @@ func (r *T) checkConnection(bb header.HeaderValue) error {
|
|||
}
|
||||
|
||||
r.code = response.BAD_REQUEST
|
||||
return fmt.Errorf("Connection header must be 'Upgrade'")
|
||||
return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
|
||||
|
||||
}
|
||||
|
||||
|
@ -106,7 +106,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
|
|||
|
||||
if len(bb) != 1 {
|
||||
r.code = response.BAD_REQUEST
|
||||
return fmt.Errorf("Upgrade header must have only 1 element")
|
||||
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
|
||||
}
|
||||
|
||||
if strings.ToLower( string(bb[0]) ) == "websocket" {
|
||||
|
@ -115,7 +115,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
|
|||
}
|
||||
|
||||
r.code = response.BAD_REQUEST
|
||||
return fmt.Errorf("Upgrade header must be 'websocket', got '%s'", bb[0])
|
||||
return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
|
||||
|
||||
}
|
||||
|
||||
|
@ -125,7 +125,7 @@ func (r *T) checkVersion(bb header.HeaderValue) error {
|
|||
|
||||
if len(bb) != 1 || string(bb[0]) != "13" {
|
||||
r.code = response.UPGRADE_REQUIRED
|
||||
return fmt.Errorf("Sec-WebSocket-Version header must be '13'")
|
||||
return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
|
||||
}
|
||||
|
||||
r.hasVersion = true
|
||||
|
@ -139,7 +139,7 @@ func (r *T) extractKey(bb header.HeaderValue) error {
|
|||
|
||||
if len(bb) != 1 || len(bb[0]) != 24 {
|
||||
r.code = response.BAD_REQUEST
|
||||
return fmt.Errorf("Sec-WebSocket-Key header must be a unique 24 bytes base64 value, got %d bytes", len(bb[0]))
|
||||
return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
|
||||
}
|
||||
|
||||
r.key = bb[0]
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
package reqline
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
||||
// extractHttpMethod extracts the HTTP method from a []byte
|
||||
// and checks for errors
|
||||
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
||||
func (r *T) extractHttpMethod(b []byte) error {
|
||||
|
||||
switch string(b) {
|
||||
case "OPTIONS": r.method = OPTIONS
|
||||
case "GET": r.method = GET
|
||||
case "HEAD": r.method = HEAD
|
||||
case "POST": r.method = POST
|
||||
case "PUT": r.method = PUT
|
||||
case "DELETE": r.method = DELETE
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Unknown HTTP method '%s'", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// extractURI extracts the URI from a []byte and checks for errors
|
||||
// allowed format: /([^/]/)*/?
|
||||
func (r *T) extractURI(b []byte) error {
|
||||
|
||||
/* (1) Check format */
|
||||
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
||||
if !checker.Match(b) {
|
||||
return fmt.Errorf("Invalid URI format, expected an absolute path (starts with /), got '%s'", b)
|
||||
}
|
||||
|
||||
/* (2) Store */
|
||||
r.uri = string(b)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
|
||||
// extractHttpVersion extracts the version and checks for errors
|
||||
// allowed format: [1-9] or [1.9].[0-9]
|
||||
func (r *T) extractHttpVersion(b []byte) error {
|
||||
|
||||
/* (1) Extract version parts */
|
||||
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`);
|
||||
|
||||
if !extractor.Match(b) {
|
||||
return fmt.Errorf("Cannot parse HTTP version, expected INT or INT.INT, got '%s'", b)
|
||||
}
|
||||
|
||||
/* (2) Extract version number */
|
||||
matches := extractor.FindSubmatch(b)
|
||||
var version byte = matches[1][0] - '0'
|
||||
|
||||
/* (3) Extract subversion (if exists) */
|
||||
var subVersion byte = 0
|
||||
if len(matches[2]) > 0 {
|
||||
subVersion = matches[2][0] - '0'
|
||||
}
|
||||
|
||||
/* (4) Store version (x 10 to fit uint8) */
|
||||
r.version = version * 10 + subVersion
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
package reqline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// parseRequestLine parses the first HTTP request line
|
||||
func (r *T) Parse(b []byte) error {
|
||||
|
||||
/* (1) Split by ' ' */
|
||||
parts := bytes.Split(b, []byte(" "))
|
||||
|
||||
/* (2) Fail when missing parts */
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("Malformed Request-Line must have 3 space-separated elements, got %d", len(parts))
|
||||
}
|
||||
|
||||
/* (3) Extract HTTP method */
|
||||
err := r.extractHttpMethod(parts[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
/* (4) Extract URI */
|
||||
err = r.extractURI(parts[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
/* (5) Extract version */
|
||||
err = r.extractHttpVersion(parts[2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
|
||||
// GetURI returns the actual URI
|
||||
func (r T) GetURI() string {
|
||||
return r.uri
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package reqline
|
||||
|
||||
// httpMethod represents available http methods
|
||||
type httpMethod byte
|
||||
const (
|
||||
OPTIONS httpMethod = iota
|
||||
GET
|
||||
HEAD
|
||||
POST
|
||||
PUT
|
||||
DELETE
|
||||
)
|
||||
|
||||
|
||||
// httpRequestLine represents the HTTP Request line
|
||||
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
||||
type T struct {
|
||||
method httpMethod
|
||||
uri string
|
||||
version byte
|
||||
}
|
|
@ -19,7 +19,7 @@ func (r *T) parseHeader(b []byte) error {
|
|||
|
||||
if err != nil {
|
||||
r.code = response.BAD_REQUEST
|
||||
return fmt.Errorf("Error while parsing first line: %s", err)
|
||||
return &InvalidRequest{"Request-Line", err.Error()}
|
||||
}
|
||||
|
||||
r.first = true
|
||||
|
@ -53,8 +53,8 @@ func (r *T) parseHeader(b []byte) error {
|
|||
}
|
||||
|
||||
|
||||
// dispatch error
|
||||
if err != nil {
|
||||
fmt.Printf("ERR: %s\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -70,37 +70,37 @@ func (r T) isComplete() error {
|
|||
|
||||
/* (1) Request-Line */
|
||||
if !r.first {
|
||||
return fmt.Errorf("Missing HTTP Request-Line");
|
||||
return &IncompleteRequest{"Request-Line"}
|
||||
}
|
||||
|
||||
/* (2) Host */
|
||||
if len(r.host) == 0 {
|
||||
return fmt.Errorf("Missing 'Host' header")
|
||||
return &IncompleteRequest{"Host"}
|
||||
}
|
||||
|
||||
/* (3) Origin */
|
||||
if !bypassOriginPolicy && len(r.origin) == 0 {
|
||||
return fmt.Errorf("Missing 'Origin' header")
|
||||
return &IncompleteRequest{"Origin"}
|
||||
}
|
||||
|
||||
/* (4) Connection */
|
||||
if !r.hasConnection {
|
||||
return fmt.Errorf("Missing 'Connection' header");
|
||||
return &IncompleteRequest{"Connection"}
|
||||
}
|
||||
|
||||
/* (5) Upgrade */
|
||||
if !r.hasUpgrade {
|
||||
return fmt.Errorf("Missing 'Upgrade' header");
|
||||
return &IncompleteRequest{"Upgrade"}
|
||||
}
|
||||
|
||||
/* (6) Sec-WebSocket-Version */
|
||||
if !r.hasVersion {
|
||||
return fmt.Errorf("Missing 'Sec-WebSocket-Version' header");
|
||||
return &IncompleteRequest{"Sec-WebSocket-Version"}
|
||||
}
|
||||
|
||||
/* (7) Sec-WebSocket-Key */
|
||||
if len(r.key) < 1 {
|
||||
return fmt.Errorf("Missing 'Sec-WebSocket-Key' header");
|
||||
return &IncompleteRequest{"Sec-WebSocket-Key"}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -32,13 +32,13 @@ func Parse(r io.Reader) (request *T, err error) {
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return req, fmt.Errorf("Cannot read from reader: %s", err)
|
||||
return req, err
|
||||
}
|
||||
|
||||
err = req.parseHeader(line)
|
||||
|
||||
if err != nil {
|
||||
return req, fmt.Errorf("Parsing error: %s\n", err);
|
||||
return req, err
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -46,7 +46,6 @@ func Parse(r io.Reader) (request *T, err error) {
|
|||
/* (3) Check completion */
|
||||
err = req.isComplete()
|
||||
if err != nil {
|
||||
fmt.Printf("not complete: %s\b", err)
|
||||
req.code = response.BAD_REQUEST
|
||||
return req, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// httpMethod represents available http methods
|
||||
type httpMethod byte
|
||||
const (
|
||||
OPTIONS httpMethod = iota
|
||||
GET
|
||||
HEAD
|
||||
POST
|
||||
PUT
|
||||
DELETE
|
||||
)
|
||||
|
||||
|
||||
// RequestLine represents the HTTP Request line
|
||||
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
|
||||
type RequestLine struct {
|
||||
method httpMethod
|
||||
uri string
|
||||
version byte
|
||||
}
|
||||
|
||||
|
||||
// parseRequestLine parses the first HTTP request line
|
||||
func (r *RequestLine) Parse(b []byte) error {
|
||||
|
||||
/* (1) Split by ' ' */
|
||||
parts := bytes.Split(b, []byte(" "))
|
||||
|
||||
/* (2) Fail when missing parts */
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
|
||||
}
|
||||
|
||||
/* (3) Extract HTTP method */
|
||||
err := r.extractHttpMethod(parts[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
/* (4) Extract URI */
|
||||
err = r.extractURI(parts[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
/* (5) Extract version */
|
||||
err = r.extractHttpVersion(parts[2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
|
||||
// GetURI returns the actual URI
|
||||
func (r RequestLine) GetURI() string {
|
||||
return r.uri
|
||||
}
|
||||
|
||||
|
||||
|
||||
// extractHttpMethod extracts the HTTP method from a []byte
|
||||
// and checks for errors
|
||||
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
|
||||
func (r *RequestLine) extractHttpMethod(b []byte) error {
|
||||
|
||||
switch string(b) {
|
||||
// case "OPTIONS": r.method = OPTIONS
|
||||
case "GET": r.method = GET
|
||||
// case "HEAD": r.method = HEAD
|
||||
// case "POST": r.method = POST
|
||||
// case "PUT": r.method = PUT
|
||||
// case "DELETE": r.method = DELETE
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Invalid HTTP method '%s', expected 'GET'", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// extractURI extracts the URI from a []byte and checks for errors
|
||||
// allowed format: /([^/]/)*/?
|
||||
func (r *RequestLine) extractURI(b []byte) error {
|
||||
|
||||
/* (1) Check format */
|
||||
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
|
||||
if !checker.Match(b) {
|
||||
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
|
||||
}
|
||||
|
||||
/* (2) Store */
|
||||
r.uri = string(b)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
|
||||
// extractHttpVersion extracts the version and checks for errors
|
||||
// allowed format: [1-9] or [1.9].[0-9]
|
||||
func (r *RequestLine) extractHttpVersion(b []byte) error {
|
||||
|
||||
/* (1) Extract version parts */
|
||||
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`);
|
||||
|
||||
if !extractor.Match(b) {
|
||||
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
|
||||
}
|
||||
|
||||
/* (2) Extract version number */
|
||||
matches := extractor.FindSubmatch(b)
|
||||
var version byte = matches[1][0] - '0'
|
||||
|
||||
/* (3) Extract subversion (if exists) */
|
||||
var subVersion byte = 0
|
||||
if len(matches[2]) > 0 {
|
||||
subVersion = matches[2][0] - '0'
|
||||
}
|
||||
|
||||
/* (4) Store version (x 10 to fit uint8) */
|
||||
r.version = version * 10 + subVersion
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"io"
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
// /* (1) Parse request */
|
||||
// req, _ := request.Parse(s)
|
||||
|
||||
// /* (3) Build response */
|
||||
// res := req.BuildResponse()
|
||||
|
||||
// /* (4) Write into socket */
|
||||
// _, err := res.Send(s)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("Upgrade write error: %s", err)
|
||||
// }
|
||||
|
||||
// if res.GetStatusCode() != 101 {
|
||||
// s.Close()
|
||||
// return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
|
||||
// }
|
||||
|
||||
func TestEOFSocket(t *testing.T){
|
||||
|
||||
socket := new(bytes.Buffer)
|
||||
|
||||
_, err := Parse(socket)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Empty socket expected EOF, got no error")
|
||||
} else if err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("Empty socket expected EOF, got '%s'", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInvalidRequestLine(t *testing.T){
|
||||
|
||||
socket := new(bytes.Buffer)
|
||||
cases := []struct{
|
||||
Reqline string
|
||||
HasError bool
|
||||
}{
|
||||
{ "abc", true },
|
||||
{ "a c", true },
|
||||
{ "a c", true },
|
||||
{ "a c", true },
|
||||
{ "a b c", true },
|
||||
|
||||
{ "GET invaliduri HTTP/1.1", true },
|
||||
{ "GET /validuri HTTP/1.1", false },
|
||||
|
||||
{ "POST /validuri HTTP/1.1", true },
|
||||
{ "PUT /validuri HTTP/1.1", true },
|
||||
{ "DELETE /validuri HTTP/1.1", true },
|
||||
{ "OPTIONS /validuri HTTP/1.1", true },
|
||||
{ "UNKNOWN /validuri HTTP/1.1", true },
|
||||
|
||||
{ "GET / HTTP/52", true },
|
||||
{ "GET / HTTP/1.", true },
|
||||
{ "GET / HTTP/.1", true },
|
||||
{ "GET / HTTP/1.1", false },
|
||||
{ "GET / HTTP/2", false },
|
||||
}
|
||||
|
||||
for ti, tc := range cases {
|
||||
|
||||
socket.Reset()
|
||||
socket.Write( []byte(tc.Reqline) )
|
||||
socket.Write( []byte("\r\n\r\n") )
|
||||
|
||||
_, err := Parse(socket)
|
||||
|
||||
if !tc.HasError {
|
||||
|
||||
// no error -> ok
|
||||
if err == nil {
|
||||
continue
|
||||
// error for the end of the request -> ok
|
||||
} else if _, ok := err.(*IncompleteRequest); ok {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Errorf("[%d] Expected no error", ti)
|
||||
}
|
||||
|
||||
// missing required error -> error
|
||||
if tc.HasError && err == nil {
|
||||
t.Errorf("[%d] Expected error", ti)
|
||||
continue
|
||||
}
|
||||
|
||||
ir, ok := err.(*InvalidRequest);
|
||||
|
||||
// not InvalidRequest err -> error
|
||||
if !ok || ir.Field != "Request-Line" {
|
||||
t.Errorf("[%d] expected InvalidRequest", ti)
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
package request
|
||||
|
||||
import "git.xdrm.io/gws/internal/http/upgrade/request/parser/reqline"
|
||||
import "git.xdrm.io/gws/internal/http/upgrade/response"
|
||||
|
||||
// If origin is required
|
||||
|
@ -14,7 +13,7 @@ type T struct {
|
|||
code response.StatusCode
|
||||
|
||||
// request line
|
||||
request reqline.T
|
||||
request RequestLine
|
||||
|
||||
// data to check origin (depends of reading order)
|
||||
host string
|
||||
|
|
Loading…
Reference in New Issue