refactor upgrade.request + merge 'upgrade.reqline' + began tests

This commit is contained in:
xdrm-brackets 2018-05-04 08:41:40 +02:00
parent b713011e7b
commit dc510ad5d9
10 changed files with 298 additions and 163 deletions

View File

@ -0,0 +1,35 @@
package request
import (
"fmt"
)
// invalid request
// - multiple-value if only 1 expected
type InvalidRequest struct {
Field string
Reason string
}
func (err InvalidRequest) Error() string {
return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason)
}
// Request misses fields (request-line or headers)
type IncompleteRequest struct {
MissingField string
}
func (err IncompleteRequest) Error() string {
return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField)
}
// Request has a violated origin policy
type InvalidOriginPolicy struct {
Host string
Origin string
err error
}
func (err InvalidOriginPolicy) Error() string {
return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err)
}

View File

@ -14,7 +14,7 @@ import (
func (r *T) extractHostPort(bb header.HeaderValue) error { func (r *T) extractHostPort(bb header.HeaderValue) error {
if len(bb) != 1 { if len(bb) != 1 {
return fmt.Errorf("Host header must have a unique value") return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
split := strings.Split(string(bb[0]), ":") split := strings.Split(string(bb[0]), ":")
@ -30,7 +30,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
readPort, err := strconv.ParseUint(split[1], 10, 16) readPort, err := strconv.ParseUint(split[1], 10, 16)
if err != nil { if err != nil {
r.code = response.BAD_REQUEST r.code = response.BAD_REQUEST
return fmt.Errorf("Cannot read port number '%s'", split[1]) return &InvalidRequest{"Host", "cannot read port"}
} }
r.port = uint16(readPort) r.port = uint16(readPort)
@ -40,7 +40,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error {
if err != nil { if err != nil {
err = r.checkOriginPolicy() err = r.checkOriginPolicy()
r.code = response.FORBIDDEN r.code = response.FORBIDDEN
return err return &InvalidOriginPolicy{r.host, r.origin, err}
} }
} }
@ -57,7 +57,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
if len(bb) != 1 { if len(bb) != 1 {
r.code = response.FORBIDDEN r.code = response.FORBIDDEN
return fmt.Errorf("Origin header must have a unique value") return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
r.origin = string(bb[0]) r.origin = string(bb[0])
@ -67,7 +67,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error {
err := r.checkOriginPolicy() err := r.checkOriginPolicy()
if err != nil { if err != nil {
r.code = response.FORBIDDEN r.code = response.FORBIDDEN
return err return &InvalidOriginPolicy{r.host, r.origin, err}
} }
} }
@ -96,7 +96,7 @@ func (r *T) checkConnection(bb header.HeaderValue) error {
} }
r.code = response.BAD_REQUEST r.code = response.BAD_REQUEST
return fmt.Errorf("Connection header must be 'Upgrade'") return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"}
} }
@ -106,7 +106,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
if len(bb) != 1 { if len(bb) != 1 {
r.code = response.BAD_REQUEST r.code = response.BAD_REQUEST
return fmt.Errorf("Upgrade header must have only 1 element") return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))}
} }
if strings.ToLower( string(bb[0]) ) == "websocket" { if strings.ToLower( string(bb[0]) ) == "websocket" {
@ -115,7 +115,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error {
} }
r.code = response.BAD_REQUEST r.code = response.BAD_REQUEST
return fmt.Errorf("Upgrade header must be 'websocket', got '%s'", bb[0]) return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])}
} }
@ -125,7 +125,7 @@ func (r *T) checkVersion(bb header.HeaderValue) error {
if len(bb) != 1 || string(bb[0]) != "13" { if len(bb) != 1 || string(bb[0]) != "13" {
r.code = response.UPGRADE_REQUIRED r.code = response.UPGRADE_REQUIRED
return fmt.Errorf("Sec-WebSocket-Version header must be '13'") return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])}
} }
r.hasVersion = true r.hasVersion = true
@ -139,7 +139,7 @@ func (r *T) extractKey(bb header.HeaderValue) error {
if len(bb) != 1 || len(bb[0]) != 24 { if len(bb) != 1 || len(bb[0]) != 24 {
r.code = response.BAD_REQUEST r.code = response.BAD_REQUEST
return fmt.Errorf("Sec-WebSocket-Key header must be a unique 24 bytes base64 value, got %d bytes", len(bb[0])) return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))}
} }
r.key = bb[0] r.key = bb[0]

View File

@ -1,73 +0,0 @@
package reqline
import (
"regexp"
"fmt"
)
// extractHttpMethod extracts the HTTP method from a []byte
// and checks for errors
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
func (r *T) extractHttpMethod(b []byte) error {
switch string(b) {
case "OPTIONS": r.method = OPTIONS
case "GET": r.method = GET
case "HEAD": r.method = HEAD
case "POST": r.method = POST
case "PUT": r.method = PUT
case "DELETE": r.method = DELETE
default:
return fmt.Errorf("Unknown HTTP method '%s'", b)
}
return nil
}
// extractURI extracts the URI from a []byte and checks for errors
// allowed format: /([^/]/)*/?
func (r *T) extractURI(b []byte) error {
/* (1) Check format */
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
if !checker.Match(b) {
return fmt.Errorf("Invalid URI format, expected an absolute path (starts with /), got '%s'", b)
}
/* (2) Store */
r.uri = string(b)
return nil
}
// extractHttpVersion extracts the version and checks for errors
// allowed format: [1-9] or [1.9].[0-9]
func (r *T) extractHttpVersion(b []byte) error {
/* (1) Extract version parts */
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`);
if !extractor.Match(b) {
return fmt.Errorf("Cannot parse HTTP version, expected INT or INT.INT, got '%s'", b)
}
/* (2) Extract version number */
matches := extractor.FindSubmatch(b)
var version byte = matches[1][0] - '0'
/* (3) Extract subversion (if exists) */
var subVersion byte = 0
if len(matches[2]) > 0 {
subVersion = matches[2][0] - '0'
}
/* (4) Store version (x 10 to fit uint8) */
r.version = version * 10 + subVersion
return nil
}

View File

@ -1,45 +0,0 @@
package reqline
import (
"fmt"
"bytes"
)
// parseRequestLine parses the first HTTP request line
func (r *T) Parse(b []byte) error {
/* (1) Split by ' ' */
parts := bytes.Split(b, []byte(" "))
/* (2) Fail when missing parts */
if len(parts) != 3 {
return fmt.Errorf("Malformed Request-Line must have 3 space-separated elements, got %d", len(parts))
}
/* (3) Extract HTTP method */
err := r.extractHttpMethod(parts[0])
if err != nil {
return err
}
/* (4) Extract URI */
err = r.extractURI(parts[1])
if err != nil {
return err
}
/* (5) Extract version */
err = r.extractHttpVersion(parts[2])
if err != nil {
return err
}
return nil
}
// GetURI returns the actual URI
func (r T) GetURI() string {
return r.uri
}

View File

@ -1,21 +0,0 @@
package reqline
// httpMethod represents available http methods
type httpMethod byte
const (
OPTIONS httpMethod = iota
GET
HEAD
POST
PUT
DELETE
)
// httpRequestLine represents the HTTP Request line
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
type T struct {
method httpMethod
uri string
version byte
}

View File

@ -19,7 +19,7 @@ func (r *T) parseHeader(b []byte) error {
if err != nil { if err != nil {
r.code = response.BAD_REQUEST r.code = response.BAD_REQUEST
return fmt.Errorf("Error while parsing first line: %s", err) return &InvalidRequest{"Request-Line", err.Error()}
} }
r.first = true r.first = true
@ -53,8 +53,8 @@ func (r *T) parseHeader(b []byte) error {
} }
// dispatch error
if err != nil { if err != nil {
fmt.Printf("ERR: %s\n", err)
return err return err
} }
@ -70,37 +70,37 @@ func (r T) isComplete() error {
/* (1) Request-Line */ /* (1) Request-Line */
if !r.first { if !r.first {
return fmt.Errorf("Missing HTTP Request-Line"); return &IncompleteRequest{"Request-Line"}
} }
/* (2) Host */ /* (2) Host */
if len(r.host) == 0 { if len(r.host) == 0 {
return fmt.Errorf("Missing 'Host' header") return &IncompleteRequest{"Host"}
} }
/* (3) Origin */ /* (3) Origin */
if !bypassOriginPolicy && len(r.origin) == 0 { if !bypassOriginPolicy && len(r.origin) == 0 {
return fmt.Errorf("Missing 'Origin' header") return &IncompleteRequest{"Origin"}
} }
/* (4) Connection */ /* (4) Connection */
if !r.hasConnection { if !r.hasConnection {
return fmt.Errorf("Missing 'Connection' header"); return &IncompleteRequest{"Connection"}
} }
/* (5) Upgrade */ /* (5) Upgrade */
if !r.hasUpgrade { if !r.hasUpgrade {
return fmt.Errorf("Missing 'Upgrade' header"); return &IncompleteRequest{"Upgrade"}
} }
/* (6) Sec-WebSocket-Version */ /* (6) Sec-WebSocket-Version */
if !r.hasVersion { if !r.hasVersion {
return fmt.Errorf("Missing 'Sec-WebSocket-Version' header"); return &IncompleteRequest{"Sec-WebSocket-Version"}
} }
/* (7) Sec-WebSocket-Key */ /* (7) Sec-WebSocket-Key */
if len(r.key) < 1 { if len(r.key) < 1 {
return fmt.Errorf("Missing 'Sec-WebSocket-Key' header"); return &IncompleteRequest{"Sec-WebSocket-Key"}
} }
return nil return nil

View File

@ -32,13 +32,13 @@ func Parse(r io.Reader) (request *T, err error) {
} }
if err != nil { if err != nil {
return req, fmt.Errorf("Cannot read from reader: %s", err) return req, err
} }
err = req.parseHeader(line) err = req.parseHeader(line)
if err != nil { if err != nil {
return req, fmt.Errorf("Parsing error: %s\n", err); return req, err
} }
} }
@ -46,7 +46,6 @@ func Parse(r io.Reader) (request *T, err error) {
/* (3) Check completion */ /* (3) Check completion */
err = req.isComplete() err = req.isComplete()
if err != nil { if err != nil {
fmt.Printf("not complete: %s\b", err)
req.code = response.BAD_REQUEST req.code = response.BAD_REQUEST
return req, err return req, err
} }

View File

@ -0,0 +1,135 @@
package request
import (
"fmt"
"bytes"
"regexp"
)
// httpMethod represents available http methods
type httpMethod byte
const (
OPTIONS httpMethod = iota
GET
HEAD
POST
PUT
DELETE
)
// RequestLine represents the HTTP Request line
// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1
type RequestLine struct {
method httpMethod
uri string
version byte
}
// parseRequestLine parses the first HTTP request line
func (r *RequestLine) Parse(b []byte) error {
/* (1) Split by ' ' */
parts := bytes.Split(b, []byte(" "))
/* (2) Fail when missing parts */
if len(parts) != 3 {
return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts))
}
/* (3) Extract HTTP method */
err := r.extractHttpMethod(parts[0])
if err != nil {
return err
}
/* (4) Extract URI */
err = r.extractURI(parts[1])
if err != nil {
return err
}
/* (5) Extract version */
err = r.extractHttpVersion(parts[2])
if err != nil {
return err
}
return nil
}
// GetURI returns the actual URI
func (r RequestLine) GetURI() string {
return r.uri
}
// extractHttpMethod extracts the HTTP method from a []byte
// and checks for errors
// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE
func (r *RequestLine) extractHttpMethod(b []byte) error {
switch string(b) {
// case "OPTIONS": r.method = OPTIONS
case "GET": r.method = GET
// case "HEAD": r.method = HEAD
// case "POST": r.method = POST
// case "PUT": r.method = PUT
// case "DELETE": r.method = DELETE
default:
return fmt.Errorf("Invalid HTTP method '%s', expected 'GET'", b)
}
return nil
}
// extractURI extracts the URI from a []byte and checks for errors
// allowed format: /([^/]/)*/?
func (r *RequestLine) extractURI(b []byte) error {
/* (1) Check format */
checker := regexp.MustCompile("^(?:/[^/]+)*/?$")
if !checker.Match(b) {
return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b)
}
/* (2) Store */
r.uri = string(b)
return nil
}
// extractHttpVersion extracts the version and checks for errors
// allowed format: [1-9] or [1.9].[0-9]
func (r *RequestLine) extractHttpVersion(b []byte) error {
/* (1) Extract version parts */
extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`);
if !extractor.Match(b) {
return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b)
}
/* (2) Extract version number */
matches := extractor.FindSubmatch(b)
var version byte = matches[1][0] - '0'
/* (3) Extract subversion (if exists) */
var subVersion byte = 0
if len(matches[2]) > 0 {
subVersion = matches[2][0] - '0'
}
/* (4) Store version (x 10 to fit uint8) */
r.version = version * 10 + subVersion
return nil
}

View File

@ -0,0 +1,106 @@
package request
import (
"io"
"bytes"
"testing"
)
// /* (1) Parse request */
// req, _ := request.Parse(s)
// /* (3) Build response */
// res := req.BuildResponse()
// /* (4) Write into socket */
// _, err := res.Send(s)
// if err != nil {
// return nil, fmt.Errorf("Upgrade write error: %s", err)
// }
// if res.GetStatusCode() != 101 {
// s.Close()
// return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode())
// }
func TestEOFSocket(t *testing.T){
socket := new(bytes.Buffer)
_, err := Parse(socket)
if err == nil {
t.Fatalf("Empty socket expected EOF, got no error")
} else if err != io.ErrUnexpectedEOF {
t.Fatalf("Empty socket expected EOF, got '%s'", err)
}
}
func TestInvalidRequestLine(t *testing.T){
socket := new(bytes.Buffer)
cases := []struct{
Reqline string
HasError bool
}{
{ "abc", true },
{ "a c", true },
{ "a c", true },
{ "a c", true },
{ "a b c", true },
{ "GET invaliduri HTTP/1.1", true },
{ "GET /validuri HTTP/1.1", false },
{ "POST /validuri HTTP/1.1", true },
{ "PUT /validuri HTTP/1.1", true },
{ "DELETE /validuri HTTP/1.1", true },
{ "OPTIONS /validuri HTTP/1.1", true },
{ "UNKNOWN /validuri HTTP/1.1", true },
{ "GET / HTTP/52", true },
{ "GET / HTTP/1.", true },
{ "GET / HTTP/.1", true },
{ "GET / HTTP/1.1", false },
{ "GET / HTTP/2", false },
}
for ti, tc := range cases {
socket.Reset()
socket.Write( []byte(tc.Reqline) )
socket.Write( []byte("\r\n\r\n") )
_, err := Parse(socket)
if !tc.HasError {
// no error -> ok
if err == nil {
continue
// error for the end of the request -> ok
} else if _, ok := err.(*IncompleteRequest); ok {
continue
}
t.Errorf("[%d] Expected no error", ti)
}
// missing required error -> error
if tc.HasError && err == nil {
t.Errorf("[%d] Expected error", ti)
continue
}
ir, ok := err.(*InvalidRequest);
// not InvalidRequest err -> error
if !ok || ir.Field != "Request-Line" {
t.Errorf("[%d] expected InvalidRequest", ti)
continue
}
}
}

View File

@ -1,6 +1,5 @@
package request package request
import "git.xdrm.io/gws/internal/http/upgrade/request/parser/reqline"
import "git.xdrm.io/gws/internal/http/upgrade/response" import "git.xdrm.io/gws/internal/http/upgrade/response"
// If origin is required // If origin is required
@ -14,7 +13,7 @@ type T struct {
code response.StatusCode code response.StatusCode
// request line // request line
request reqline.T request RequestLine
// data to check origin (depends of reading order) // data to check origin (depends of reading order)
host string host string