diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index 3589d0da7..1eac28796 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -64,7 +64,7 @@ async def connect( # Unwrap FastMCP to get underlying Server actual_server: Server[Any] if isinstance(self._server, FastMCP): - actual_server = self._server._mcp_server # type: ignore[reportPrivateUsage] + actual_server = self._server.mcp_server else: actual_server = self._server diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 1738c12de..fa876cf07 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -296,5 +296,4 @@ async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta async def send_roots_list_changed(self) -> None: """Send a notification that the roots list has changed.""" - # TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support. await self.session.send_roots_list_changed() # pragma: no cover diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 71cd81eb2..008fe485f 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -214,6 +214,14 @@ def session_manager(self) -> StreamableHTTPSessionManager: """ return self._mcp_server.session_manager # pragma: no cover + @property + def mcp_server(self): + """Get the underlying MCP server instance. + + This is exposed to enable advanced use cases like in-memory testing. + """ + return self._mcp_server + @overload def run(self, transport: Literal["stdio"] = ...) -> None: ... @@ -255,8 +263,8 @@ def run( transport: Transport protocol to use ("stdio", "sse", or "streamable-http") **kwargs: Transport-specific options (see overloads for details) """ - TRANSPORTS = Literal["stdio", "sse", "streamable-http"] - if transport not in TRANSPORTS.__args__: # type: ignore # pragma: no cover + SUPPORTED_TRANSPORTS = {"stdio", "sse", "streamable-http"} + if transport not in SUPPORTED_TRANSPORTS: # pragma: no cover raise ValueError(f"Unknown transport: {transport}") match transport: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 50a441d69..b1a3e1a49 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -133,34 +133,45 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: client_caps = self._client_params.capabilities - if capability.roots is not None: # pragma: lax no cover - if client_caps.roots is None: - return False - if capability.roots.list_changed and not client_caps.roots.list_changed: - return False + # Check roots capability + if capability.roots and not client_caps.roots: # pragma: lax no cover + return False + if ( + capability.roots + and capability.roots.list_changed + and client_caps.roots + and not client_caps.roots.list_changed + ): # pragma: lax no cover + return False - if capability.sampling is not None: # pragma: lax no cover - if client_caps.sampling is None: - return False - if capability.sampling.context is not None and client_caps.sampling.context is None: + # Check sampling capability + if capability.sampling and not client_caps.sampling: # pragma: lax no cover + return False + if capability.sampling and client_caps.sampling: # pragma: lax no cover + if capability.sampling.context and not client_caps.sampling.context: # pragma: lax no cover return False - if capability.sampling.tools is not None and client_caps.sampling.tools is None: + if capability.sampling.tools and not client_caps.sampling.tools: # pragma: lax no cover return False - if capability.elicitation is not None and client_caps.elicitation is None: # pragma: lax no cover + # Check elicitation capability + if capability.elicitation and not client_caps.elicitation: # pragma: lax no cover return False - if capability.experimental is not None: # pragma: lax no cover - if client_caps.experimental is None: + # Check experimental capability + if capability.experimental: # pragma: lax no cover + if not client_caps.experimental: # pragma: lax no cover return False - for exp_key, exp_value in capability.experimental.items(): - if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: + for exp_key, exp_value in capability.experimental.items(): # pragma: lax no cover + if ( + exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value + ): # pragma: lax no cover return False - if capability.tasks is not None: # pragma: lax no cover - if client_caps.tasks is None: + # Check tasks capability + if capability.tasks: # pragma: lax no cover + if not client_caps.tasks: # pragma: lax no cover return False - if not check_tasks_capability(capability.tasks, client_caps.tasks): + if not check_tasks_capability(capability.tasks, client_caps.tasks): # pragma: lax no cover return False return True @@ -207,6 +218,9 @@ async def _received_notification(self, notification: types.ClientNotification) - match notification: case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized + case types.RootsListChangedNotification(): + # When roots list changes, server should request updated list + await self.list_roots() # pragma: no cover case _: if self._initialization_state != InitializationState.Initialized: # pragma: no cover raise RuntimeError("Received notification before initialization was complete")