diff --git a/http/upgrade/request/private.go b/http/upgrade/request/private.go index 5ae960c..5eec14e 100644 --- a/http/upgrade/request/private.go +++ b/http/upgrade/request/private.go @@ -44,7 +44,11 @@ func (r *T) parseHeader(b []byte) error { err = r.extractHostPort(head.Values[0]) case header.UPGRADE: fmt.Printf("[upgrade] ") + err = r.checkUpgrade(head.Values) + case header.CONNECTION: fmt.Printf("[connection] ") + err = r.checkConnection(head.Values) + case header.WSKEY: fmt.Printf("[sec-websocket-key] ") case header.ORIGIN: fmt.Printf("[origin] ") case header.WSPROTOCOL: fmt.Printf("[sec-websocket-protocol] ") @@ -91,4 +95,37 @@ func (r *T) extractHostPort(b []byte) error { return nil +} + + +// checkConnection checks the 'Connection' header +// it MUST contain 'Upgrade' +func (r T) checkConnection(bb [][]byte) error { + + for _, b := range bb { + + if strings.ToLower( string(b) ) == "upgrade" { + return nil + } + + } + + return fmt.Errorf("Connection header must be 'Upgrade'") + +} + +// checkUpgrade checks the 'Upgrade' header +// it MUST be 'websocket' +func (r T) checkUpgrade(bb [][]byte) error { + + if len(bb) != 1 { + return fmt.Errorf("Upgrade header must have only 1 element") + } + + if strings.ToLower( string(bb[0]) ) == "websocket" { + return nil + } + + return fmt.Errorf("Upgrade header must be 'websocket', got '%s'", bb[0]) + } \ No newline at end of file