diff --git a/README.md b/README.md index 094eefccc..e8d30b59c 100644 --- a/README.md +++ b/README.md @@ -712,19 +712,114 @@ In order to dump the configuration, pass `--dump-config` command line option. > **⚠ Warning:** This feature is experimental and currently under development. -OLS can gather real-time information from your cluster to assist with specific queries. You can enable this feature by adding the following configuration: +OLS can gather real-time information from your cluster to assist with specific queries using MCP (Model Context Protocol) servers. + +## MCP Server Configuration + +MCP servers provide tools and capabilities to the AI agents. Only MCP servers defined in the `olsconfig.yaml` configuration are available to the agents. + +### Basic Configuration + +Each MCP server requires: +- `name`: Unique identifier for the MCP server +- `url`: The HTTP endpoint where the MCP server is running + +Optional fields: +- `authorization_headers`: Authentication headers for secure communication +- `timeout`: Request timeout in seconds + +**Minimal Example:** + +```yaml +mcp_servers: + - name: openshift + url: http://localhost:8080 +``` + +### MCP Server Authentication + +OLS supports three methods for authenticating with MCP servers: + +#### 1. Static Tokens from Files (Recommended for Service Credentials) + +Store authentication tokens in secret files and reference them in your configuration. Ideal for API keys and service tokens: + +```yaml +mcp_servers: + - name: api-service + url: http://api-service:8080 + authorization_headers: + Authorization: /var/secrets/api-token # Path to file containing token + X-API-Key: /var/secrets/api-key # Multiple headers supported + timeout: 30 # Optional timeout in seconds +``` + +#### 2. Kubernetes Token (User Context) + +Use the special `kubernetes` placeholder to automatically inject the authenticated user's Kubernetes token. This requires the `k8s` authentication module: + ```yaml mcp_servers: - name: openshift - transport: stdio - stdio: - command: python - args: - - ./mcp_local/openshift.py + url: http://openshift-mcp-server:8080 + authorization_headers: + Authorization: kubernetes # Uses user's k8s token from request ``` -OLS utilizes tools based on the oc CLI to collect relevant cluster context. The following safeguards are in place: -- Tools operate in read-only mode—they can retrieve data but cannot modify the cluster. -- Tools run using only the user's token (from the request). If the user lacks the necessary permissions, tool outputs may include permission errors. + +**Important**: The `kubernetes` placeholder only works when `authentication_config.module` is set to `k8s`. If not configured properly, the MCP server will be skipped with a warning. + +#### 3. Client-Provided Tokens (Per-Request) + +Use the `client` placeholder to allow clients to provide their own tokens per-request via the `MCP-HEADERS` HTTP header: + +```yaml +mcp_servers: + - name: github + url: http://github-mcp-server:8080 + authorization_headers: + Authorization: client # Client provides token via MCP-HEADERS header +``` + +Clients can discover which servers accept client-provided tokens by calling: +```bash +GET /v1/mcp-auth/client-options +``` + +Response: +```json +{ + "servers": [ + { + "name": "github", + "client_auth_headers": ["Authorization"] + } + ] +} +``` + +Then provide tokens in the query request: +```bash +curl -H "MCP-HEADERS: {\"github\": {\"Authorization\": \"Bearer github_token\"}}" \ + -X POST /v1/query -d '{"query": "..."}' +``` + +### OpenShift MCP Server Example + +For cluster context gathering with the OpenShift MCP server: + +```yaml +mcp_servers: + - name: openshift + url: http://openshift-mcp-server:8080 + authorization_headers: + Authorization: kubernetes # Uses authenticated user's token + timeout: 30 +``` + +**Safeguards:** +- Tools operate in read-only mode—they can retrieve data but cannot modify the cluster +- Tools run using only the user's token (from the request) +- If the user lacks necessary permissions, tool outputs may include permission errors # Usage @@ -743,9 +838,6 @@ To enable GradIO web UI you need to have the following `dev_config` section in y ```yaml dev_config: enable_dev_ui: true - ... - ... - ... ``` diff --git a/docs/config.puml b/docs/config.puml index 09f403b97..fb52f3e97 100644 --- a/docs/config.puml +++ b/docs/config.puml @@ -75,10 +75,11 @@ class "LoggingConfig" as ols.app.models.config.LoggingConfig { } class "MCPServerConfig" as ols.app.models.config.MCPServerConfig { name : str - sse : Optional[SseTransportConfig] - stdio : Optional[StdioTransportConfig] - transport : Literal['sse', 'stdio'] - correct_transport_specified() -> Self + url : str + authorization_headers : dict[str, str] + resolved_authorization_headers : dict[str, str] + timeout : Optional[int] + resolve_auth_headers() -> Self } class "MCPServers" as ols.app.models.config.MCPServers { servers : list[MCPServerConfig] @@ -196,18 +197,6 @@ class "ReferenceContentIndex" as ols.app.models.config.ReferenceContentIndex { class "SchedulerConfig" as ols.app.models.config.SchedulerConfig { period : int } -class "SseTransportConfig" as ols.app.models.config.SseTransportConfig { - sse_read_timeout : int - timeout : int - url : str -} -class "StdioTransportConfig" as ols.app.models.config.StdioTransportConfig { - args : list[str] - command : str - cwd : str - encoding : str - env : dict[str, str | int] -} class "TLSConfig" as ols.app.models.config.TLSConfig { tls_certificate_path : Optional[FilePath] tls_key_password : Optional[str] diff --git a/examples/olsconfig.yaml b/examples/olsconfig.yaml index ca732301a..ba4ee4379 100644 --- a/examples/olsconfig.yaml +++ b/examples/olsconfig.yaml @@ -113,8 +113,7 @@ dev_config: # MCP servers - enables tool calling if present mcp_servers: - name: openshift - transport: stdio - stdio: - command: python - args: - - ./mcp_local/openshift.py + url: http://localhost:8080 + authorization_headers: + Authorization: kubernetes # Special placeholder - uses k8s token from request + # timeout: 30 # Optional timeout in seconds diff --git a/ols/app/models/config.py b/ols/app/models/config.py index fc91bc412..73c520999 100644 --- a/ols/app/models/config.py +++ b/ols/app/models/config.py @@ -3,7 +3,7 @@ import logging import os import re -from typing import Any, Literal, Optional, Self +from typing import Any, Optional, Self from pydantic import ( AnyHttpUrl, @@ -11,6 +11,7 @@ Field, FilePath, PositiveInt, + PrivateAttr, field_validator, model_validator, ) @@ -305,7 +306,7 @@ class ProviderConfig(BaseModel): url: Optional[AnyHttpUrl] = None credentials: Optional[str] = None project_id: Optional[str] = None - models: dict[str, ModelConfig] = {} + models: dict[str, ModelConfig] = Field(default_factory=dict) api_version: Optional[str] = None deployment_name: Optional[str] = None openai_config: Optional[OpenAIConfig] = None @@ -538,7 +539,7 @@ def validate_yaml(self) -> None: class LLMProviders(BaseModel): """LLM providers configuration.""" - providers: dict[str, ProviderConfig] = {} + providers: dict[str, ProviderConfig] = Field(default_factory=dict) def __init__( self, @@ -568,76 +569,86 @@ def validate_yaml(self) -> None: v.validate_yaml() -class StdioTransportConfig(BaseModel): - """Stdio transport configuration for MCP server.""" +class MCPServerConfig(BaseModel): + """MCP server configuration. - command: str - args: list[str] = [] - env: dict[str, str | int] = constants.STDIO_TRANSPORT_DEFAULT_ENV - cwd: str = constants.STDIO_TRANSPORT_DEFAULT_CWD - encoding: str = constants.STDIO_TRANSPORT_DEFAULT_ENCODING + MCP (Model Context Protocol) servers provide tools and capabilities to the + AI agents. These are configured by this structure. Only MCP servers + defined in the olsconfig.yaml configuration are available to the agents. + """ + name: str = Field( + title="MCP name", + description="MCP server name that must be unique", + ) -class SseTransportConfig(BaseModel): - """SSE transport configuration for MCP server.""" + url: str = Field( + title="MCP server URL", + description="URL of the MCP server", + ) - url: str - timeout: int = constants.SSE_TRANSPORT_DEFAULT_TIMEOUT - sse_read_timeout: int = constants.SSE_TRANSPORT_DEFAULT_READ_TIMEOUT - headers: dict[str, str] = Field(default_factory=dict) + timeout: Optional[int] = Field( + default=None, + title="Request timeout", + description=( + "Timeout in seconds for requests to the MCP server. " + "If not specified, the default timeout will be used." + ), + ) + headers: dict[str, str] = Field( + default_factory=dict, + title="Authorization headers", + description=( + "Headers to send to the MCP server. " + "The map contains the header name and the path to a file containing " + "the header value (secret). " + "There are 2 special cases: " + f"1. Usage of the kubernetes token in the header. " + f"To specify this use a string '{constants.MCP_KUBERNETES_PLACEHOLDER}' " + f"instead of the file path. " + f"2. Usage of the client provided token in the header. " + f"To specify this use a string '{constants.MCP_CLIENT_PLACEHOLDER}' " + f"instead of the file path." + ), + ) -class StreamableHttpTransportConfig(BaseModel): - """Streamable HTTP transport configuration for MCP server.""" + _resolved_headers: dict[str, str] = PrivateAttr(default_factory=dict) - url: str - timeout: int = constants.STREAMABLE_HTTP_TRANSPORT_DEFAULT_TIMEOUT - sse_read_timeout: int = constants.STREAMABLE_HTTP_TRANSPORT_DEFAULT_READ_TIMEOUT - headers: dict[str, str] = Field(default_factory=dict) + @property + def resolved_headers(self) -> dict[str, str]: + """Resolved headers (computed from headers).""" + return self._resolved_headers -class MCPServerConfig(BaseModel): - """MCP server configuration.""" +class ToolFilteringConfig(BaseModel): + """Configuration for tool filtering using hybrid RAG retrieval. - name: str - transport: Literal["sse", "stdio", "streamable_http"] - stdio: Optional[StdioTransportConfig] = None - sse: Optional[SseTransportConfig] = None - streamable_http: Optional[StreamableHttpTransportConfig] = None + If this config is present, tool filtering is enabled. If absent, all tools are used. + """ - @model_validator(mode="after") - def correct_transport_specified(self) -> Self: - """Check if correct transport is specified.""" - if self.transport == "stdio": - if self.stdio is None: - raise ValueError( - "Stdio transport selected but 'stdio' config not provided" - ) - if self.sse is not None or self.streamable_http is not None: - raise ValueError( - "Stdio transport selected but 'sse' or 'streamable_http' " - "config should not be provided" - ) - elif self.transport == "sse": - if self.sse is None: - raise ValueError("SSE transport selected but 'sse' config not provided") - if self.stdio is not None or self.streamable_http is not None: - raise ValueError( - "SSE transport selected but 'stdio' or 'streamable_http' " - "config should not be provided" - ) - elif self.transport == "streamable_http": - if self.streamable_http is None: - raise ValueError( - "Streamable HTTP transport selected but 'streamable_http' " - "config not provided" - ) - if self.stdio is not None or self.sse is not None: - raise ValueError( - "Streamable HTTP transport selected but 'stdio' or 'sse' " - "config should not be provided" - ) - return self + embed_model_path: Optional[str] = Field( + default=None, + description="Path to sentence transformer model for embeddings", + ) + + alpha: float = Field( + default=0.8, + ge=0.0, + le=1.0, + description="Weight for dense vs sparse retrieval (1.0 = full dense, 0.0 = full sparse)", + ) + + top_k: int = Field( + default=10, ge=1, le=50, description="Number of tools to retrieve" + ) + + threshold: float = Field( + default=0.01, + ge=0.0, + le=1.0, + description="Minimum similarity threshold for filtering results", + ) class MCPServers(BaseModel): @@ -1065,6 +1076,8 @@ class OLSConfig(BaseModel): proxy_config: Optional[ProxyConfig] = None + tool_filtering: Optional[ToolFilteringConfig] = None + def __init__( self, data: Optional[dict] = None, ignore_missing_certs: bool = False ) -> None: @@ -1117,6 +1130,8 @@ def __init__( ) self.quota_handlers = QuotaHandlersConfig(data.get("quota_handlers", None)) self.proxy_config = ProxyConfig(data.get("proxy_config")) + if data.get("tool_filtering", None) is not None: + self.tool_filtering = ToolFilteringConfig(**data.get("tool_filtering")) def __eq__(self, other: object) -> bool: """Compare two objects for equality.""" @@ -1140,6 +1155,7 @@ def __eq__(self, other: object) -> bool: == other.expire_llm_is_ready_persistent_state and self.quota_handlers == other.quota_handlers and self.proxy_config == other.proxy_config + and self.tool_filtering == other.tool_filtering ) return False @@ -1235,6 +1251,9 @@ def __init__( # initialize MCP servers self.mcp_servers = MCPServers(servers=data.get("mcp_servers", [])) + # Validate MCP servers now that auth config is available + self._validate_mcp_servers() + # Always initialize dev config, even if there's no config for it. self.dev_config = DevConfig(**data.get("dev_config", {})) @@ -1272,6 +1291,21 @@ def _validate_default_provider_and_model(self) -> None: f"default_model specifies an unknown model {selected_default_model}" ) + def _validate_mcp_servers(self) -> None: + """Validate MCP servers with auth module context. + + Filters out servers where authorization headers cannot be resolved. + """ + auth_module = getattr( + getattr(self.ols_config, "authentication_config", None), + "module", + None, + ) + self.mcp_servers.servers = checks.validate_mcp_servers( + self.mcp_servers.servers, + auth_module, + ) + def validate_yaml(self) -> None: """Validate all configurations.""" self.llm_providers.validate_yaml() diff --git a/ols/constants.py b/ols/constants.py index ec3172489..b242efc35 100644 --- a/ols/constants.py +++ b/ols/constants.py @@ -200,6 +200,9 @@ class GenericLLMParameters: # Default authentication module DEFAULT_AUTHENTICATION_MODULE = "k8s" +# Authentication module for testing with token +NOOP_WITH_TOKEN_AUTHENTICATION_MODULE = "noop-with-token" # noqa: S105 + # All supported authentication modules SUPPORTED_AUTHENTICATION_MODULES = {"k8s", "noop", "noop-with-token"} @@ -224,16 +227,12 @@ class GenericLLMParameters: USER_QUOTA_LIMITER = "user_limiter" CLUSTER_QUOTA_LIMITER = "cluster_limiter" -# MCP transport types -MCP_TRANSPORT_STDIO = "stdio" -MCP_TRANSPORT_SSE = "sse" -SSE_TRANSPORT_DEFAULT_TIMEOUT = 5 # in seconds -SSE_TRANSPORT_DEFAULT_READ_TIMEOUT = 10 # in seconds -STDIO_TRANSPORT_DEFAULT_ENCODING = "utf-8" -STDIO_TRANSPORT_DEFAULT_ENV: dict[str, str | int] = {} -STDIO_TRANSPORT_DEFAULT_CWD = "." -STREAMABLE_HTTP_TRANSPORT_DEFAULT_TIMEOUT = 5 # in seconds -STREAMABLE_HTTP_TRANSPORT_DEFAULT_READ_TIMEOUT = 10 # in seconds +# MCP transport default timeout +MCP_HTTP_TRANSPORT_DEFAULT_TIMEOUT = 5 # in seconds + +# MCP authorization header placeholders +MCP_KUBERNETES_PLACEHOLDER = "kubernetes" +MCP_CLIENT_PLACEHOLDER = "client" # timeout value for a single llm with tools round # Keeping it really high at this moment (until this is configurable) diff --git a/ols/src/config_status/config_status.py b/ols/src/config_status/config_status.py index 787e5166b..cd69b9c31 100644 --- a/ols/src/config_status/config_status.py +++ b/ols/src/config_status/config_status.py @@ -82,7 +82,10 @@ def extract_config_status(cfg: Config) -> ConfigStatus: if p.tls_security_profile and p.tls_security_profile.profile_type ] - mcp_servers = {s.name: s.transport for s in mcp_cfg.servers} + mcp_servers = { + s.name: s.url.split("://")[0] if "://" in s.url else "http" + for s in mcp_cfg.servers + } quota_management_enabled = ols_cfg.quota_handlers is not None and ( ols_cfg.quota_handlers.limiters is not None diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index dc409fdbe..b2da3b132 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -21,7 +21,6 @@ from ols.customize import reranker from ols.src.prompts.prompt_generator import GeneratePrompt from ols.src.query_helpers.query_helper import QueryHelper -from ols.src.tools.mcp_config_builder import MCPConfigBuilder from ols.src.tools.tools import execute_tool_calls from ols.utils.token_handler import TokenHandler @@ -118,7 +117,10 @@ class DocsSummarizer(QueryHelper): """A class for summarizing documentation context.""" def __init__( - self, *args: Any, user_token: Optional[str] = None, **kwargs: Any + self, + *args: Any, + user_token: Optional[str] = None, + **kwargs: Any, ) -> None: """Initialize the DocsSummarizer. @@ -133,14 +135,8 @@ def __init__( # tools part self.user_token = user_token - mcp_config_builder = MCPConfigBuilder( - self.user_token, config.mcp_servers.servers - ) - try: - self.mcp_servers = mcp_config_builder.dump_client_config() - except Exception as e: - logger.error("Failed to resolve MCP server(s): %s", e) - self.mcp_servers = {} + self.mcp_servers = self._build_mcp_config() + if self.mcp_servers: logger.info("MCP servers provided: %s", list(self.mcp_servers.keys())) self._tool_calling_enabled = True @@ -163,6 +159,87 @@ def _prepare_llm(self) -> None: self.generic_llm_params, ) + def _get_token_value( + self, value: str, header_name: str, server_name: str + ) -> Optional[str]: + """Resolve header value by substituting placeholders. + + Args: + value: Header value (may be a placeholder or actual value) + header_name: Name of the header + server_name: Name of the MCP server (for logging) + + Returns: + Resolved header value, or None if resolution failed + """ + if value == constants.MCP_KUBERNETES_PLACEHOLDER: + # If we reach here, auth module is k8s (validated at config load) + # and user_token is guaranteed to be present from Authorization header + return f"Bearer {self.user_token}" + + if value == constants.MCP_CLIENT_PLACEHOLDER: + # Client placeholder - value should come from client-provided headers + logger.debug( + "MCP server '%s' header '%s' requires client-provided value", + server_name, + header_name, + ) + return None # Will be filled from client headers + + # Already resolved (from file) at config load time + return value + + def _build_mcp_config(self) -> dict[str, Any]: + """Build MCP client configuration from config. + + Resolves authorization headers, substituting placeholders with runtime values + (e.g., "kubernetes" → user token). + + Returns: + Dictionary mapping server names to their config for MultiServerMCPClient. + Returns empty dict if no MCP servers configured or on error. + """ + if not config.mcp_servers or not config.mcp_servers.servers: + return {} + + servers_config: dict[str, Any] = {} + + try: + for server in config.mcp_servers.servers: + # Resolve authorization headers from config + # Pattern adapted from LCore implementation + headers = {} + for name, value in server.resolved_headers.items(): + h_value = self._get_token_value(value, name, server.name) + if h_value is not None: + headers[name] = h_value + + # Skip server if auth headers were configured but not all could be resolved + if server.headers and len(headers) != len(server.headers): + logger.warning( + "Skipping MCP server %s: required %d auth headers but only resolved %d", + server.name, + len(server.headers), + len(headers), + ) + continue + + # Build MultiServerMCPClient config format + servers_config[server.name] = { + "transport": "http", + "url": server.url, + } + if headers: + servers_config[server.name]["headers"] = headers + if server.timeout: + servers_config[server.name]["timeout"] = server.timeout + + except Exception as e: + logger.error("Failed to build MCP config: %s", e) + return {} + + return servers_config + def _prepare_prompt( self, query: str, diff --git a/ols/src/tools/mcp_config_builder.py b/ols/src/tools/mcp_config_builder.py deleted file mode 100644 index 878d7a668..000000000 --- a/ols/src/tools/mcp_config_builder.py +++ /dev/null @@ -1,107 +0,0 @@ -"""MCPConfigBuilder for building MultiServerMCPClient configuration.""" - -import logging -import os -from datetime import timedelta -from typing import Any - -from ols.app.models.config import MCPServerConfig -from ols.utils import checks - -logger = logging.getLogger(__name__) - -# Constant, defining usage of kubernetes token -KUBERNETES_PLACEHOLDER = "kubernetes" - - -class MCPConfigBuilder: - """Builds MCP config for MultiServerMCPClient.""" - - def __init__( - self, user_token: str, mcp_server_configs: list[MCPServerConfig] - ) -> None: - """Initialize the MCPConfigBuilder with user token and server config list.""" - self.user_token = user_token - self.mcp_server_configs = mcp_server_configs - - def include_auth_to_stdio(self, server_envs: dict[str, str]) -> dict[str, str]: - """Resolve OpenShift stdio env config.""" - logger.debug("Updating env configuration of openshift stdio mcp server") - env = {**server_envs} - - if "OC_USER_TOKEN" in env: - logger.warning("OC_USER_TOKEN is set, overriding with actual user token.") - env["OC_USER_TOKEN"] = self.user_token - - if "KUBECONFIG" not in env: - if "KUBECONFIG" in os.environ: - logger.info("Using KUBECONFIG from environment.") - env["KUBECONFIG"] = os.environ["KUBECONFIG"] - elif ( - "KUBERNETES_SERVICE_HOST" in os.environ - and "KUBERNETES_SERVICE_PORT" in os.environ - ): - logger.info("Using KUBERNETES_SERVICE_* from environment.") - env["KUBERNETES_SERVICE_HOST"] = os.environ["KUBERNETES_SERVICE_HOST"] - env["KUBERNETES_SERVICE_PORT"] = os.environ["KUBERNETES_SERVICE_PORT"] - else: - logger.error("Missing necessary KUBECONFIG/KUBERNETES_SERVICE_* envs.") - return env - - def dump_client_config(self) -> dict[str, Any]: - """Convert server configs to MultiServerMCPClient config format.""" - servers_conf: dict[str, Any] = {} - - for server_conf in self.mcp_server_configs: - servers_conf[server_conf.name] = { - "transport": server_conf.transport, - } - - if server_conf.stdio: - stdio_conf = server_conf.stdio.model_dump() - if server_conf.name == "openshift": - stdio_conf["env"] = self.include_auth_to_stdio( - server_conf.stdio.env - ) - servers_conf[server_conf.name].update(stdio_conf) - continue - - if server_conf.sse: - sse_conf = server_conf.sse.model_dump() - sse_conf["headers"] = self._resolve_tokens_to_value(sse_conf["headers"]) - servers_conf[server_conf.name].update(sse_conf) - continue - - if server_conf.streamable_http: - http_conf = server_conf.streamable_http.model_dump() - http_conf["headers"] = self._resolve_tokens_to_value( - http_conf["headers"] - ) - servers_conf[server_conf.name].update(http_conf) - # Note: Streamable HTTP transport expects timedelta instead of - # int as for the sse - blame langchain-mcp-adapters for - # inconsistency - for timeout in ("timeout", "sse_read_timeout"): - servers_conf[server_conf.name][timeout] = timedelta( - seconds=servers_conf[server_conf.name][timeout] # type: ignore [assignment] - ) - - return servers_conf - - def _resolve_tokens_to_value(self, headers: dict[str, str]) -> dict[str, Any]: - """Convert header definitions to values.""" - updated = {} - for name, value in headers.items(): - if value == KUBERNETES_PLACEHOLDER: - updated[name] = f"Bearer {self.user_token}" - else: - try: - # load token value - with open(value, "r", encoding="utf-8") as token_store: - token = token_store.read() - updated[name] = token - except Exception as e: - raise checks.InvalidConfigurationError( - f"token value refers to non existent file '{value}', error {e}" - ) - return updated diff --git a/ols/utils/checks.py b/ols/utils/checks.py index c9e1efa66..4c5f335a5 100644 --- a/ols/utils/checks.py +++ b/ols/utils/checks.py @@ -7,6 +7,8 @@ from pydantic import AnyHttpUrl, FilePath +from ols import constants + class InvalidConfigurationError(Exception): """OLS Configuration is invalid.""" @@ -97,3 +99,145 @@ def get_log_level(value: str) -> int: f"{[k.lower() for k in logging.getLevelNamesMapping()]}" ) return log_level + + +def resolve_headers( + headers: dict[str, str], + auth_module: Optional[str] = None, +) -> dict[str, str]: + """Resolve authorization headers by reading secret files or preserving special values. + + Args: + headers: Map of header names to secret locations or special keywords. + - If value is "kubernetes": preserved unchanged for later substitution during request. + Only valid when authentication module is "k8s" or "noop-with-token". + "noop-with-token" is for testing only - the real k8s token must be passed at + request time. + If used with other auth modules, a warning is logged and the server is skipped. + - If value is "client": preserved unchanged for later substitution during request. + - Otherwise: Treated as file path and read the secret from that file. + auth_module: The authentication module being used (e.g., "k8s", "noop-with-token"). + Used to validate that "kubernetes" placeholder is only used with appropriate auth + modules. + + Returns: + Map of header names to resolved header values or special keywords. + Returns empty dict if any header fails to resolve (kubernetes placeholder + with non-k8s/non-noop-with-token auth, or secret file cannot be read). + + Examples: + >>> # With file paths + >>> resolve_headers({"Authorization": "/var/secrets/token"}) + {"Authorization": "secret-value-from-file"} + + >>> # With kubernetes special case (kept as-is, requires k8s or noop-with-token auth) + >>> resolve_authorization_headers( + ... {"Authorization": "kubernetes"}, + ... auth_module="k8s" + ... ) + {"Authorization": "kubernetes"} + + >>> # With client special case (kept as-is) + >>> resolve_headers({"Authorization": "client"}) + {"Authorization": "client"} + """ + logger = logging.getLogger(__name__) + resolved: dict[str, str] = {} + + for header_name, header_value in headers.items(): + match header_value.strip(): + case constants.MCP_KUBERNETES_PLACEHOLDER: + # Validate that kubernetes placeholder is only used with k8s or noop-with-token auth + # (noop-with-token is allowed for testing purposes) + if auth_module not in ( + constants.DEFAULT_AUTHENTICATION_MODULE, + constants.NOOP_WITH_TOKEN_AUTHENTICATION_MODULE, + ): + logger.warning( + "MCP server authorization header '%s' uses '%s' placeholder, but " + "authentication module is '%s'. " + "The 'kubernetes' placeholder requires authentication module to be " + "'%s' or '%s'. This MCP server will be skipped.", + header_name, + constants.MCP_KUBERNETES_PLACEHOLDER, + auth_module, + constants.DEFAULT_AUTHENTICATION_MODULE, + constants.NOOP_WITH_TOKEN_AUTHENTICATION_MODULE, + ) + return {} # Return empty dict to signal server should be skipped + resolved[header_name] = constants.MCP_KUBERNETES_PLACEHOLDER + logger.debug( + "Header %s will use Kubernetes token (resolved at request time)", + header_name, + ) + + case constants.MCP_CLIENT_PLACEHOLDER: + resolved[header_name] = constants.MCP_CLIENT_PLACEHOLDER + logger.debug( + "Header %s will use client-provided token (resolved at request time)", + header_name, + ) + + case _: + # Read secret from file path + secret_value = read_secret( + data={"path": header_value}, + path_key="path", + default_filename="", + raise_on_error=False, + ) + if secret_value: + resolved[header_name] = secret_value + logger.debug( + "Resolved header %s from secret file %s", + header_name, + header_value, + ) + else: + logger.warning( + "MCP server authorization header '%s' failed to read secret file '%s'. " + "This MCP server will be skipped.", + header_name, + header_value, + ) + return {} # Return empty dict to signal server should be skipped + + return resolved + + +def validate_mcp_servers( + servers: list, + auth_module: Optional[str], +) -> list: + """Validate and filter MCP servers, resolving their authorization headers. + + Args: + servers: List of MCPServerConfig objects to validate. + auth_module: The authentication module being used (e.g., "k8s", "noop"). + + Returns: + List of valid MCPServerConfig objects with resolved authorization headers. + Servers are excluded if any authorization header cannot be resolved. + """ + logger = logging.getLogger(__name__) + valid_servers = [] + + for server in servers: + if server.headers: + # Resolve headers with auth module context + resolved = resolve_headers( + server.headers, + auth_module=auth_module, + ) + if not resolved: + # Already logged in resolve_headers + logger.debug( + "MCP server '%s' excluded due to unresolvable authorization headers", + server.name, + ) + continue + # Store the resolved headers + server._resolved_headers = resolved + valid_servers.append(server) + + return valid_servers diff --git a/tests/config/valid_config.yaml b/tests/config/valid_config.yaml index 4883d8876..a17ac1fcb 100644 --- a/tests/config/valid_config.yaml +++ b/tests/config/valid_config.yaml @@ -41,15 +41,11 @@ ols_config: system_prompt_path: "tests/config/system_prompt.txt" mcp_servers: - name: foo - transport: stdio - stdio: - command: python - args: - - mcp_server_1.py + url: http://localhost:8080 - name: bar - transport: sse - sse: - url: 127.0.0.1:8080 + url: http://127.0.0.1:8081 + headers: + X-API-Key: /path/to/secret dev_config: disable_auth: true disable_tls: true diff --git a/tests/integration/test_ols.py b/tests/integration/test_ols.py index 648bd8030..de5078c72 100644 --- a/tests/integration/test_ols.py +++ b/tests/integration/test_ols.py @@ -1178,15 +1178,12 @@ def test_tool_calling(_setup, caplog) -> None: """Check the REST API query endpoints when tool calling is enabled.""" endpoint = "/v1/query" caplog.set_level(10) - # MCP servers config is a dict mapping server names to their configurations - mcp_servers = { - "fake-server": {"transport": "stdio", "stdio": {}}, - } + mcp_servers = {"fake-server": {"transport": "http", "url": "http://fake-server"}} with ( patch("ols.customize.prompts.QUERY_SYSTEM_INSTRUCTION", "System Instruction"), patch( - "ols.src.query_helpers.docs_summarizer.MCPConfigBuilder.dump_client_config", + "ols.src.query_helpers.docs_summarizer.DocsSummarizer._build_mcp_config", return_value=mcp_servers, ), patch( diff --git a/tests/unit/app/models/test_config.py b/tests/unit/app/models/test_config.py index a6ec9f7a9..9c583df40 100644 --- a/tests/unit/app/models/test_config.py +++ b/tests/unit/app/models/test_config.py @@ -31,9 +31,6 @@ QuotaHandlersConfig, ReferenceContent, ReferenceContentIndex, - SseTransportConfig, - StdioTransportConfig, - StreamableHttpTransportConfig, TLSConfig, TLSSecurityProfile, UserDataCollection, @@ -42,49 +39,59 @@ @pytest.fixture -def mcp_server_config_stdio_transport(): +def mcp_server_config_http(): """MCP server config map fixture.""" return { "name": "foo", - "transport": "stdio", - "stdio": {"command": "python", "args": ["server1.py"]}, + "url": "http://localhost:8080", } @pytest.fixture -def mcp_server_config_sse_transport(): - """MCP server config map fixture.""" +def mcp_server_config_http_with_auth(): + """MCP server config map fixture with authorization.""" return { "name": "bar", - "transport": "sse", - "sse": {"url": "127.0.0.1:8080"}, + "url": "http://localhost:8081", + "headers": {"Authorization": "kubernetes"}, } @pytest.fixture -def mcp_server_config_streamable_http_transport(): - """MCP server config map fixture.""" +def mcp_server_config_http_with_timeout(): + """MCP server config map fixture with timeout.""" return { "name": "gru", - "transport": "streamable_http", - "streamable_http": {"url": "127.0.0.1:8080"}, + "url": "http://localhost:8082", + "timeout": 30, } def test_mcp_server_config( - mcp_server_config_stdio_transport, - mcp_server_config_sse_transport, - mcp_server_config_streamable_http_transport, + mcp_server_config_http, + mcp_server_config_http_with_auth, + mcp_server_config_http_with_timeout, ): """Test the MCPServerConfig model.""" - mcp_server_config = MCPServerConfig(**mcp_server_config_stdio_transport) + mcp_server_config = MCPServerConfig(**mcp_server_config_http) assert mcp_server_config.name == "foo" + assert mcp_server_config.url == "http://localhost:8080" + + # Mock k8s authentication for kubernetes placeholder test + with mock.patch("ols.config") as mock_config: + mock_auth_config = mock.Mock() + mock_auth_config.module = "k8s" + mock_ols_config = mock.Mock() + mock_ols_config.authentication_config = mock_auth_config + mock_config.ols_config = mock_ols_config - mcp_server_config = MCPServerConfig(**mcp_server_config_sse_transport) - assert mcp_server_config.name == "bar" + mcp_server_config = MCPServerConfig(**mcp_server_config_http_with_auth) + assert mcp_server_config.name == "bar" + assert mcp_server_config.url == "http://localhost:8081" - mcp_server_config = MCPServerConfig(**mcp_server_config_streamable_http_transport) + mcp_server_config = MCPServerConfig(**mcp_server_config_http_with_timeout) assert mcp_server_config.name == "gru" + assert mcp_server_config.url == "http://localhost:8082" def test_mcp_server_config_required_name(): @@ -93,111 +100,26 @@ def test_mcp_server_config_required_name(): ValidationError, match=r"(?s)name.*Field required", ): - MCPServerConfig( # pyright: ignore[reportCallIssue] - transport="stdio", stdio={"command": "python", "args": ["server.py"]} - ) - - -def test_mcp_server_config_transport(): - """Test the MCPServerConfig model for missing transport option.""" - with pytest.raises( - ValidationError, - match=r"(?s)transport.*Field required", - ): - MCPServerConfig() # pyright: ignore[reportCallIssue] - - with pytest.raises( - ValidationError, - match="Input should be", - ): - MCPServerConfig(transport="unknown") # pyright: ignore[reportCallIssue] - - -def test_mcp_server_config_missing_options(): - """Test the MCPServerConfig model for missing options.""" - stdio_conf = StdioTransportConfig(command="python", args=["server.py"]) - sse_conf = SseTransportConfig(url="http://server:8080") - streamable_http_conf = StreamableHttpTransportConfig(url="http://server:8080") - - # stdio selected, but config is missing - with pytest.raises( - ValidationError, - match="Stdio transport selected but 'stdio'", - ): - MCPServerConfig( - name="foo", - transport="stdio", - ) - - # stdio selected, but other configs provided too - with pytest.raises( - ValidationError, - match="Stdio transport selected but 'sse' or 'streamable_http'", - ): - MCPServerConfig( - name="foo", - transport="stdio", - stdio=stdio_conf, - sse=sse_conf, - streamable_http=streamable_http_conf, - ) + MCPServerConfig(url="http://localhost:8080") # pyright: ignore[reportCallIssue] - # sse selected, but config is missing - with pytest.raises( - ValidationError, - match="SSE transport selected but 'sse'", - ): - MCPServerConfig( - name="foo", - transport="sse", - ) - # sse selected, but other configs provided too +def test_mcp_server_config_required_url(): + """Test the MCPServerConfig model for missing URL.""" with pytest.raises( ValidationError, - match="SSE transport selected but 'stdio' or 'streamable_http'", + match=r"(?s)url.*Field required", ): - MCPServerConfig( - name="foo", - transport="sse", - stdio=stdio_conf, - sse=sse_conf, - streamable_http=streamable_http_conf, - ) - - # streamable_http selected, but config is missing - with pytest.raises( - ValidationError, - match="Streamable HTTP transport selected but 'streamable_http'", - ): - MCPServerConfig( - name="foo", - transport="streamable_http", - ) - - # streamable_http selected, but other configs provided too - with pytest.raises( - ValidationError, - match="Streamable HTTP transport selected but 'stdio' or 'sse'", - ): - MCPServerConfig( - name="foo", - transport="streamable_http", - stdio=stdio_conf, - sse=sse_conf, - streamable_http=streamable_http_conf, - ) + MCPServerConfig(name="test") # pyright: ignore[reportCallIssue] -def test_mcp_server_config_equality(mcp_server_config_stdio_transport): +def test_mcp_server_config_equality(mcp_server_config_http): """Test the MCPServerConfig model.""" - mcp_server_config_1 = MCPServerConfig(**mcp_server_config_stdio_transport) - mcp_server_config_2 = MCPServerConfig(**mcp_server_config_stdio_transport) + mcp_server_config_1 = MCPServerConfig(**mcp_server_config_http) + mcp_server_config_2 = MCPServerConfig(**mcp_server_config_http) mcp_server_config_3 = MCPServerConfig( **{ "name": "some_other_name", - "transport": "stdio", - "stdio": {"command": "python", "args": ["server2.py"]}, + "url": "http://localhost:9090", } ) @@ -213,21 +135,29 @@ def test_mcp_server_config_equality(mcp_server_config_stdio_transport): def test_mcp_servers( - mcp_server_config_stdio_transport, - mcp_server_config_sse_transport, - mcp_server_config_streamable_http_transport, + mcp_server_config_http, + mcp_server_config_http_with_auth, + mcp_server_config_http_with_timeout, ): """Test the MCPServers model.""" mcp_servers = MCPServers() - assert mcp_servers.servers == [] - - mcp_servers = MCPServers( - servers=[ - mcp_server_config_stdio_transport, - mcp_server_config_sse_transport, - mcp_server_config_streamable_http_transport, - ] - ) + assert not mcp_servers.servers + + # Mock k8s authentication for kubernetes placeholder test + with mock.patch("ols.config") as mock_config: + mock_auth_config = mock.Mock() + mock_auth_config.module = "k8s" + mock_ols_config = mock.Mock() + mock_ols_config.authentication_config = mock_auth_config + mock_config.ols_config = mock_ols_config + + mcp_servers = MCPServers( + servers=[ + mcp_server_config_http, + mcp_server_config_http_with_auth, + mcp_server_config_http_with_timeout, + ] + ) assert len(mcp_servers.servers) == 3 assert mcp_servers.servers[0].name == "foo" assert mcp_servers.servers[1].name == "bar" @@ -238,8 +168,7 @@ def test_mcp_servers_duplicity(): """Test the MCPServers model.""" mcp_server_config = MCPServerConfig( name="foo", - transport="stdio", - stdio=StdioTransportConfig(command="python", args=["server1.py"]), + url="http://localhost:8080", ) with pytest.raises(ValidationError, match="Duplicate server name: 'foo'"): @@ -255,25 +184,23 @@ def test_mcp_servers_invalid_input(): MCPServers(servers={}) # pyright: ignore[reportArgumentType] -def test_mcp_servers_equality( - mcp_server_config_stdio_transport, mcp_server_config_sse_transport -): +def test_mcp_servers_equality(mcp_server_config_http, mcp_server_config_http_with_auth): """Test the MCPServers model.""" mcp_servers_1 = MCPServers( servers=[ - mcp_server_config_stdio_transport, - mcp_server_config_sse_transport, + mcp_server_config_http, + mcp_server_config_http_with_auth, ] ) mcp_servers_2 = MCPServers( servers=[ - mcp_server_config_stdio_transport, - mcp_server_config_sse_transport, + mcp_server_config_http, + mcp_server_config_http_with_auth, ] ) mcp_servers_3 = MCPServers( servers=[ - mcp_server_config_stdio_transport, + mcp_server_config_http, ] ) @@ -288,56 +215,6 @@ def test_mcp_servers_equality( assert mcp_servers_1 != other_value -def test_sse_transport_configuration_on_no_data(): - """Test the SSE transport configuration handling when no data are provided.""" - with pytest.raises(ValidationError, match=r"(?s)url.*Field required"): - SseTransportConfig() # pyright: ignore[reportCallIssue] - - -def test_sse_transport_defaults(): - """Test the SSE transport configuration defaults.""" - sse_transport = SseTransportConfig(url="http://localhost:8080") - assert sse_transport.url == "http://localhost:8080" - assert sse_transport.timeout == constants.SSE_TRANSPORT_DEFAULT_TIMEOUT - assert ( - sse_transport.sse_read_timeout == constants.SSE_TRANSPORT_DEFAULT_READ_TIMEOUT - ) - assert sse_transport.headers == {} - - -def test_streamable_http_transport_configuration_on_no_data(): - """Test the SSE transport configuration handling when no data are provided.""" - with pytest.raises(ValidationError, match=r"(?s)url.*Field required"): - StreamableHttpTransportConfig() # pyright: ignore[reportCallIssue] - - -def test_streamable_http_transport_defaults(): - """Test the SSE transport configuration defaults.""" - sse_transport = StreamableHttpTransportConfig(url="http://localhost:8080") - assert sse_transport.url == "http://localhost:8080" - assert sse_transport.timeout == constants.STREAMABLE_HTTP_TRANSPORT_DEFAULT_TIMEOUT - assert ( - sse_transport.sse_read_timeout - == constants.STREAMABLE_HTTP_TRANSPORT_DEFAULT_READ_TIMEOUT - ) - - -def test_stdio_transport_configuration_on_no_data(): - """Test the STDIO transport configuration handling when no data are provided.""" - with pytest.raises(ValidationError, match=r"(?s)command.*Field required"): - StdioTransportConfig() # pyright: ignore[reportCallIssue] - - -def test_stdio_transport_defaults(): - """Test the STDIO transport configuration defaults.""" - stdio_transport = StdioTransportConfig(command="python") - assert stdio_transport.command == "python" - assert stdio_transport.args == [] - assert stdio_transport.env == constants.STDIO_TRANSPORT_DEFAULT_ENV - assert stdio_transport.cwd == constants.STDIO_TRANSPORT_DEFAULT_CWD - assert stdio_transport.encoding == constants.STDIO_TRANSPORT_DEFAULT_ENCODING - - def test_model_parameters(): """Test the ModelParameters model.""" default_params = ModelParameters() @@ -1557,21 +1434,37 @@ def test_llm_providers(): ] ) assert len(llm_providers.providers) == 1 - assert llm_providers.providers["test_provider_name"].name == "test_provider_name" - assert llm_providers.providers["test_provider_name"].type == "bam" - assert llm_providers.providers["test_provider_name"].url == "test_provider_url" - assert llm_providers.providers["test_provider_name"].credentials == "secret_key" - assert len(llm_providers.providers["test_provider_name"].models) == 1 assert ( - llm_providers.providers["test_provider_name"].models["test_model_name"].name + llm_providers.providers["test_provider_name"].name == "test_provider_name" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["test_provider_name"].type == "bam" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["test_provider_name"].url == "test_provider_url" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["test_provider_name"].credentials == "secret_key" + ) # pyright: ignore[reportIndexIssue] + assert ( + len(llm_providers.providers["test_provider_name"].models) == 1 + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["test_provider_name"] + .models["test_model_name"] + .name # pyright: ignore[reportIndexIssue] == "test_model_name" ) assert ( - str(llm_providers.providers["test_provider_name"].models["test_model_name"].url) + str( + llm_providers.providers["test_provider_name"].models["test_model_name"].url + ) # pyright: ignore[reportIndexIssue] == "http://test.url/" ) assert ( - llm_providers.providers["test_provider_name"] + llm_providers.providers[ + "test_provider_name" + ] # pyright: ignore[reportIndexIssue] .models["test_model_name"] .credentials == "secret_key" @@ -1609,8 +1502,12 @@ def test_llm_providers_type_defaulting(): ] ) assert len(llm_providers.providers) == 1 - assert llm_providers.providers["bam"].name == "bam" - assert llm_providers.providers["bam"].type == "bam" + assert ( + llm_providers.providers["bam"].name == "bam" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["bam"].type == "bam" + ) # pyright: ignore[reportIndexIssue] llm_providers = LLMProviders( [ @@ -1627,8 +1524,12 @@ def test_llm_providers_type_defaulting(): ] ) assert len(llm_providers.providers) == 1 - assert llm_providers.providers["test_provider"].name == "test_provider" - assert llm_providers.providers["test_provider"].type == "bam" + assert ( + llm_providers.providers["test_provider"].name == "test_provider" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["test_provider"].type == "bam" + ) # pyright: ignore[reportIndexIssue] def test_llm_providers_type_validation(): @@ -1690,10 +1591,16 @@ def test_llm_providers_watsonx_required_projectid(): ] ) assert len(llm_providers.providers) == 1 - assert llm_providers.providers["watsonx"].name == "watsonx" - assert llm_providers.providers["watsonx"].type == "watsonx" assert ( - llm_providers.providers["watsonx"].project_id + llm_providers.providers["watsonx"].name == "watsonx" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["watsonx"].type == "watsonx" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers[ + "watsonx" + ].project_id # pyright: ignore[reportIndexIssue] == "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX" ) @@ -1710,10 +1617,16 @@ def test_llm_providers_watsonx_required_projectid(): ] ) assert len(llm_providers.providers) == 1 - assert llm_providers.providers["test_provider"].name == "test_provider" - assert llm_providers.providers["test_provider"].type == "watsonx" assert ( - llm_providers.providers["test_provider"].project_id + llm_providers.providers["test_provider"].name == "test_provider" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers["test_provider"].type == "watsonx" + ) # pyright: ignore[reportIndexIssue] + assert ( + llm_providers.providers[ + "test_provider" + ].project_id # pyright: ignore[reportIndexIssue] == "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX" ) @@ -1992,12 +1905,12 @@ def test_tls_config_incorrect_certificate_path(): config2 = TLSConfig( { - "tls_certificate_path": "/etc/shadow", + "tls_certificate_path": "/nonexistent/path/cert.pem", "tls_key_path": "tests/config/key", "tls_key_password_path": "tests/config/password", } ) - with pytest.raises(InvalidConfigurationError, match="is not readable"): + with pytest.raises(InvalidConfigurationError, match="is not a file"): config2.validate_yaml() @@ -2470,7 +2383,7 @@ def test_ols_config_equality(subtests): # compare OLSConfig with other object assert ols_config_1 != "foo" - assert ols_config_2 != {} + assert ols_config_2 def test_config(): @@ -2540,46 +2453,63 @@ def test_config(): ) assert len(config.llm_providers.providers) == 3 assert ( - config.llm_providers.providers["test_provider_name"].name + config.llm_providers.providers[ + "test_provider_name" + ].name # pyright: ignore[reportIndexIssue] == "test_provider_name" ) assert ( - config.llm_providers.providers["test_provider_name"].url == "test_provider_url" + config.llm_providers.providers["test_provider_name"].url + == "test_provider_url" # pyright: ignore[reportIndexIssue] ) assert ( - config.llm_providers.providers["test_provider_name"].credentials == "secret_key" + config.llm_providers.providers["test_provider_name"].credentials + == "secret_key" # pyright: ignore[reportIndexIssue] ) - assert len(config.llm_providers.providers["test_provider_name"].models) == 1 assert ( - config.llm_providers.providers["test_provider_name"] + len(config.llm_providers.providers["test_provider_name"].models) == 1 + ) # pyright: ignore[reportIndexIssue] + assert ( + config.llm_providers.providers[ + "test_provider_name" + ] # pyright: ignore[reportIndexIssue] .models["test_model_name"] .name == "test_model_name" ) assert ( str( - config.llm_providers.providers["test_provider_name"] + config.llm_providers.providers[ + "test_provider_name" + ] # pyright: ignore[reportIndexIssue] .models["test_model_name"] .url ) == "http://test_model_url/" ) assert ( - config.llm_providers.providers["test_provider_name"] + config.llm_providers.providers[ + "test_provider_name" + ] # pyright: ignore[reportIndexIssue] .models["test_model_name"] .credentials == "secret_key" ) assert ( - config.llm_providers.providers["rhoai_provider_name"].certificates_store + config.llm_providers.providers[ + "rhoai_provider_name" + ].certificates_store # pyright: ignore[reportIndexIssue] == "/foo/bar/baz/ols.pem" ) assert ( - config.llm_providers.providers["rhelai_provider_name"].certificates_store + config.llm_providers.providers[ + "rhelai_provider_name" + ].certificates_store # pyright: ignore[reportIndexIssue] == "/foo/bar/baz/ols.pem" ) assert ( - config.llm_providers.providers["test_provider_name"].certificates_store is None + config.llm_providers.providers["test_provider_name"].certificates_store + is None # pyright: ignore[reportIndexIssue] ) assert config.ols_config.default_provider == "test_default_provider" @@ -3599,8 +3529,6 @@ def test_user_data_config__transcripts(tmpdir): def test_user_data_config__config_status(tmpdir): """Tests the UserDataCollection model, config_status part.""" - import os - parent_dir = os.path.dirname(tmpdir.strpath) # config status is inferred from feedback/transcripts settings @@ -3640,7 +3568,7 @@ def test_dev_config_defaults(): dev_config = DevConfig() assert dev_config.pyroscope_url is None assert dev_config.enable_dev_ui is False - assert dev_config.llm_params == {} + assert not dev_config.llm_params assert dev_config.disable_auth is False assert dev_config.disable_tls is False assert dev_config.k8s_auth_token is None diff --git a/tests/unit/config_status/test_config_status.py b/tests/unit/config_status/test_config_status.py index 846c8a5e4..4f47003b0 100644 --- a/tests/unit/config_status/test_config_status.py +++ b/tests/unit/config_status/test_config_status.py @@ -18,7 +18,6 @@ QueryFilter, ReferenceContent, ReferenceContentIndex, - StdioTransportConfig, TLSSecurityProfile, UserDataCollection, ) @@ -130,14 +129,13 @@ def test_extract_config_status_with_mcp_servers(self): config = create_minimal_config() mcp_server = MCPServerConfig( name="test_server", - transport="stdio", - stdio=StdioTransportConfig(command="python", args=["test.py"]), + url="http://localhost:8080", ) config.mcp_servers.servers = [mcp_server] status = extract_config_status(config) - assert status.mcp_servers == {"test_server": "stdio"} + assert status.mcp_servers == {"test_server": "http"} def test_extract_config_status_with_provider_tls_config(self): """Test extracting config status with provider TLS security profile.""" @@ -194,7 +192,7 @@ def test_store_config_status(self, tmpdir): query_redactor_enabled=True, query_filter_count=1, providers_with_tls_config=["my_openai"], - mcp_servers={"openshift": "stdio"}, + mcp_servers={"openshift": "http"}, quota_management_enabled=False, token_history_enabled=False, proxy_enabled=False, @@ -212,14 +210,14 @@ def test_store_config_status(self, tmpdir): config_status_file = Path(storage_path) / "test-uuid.json" assert config_status_file.exists() - with open(config_status_file) as f: + with open(config_status_file, encoding="utf-8") as f: stored_data = json.load(f) assert stored_data["timestamp"] == "2024-01-01T00:00:00+00:00" assert stored_data["providers"] == {"openai": ["my_openai"]} assert stored_data["models"] == {"my_openai": ["gpt-4"]} assert stored_data["rag_indexes"] == ["ocp-docs-4_17", "user-docs-v1"] - assert stored_data["mcp_servers"] == {"openshift": "stdio"} + assert stored_data["mcp_servers"] == {"openshift": "http"} def test_store_config_status_creates_directory(self, tmpdir): """Test that store_config_status creates the storage directory if needed.""" diff --git a/tests/unit/query_helpers/test_docs_summarizer.py b/tests/unit/query_helpers/test_docs_summarizer.py index 89e74769f..ba3fef2f8 100644 --- a/tests/unit/query_helpers/test_docs_summarizer.py +++ b/tests/unit/query_helpers/test_docs_summarizer.py @@ -296,7 +296,7 @@ def test_tool_calling_tool_execution(caplog): "ols.src.query_helpers.docs_summarizer.DocsSummarizer._invoke_llm" ) as mock_invoke, patch( - "ols.src.query_helpers.docs_summarizer.MCPConfigBuilder.dump_client_config", + "ols.src.query_helpers.docs_summarizer.DocsSummarizer._build_mcp_config", return_value=mcp_servers_config, ), ): diff --git a/tests/unit/tools/test_mcp_config_builder.py b/tests/unit/tools/test_mcp_config_builder.py deleted file mode 100644 index 627195015..000000000 --- a/tests/unit/tools/test_mcp_config_builder.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Tests for MCPConfigBuilder.""" - -import os -import tempfile -from datetime import timedelta -from unittest.mock import patch - -from ols.app.models.config import ( - MCPServerConfig, - SseTransportConfig, - StdioTransportConfig, - StreamableHttpTransportConfig, -) -from ols.src.tools.mcp_config_builder import KUBERNETES_PLACEHOLDER, MCPConfigBuilder - - -def test_mcp_config_builder_dump_client_config(): - """Test MCPConfigBuilder.dump_client_config method.""" - mcp_server_configs = [ - MCPServerConfig( - name="openshift", - transport="stdio", - stdio=StdioTransportConfig( - command="hello", - env={"X": "Y"}, - ), - ), - MCPServerConfig( - name="not-openshift", - transport="stdio", - stdio=StdioTransportConfig( - command="hello", - env={"X": "Y"}, - ), - ), - ] - user_token = "fake-token" # noqa: S105 - - # patch the environment variable to avoid using values from the system - with patch.dict(os.environ, {}, clear=True): - builder = MCPConfigBuilder(user_token, mcp_server_configs) - mcp_config = builder.dump_client_config() - - assert mcp_config == { - "openshift": { - "transport": "stdio", - "command": "hello", - "args": [], - "env": {"X": "Y", "OC_USER_TOKEN": "fake-token"}, - "cwd": ".", - "encoding": "utf-8", - }, - "not-openshift": { - "transport": "stdio", - "command": "hello", - "args": [], - "env": {"X": "Y"}, - "cwd": ".", - "encoding": "utf-8", - }, - } - - -class TestMCPConfigBuilder: - """Test MCPConfigBuilder class.""" - - @staticmethod - def test_include_auth_to_stdio(): - """Test include_auth_to_stdio method.""" - user_token = "fake-token" # noqa: S105 - envs = {"A": 42, "KUBECONFIG": "bla"} - - builder = MCPConfigBuilder(user_token, []) - mcp_config = builder.include_auth_to_stdio(envs) - - expected = {**envs, "OC_USER_TOKEN": user_token} - assert mcp_config == expected - - @staticmethod - def test_token_set_in_env(caplog): - """Test include_auth_to_stdio with token set in env.""" - # OC_USER_TOKEN set in env - is logged and overriden - user_token = "fake-token" # noqa: S105 - envs = {"OC_USER_TOKEN": "different-value"} - - builder = MCPConfigBuilder(user_token, []) - with patch.dict(os.environ, {}, clear=True): - mcp_config = builder.include_auth_to_stdio(envs) - - expected = {"OC_USER_TOKEN": user_token} - assert mcp_config == expected - assert "overriding with actual user token" in caplog.text - - @staticmethod - def test_kubeconfig_from_environ(caplog): - """Test include_auth_to_stdio with KUBECONFIG from environment.""" - # KUBECONFIG is not set in env - value from os.environ is used - caplog.set_level(20) # info - envs = {"A": 42} - user_token = "fake-token" # noqa: S105 - - builder = MCPConfigBuilder(user_token, []) - with patch.dict(os.environ, {"KUBECONFIG": "os value"}): - mcp_config = builder.include_auth_to_stdio(envs) - - expected = {**envs, "OC_USER_TOKEN": user_token, "KUBECONFIG": "os value"} - assert mcp_config == expected - assert "Using KUBECONFIG from environment" in caplog.text - - @staticmethod - def test_kubernetes_service_from_environ(caplog): - """Test include_auth_to_stdio with KUBERNETES_SERVICE_* from environment.""" - # KUBECONFIG is not set, but KUBERNETES_SERVICE_* is available - caplog.set_level(20) # info - envs = {"A": 42} - user_token = "fake-token" # noqa: S105 - - builder = MCPConfigBuilder(user_token, []) - with patch.dict( - os.environ, - {"KUBERNETES_SERVICE_HOST": "k8s-host", "KUBERNETES_SERVICE_PORT": "8443"}, - ): - mcp_config = builder.include_auth_to_stdio(envs) - - expected = { - **envs, - "OC_USER_TOKEN": user_token, - "KUBERNETES_SERVICE_HOST": "k8s-host", - "KUBERNETES_SERVICE_PORT": "8443", - } - assert mcp_config == expected - assert "Using KUBERNETES_SERVICE_* from environment" in caplog.text - - @staticmethod - def test_missing_kubeconfig_and_kubernetes_service(caplog): - """Test include_auth_to_stdio with missing KUBECONFIG and KUBERNETES_SERVICE_*.""" - # Both KUBECONFIG and KUBERNETES_SERVICE_* are missing - envs = {} - user_token = "fake-token" # noqa: S105 - - builder = MCPConfigBuilder(user_token, []) - with patch.dict(os.environ, {}, clear=True): - mcp_config = builder.include_auth_to_stdio(envs) - - expected = {"OC_USER_TOKEN": user_token} - assert mcp_config == expected - assert "Missing necessary KUBECONFIG/KUBERNETES_SERVICE_* envs" in caplog.text - - @staticmethod - def test_dump_client_config_with_sse(): - """Test dump_client_config with SSE configuration.""" - file_descriptor, file_path = tempfile.mkstemp(suffix=".tmp") - try: - with os.fdopen(file_descriptor, "w") as open_file: - open_file.write("value") - mcp_server_configs = [ - MCPServerConfig( - name="sse-server", - transport="sse", - sse=SseTransportConfig( - url="https://example.com/events", - headers={ - "X-Custom-Header": file_path, - "kubernetes": KUBERNETES_PLACEHOLDER, - }, - ), - ), - ] - user_token = "fake-token" # noqa: S105 - - builder = MCPConfigBuilder(user_token, mcp_server_configs) - try: - result = builder.dump_client_config() - except Exception as e: - print(f"failed creating config {e}") - assert False - assert result["sse-server"]["transport"] == "sse" - assert result["sse-server"]["url"] == "https://example.com/events" - assert result["sse-server"]["headers"]["X-Custom-Header"] == "value" - assert ( - result["sse-server"]["headers"]["kubernetes"] == f"Bearer {user_token}" - ) - finally: - os.unlink(file_path) - - @staticmethod - def test_dump_client_config_with_mixed_transports(): - """Test dump_client_config with both SSE and stdio configurations.""" - mcp_server_configs = [ - MCPServerConfig( - name="openshift", - transport="stdio", - stdio=StdioTransportConfig( - command="hello", - env={"X": "Y"}, - ), - ), - MCPServerConfig( - name="sse-server", - transport="sse", - sse=SseTransportConfig( - url="https://example.com/events", - ), - ), - ] - user_token = "fake-token" # noqa: S105 - - with patch.dict(os.environ, {}, clear=True): - builder = MCPConfigBuilder(user_token, mcp_server_configs) - result = builder.dump_client_config() - - assert "openshift" in result - assert "sse-server" in result - assert result["openshift"]["transport"] == "stdio" - assert result["sse-server"]["transport"] == "sse" - assert result["openshift"]["env"]["OC_USER_TOKEN"] == user_token - - @staticmethod - def test_dump_client_config_with_streamable_http(): - """Test dump_client_config with streamable HTTP configuration.""" - file_descriptor, file_path = tempfile.mkstemp(suffix=".tmp") - try: - with os.fdopen(file_descriptor, "w") as open_file: - open_file.write("value") - mcp_server_configs = [ - MCPServerConfig( - name="streamable-server", - transport="streamable_http", - streamable_http=StreamableHttpTransportConfig( - url="https://example.com/stream", - headers={ - "X-Custom-Header": file_path, - "kubernetes": KUBERNETES_PLACEHOLDER, - }, - timeout=30, - sse_read_timeout=60, - ), - ), - ] - user_token = "fake-token" # noqa: S105 - - builder = MCPConfigBuilder(user_token, mcp_server_configs) - try: - result = builder.dump_client_config() - except Exception as e: - print(f"failed creating config {e}") - assert False - - assert result["streamable-server"]["transport"] == "streamable_http" - assert result["streamable-server"]["url"] == "https://example.com/stream" - assert result["streamable-server"]["headers"]["X-Custom-Header"] == "value" - assert ( - result["streamable-server"]["headers"]["kubernetes"] - == f"Bearer {user_token}" - ) - # Verify that timeout values are converted to timedelta objects - assert result["streamable-server"]["timeout"] == timedelta(seconds=30) - assert result["streamable-server"]["sse_read_timeout"] == timedelta( - seconds=60 - ) - finally: - os.unlink(file_path) - - @staticmethod - def test_dump_client_config_with_all_transports(): - """Test dump_client_config with stdio, SSE, and streamable HTTP configurations.""" - mcp_server_configs = [ - MCPServerConfig( - name="openshift", - transport="stdio", - stdio=StdioTransportConfig( - command="hello", - env={"X": "Y"}, - ), - ), - MCPServerConfig( - name="sse-server", - transport="sse", - sse=SseTransportConfig( - url="https://example.com/events", - ), - ), - MCPServerConfig( - name="streamable-server", - transport="streamable_http", - streamable_http=StreamableHttpTransportConfig( - url="https://example.com/stream", - ), - ), - ] - user_token = "fake-token" # noqa: S105 - - with patch.dict(os.environ, {}, clear=True): - builder = MCPConfigBuilder(user_token, mcp_server_configs) - result = builder.dump_client_config() - - assert "openshift" in result - assert "sse-server" in result - assert "streamable-server" in result - assert result["openshift"]["transport"] == "stdio" - assert result["sse-server"]["transport"] == "sse" - assert result["streamable-server"]["transport"] == "streamable_http" - assert result["openshift"]["env"]["OC_USER_TOKEN"] == user_token diff --git a/tests/unit/utils/test_checks.py b/tests/unit/utils/test_checks.py new file mode 100644 index 000000000..ee7d710e2 --- /dev/null +++ b/tests/unit/utils/test_checks.py @@ -0,0 +1,154 @@ +"""Unit tests for checks utilities.""" + +from pathlib import Path + +from ols.constants import NOOP_WITH_TOKEN_AUTHENTICATION_MODULE +from ols.utils.checks import resolve_headers + + +def test_resolve_headers_empty() -> None: + """Test resolving empty authorization headers.""" + result = resolve_headers({}) + assert not result + + +def test_resolve_headers_with_file(tmp_path: Path) -> None: + """Test resolving authorization headers from file.""" + # Create a temporary secret file + secret_file = tmp_path / "secret.txt" + secret_file.write_text("my-secret-token") + + headers = {"Authorization": str(secret_file)} + result = resolve_headers(headers) + + assert result == {"Authorization": "my-secret-token"} + + +def test_resolve_headers_with_file_strips_whitespace( + tmp_path: Path, +) -> None: + """Test that resolving headers strips trailing whitespace from file content.""" + secret_file = tmp_path / "secret.txt" + secret_file.write_text(" my-secret-token\n ") + + headers = {"Authorization": str(secret_file)} + result = resolve_headers(headers) + + # rstrip() only removes trailing whitespace, not leading + assert result == {"Authorization": " my-secret-token"} + + +def test_resolve_headers_with_nonexistent_file() -> None: + """Test resolving headers with nonexistent file logs warning and skips.""" + headers = {"Authorization": "/nonexistent/path/to/secret.txt"} + result = resolve_headers(headers) + + # Should return empty dict when file doesn't exist + assert not result + + +def test_resolve_headers_client_token() -> None: + """Test that client token keyword is preserved.""" + headers = {"Authorization": "client"} + result = resolve_headers(headers) + + # Should keep "client" as-is for later substitution + assert result == {"Authorization": "client"} + + +def test_resolve_headers_kubernetes_token() -> None: + """Test that kubernetes keyword is preserved when k8s auth is configured.""" + headers = {"Authorization": "kubernetes"} + result = resolve_headers(headers, auth_module="k8s") + + # Should keep "kubernetes" as-is for later substitution + assert result == {"Authorization": "kubernetes"} + + +def test_resolve_headers_kubernetes_token_with_noop_with_token() -> None: + """Test that kubernetes keyword is preserved when noop_with_token auth is configured.""" + headers = {"Authorization": "kubernetes"} + result = resolve_headers(headers, auth_module=NOOP_WITH_TOKEN_AUTHENTICATION_MODULE) + + # Should keep "kubernetes" as-is for later substitution (for testing) + assert result == {"Authorization": "kubernetes"} + + +def test_resolve_headers_multiple_headers(tmp_path: Path) -> None: + """Test resolving multiple authorization headers.""" + # Create multiple secret files + auth_file = tmp_path / "auth.txt" + auth_file.write_text("auth-token") + api_key_file = tmp_path / "api_key.txt" + api_key_file.write_text("api-key-value") + + headers = { + "Authorization": str(auth_file), + "X-API-Key": str(api_key_file), + } + result = resolve_headers(headers) + + assert result == { + "Authorization": "auth-token", + "X-API-Key": "api-key-value", + } + + +def test_resolve_headers_mixed_types(tmp_path: Path) -> None: + """Test resolving mixed header types (file, client, kubernetes).""" + # Create a secret file + secret_file = tmp_path / "secret.txt" + secret_file.write_text("file-secret") + + headers = { + "Authorization": "client", + "X-API-Key": str(secret_file), + "X-K8s-Token": "kubernetes", + } + + result = resolve_headers(headers, auth_module="k8s") + + # Special keywords should be preserved, file should be resolved + assert result["Authorization"] == "client" + assert result["X-API-Key"] == "file-secret" + assert result["X-K8s-Token"] == "kubernetes" + + +def test_resolve_headers_file_read_error(tmp_path: Path) -> None: + """Test handling of file read errors.""" + # Create a directory instead of a file to cause an error + secret_dir = tmp_path / "secret_dir" + secret_dir.mkdir() + + headers = {"Authorization": str(secret_dir)} + result = resolve_headers(headers) + + # Should handle error gracefully and return empty dict + assert not result + + +def test_resolve_headers_kubernetes_requires_k8s_auth(caplog) -> None: + """Test that kubernetes placeholder logs warning and returns empty dict with non-k8s auth.""" + headers = {"Authorization": "kubernetes"} + + result = resolve_headers(headers, auth_module="azure") + + # Should return empty dict when kubernetes placeholder used with + # non-k8s/non-noop_with_token auth + assert result == {} + assert "kubernetes" in caplog.text.lower() + assert "k8s" in caplog.text + assert "azure" in caplog.text + assert "skipped" in caplog.text.lower() + + +def test_resolve_headers_kubernetes_with_no_auth_module(caplog) -> None: + """Test that kubernetes placeholder logs warning when auth module is None.""" + headers = {"Authorization": "kubernetes"} + + result = resolve_headers(headers, auth_module=None) + + # Should return empty dict + assert result == {} + assert "kubernetes" in caplog.text.lower() + assert "skipped" in caplog.text.lower() diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index 892544d04..a841dd3cc 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -651,15 +651,9 @@ def test_valid_config_stream(): minTLSVersion: VersionTLS13 mcp_servers: - name: foo - transport: stdio - stdio: - command: python - args: - - mcp_server_1.py + url: http://foo-server:8080/mcp - name: bar - transport: sse - sse: - url: 127.0.0.1:8080 + url: http://bar-server:8080/mcp dev_config: enable_dev_ui: true disable_auth: false @@ -679,98 +673,33 @@ def test_valid_config_file(): try: config.reload_from_yaml_file("tests/config/valid_config.yaml") - expected_config = Config( - { - "llm_providers": [ - { - "name": "p1", - "type": "bam", - "url": "https://url1", - "credentials_path": "tests/config/secret/apitoken", - "models": [ - { - "name": "m1", - "url": "https://murl1", - "credentials_path": "tests/config/secret/apitoken", - "context_window_size": 450, - "parameters": {"max_tokens_for_response": 100}, - }, - { - "name": "m2", - "url": "https://murl2", - }, - ], - }, - { - "name": "p2", - "type": "openai", - "url": "https://url2", - "models": [ - { - "name": "m1", - "url": "https://murl1", - }, - { - "name": "m2", - "url": "https://murl2", - }, - ], - }, - ], - "ols_config": { - "max_workers": 1, - "reference_content": { - "indexes": [ - { - "product_docs_index_path": "tests/config", - "product_docs_index_id": "product", - } - ], - }, - "conversation_cache": { - "type": "memory", - "memory": { - "max_entries": 1000, - }, - }, - "logging_config": { - "logging_level": "INFO", - }, - "default_provider": "p1", - "default_model": "m1", - "certificate_directory": "/foo/bar/baz/xyzzy", - "system_prompt_path": "tests/config/system_prompt.txt", - "user_data_collection": {"transcripts_disabled": True}, - }, - "mcp_servers": [ - { - "name": "foo", - "transport": "stdio", - "stdio": { - "command": "python", - "args": ["mcp_server_1.py"], - "env": {}, - "cwd": ".", - "encoding": "utf-8", - }, - }, - { - "name": "bar", - "transport": "sse", - "sse": { - "url": "127.0.0.1:8080", - "timeout": 5, - "sse_read_timeout": 10, - }, - }, - ], - } - ) - assert config.config == expected_config + # Verify LLM providers + assert "p1" in config.config.llm_providers.providers + assert "p2" in config.config.llm_providers.providers + assert config.config.llm_providers.providers["p1"].type == "bam" + assert config.config.llm_providers.providers["p2"].type == "openai" + + # Verify OLS config + assert config.ols_config.max_workers == 1 + assert config.ols_config.default_provider == "p1" + assert config.ols_config.default_model == "m1" assert config.ols_config.user_data_collection is not None assert config.ols_config.user_data_collection.feedback_disabled is True assert config.ols_config.quota_handlers is not None + + # Verify MCP servers assert config.mcp_servers is not None + # Only one server should remain (second one skipped due to missing secret file) + assert len(config.mcp_servers.servers) == 1 + + # First MCP server (the only one remaining) + assert config.mcp_servers.servers[0].name == "foo" + assert config.mcp_servers.servers[0].url == "http://localhost:8080" + assert config.mcp_servers.servers[0].headers == {} + assert config.mcp_servers.servers[0].timeout is None + + # Second MCP server ("bar") was skipped during validation + # because its auth header references a non-existent file except Exception as e: print(traceback.format_exc()) pytest.fail(f"loading valid configuration failed: {e}")