diff --git a/internal/http/upgrade/request/errors.go b/internal/http/upgrade/request/errors.go new file mode 100644 index 0000000..7133505 --- /dev/null +++ b/internal/http/upgrade/request/errors.go @@ -0,0 +1,35 @@ +package request + +import ( + "fmt" +) +// invalid request +// - multiple-value if only 1 expected +type InvalidRequest struct { + Field string + Reason string +} + +func (err InvalidRequest) Error() string { + return fmt.Sprintf("Invalid field '%s': %s", err.Field, err.Reason) +} + +// Request misses fields (request-line or headers) +type IncompleteRequest struct { + MissingField string +} + +func (err IncompleteRequest) Error() string { + return fmt.Sprintf("imcomplete request, '%s' is invalid or missing", err.MissingField) +} + + +// Request has a violated origin policy +type InvalidOriginPolicy struct { + Host string + Origin string + err error +} +func (err InvalidOriginPolicy) Error() string { + return fmt.Sprintf("invalid origin policy; (host: '%s' origin: '%s' error: '%s')", err.Host, err.Origin, err.err) +} \ No newline at end of file diff --git a/internal/http/upgrade/request/header_check.go b/internal/http/upgrade/request/header_check.go index e56678d..9b41bb7 100644 --- a/internal/http/upgrade/request/header_check.go +++ b/internal/http/upgrade/request/header_check.go @@ -14,7 +14,7 @@ import ( func (r *T) extractHostPort(bb header.HeaderValue) error { if len(bb) != 1 { - return fmt.Errorf("Host header must have a unique value") + return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))} } split := strings.Split(string(bb[0]), ":") @@ -30,7 +30,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error { readPort, err := strconv.ParseUint(split[1], 10, 16) if err != nil { r.code = response.BAD_REQUEST - return fmt.Errorf("Cannot read port number '%s'", split[1]) + return &InvalidRequest{"Host", "cannot read port"} } r.port = uint16(readPort) @@ -40,7 +40,7 @@ func (r *T) extractHostPort(bb header.HeaderValue) error { if err != nil { err = r.checkOriginPolicy() r.code = response.FORBIDDEN - return err + return &InvalidOriginPolicy{r.host, r.origin, err} } } @@ -57,7 +57,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error { if len(bb) != 1 { r.code = response.FORBIDDEN - return fmt.Errorf("Origin header must have a unique value") + return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))} } r.origin = string(bb[0]) @@ -67,7 +67,7 @@ func (r *T) extractOrigin(bb header.HeaderValue) error { err := r.checkOriginPolicy() if err != nil { r.code = response.FORBIDDEN - return err + return &InvalidOriginPolicy{r.host, r.origin, err} } } @@ -96,7 +96,7 @@ func (r *T) checkConnection(bb header.HeaderValue) error { } r.code = response.BAD_REQUEST - return fmt.Errorf("Connection header must be 'Upgrade'") + return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"} } @@ -106,7 +106,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error { if len(bb) != 1 { r.code = response.BAD_REQUEST - return fmt.Errorf("Upgrade header must have only 1 element") + return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} } if strings.ToLower( string(bb[0]) ) == "websocket" { @@ -115,7 +115,7 @@ func (r *T) checkUpgrade(bb header.HeaderValue) error { } r.code = response.BAD_REQUEST - return fmt.Errorf("Upgrade header must be 'websocket', got '%s'", bb[0]) + return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])} } @@ -125,7 +125,7 @@ func (r *T) checkVersion(bb header.HeaderValue) error { if len(bb) != 1 || string(bb[0]) != "13" { r.code = response.UPGRADE_REQUIRED - return fmt.Errorf("Sec-WebSocket-Version header must be '13'") + return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])} } r.hasVersion = true @@ -139,7 +139,7 @@ func (r *T) extractKey(bb header.HeaderValue) error { if len(bb) != 1 || len(bb[0]) != 24 { r.code = response.BAD_REQUEST - return fmt.Errorf("Sec-WebSocket-Key header must be a unique 24 bytes base64 value, got %d bytes", len(bb[0])) + return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))} } r.key = bb[0] diff --git a/internal/http/upgrade/request/parser/reqline/private.go b/internal/http/upgrade/request/parser/reqline/private.go deleted file mode 100644 index 43b52a0..0000000 --- a/internal/http/upgrade/request/parser/reqline/private.go +++ /dev/null @@ -1,73 +0,0 @@ -package reqline - -import ( - "regexp" - "fmt" -) - - -// extractHttpMethod extracts the HTTP method from a []byte -// and checks for errors -// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE -func (r *T) extractHttpMethod(b []byte) error { - - switch string(b) { - case "OPTIONS": r.method = OPTIONS - case "GET": r.method = GET - case "HEAD": r.method = HEAD - case "POST": r.method = POST - case "PUT": r.method = PUT - case "DELETE": r.method = DELETE - - default: - return fmt.Errorf("Unknown HTTP method '%s'", b) - } - - return nil -} - - -// extractURI extracts the URI from a []byte and checks for errors -// allowed format: /([^/]/)*/? -func (r *T) extractURI(b []byte) error { - - /* (1) Check format */ - checker := regexp.MustCompile("^(?:/[^/]+)*/?$") - if !checker.Match(b) { - return fmt.Errorf("Invalid URI format, expected an absolute path (starts with /), got '%s'", b) - } - - /* (2) Store */ - r.uri = string(b) - - return nil - -} - - -// extractHttpVersion extracts the version and checks for errors -// allowed format: [1-9] or [1.9].[0-9] -func (r *T) extractHttpVersion(b []byte) error { - - /* (1) Extract version parts */ - extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`); - - if !extractor.Match(b) { - return fmt.Errorf("Cannot parse HTTP version, expected INT or INT.INT, got '%s'", b) - } - - /* (2) Extract version number */ - matches := extractor.FindSubmatch(b) - var version byte = matches[1][0] - '0' - - /* (3) Extract subversion (if exists) */ - var subVersion byte = 0 - if len(matches[2]) > 0 { - subVersion = matches[2][0] - '0' - } - - /* (4) Store version (x 10 to fit uint8) */ - r.version = version * 10 + subVersion - - return nil -} \ No newline at end of file diff --git a/internal/http/upgrade/request/parser/reqline/public.go b/internal/http/upgrade/request/parser/reqline/public.go deleted file mode 100644 index 0d242de..0000000 --- a/internal/http/upgrade/request/parser/reqline/public.go +++ /dev/null @@ -1,45 +0,0 @@ -package reqline - -import ( - "fmt" - "bytes" -) - -// parseRequestLine parses the first HTTP request line -func (r *T) Parse(b []byte) error { - - /* (1) Split by ' ' */ - parts := bytes.Split(b, []byte(" ")) - - /* (2) Fail when missing parts */ - if len(parts) != 3 { - return fmt.Errorf("Malformed Request-Line must have 3 space-separated elements, got %d", len(parts)) - } - - /* (3) Extract HTTP method */ - err := r.extractHttpMethod(parts[0]) - if err != nil { - return err - } - - /* (4) Extract URI */ - err = r.extractURI(parts[1]) - if err != nil { - return err - } - - /* (5) Extract version */ - err = r.extractHttpVersion(parts[2]) - if err != nil { - return err - } - - return nil - -} - - -// GetURI returns the actual URI -func (r T) GetURI() string { - return r.uri -} \ No newline at end of file diff --git a/internal/http/upgrade/request/parser/reqline/types.go b/internal/http/upgrade/request/parser/reqline/types.go deleted file mode 100644 index 7ee809b..0000000 --- a/internal/http/upgrade/request/parser/reqline/types.go +++ /dev/null @@ -1,21 +0,0 @@ -package reqline - -// httpMethod represents available http methods -type httpMethod byte -const ( - OPTIONS httpMethod = iota - GET - HEAD - POST - PUT - DELETE -) - - -// httpRequestLine represents the HTTP Request line -// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1 -type T struct { - method httpMethod - uri string - version byte -} \ No newline at end of file diff --git a/internal/http/upgrade/request/private.go b/internal/http/upgrade/request/private.go index d3dd204..53bda47 100644 --- a/internal/http/upgrade/request/private.go +++ b/internal/http/upgrade/request/private.go @@ -19,7 +19,7 @@ func (r *T) parseHeader(b []byte) error { if err != nil { r.code = response.BAD_REQUEST - return fmt.Errorf("Error while parsing first line: %s", err) + return &InvalidRequest{"Request-Line", err.Error()} } r.first = true @@ -53,8 +53,8 @@ func (r *T) parseHeader(b []byte) error { } + // dispatch error if err != nil { - fmt.Printf("ERR: %s\n", err) return err } @@ -70,37 +70,37 @@ func (r T) isComplete() error { /* (1) Request-Line */ if !r.first { - return fmt.Errorf("Missing HTTP Request-Line"); + return &IncompleteRequest{"Request-Line"} } /* (2) Host */ if len(r.host) == 0 { - return fmt.Errorf("Missing 'Host' header") + return &IncompleteRequest{"Host"} } /* (3) Origin */ if !bypassOriginPolicy && len(r.origin) == 0 { - return fmt.Errorf("Missing 'Origin' header") + return &IncompleteRequest{"Origin"} } /* (4) Connection */ if !r.hasConnection { - return fmt.Errorf("Missing 'Connection' header"); + return &IncompleteRequest{"Connection"} } /* (5) Upgrade */ if !r.hasUpgrade { - return fmt.Errorf("Missing 'Upgrade' header"); + return &IncompleteRequest{"Upgrade"} } /* (6) Sec-WebSocket-Version */ if !r.hasVersion { - return fmt.Errorf("Missing 'Sec-WebSocket-Version' header"); + return &IncompleteRequest{"Sec-WebSocket-Version"} } /* (7) Sec-WebSocket-Key */ if len(r.key) < 1 { - return fmt.Errorf("Missing 'Sec-WebSocket-Key' header"); + return &IncompleteRequest{"Sec-WebSocket-Key"} } return nil diff --git a/internal/http/upgrade/request/public.go b/internal/http/upgrade/request/public.go index eeb834b..9851536 100644 --- a/internal/http/upgrade/request/public.go +++ b/internal/http/upgrade/request/public.go @@ -32,13 +32,13 @@ func Parse(r io.Reader) (request *T, err error) { } if err != nil { - return req, fmt.Errorf("Cannot read from reader: %s", err) + return req, err } err = req.parseHeader(line) if err != nil { - return req, fmt.Errorf("Parsing error: %s\n", err); + return req, err } } @@ -46,7 +46,6 @@ func Parse(r io.Reader) (request *T, err error) { /* (3) Check completion */ err = req.isComplete() if err != nil { - fmt.Printf("not complete: %s\b", err) req.code = response.BAD_REQUEST return req, err } diff --git a/internal/http/upgrade/request/request-line.go b/internal/http/upgrade/request/request-line.go new file mode 100644 index 0000000..fddd8ea --- /dev/null +++ b/internal/http/upgrade/request/request-line.go @@ -0,0 +1,135 @@ +package request + +import ( + "fmt" + "bytes" + "regexp" +) + +// httpMethod represents available http methods +type httpMethod byte +const ( + OPTIONS httpMethod = iota + GET + HEAD + POST + PUT + DELETE +) + + +// RequestLine represents the HTTP Request line +// defined in rfc-2616 : https://tools.ietf.org/html/rfc2616#section-5.1 +type RequestLine struct { + method httpMethod + uri string + version byte +} + + +// parseRequestLine parses the first HTTP request line +func (r *RequestLine) Parse(b []byte) error { + + /* (1) Split by ' ' */ + parts := bytes.Split(b, []byte(" ")) + + /* (2) Fail when missing parts */ + if len(parts) != 3 { + return fmt.Errorf("expected 3 space-separated elements, got %d elements", len(parts)) + } + + /* (3) Extract HTTP method */ + err := r.extractHttpMethod(parts[0]) + if err != nil { + return err + } + + /* (4) Extract URI */ + err = r.extractURI(parts[1]) + if err != nil { + return err + } + + /* (5) Extract version */ + err = r.extractHttpVersion(parts[2]) + if err != nil { + return err + } + + return nil + +} + + +// GetURI returns the actual URI +func (r RequestLine) GetURI() string { + return r.uri +} + + + +// extractHttpMethod extracts the HTTP method from a []byte +// and checks for errors +// allowed format: OPTIONS|GET|HEAD|POST|PUT|DELETE +func (r *RequestLine) extractHttpMethod(b []byte) error { + + switch string(b) { + // case "OPTIONS": r.method = OPTIONS + case "GET": r.method = GET + // case "HEAD": r.method = HEAD + // case "POST": r.method = POST + // case "PUT": r.method = PUT + // case "DELETE": r.method = DELETE + + default: + return fmt.Errorf("Invalid HTTP method '%s', expected 'GET'", b) + } + + return nil +} + + +// extractURI extracts the URI from a []byte and checks for errors +// allowed format: /([^/]/)*/? +func (r *RequestLine) extractURI(b []byte) error { + + /* (1) Check format */ + checker := regexp.MustCompile("^(?:/[^/]+)*/?$") + if !checker.Match(b) { + return fmt.Errorf("invalid URI, expected an absolute path, got '%s'", b) + } + + /* (2) Store */ + r.uri = string(b) + + return nil + +} + + +// extractHttpVersion extracts the version and checks for errors +// allowed format: [1-9] or [1.9].[0-9] +func (r *RequestLine) extractHttpVersion(b []byte) error { + + /* (1) Extract version parts */ + extractor := regexp.MustCompile(`^HTTP/([1-9])(?:\.([0-9]))?$`); + + if !extractor.Match(b) { + return fmt.Errorf("HTTP version, expected INT or INT.INT, got '%s'", b) + } + + /* (2) Extract version number */ + matches := extractor.FindSubmatch(b) + var version byte = matches[1][0] - '0' + + /* (3) Extract subversion (if exists) */ + var subVersion byte = 0 + if len(matches[2]) > 0 { + subVersion = matches[2][0] - '0' + } + + /* (4) Store version (x 10 to fit uint8) */ + r.version = version * 10 + subVersion + + return nil +} \ No newline at end of file diff --git a/internal/http/upgrade/request/request_test.go b/internal/http/upgrade/request/request_test.go new file mode 100644 index 0000000..66e4e3c --- /dev/null +++ b/internal/http/upgrade/request/request_test.go @@ -0,0 +1,106 @@ +package request + +import ( + "io" + "bytes" + "testing" +) +// /* (1) Parse request */ +// req, _ := request.Parse(s) + +// /* (3) Build response */ +// res := req.BuildResponse() + +// /* (4) Write into socket */ +// _, err := res.Send(s) +// if err != nil { +// return nil, fmt.Errorf("Upgrade write error: %s", err) +// } + +// if res.GetStatusCode() != 101 { +// s.Close() +// return nil, fmt.Errorf("Upgrade error (HTTP %d)\n", res.GetStatusCode()) +// } + +func TestEOFSocket(t *testing.T){ + + socket := new(bytes.Buffer) + + _, err := Parse(socket) + + if err == nil { + t.Fatalf("Empty socket expected EOF, got no error") + } else if err != io.ErrUnexpectedEOF { + t.Fatalf("Empty socket expected EOF, got '%s'", err) + } + +} + +func TestInvalidRequestLine(t *testing.T){ + + socket := new(bytes.Buffer) + cases := []struct{ + Reqline string + HasError bool + }{ + { "abc", true }, + { "a c", true }, + { "a c", true }, + { "a c", true }, + { "a b c", true }, + + { "GET invaliduri HTTP/1.1", true }, + { "GET /validuri HTTP/1.1", false }, + + { "POST /validuri HTTP/1.1", true }, + { "PUT /validuri HTTP/1.1", true }, + { "DELETE /validuri HTTP/1.1", true }, + { "OPTIONS /validuri HTTP/1.1", true }, + { "UNKNOWN /validuri HTTP/1.1", true }, + + { "GET / HTTP/52", true }, + { "GET / HTTP/1.", true }, + { "GET / HTTP/.1", true }, + { "GET / HTTP/1.1", false }, + { "GET / HTTP/2", false }, + } + + for ti, tc := range cases { + + socket.Reset() + socket.Write( []byte(tc.Reqline) ) + socket.Write( []byte("\r\n\r\n") ) + + _, err := Parse(socket) + + if !tc.HasError { + + // no error -> ok + if err == nil { + continue + // error for the end of the request -> ok + } else if _, ok := err.(*IncompleteRequest); ok { + continue + } + + t.Errorf("[%d] Expected no error", ti) + } + + // missing required error -> error + if tc.HasError && err == nil { + t.Errorf("[%d] Expected error", ti) + continue + } + + ir, ok := err.(*InvalidRequest); + + // not InvalidRequest err -> error + if !ok || ir.Field != "Request-Line" { + t.Errorf("[%d] expected InvalidRequest", ti) + continue + } + + } + + +} \ No newline at end of file diff --git a/internal/http/upgrade/request/types.go b/internal/http/upgrade/request/types.go index 752da3b..b546813 100644 --- a/internal/http/upgrade/request/types.go +++ b/internal/http/upgrade/request/types.go @@ -1,6 +1,5 @@ package request -import "git.xdrm.io/gws/internal/http/upgrade/request/parser/reqline" import "git.xdrm.io/gws/internal/http/upgrade/response" // If origin is required @@ -14,7 +13,7 @@ type T struct { code response.StatusCode // request line - request reqline.T + request RequestLine // data to check origin (depends of reading order) host string