package upgrade import ( "bytes" "io" "testing" ) func TestEOFSocket(t *testing.T) { socket := &bytes.Buffer{} _, err := Parse(socket) if err == nil { t.Fatalf("Empty socket expected EOF, got no error") } else if err != io.ErrUnexpectedEOF { t.Fatalf("Empty socket expected EOF, got '%s'", err) } } func TestInvalidRequestLine(t *testing.T) { socket := &bytes.Buffer{} cases := []struct { Reqline string HasError bool }{ {"abc", true}, {"a c", true}, {"a c", true}, {"a c", true}, {"a b c", true}, {"GET invaliduri HTTP/1.1", true}, {"GET /validuri HTTP/1.1", false}, {"POST /validuri HTTP/1.1", true}, {"PUT /validuri HTTP/1.1", true}, {"DELETE /validuri HTTP/1.1", true}, {"OPTIONS /validuri HTTP/1.1", true}, {"UNKNOWN /validuri HTTP/1.1", true}, {"GET / HTTP", true}, {"GET / HTTP/", true}, {"GET / 1.1", true}, {"GET / 1", true}, {"GET / HTTP/52", true}, {"GET / HTTP/1.", true}, {"GET / HTTP/.1", true}, {"GET / HTTP/1.1", false}, {"GET / HTTP/2", false}, } for ti, tc := range cases { socket.Reset() socket.Write([]byte(tc.Reqline)) socket.Write([]byte("\r\n\r\n")) _, err := Parse(socket) if !tc.HasError { // no error -> ok if err == nil { continue // error for the end of the request -> ok } else if _, ok := err.(*IncompleteRequest); ok { continue } t.Errorf("[%d] Expected no error", ti) } // missing required error -> error if tc.HasError && err == nil { t.Errorf("[%d] Expected error", ti) continue } ir, ok := err.(*InvalidRequest) // not InvalidRequest err -> error if !ok || ir.Field != "Request-Line" { t.Errorf("[%d] expected InvalidRequest", ti) continue } } } func TestInvalidHost(t *testing.T) { requestLine := []byte("GET / HTTP/1.1\r\n") socket := &bytes.Buffer{} cases := []struct { Host string HasError bool }{ {"1", true}, {"12", true}, {"123", true}, {"1234", false}, {"singlevalue", false}, {"multi value", true}, {"singlevalue:1", false}, {"singlevalue:", true}, {"singlevalue:x", true}, {"xx:x", true}, {":xxx", true}, {"xxx:", true}, {"a:12", false}, {"google.com", false}, {"8.8.8.8", false}, {"google.com:8080", false}, {"8.8.8.8:8080", false}, } for ti, tc := range cases { socket.Reset() socket.Write(requestLine) socket.Write([]byte("Host: ")) socket.Write([]byte(tc.Host)) socket.Write([]byte("\r\n\r\n")) _, err := Parse(socket) if !tc.HasError { // no error -> ok if err == nil { continue // error for the end of the request -> ok } else if _, ok := err.(*IncompleteRequest); ok { continue } t.Errorf("[%d] Expected no error; %s", ti, err) } // missing required error -> error if tc.HasError && err == nil { t.Errorf("[%d] Expected error", ti) continue } // check if InvalidRequest ir, ok := err.(*InvalidRequest) // not InvalidRequest err -> error if ok && ir.Field != "Host" { t.Errorf("[%d] expected InvalidRequest", ti) continue } } }