package request import ( "fmt" "strconv" "strings" "git.xdrm.io/go/ws/internal/http/upgrade/request/parser/header" "git.xdrm.io/go/ws/internal/http/upgrade/response" ) // checkHost checks and extracts the Host header func (r *T) extractHostPort(bb header.HeaderValue) error { if len(bb) != 1 { return &InvalidRequest{"Host", fmt.Sprintf("expected single value, got %d", len(bb))} } if len(bb[0]) <= 3 { return &InvalidRequest{"Host", fmt.Sprintf("expected non-empty value (got %d bytes)", len(bb[0]))} } split := strings.Split(string(bb[0]), ":") r.host = split[0] // no port if len(split) < 2 { return nil } // extract port readPort, err := strconv.ParseUint(split[1], 10, 16) if err != nil { r.code = response.BadRequest return &InvalidRequest{"Host", "cannot read port"} } r.port = uint16(readPort) // if 'Origin' header is already read, check it if len(r.origin) > 0 { if err != nil { err = r.checkOriginPolicy() r.code = response.Forbidden return &InvalidOriginPolicy{r.host, r.origin, err} } } return nil } // checkOrigin checks the Origin Header func (r *T) extractOrigin(bb header.HeaderValue) error { // bypass if bypassOriginPolicy { return nil } if len(bb) != 1 { r.code = response.Forbidden return &InvalidRequest{"Origin", fmt.Sprintf("expected single value, got %d", len(bb))} } r.origin = string(bb[0]) // if host already stored, check origin policy if len(r.host) > 0 { err := r.checkOriginPolicy() if err != nil { r.code = response.Forbidden return &InvalidOriginPolicy{r.host, r.origin, err} } } return nil } // checkOriginPolicy origin policy based on 'host' value func (r *T) checkOriginPolicy() error { // TODO: Origin policy, for now BYPASS r.validPolicy = true return nil } // checkConnection checks the 'Connection' header // it MUST contain 'Upgrade' func (r *T) checkConnection(bb header.HeaderValue) error { for _, b := range bb { if strings.ToLower(string(b)) == "upgrade" { r.hasConnection = true return nil } } r.code = response.BadRequest return &InvalidRequest{"Upgrade", "expected 'Upgrade' (case insensitive)"} } // checkUpgrade checks the 'Upgrade' header // it MUST be 'websocket' func (r *T) checkUpgrade(bb header.HeaderValue) error { if len(bb) != 1 { r.code = response.BadRequest return &InvalidRequest{"Upgrade", fmt.Sprintf("expected single value, got %d", len(bb))} } if strings.ToLower(string(bb[0])) == "websocket" { r.hasUpgrade = true return nil } r.code = response.BadRequest return &InvalidRequest{"Upgrade", fmt.Sprintf("expected 'websocket' (case insensitive), got '%s'", bb[0])} } // checkVersion checks the 'Sec-WebSocket-Version' header // it MUST be '13' func (r *T) checkVersion(bb header.HeaderValue) error { if len(bb) != 1 || string(bb[0]) != "13" { r.code = response.UpgradeRequired return &InvalidRequest{"Sec-WebSocket-Version", fmt.Sprintf("expected '13', got '%s'", bb[0])} } r.hasVersion = true return nil } // extractKey extracts the 'Sec-WebSocket-Key' header // it MUST be 24 bytes (base64) func (r *T) extractKey(bb header.HeaderValue) error { if len(bb) != 1 || len(bb[0]) != 24 { r.code = response.BadRequest return &InvalidRequest{"Sec-WebSocket-Key", fmt.Sprintf("expected 24 bytes base64, got %d bytes", len(bb[0]))} } r.key = bb[0] return nil } // extractProtocols extracts the 'Sec-WebSocket-Protocol' header // it can contain multiple values func (r *T) extractProtocols(bb header.HeaderValue) error { r.protocols = bb return nil }