package request import ( "git.xdrm.io/gws/internal/http/upgrade/request/parser/header" "fmt" "strconv" "strings" ) // checkHost checks and extracts the Host header func (r *T) extractHostPort(bb header.HeaderValue) error { if len(bb) != 1 { return fmt.Errorf("Host header must have a unique value") } split := strings.Split(string(bb[0]), ":") r.host = split[0] // no port if len(split) < 2 { return nil } // extract port readPort, err := strconv.ParseUint(split[1], 10, 16) if err != nil { return fmt.Errorf("Cannot read port number '%s'", split[1]) } r.port = uint16(readPort) // if 'Origin' header is already read, check it if len(r.origin) > 0 { err = r.checkOriginPolicy() if err != nil { return err } } return nil } // checkOrigin checks the Origin Header func (r *T) extractOrigin(bb header.HeaderValue) error { if len(bb) != 1 { return fmt.Errorf("Origin header must have a unique value") } r.origin = string(bb[0]) // if host already stored, check origin policy if len(r.host) > 0 { err := r.checkOriginPolicy() if err != nil { return err } } return nil } // checkOriginPolicy origin policy based on 'host' value func (r T) checkOriginPolicy() error { // TODO: Origin policy, for now BYPASS return nil } // checkConnection checks the 'Connection' header // it MUST contain 'Upgrade' func (r T) checkConnection(bb header.HeaderValue) error { for _, b := range bb { if strings.ToLower( string(b) ) == "upgrade" { return nil } } return fmt.Errorf("Connection header must be 'Upgrade'") } // checkUpgrade checks the 'Upgrade' header // it MUST be 'websocket' func (r T) checkUpgrade(bb header.HeaderValue) error { if len(bb) != 1 { return fmt.Errorf("Upgrade header must have only 1 element") } if strings.ToLower( string(bb[0]) ) == "websocket" { return nil } return fmt.Errorf("Upgrade header must be 'websocket', got '%s'", bb[0]) } // checkVersion checks the 'Sec-WebSocket-Version' header // it MUST be '13' func (r T) checkVersion(bb header.HeaderValue) error { if len(bb) != 1 || string(bb[0]) != "13" { return fmt.Errorf("Sec-WebSocket-Version header must be '13'") } return nil } // extractKey extracts the 'Sec-WebSocket-Key' header // it MUST be 24 bytes (base64) func (r *T) extractKey(bb header.HeaderValue) error { if len(bb) != 1 || len(bb[0]) != 24 { return fmt.Errorf("Sec-WebSocket-Key header must be a unique 24 bytes base64 value, got %d bytes", len(bb[0])) } return nil }