ws/internal/http/upgrade/request_test.go

163 lines
3.0 KiB
Go

package upgrade
import (
"bytes"
"io"
"testing"
)
func TestEOFSocket(t *testing.T) {
var (
socket = &bytes.Buffer{}
req = &Request{}
)
_, err := req.ReadFrom(socket)
if err != io.ErrUnexpectedEOF {
t.Fatalf("unexpected error <%v> expected <%v>", err, io.ErrUnexpectedEOF)
}
}
func TestInvalidRequestLine(t *testing.T) {
socket := &bytes.Buffer{}
cases := []struct {
Reqline string
HasError bool
}{
{"abc", true},
{"a c", true},
{"a c", true},
{"a c", true},
{"a b c", true},
{"GET invaliduri HTTP/1.1", true},
{"GET /validuri HTTP/1.1", false},
{"POST /validuri HTTP/1.1", true},
{"PUT /validuri HTTP/1.1", true},
{"DELETE /validuri HTTP/1.1", true},
{"OPTIONS /validuri HTTP/1.1", true},
{"UNKNOWN /validuri HTTP/1.1", true},
{"GET / HTTP", true},
{"GET / HTTP/", true},
{"GET / 1.1", true},
{"GET / 1", true},
{"GET / HTTP/52", true},
{"GET / HTTP/1.", true},
{"GET / HTTP/.1", true},
{"GET / HTTP/1.1", false},
{"GET / HTTP/2", false},
}
for ti, tc := range cases {
socket.Reset()
socket.Write([]byte(tc.Reqline))
socket.Write([]byte("\r\n\r\n"))
var req = &Request{}
_, err := req.ReadFrom(socket)
if !tc.HasError {
if err == nil {
continue
// error for the end of the request -> ok
} else if _, ok := err.(ErrIncompleteRequest); ok {
continue
}
t.Errorf("[%d] Expected no error", ti)
}
// missing required error -> error
if tc.HasError && err == nil {
t.Errorf("[%d] Expected error", ti)
continue
}
ir, ok := err.(*ErrInvalidRequest)
// not InvalidRequest err -> error
if !ok || ir.Field != "Request-Line" {
t.Errorf("[%d] expected InvalidRequest", ti)
continue
}
}
}
func TestInvalidHost(t *testing.T) {
requestLine := []byte("GET / HTTP/1.1\r\n")
socket := &bytes.Buffer{}
cases := []struct {
Host string
HasError bool
}{
{"1", true},
{"12", true},
{"123", true},
{"1234", false},
{"singlevalue", false},
{"multi value", true},
{"singlevalue:1", false},
{"singlevalue:", true},
{"singlevalue:x", true},
{"xx:x", true},
{":xxx", true},
{"xxx:", true},
{"a:12", false},
{"google.com", false},
{"8.8.8.8", false},
{"google.com:8080", false},
{"8.8.8.8:8080", false},
}
for ti, tc := range cases {
socket.Reset()
socket.Write(requestLine)
socket.Write([]byte("Host: "))
socket.Write([]byte(tc.Host))
socket.Write([]byte("\r\n\r\n"))
var req = &Request{}
_, err := req.ReadFrom(socket)
if !tc.HasError {
// no error -> ok
if err == nil {
continue
// error for the end of the request -> ok
} else if _, ok := err.(ErrIncompleteRequest); ok {
continue
}
t.Errorf("[%d] Expected no error; %s", ti, err)
}
// missing required error -> error
if tc.HasError && err == nil {
t.Errorf("[%d] Expected error", ti)
continue
}
// check if InvalidRequest
ir, ok := err.(ErrInvalidRequest)
// not InvalidRequest err -> error
if ok && ir.Field != "Host" {
t.Errorf("[%d] expected InvalidRequest", ti)
continue
}
}
}