Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,16 @@ async def _handle_sse_response(
await response.aclose()
return # Normal completion, no reconnect needed
except Exception as e:
logger.debug(f"SSE stream ended: {e}") # pragma: no cover
logger.debug(f"SSE stream ended: {e}")

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
if last_event_id is not None:
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
else:
# No event ID received before disconnect - cannot reconnect,
# send error to unblock the client
await self._send_disconnect_error(ctx)

async def _handle_reconnection(
self,
Expand All @@ -352,6 +356,7 @@ async def _handle_reconnection(
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
await self._send_disconnect_error(ctx)
return

# Always wait - use server value or default
Expand Down Expand Up @@ -417,6 +422,17 @@ async def _send_session_terminated_error(self, read_stream_writer: StreamWriter,
session_message = SessionMessage(jsonrpc_error)
await read_stream_writer.send(session_message)

async def _send_disconnect_error(self, ctx: RequestContext) -> None:
"""Send a disconnect error to unblock the client waiting on the read stream."""
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
request_id = ctx.session_message.message.id
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=-32000, message="SSE stream disconnected before receiving a response"),
)
await ctx.read_stream_writer.send(SessionMessage(jsonrpc_error))

async def post_writer(
self,
client: httpx.AsyncClient,
Expand Down
178 changes: 178 additions & 0 deletions tests/issues/test_1811_sse_disconnect_hang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Test for issue #1811 - client hangs after SSE disconnection.

When the SSE stream disconnects before the server sends a response (e.g., due to
a read timeout), the client's read_stream_writer was never sent an error message,
causing the client to hang indefinitely on .receive(). The fix sends a JSONRPCError
when the stream disconnects without a resumable event ID.
"""

import multiprocessing
import socket
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

import anyio
import httpx
import pytest
from starlette.applications import Starlette
from starlette.routing import Mount

from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.exceptions import McpError
from mcp.types import TextContent, Tool
from tests.test_helpers import wait_for_server

SERVER_NAME = "test_sse_disconnect_server"


def get_free_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def create_slow_server_app() -> Starlette: # pragma: no cover
"""Create a server with a tool that takes a long time to respond."""
server = Server(SERVER_NAME)

@server.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="slow_tool",
description="A tool that takes a long time",
input_schema={"type": "object", "properties": {}},
)
]

@server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[TextContent]:
# Sleep long enough that the client timeout fires first
await anyio.sleep(30)
return [TextContent(type="text", text="done")]

session_manager = StreamableHTTPSessionManager(app=server, stateless=True)

@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
async with session_manager.run():
yield

return Starlette(
routes=[Mount("/mcp", app=session_manager.handle_request)],
lifespan=lifespan,
)


def create_fast_server_app() -> Starlette: # pragma: no cover
"""Create a server with a fast tool for sanity testing."""
server = Server(SERVER_NAME)

@server.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="fast_tool",
description="A fast tool",
input_schema={"type": "object", "properties": {}},
)
]

@server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[TextContent]:
return [TextContent(type="text", text="fast result")]

session_manager = StreamableHTTPSessionManager(app=server, stateless=True)

@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
async with session_manager.run():
yield

return Starlette(
routes=[Mount("/mcp", app=session_manager.handle_request)],
lifespan=lifespan,
)


def run_server(port: int, slow: bool = True) -> None: # pragma: no cover
"""Run the server in a separate process."""
import uvicorn

app = create_slow_server_app() if slow else create_fast_server_app()
uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning")


@pytest.fixture
def slow_server_url():
"""Start the slow server and return its URL."""
port = get_free_port()
proc = multiprocessing.Process(target=run_server, args=(port, True), daemon=True)
proc.start()
wait_for_server(port)

yield f"http://127.0.0.1:{port}"

proc.kill()
proc.join(timeout=2)


@pytest.fixture
def fast_server_url():
"""Start the fast server and return its URL."""
port = get_free_port()
proc = multiprocessing.Process(target=run_server, args=(port, False), daemon=True)
proc.start()
wait_for_server(port)

yield f"http://127.0.0.1:{port}"

proc.kill()
proc.join(timeout=2)


@pytest.mark.anyio
async def test_client_receives_error_on_sse_disconnect(slow_server_url: str):
"""Client should receive an error instead of hanging when SSE stream disconnects.

When the read timeout fires before the server sends a response, the SSE stream
is closed. Previously, if no event ID had been received, the client would hang
forever. Now it should raise McpError with the disconnect message.
"""
# Use a short read timeout so the SSE stream disconnects quickly
short_timeout_client = httpx.AsyncClient(
timeout=httpx.Timeout(5.0, read=0.5),
)

async with streamable_http_client(
f"{slow_server_url}/mcp/",
http_client=short_timeout_client,
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

# Call the slow tool - the read timeout should fire
# and the client should receive an error instead of hanging
with pytest.raises(McpError, match="SSE stream disconnected"): # pragma: no branch
await session.call_tool("slow_tool", {})


@pytest.mark.anyio
async def test_fast_tool_still_works_normally(fast_server_url: str):
"""Ensure normal (fast) tool calls still work correctly after the fix."""
client = httpx.AsyncClient(timeout=httpx.Timeout(5.0))

async with streamable_http_client(
f"{fast_server_url}/mcp/",
http_client=client,
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

result = await session.call_tool("fast_tool", {})
assert result.content[0].type == "text"
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "fast result"
Loading