package request import ( "git.xdrm.io/gws/internal/http/upgrade/response" "git.xdrm.io/gws/internal/http/upgrade/request/parser/header" "fmt" "strconv" "strings" ) // checkHost checks and extracts the Host header func (r *T) extractHostPort(bb header.HeaderValue) error { if len(bb) != 1 { return fmt.Errorf("Host header must have a unique value") } split := strings.Split(string(bb[0]), ":") r.host = split[0] // no port if len(split) < 2 { return nil } // extract port readPort, err := strconv.ParseUint(split[1], 10, 16) if err != nil { r.code = response.BAD_REQUEST return fmt.Errorf("Cannot read port number '%s'", split[1]) } r.port = uint16(readPort) // if 'Origin' header is already read, check it if len(r.origin) > 0 { if err != nil { err = r.checkOriginPolicy() r.code = response.FORBIDDEN return err } } return nil } // checkOrigin checks the Origin Header func (r *T) extractOrigin(bb header.HeaderValue) error { // bypass if bypassOriginPolicy { return nil } if len(bb) != 1 { r.code = response.FORBIDDEN return fmt.Errorf("Origin header must have a unique value") } r.origin = string(bb[0]) // if host already stored, check origin policy if len(r.host) > 0 { err := r.checkOriginPolicy() if err != nil { r.code = response.FORBIDDEN return err } } return nil } // checkOriginPolicy origin policy based on 'host' value func (r *T) checkOriginPolicy() error { // TODO: Origin policy, for now BYPASS r.validPolicy = true return nil } // checkConnection checks the 'Connection' header // it MUST contain 'Upgrade' func (r *T) checkConnection(bb header.HeaderValue) error { for _, b := range bb { if strings.ToLower( string(b) ) == "upgrade" { r.hasConnection = true return nil } } r.code = response.BAD_REQUEST return fmt.Errorf("Connection header must be 'Upgrade'") } // checkUpgrade checks the 'Upgrade' header // it MUST be 'websocket' func (r *T) checkUpgrade(bb header.HeaderValue) error { if len(bb) != 1 { r.code = response.BAD_REQUEST return fmt.Errorf("Upgrade header must have only 1 element") } if strings.ToLower( string(bb[0]) ) == "websocket" { r.hasUpgrade = true return nil } r.code = response.BAD_REQUEST return fmt.Errorf("Upgrade header must be 'websocket', got '%s'", bb[0]) } // checkVersion checks the 'Sec-WebSocket-Version' header // it MUST be '13' func (r *T) checkVersion(bb header.HeaderValue) error { if len(bb) != 1 || string(bb[0]) != "13" { r.code = response.UPGRADE_REQUIRED return fmt.Errorf("Sec-WebSocket-Version header must be '13'") } r.hasVersion = true return nil } // extractKey extracts the 'Sec-WebSocket-Key' header // it MUST be 24 bytes (base64) func (r *T) extractKey(bb header.HeaderValue) error { if len(bb) != 1 || len(bb[0]) != 24 { r.code = response.BAD_REQUEST return fmt.Errorf("Sec-WebSocket-Key header must be a unique 24 bytes base64 value, got %d bytes", len(bb[0])) } r.key = bb[0] return nil } // extractProtocols extracts the 'Sec-WebSocket-Protocol' header // it can contain multiple values func (r *T) extractProtocols(bb header.HeaderValue) error { r.protocols = bb return nil }