From 04354dd499393cbee4409e90713c6c67faa22185 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 18 Jan 2026 21:48:00 +0100 Subject: [PATCH 1/5] Drop RootModel from JSONRPCMessage --- src/mcp/client/streamable_http.py | 2 +- .../experimental/task_result_handler.py | 10 +- src/mcp/shared/session.py | 44 ++--- src/mcp/types.py | 3 +- tests/client/test_session.py | 182 +++++++----------- tests/client/test_stdio.py | 8 +- .../tasks/client/test_capabilities.py | 49 ++--- .../tasks/client/test_handlers.py | 70 ++++--- .../experimental/tasks/server/test_server.py | 15 +- tests/issues/test_192_request_id.py | 10 +- tests/issues/test_malformed_input.py | 24 +-- tests/server/test_lifespan.py | 86 +++------ tests/server/test_session.py | 63 ++---- tests/server/test_session_race_condition.py | 47 ++--- tests/server/test_stdio.py | 16 +- tests/shared/test_session.py | 11 +- tests/shared/test_sse.py | 12 +- tests/shared/test_streamable_http.py | 3 +- 18 files changed, 252 insertions(+), 403 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 75dcd5e891..27fa423388 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -416,7 +416,7 @@ async def _send_session_terminated_error(self, read_stream_writer: StreamWriter, id=request_id, error=ErrorData(code=32600, message="Session terminated"), ) - session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(jsonrpc_error) await read_stream_writer.send(session_message) async def post_writer( diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 078de66286..4d763ef0e6 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -26,7 +26,6 @@ ErrorData, GetTaskPayloadRequest, GetTaskPayloadResult, - JSONRPCMessage, RelatedTaskMetadata, RequestId, ) @@ -107,12 +106,7 @@ async def handle( while True: task = await self._store.get_task(task_id) if task is None: - raise McpError( - ErrorData( - code=INVALID_PARAMS, - message=f"Task not found: {task_id}", - ) - ) + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Task not found: {task_id}")) await self._deliver_queued_messages(task_id, session, request_id) @@ -161,7 +155,7 @@ async def _deliver_queued_messages( # Send the message with relatedRequestId for routing session_message = SessionMessage( - message=JSONRPCMessage(message.message), + message=message.message, metadata=ServerMessageMetadata(related_request_id=request_id), ) await self.send_message(session, session_message) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 8006933541..be1990d618 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -24,7 +24,6 @@ ClientResult, ErrorData, JSONRPCError, - JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, @@ -271,7 +270,7 @@ async def send_request( **request_data, ) - await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) # request read timeout takes precedence over session read timeout timeout = None @@ -321,7 +320,7 @@ async def send_notification( **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) session_message = SessionMessage( # pragma: no cover - message=JSONRPCMessage(jsonrpc_notification), + message=jsonrpc_notification, metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) await self._write_stream.send(session_message) @@ -329,7 +328,7 @@ async def send_notification( async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(message=jsonrpc_error) await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( @@ -337,7 +336,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er id=request_id, result=response.model_dump(by_alias=True, mode="json", exclude_none=True), ) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + session_message = SessionMessage(message=jsonrpc_response) await self._write_stream.send(session_message) async def _receive_loop(self) -> None: @@ -349,14 +348,14 @@ async def _receive_loop(self) -> None: async for message in self._read_stream: if isinstance(message, Exception): # pragma: no cover await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): + elif isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True), + message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) responder = RequestResponder( - request_id=message.message.root.id, + request_id=message.message.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, @@ -374,23 +373,23 @@ async def _receive_loop(self) -> None: # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server logging.warning(f"Failed to validate request: {e}") - logging.debug(f"Message that failed validation: {message.message.root}") + logging.debug(f"Message that failed validation: {message.message}") error_response = JSONRPCError( jsonrpc="2.0", - id=message.message.root.id, + id=message.message.id, error=ErrorData( code=INVALID_PARAMS, message="Invalid request parameters", data="", ), ) - session_message = SessionMessage(message=JSONRPCMessage(error_response)) + session_message = SessionMessage(message=error_response) await self._write_stream.send(session_message) - elif isinstance(message.message.root, JSONRPCNotification): + elif isinstance(message.message, JSONRPCNotification): try: notification = self._receive_notification_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True), + message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) # Handle cancellation notifications @@ -419,10 +418,11 @@ async def _receive_loop(self) -> None: ) await self._received_notification(notification) await self._handle_incoming(notification) - except Exception as e: # pragma: no cover + except Exception: # pragma: no cover # For other validation errors, log and continue logging.warning( - f"Failed to validate notification: {e}. Message was: {message.message.root}" + f"Failed to validate notification:. Message was: {message.message}", + exc_info=True, ) else: # Response or error await self._handle_response(message) @@ -475,27 +475,25 @@ async def _handle_response(self, message: SessionMessage) -> None: Checks response routers first (e.g., for task-related responses), then falls back to the normal response stream mechanism. """ - root = message.message.root - # This check is always true at runtime: the caller (_receive_loop) only invokes # this method in the else branch after checking for JSONRPCRequest and # JSONRPCNotification. However, the type checker can't infer this from the # method signature, so we need this guard for type narrowing. - if not isinstance(root, JSONRPCResponse | JSONRPCError): + if not isinstance(message.message, JSONRPCResponse | JSONRPCError): return # pragma: no cover # Normalize response ID to handle type mismatches (e.g., "0" vs 0) - response_id = self._normalize_request_id(root.id) + response_id = self._normalize_request_id(message.message.id) # First, check response routers (e.g., TaskResultHandler) - if isinstance(root, JSONRPCError): + if isinstance(message.message, JSONRPCError): # Route error to routers for router in self._response_routers: - if router.route_error(response_id, root.error): + if router.route_error(response_id, message.message.error): return # Handled else: # Route success response to routers - response_data: dict[str, Any] = root.result or {} + response_data: dict[str, Any] = message.message.result or {} for router in self._response_routers: if router.route_response(response_id, response_data): return # Handled @@ -503,7 +501,7 @@ async def _handle_response(self, message: SessionMessage) -> None: # Fall back to normal response streams stream = self._response_streams.pop(response_id, None) if stream: # pragma: no cover - await stream.send(root) + await stream.send(message.message) else: # pragma: no cover await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) diff --git a/src/mcp/types.py b/src/mcp/types.py index b2afd977df..cad0c5f058 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -197,8 +197,7 @@ class JSONRPCError(MCPModel): error: ErrorData -class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): - pass +JSONRPCMessage = JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError class EmptyResult(Result): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 78df8ed191..9512a0a7c4 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -18,7 +18,6 @@ InitializedNotification, InitializeRequest, InitializeResult, - JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, @@ -41,7 +40,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -65,18 +64,16 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) session_notification = await client_to_server_receive.receive() jsonrpc_notification = session_notification.message - assert isinstance(jsonrpc_notification.root, JSONRPCNotification) + assert isinstance(jsonrpc_notification, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -128,7 +125,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -146,12 +143,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -189,7 +184,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -207,12 +202,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -220,10 +213,7 @@ async def mock_server(): await client_to_server_receive.receive() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -247,7 +237,7 @@ async def test_client_session_version_negotiation_success(): async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -268,12 +258,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -281,10 +269,7 @@ async def mock_server(): await client_to_server_receive.receive() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -309,7 +294,7 @@ async def test_client_session_version_negotiation_failure(): async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -327,21 +312,16 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -368,7 +348,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -386,12 +366,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -399,10 +377,7 @@ async def mock_server(): await client_to_server_receive.receive() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -446,7 +421,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -464,12 +439,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -529,7 +502,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -547,12 +520,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -600,7 +571,7 @@ async def test_get_server_capabilities(): async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -617,12 +588,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -669,7 +638,7 @@ async def mock_server(): # Receive initialization request from client session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -686,12 +655,10 @@ async def mock_server(): # Answer initialization request await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -702,14 +669,14 @@ async def mock_server(): # Wait for the client to send a 'tools/call' request session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) - assert jsonrpc_request.root.method == "tools/call" + assert jsonrpc_request.method == "tools/call" if meta is not None: - assert jsonrpc_request.root.params - assert "_meta" in jsonrpc_request.root.params - assert jsonrpc_request.root.params["_meta"] == meta + assert jsonrpc_request.params + assert "_meta" in jsonrpc_request.params + assert jsonrpc_request.params["_meta"] == meta result = ServerResult( CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False) @@ -718,12 +685,10 @@ async def mock_server(): # Send the tools/call result await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -732,20 +697,18 @@ async def mock_server(): # The client requires this step to validate the tool output schema session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) - assert jsonrpc_request.root.method == "tools/list" + assert jsonrpc_request.method == "tools/list" result = types.ListToolsResult(tools=[mocked_tool]) await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -753,10 +716,7 @@ async def mock_server(): server_to_client_send.close() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 61b7ce4faf..4059a92682 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -47,8 +47,8 @@ async def test_stdio_client(): async with stdio_client(server_parameters) as (read_stream, write_stream): # Test sending and receiving messages messages = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=2, result={}), ] async with write_stream: @@ -67,8 +67,8 @@ async def test_stdio_client(): break assert len(read_messages) == 2 - assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) - assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) + assert read_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + assert read_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={}) @pytest.mark.anyio diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index de73b8c062..be35478016 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -15,7 +15,6 @@ Implementation, InitializeRequest, InitializeResult, - JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, ServerCapabilities, @@ -36,7 +35,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -54,12 +53,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -110,7 +107,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -128,12 +125,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -194,7 +189,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -212,12 +207,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -274,7 +267,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -292,12 +285,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py index 0e4e8f45a5..0cac3c7362 100644 --- a/tests/experimental/tasks/client/test_handlers.py +++ b/tests/experimental/tasks/client/test_handlers.py @@ -151,15 +151,11 @@ async def run_client() -> None: await client_ready.wait() typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="test-task-123")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-1", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + request = types.JSONRPCRequest(jsonrpc="2.0", id="req-1", **typed_request.model_dump(by_alias=True)) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) assert response.id == "req-1" @@ -219,10 +215,10 @@ async def run_client() -> None: id="req-2", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) assert isinstance(response.result, dict) @@ -277,10 +273,10 @@ async def run_client() -> None: id="req-3", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) result = ListTasksResult.model_validate(response.result) @@ -340,10 +336,10 @@ async def run_client() -> None: id="req-4", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) result = CancelTaskResult.model_validate(response.result) @@ -448,11 +444,11 @@ async def run_client() -> None: id="req-sampling", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) # Step 2: Client responds with CreateTaskResult response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) task_result = CreateTaskResult.model_validate(response.result) @@ -469,10 +465,10 @@ async def run_client() -> None: id="req-poll", **typed_poll.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) + await client_streams.server_send.send(SessionMessage(poll_request)) poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message.root + poll_response = poll_response_msg.message assert isinstance(poll_response, types.JSONRPCResponse) status = GetTaskResult.model_validate(poll_response.result) @@ -485,10 +481,10 @@ async def run_client() -> None: id="req-result", **typed_result_req.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request))) + await client_streams.server_send.send(SessionMessage(result_request)) result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message.root + result_response = result_response_msg.message assert isinstance(result_response, types.JSONRPCResponse) assert isinstance(result_response.result, dict) @@ -588,11 +584,11 @@ async def run_client() -> None: id="req-elicit", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) # Step 2: Client responds with CreateTaskResult response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) task_result = CreateTaskResult.model_validate(response.result) @@ -609,10 +605,10 @@ async def run_client() -> None: id="req-poll", **typed_poll.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) + await client_streams.server_send.send(SessionMessage(poll_request)) poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message.root + poll_response = poll_response_msg.message assert isinstance(poll_response, types.JSONRPCResponse) status = GetTaskResult.model_validate(poll_response.result) @@ -625,10 +621,10 @@ async def run_client() -> None: id="req-result", **typed_result_req.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request))) + await client_streams.server_send.send(SessionMessage(result_request)) result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message.root + result_response = result_response_msg.message assert isinstance(result_response, types.JSONRPCResponse) # Verify the elicitation result @@ -667,10 +663,10 @@ async def run_client() -> None: id="req-unhandled", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert ( "not supported" in response.error.message.lower() @@ -706,10 +702,10 @@ async def run_client() -> None: id="req-result", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -742,10 +738,10 @@ async def run_client() -> None: id="req-list", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -778,10 +774,10 @@ async def run_client() -> None: id="req-cancel", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -822,10 +818,10 @@ async def run_client() -> None: id="req-sampling", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -868,10 +864,10 @@ async def run_client() -> None: id="req-elicit", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 94b37e6d07..6b0bbfef3e 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -36,7 +36,6 @@ GetTaskRequestParams, GetTaskResult, JSONRPCError, - JSONRPCMessage, JSONRPCNotification, JSONRPCResponse, ListTasksRequest, @@ -724,7 +723,7 @@ async def test_send_message() -> None: # Create a test message notification = JSONRPCNotification(jsonrpc="2.0", method="test/notification") message = SessionMessage( - message=JSONRPCMessage(notification), + message=notification, metadata=ServerMessageMetadata(related_request_id="test-req-1"), ) @@ -733,8 +732,8 @@ async def test_send_message() -> None: # Verify it was sent to the stream received = await server_to_client_receive.receive() - assert isinstance(received.message.root, JSONRPCNotification) - assert received.message.root.method == "test/notification" + assert isinstance(received.message, JSONRPCNotification) + assert received.message.method == "test/notification" finally: # pragma: no cover await server_to_client_send.aclose() await server_to_client_receive.aclose() @@ -776,7 +775,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Simulate receiving a response from client response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=JSONRPCMessage(response)) + message = SessionMessage(message=response) # Send from "client" side await client_to_server_send.send(message) @@ -831,7 +830,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Simulate receiving an error response from client error_data = ErrorData(code=INVALID_REQUEST, message="Test error") error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=JSONRPCMessage(error_response)) + message = SessionMessage(message=error_response) # Send from "client" side await client_to_server_send.send(message) @@ -894,7 +893,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Send a response - should skip first router and be handled by second response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=JSONRPCMessage(response)) + message = SessionMessage(message=response) await client_to_server_send.send(message) with anyio.fail_after(5): @@ -953,7 +952,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Send an error - should skip first router and be handled by second error_data = ErrorData(code=INVALID_REQUEST, message="Test error") error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=JSONRPCMessage(error_response)) + message = SessionMessage(message=error_response) await client_to_server_send.send(message) with anyio.fail_after(5): diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index ca4a95e5d9..de96dbe23a 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -66,7 +66,7 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) + await client_writer.send(SessionMessage(init_req)) response = await server_reader.receive() # Get init response but don't need to check it # Send initialized notification @@ -75,12 +75,12 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send(SessionMessage(JSONRPCMessage(root=initialized_notification))) + await client_writer.send(SessionMessage(initialized_notification)) # Send ping request with custom ID ping_request = JSONRPCRequest(id=custom_request_id, method="ping", params={}, jsonrpc="2.0") - await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) + await client_writer.send(SessionMessage(ping_request)) # Read response response = await server_reader.receive() @@ -88,8 +88,8 @@ async def run_server(): # Verify response ID matches request ID assert isinstance(response, SessionMessage) assert isinstance(response.message, JSONRPCMessage) - assert isinstance(response.message.root, JSONRPCResponse) - assert response.message.root.id == custom_request_id, "Response ID should match request ID" + assert isinstance(response.message, JSONRPCResponse) + assert response.message.id == custom_request_id, "Response ID should match request ID" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 34498ba747..cb60ca42a6 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -1,21 +1,13 @@ # Claude Debug """Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" -from typing import Any - import anyio import pytest from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage -from mcp.types import ( - INVALID_PARAMS, - JSONRPCError, - JSONRPCMessage, - JSONRPCRequest, - ServerCapabilities, -) +from mcp.types import INVALID_PARAMS, JSONRPCError, JSONRPCMessage, JSONRPCRequest, ServerCapabilities @pytest.mark.anyio @@ -37,7 +29,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): ) # Wrap in session message - request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) + request_message = SessionMessage(message=malformed_request) # Start a server session async with ServerSession( @@ -58,7 +50,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): # Check that we received an error response instead of a crash try: response_message = write_receive_stream.receive_nowait() - response = response_message.message.root + response = response_message.message # Verify it's a proper JSON-RPC error response assert isinstance(response, JSONRPCError) @@ -75,14 +67,14 @@ async def test_malformed_initialize_request_does_not_crash_server(): method="tools/call", # params=None # Missing required params ) - another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request)) + another_request_message = SessionMessage(message=another_malformed_request) await read_send_stream.send(another_request_message) await anyio.sleep(0.1) # Should get another error response, not a crash second_response_message = write_receive_stream.receive_nowait() - second_response = second_response_message.message.root + second_response = second_response_message.message assert isinstance(second_response, JSONRPCError) assert second_response.id == "test_id_2" @@ -125,7 +117,7 @@ async def test_multiple_concurrent_malformed_requests(): method="initialize", # params=None # Missing required params ) - request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) + request_message = SessionMessage(message=malformed_request) malformed_requests.append(request_message) # Send all requests @@ -136,11 +128,11 @@ async def test_multiple_concurrent_malformed_requests(): await anyio.sleep(0.2) # Verify we get error responses for all requests - error_responses: list[Any] = [] + error_responses: list[JSONRPCMessage] = [] try: while True: response_message = write_receive_stream.receive_nowait() - error_responses.append(response_message.message.root) + error_responses.append(response_message.message) except anyio.WouldBlock: pass # No more messages diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 1382785942..caeb0530d5 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -82,13 +82,11 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), ) ) ) @@ -96,27 +94,16 @@ async def run_server(): response = response.message # Send initialized notification - await send_stream1.send( - SessionMessage( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) - ) + await send_stream1.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"))) # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, ) ) ) @@ -125,8 +112,8 @@ async def run_server(): response = await receive_stream2.receive() response = response.message assert isinstance(response, JSONRPCMessage) - assert isinstance(response.root, JSONRPCResponse) - assert response.root.result["content"][0]["text"] == "true" + assert isinstance(response, JSONRPCResponse) + assert response.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() @@ -162,13 +149,7 @@ def check_lifespan(ctx: Context[ServerSession, None]) -> bool: return True # Run server in background task - async with ( - anyio.create_task_group() as tg, - send_stream1, - receive_stream1, - send_stream2, - receive_stream2, - ): + async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: async def run_server(): await server._mcp_server.run( @@ -188,13 +169,11 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), ) ) ) @@ -202,27 +181,16 @@ async def run_server(): response = response.message # Send initialized notification - await send_stream1.send( - SessionMessage( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) - ) + await send_stream1.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"))) # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, ) ) ) @@ -231,8 +199,8 @@ async def run_server(): response = await receive_stream2.receive() response = response.message assert isinstance(response, JSONRPCMessage) - assert isinstance(response.root, JSONRPCResponse) - assert response.root.result["content"][0]["text"] == "true" + assert isinstance(response, JSONRPCResponse) + assert response.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/server/test_session.py b/tests/server/test_session.py index ced1d92ff7..3c1e96c126 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -169,25 +169,23 @@ async def mock_client(): # Send initialization request with older protocol version (2024-11-05) await client_to_server_send.send( SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version="2024-11-05", - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) + types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=types.InitializeRequestParams( + protocol_version="2024-11-05", + capabilities=types.ClientCapabilities(), + client_info=types.Implementation(name="test-client", version="1.0.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) # Wait for the initialize response init_response_message = await server_to_client_receive.receive() - assert isinstance(init_response_message.message.root, types.JSONRPCResponse) - result_data = init_response_message.message.root.result + assert isinstance(init_response_message.message, types.JSONRPCResponse) + result_data = init_response_message.message.result init_result = types.InitializeResult.model_validate(result_data) # Check that the server responded with the requested protocol version @@ -196,14 +194,7 @@ async def mock_client(): # Send initialized notification await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) + SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) ) async with ( @@ -256,24 +247,14 @@ async def mock_client(): nonlocal ping_response_received, ping_response_id # Send ping request before any initialization - await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=42, - method="ping", - ) - ) - ) - ) + await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=42, method="ping"))) # Wait for the ping response ping_response_message = await server_to_client_receive.receive() - assert isinstance(ping_response_message.message.root, types.JSONRPCResponse) + assert isinstance(ping_response_message.message, types.JSONRPCResponse) ping_response_received = True - ping_response_id = ping_response_message.message.root.id + ping_response_id = ping_response_message.message.id async with ( client_to_server_send, @@ -493,22 +474,14 @@ async def mock_client(): # Try to send a non-ping request before initialization await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="prompts/list", - ) - ) - ) + SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=1, method="prompts/list")) ) # Wait for the error response error_message = await server_to_client_receive.receive() - if isinstance(error_message.message.root, types.JSONRPCError): # pragma: no branch + if isinstance(error_message.message, types.JSONRPCError): # pragma: no branch error_response_received = True - error_code = error_message.message.root.error.code + error_code = error_message.message.error.code async with ( client_to_server_send, diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py index aa256f5b0e..bc6145acaf 100644 --- a/tests/server/test_session_race_condition.py +++ b/tests/server/test_session_race_condition.py @@ -87,54 +87,35 @@ async def mock_client(): # Step 1: Send InitializeRequest await client_to_server_send.send( SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) + types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=types.InitializeRequestParams( + protocol_version=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities(), + client_info=types.Implementation(name="test-client", version="1.0.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) # Step 2: Wait for InitializeResult init_msg = await server_to_client_receive.receive() - assert isinstance(init_msg.message.root, types.JSONRPCResponse) + assert isinstance(init_msg.message, types.JSONRPCResponse) # Step 3: Immediately send tools/list BEFORE InitializedNotification # This is the race condition scenario - await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/list", - ) - ) - ) - ) + await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=2, method="tools/list"))) # Step 4: Check the response tools_msg = await server_to_client_receive.receive() - if isinstance(tools_msg.message.root, types.JSONRPCError): # pragma: no cover - error_received = tools_msg.message.root.error.message + if isinstance(tools_msg.message, types.JSONRPCError): # pragma: no cover + error_received = tools_msg.message.error.message # Step 5: Send InitializedNotification await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) + SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) ) async with ( diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 13cdde3d61..71281eb9ac 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -14,8 +14,8 @@ async def test_stdio_server(): stdout = io.StringIO() messages = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=2, result={}), ] for message in messages: @@ -37,13 +37,13 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) - assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) + assert received_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + assert received_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={}) # Test sending responses from the server responses = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})), + JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=4, result={}), ] async with write_stream: @@ -57,5 +57,5 @@ async def test_stdio_server(): received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines] assert len(received_responses) == 2 - assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")) - assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})) + assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") + assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={}) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 8b4ebd81f4..77bec4aa33 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -18,7 +18,6 @@ EmptyResult, ErrorData, JSONRPCError, - JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, TextContent, @@ -130,7 +129,7 @@ async def mock_server(): """Receive a request and respond with a string ID instead of integer.""" message = await server_read.receive() assert isinstance(message, SessionMessage) - root = message.message.root + root = message.message assert isinstance(root, JSONRPCRequest) # Get the original request ID (which is an integer) request_id = root.id @@ -142,7 +141,7 @@ async def mock_server(): id=str(request_id), # Convert to string to simulate mismatch result={}, ) - await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + await server_write.send(SessionMessage(message=response)) async def make_request(client_session: ClientSession): nonlocal result_holder @@ -185,7 +184,7 @@ async def mock_server(): """Receive a request and respond with an error using a string ID.""" message = await server_read.receive() assert isinstance(message, SessionMessage) - root = message.message.root + root = message.message assert isinstance(root, JSONRPCRequest) request_id = root.id assert isinstance(request_id, int) @@ -196,7 +195,7 @@ async def mock_server(): id=str(request_id), # Convert to string to simulate mismatch error=ErrorData(code=-32600, message="Test error"), ) - await server_write.send(SessionMessage(message=JSONRPCMessage(error_response))) + await server_write.send(SessionMessage(message=error_response)) async def make_request(client_session: ClientSession): nonlocal error_holder @@ -247,7 +246,7 @@ async def mock_server(): id="not_a_number", # Non-numeric string result={}, ) - await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + await server_write.send(SessionMessage(message=response)) async def make_request(client_session: ClientSession): try: diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ad198e627b..fb006424c6 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -503,12 +503,12 @@ def test_sse_message_id_coercion(): See for more details. """ json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}' - msg = types.JSONRPCMessage.model_validate_json(json_message) - assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id="123"))) + msg = types.JSONRPCRequest.model_validate_json(json_message) + assert msg == snapshot(types.JSONRPCRequest(method="ping", jsonrpc="2.0", id="123")) json_message = '{"jsonrpc": "2.0", "id": 123, "method": "ping", "params": null}' - msg = types.JSONRPCMessage.model_validate_json(json_message) - assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))) + msg = types.JSONRPCRequest.model_validate_json(json_message) + assert msg == snapshot(types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)) @pytest.mark.parametrize( @@ -601,5 +601,5 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: msg = await read_stream.receive() # If we get here without error, the empty message was skipped successfully assert not isinstance(msg, Exception) - assert isinstance(msg.message.root, types.JSONRPCResponse) - assert msg.message.root.id == 1 + assert isinstance(msg.message, types.JSONRPCResponse) + assert msg.message.id == 1 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 8838eb62b3..0c702dce26 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -49,7 +49,6 @@ from mcp.shared.session import RequestResponder from mcp.types import ( InitializeResult, - JSONRPCMessage, JSONRPCRequest, TextContent, TextResourceContents, @@ -1859,7 +1858,7 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() ) # Create a mock message and request - mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list")) + mock_message = JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list") mock_request = MagicMock() # Call _create_session_message with OLD protocol version From 8e567ec2e1d25a23b44fdaa147f2a24765881b99 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 18 Jan 2026 22:19:45 +0100 Subject: [PATCH 2/5] Drop RootModel from JSONRPCMessage --- src/mcp/client/sse.py | 6 ++--- src/mcp/client/stdio/__init__.py | 2 +- src/mcp/client/streamable_http.py | 42 +++++++++++++++---------------- src/mcp/client/websocket.py | 2 +- src/mcp/server/sse.py | 2 +- src/mcp/server/stdio.py | 2 +- src/mcp/server/streamable_http.py | 36 +++++++++++--------------- src/mcp/server/websocket.py | 2 +- src/mcp/types.py | 3 ++- tests/client/conftest.py | 18 ++++++------- tests/server/test_stdio.py | 4 +-- tests/test_types.py | 16 ++++++------ 12 files changed, 62 insertions(+), 73 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 13d5ecb1ea..47e5b845a9 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -73,9 +73,7 @@ async def sse_client( event_source.response.raise_for_status() logger.debug("SSE connection established") - async def sse_reader( - task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, - ): + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): try: async for sse in event_source.aiter_sse(): # pragma: no branch logger.debug(f"Received SSE event: {sse.event}") @@ -108,7 +106,7 @@ async def sse_reader( if not sse.data: continue try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + message = types.jsonrpc_message_adapter.validate_json( sse.data, by_name=False ) logger.debug(f"Received server message: {message}") diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 5ab541da88..19fdec5a38 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -150,7 +150,7 @@ async def stdout_reader(): for line in lines: try: - message = types.JSONRPCMessage.model_validate_json(line, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) except Exception as exc: # pragma: no cover logger.exception("Failed to parse JSONRPC message from server") await read_stream_writer.send(exc) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 27fa423388..555dd1290c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -25,6 +25,7 @@ JSONRPCRequest, JSONRPCResponse, RequestId, + jsonrpc_message_adapter, ) logger = logging.getLogger(__name__) @@ -95,11 +96,11 @@ def _prepare_headers(self) -> dict[str, str]: def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" - return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" + return isinstance(message, JSONRPCRequest) and message.method == "initialize" def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialized notification.""" - return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" + return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized" def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None: """Extract and store session ID from response headers.""" @@ -110,15 +111,15 @@ def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> N def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None: """Extract protocol version from initialization response message.""" - if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch + if isinstance(message, JSONRPCResponse) and message.result: # pragma: no branch try: # Parse the result as InitializeResult for type safety - init_result = InitializeResult.model_validate(message.root.result, by_name=False) + init_result = InitializeResult.model_validate(message.result, by_name=False) self.protocol_version = str(init_result.protocol_version) logger.info(f"Negotiated protocol version: {self.protocol_version}") except Exception: # pragma: no cover logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True) - logger.warning(f"Raw result: {message.root.result}") + logger.warning(f"Raw result: {message.result}") async def _handle_sse_event( self, @@ -137,7 +138,7 @@ async def _handle_sse_event( await resumption_callback(sse.id) return False try: - message = JSONRPCMessage.model_validate_json(sse.data, by_name=False) + message = jsonrpc_message_adapter.validate_json(sse.data, by_name=False) logger.debug(f"SSE message: {message}") # Extract protocol version from initialization response @@ -145,8 +146,8 @@ async def _handle_sse_event( self._maybe_extract_protocol_version_from_message(message) # If this is a response and we have original_request_id, replace it - if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): - message.root.id = original_request_id + if original_request_id is not None and isinstance(message, JSONRPCResponse | JSONRPCError): + message.id = original_request_id session_message = SessionMessage(message) await read_stream_writer.send(session_message) @@ -157,7 +158,7 @@ async def _handle_sse_event( # If this is a response or error return True indicating completion # Otherwise, return False to continue listening - return isinstance(message.root, JSONRPCResponse | JSONRPCError) + return isinstance(message, JSONRPCResponse | JSONRPCError) except Exception as exc: # pragma: no cover logger.exception("Error parsing SSE message") @@ -222,8 +223,8 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: # Extract original request ID to map responses original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch - original_request_id = ctx.session_message.message.root.id + if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch + original_request_id = ctx.session_message.message.id async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: event_source.response.raise_for_status() @@ -257,12 +258,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: return if response.status_code == 404: # pragma: no branch - if isinstance(message.root, JSONRPCRequest): - await self._send_session_terminated_error( # pragma: no cover - ctx.read_stream_writer, # pragma: no cover - message.root.id, # pragma: no cover - ) # pragma: no cover - return # pragma: no cover + if isinstance(message, JSONRPCRequest): # pragma: no branch + await self._send_session_terminated_error(ctx.read_stream_writer, message.id) + return response.raise_for_status() if is_initialization: @@ -270,7 +268,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: # The server MUST NOT send a response to notifications. - if isinstance(message.root, JSONRPCRequest): + if isinstance(message, JSONRPCRequest): content_type = response.headers.get("content-type", "").lower() if content_type.startswith("application/json"): await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) @@ -291,7 +289,7 @@ async def _handle_json_response( """Handle JSON response from the server.""" try: content = await response.aread() - message = JSONRPCMessage.model_validate_json(content, by_name=False) + message = jsonrpc_message_adapter.validate_json(content, by_name=False) # Extract protocol version from initialization response if is_initialization: @@ -365,8 +363,8 @@ async def _handle_reconnection( # Extract original request ID to map responses original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch - original_request_id = ctx.session_message.message.root.id + if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch + original_request_id = ctx.session_message.message.id try: async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: @@ -463,7 +461,7 @@ async def handle_request_async(): await self._handle_post_request(ctx) # If this is a request, start a new task to handle it - if isinstance(message.root, JSONRPCRequest): + if isinstance(message, JSONRPCRequest): tg.start_soon(handle_request_async) else: await handle_request_async() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 71860be00a..d9d0aa4975 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -51,7 +51,7 @@ async def ws_reader(): async with read_stream_writer: async for raw_text in ws: try: - message = types.JSONRPCMessage.model_validate_json(raw_text, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(raw_text, by_name=False) session_message = SessionMessage(message) await read_stream_writer.send(session_message) except ValidationError as exc: # pragma: no cover diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 46849eb82e..ea0c8db4a5 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -227,7 +227,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) logger.debug(f"Received JSON: {body}") try: - message = types.JSONRPCMessage.model_validate_json(body, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.exception("Failed to parse message") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index d494d075fa..531404f21b 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -60,7 +60,7 @@ async def stdin_reader(): async with read_stream_writer: async for line in stdin: try: - message = types.JSONRPCMessage.model_validate_json(line, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) except Exception as exc: # pragma: no cover await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 137a7da397..6b16b15549 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -42,6 +42,7 @@ JSONRPCRequest, JSONRPCResponse, RequestId, + jsonrpc_message_adapter, ) logger = logging.getLogger(__name__) @@ -301,10 +302,7 @@ def _create_error_response( error_response = JSONRPCError( jsonrpc="2.0", id="server-error", # We don't have a request ID for general errors - error=ErrorData( - code=error_code, - message=error_message, - ), + error=ErrorData(code=error_code, message=error_message), ) return Response( @@ -455,6 +453,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re body = await request.body() try: + # TODO(Marcelo): Replace `json.loads` with `pydantic_core.from_json`. raw_message = json.loads(body) except json.JSONDecodeError as e: response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) @@ -462,7 +461,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return try: # pragma: no cover - message = JSONRPCMessage.model_validate(raw_message, by_name=False) + message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) except ValidationError as e: # pragma: no cover response = self._create_error_response( f"Validation error: {str(e)}", @@ -473,9 +472,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return # Check if this is an initialization request - is_initialization_request = ( - isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" - ) # pragma: no cover + is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" if is_initialization_request: # pragma: no cover # Check if the server already has an established session @@ -495,7 +492,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return # For notifications and responses only, return 202 Accepted - if not isinstance(message.root, JSONRPCRequest): # pragma: no cover + if not isinstance(message, JSONRPCRequest): # pragma: no cover # Create response object and send it response = self._create_json_response( None, @@ -514,13 +511,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # For initialize requests, get from request params. # For other requests, get from header (already validated). protocol_version = ( - str(message.root.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION)) - if is_initialization_request and message.root.params + str(message.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION)) + if is_initialization_request and message.params else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) ) # Extract the request ID outside the try block for proper scope - request_id = str(message.root.id) # pragma: no cover + request_id = str(message.id) # pragma: no cover # Register this stream for the request ID self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover request_stream_reader = self._request_streams[request_id][1] # pragma: no cover @@ -538,12 +535,12 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Use similar approach to SSE writer for consistency async for event_message in request_stream_reader: # If it's a response, this is what we're waiting for - if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError): + if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): response_message = event_message.message break # For notifications and request, keep waiting else: - logger.debug(f"received: {event_message.message.root.method}") + logger.debug(f"received: {event_message.message.method}") # At this point we should have a response if response_message: @@ -589,10 +586,7 @@ async def sse_writer(): await sse_stream_writer.send(event_data) # If response, remove from pending streams and close - if isinstance( - event_message.message.root, - JSONRPCResponse | JSONRPCError, - ): + if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): break except anyio.ClosedResourceError: # Expected when close_sse_stream() is called @@ -984,8 +978,8 @@ async def message_router(): # pragma: no cover message = session_message.message target_request_id = None # Check if this is a response - if isinstance(message.root, JSONRPCResponse | JSONRPCError): - response_id = str(message.root.id) + if isinstance(message, JSONRPCResponse | JSONRPCError): + response_id = str(message.id) # If this response is for an existing request stream, # send it there target_request_id = response_id @@ -1022,7 +1016,7 @@ async def message_router(): # pragma: no cover self._request_streams.pop(request_stream_id, None) else: logger.debug( - f"""Request stream {request_stream_id} not found + f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 9dde5e016c..9df3e25c87 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -36,7 +36,7 @@ async def ws_reader(): async with read_stream_writer: async for msg in websocket.iter_text(): try: - client_message = types.JSONRPCMessage.model_validate_json(msg, by_name=False) + client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False) except ValidationError as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/types.py b/src/mcp/types.py index cad0c5f058..4c886680aa 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel +from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, TypeAdapter from pydantic.alias_generators import to_camel LATEST_PROTOCOL_VERSION = "2025-11-25" @@ -198,6 +198,7 @@ class JSONRPCError(MCPModel): JSONRPCMessage = JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError +jsonrpc_message_adapter = TypeAdapter[JSONRPCMessage](JSONRPCMessage) class EmptyResult(Result): diff --git a/tests/client/conftest.py b/tests/client/conftest.py index dfcad8215d..7314a37351 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -43,35 +43,33 @@ def clear(self) -> None: def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: # pragma: no cover """Get client-sent requests, optionally filtered by method.""" return [ - req.message.root + req.message for req in self.client.sent_messages - if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) + if isinstance(req.message, JSONRPCRequest) and (method is None or req.message.method == method) ] def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: # pragma: no cover """Get server-sent requests, optionally filtered by method.""" return [ # pragma: no cover - req.message.root + req.message for req in self.server.sent_messages - if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) + if isinstance(req.message, JSONRPCRequest) and (method is None or req.message.method == method) ] def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover """Get client-sent notifications, optionally filtered by method.""" return [ - notif.message.root + notif.message for notif in self.client.sent_messages - if isinstance(notif.message.root, JSONRPCNotification) - and (method is None or notif.message.root.method == method) + if isinstance(notif.message, JSONRPCNotification) and (method is None or notif.message.method == method) ] def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover """Get server-sent notifications, optionally filtered by method.""" return [ - notif.message.root + notif.message for notif in self.server.sent_messages - if isinstance(notif.message.root, JSONRPCNotification) - and (method is None or notif.message.root.method == method) + if isinstance(notif.message, JSONRPCNotification) and (method is None or notif.message.method == method) ] diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 71281eb9ac..9a7ddaab40 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -5,7 +5,7 @@ from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter @pytest.mark.anyio @@ -55,7 +55,7 @@ async def test_stdio_server(): output_lines = stdout.readlines() assert len(output_lines) == 2 - received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines] + received_responses = [jsonrpc_message_adapter.validate_json(line.strip()) for line in output_lines] assert len(received_responses) == 2 assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={}) diff --git a/tests/test_types.py b/tests/test_types.py index 7a9576c0be..454bac34b0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -12,7 +12,6 @@ Implementation, InitializeRequest, InitializeRequestParams, - JSONRPCMessage, JSONRPCRequest, ListToolsResult, SamplingCapability, @@ -22,6 +21,7 @@ ToolChoice, ToolResultContent, ToolUseContent, + jsonrpc_message_adapter, ) @@ -38,15 +38,15 @@ async def test_jsonrpc_request(): }, } - request = JSONRPCMessage.model_validate(json_data) - assert isinstance(request.root, JSONRPCRequest) + request = jsonrpc_message_adapter.validate_python(json_data) + assert isinstance(request, JSONRPCRequest) ClientRequest.model_validate(request.model_dump(by_alias=True, exclude_none=True)) - assert request.root.jsonrpc == "2.0" - assert request.root.id == 1 - assert request.root.method == "initialize" - assert request.root.params is not None - assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert request.jsonrpc == "2.0" + assert request.id == 1 + assert request.method == "initialize" + assert request.params is not None + assert request.params["protocolVersion"] == LATEST_PROTOCOL_VERSION @pytest.mark.anyio From cd4fcf27604680224f9b75f0454aba4bef285ee3 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 19 Jan 2026 11:38:48 +0100 Subject: [PATCH 3/5] Completely drop `RootModel` from `types` module --- pyproject.toml | 4 + src/mcp/client/experimental/task_handlers.py | 6 +- src/mcp/client/experimental/tasks.py | 38 ++-- src/mcp/client/session.py | 118 +++++------ src/mcp/server/auth/handlers/register.py | 11 +- src/mcp/server/auth/handlers/token.py | 44 ++-- .../server/experimental/session_features.py | 44 ++-- src/mcp/server/experimental/task_context.py | 21 +- .../server/fastmcp/utilities/func_metadata.py | 13 +- src/mcp/server/lowlevel/experimental.py | 46 ++--- src/mcp/server/lowlevel/server.py | 84 ++++---- src/mcp/server/session.py | 150 +++++++------- src/mcp/shared/session.py | 59 +++--- src/mcp/types.py | 44 ++-- tests/client/test_notification_response.py | 6 +- tests/client/test_resource_cleanup.py | 18 +- tests/client/test_session.py | 191 ++++++++---------- .../tasks/client/test_capabilities.py | 75 +++---- tests/experimental/tasks/client/test_tasks.py | 49 ++--- .../tasks/server/test_integration.py | 36 ++-- .../experimental/tasks/server/test_server.py | 85 +++----- tests/issues/test_129_resource_templates.py | 4 +- tests/issues/test_342_base64_encoding.py | 2 +- tests/server/fastmcp/test_integration.py | 16 +- tests/server/lowlevel/test_server_listing.py | 24 +-- tests/server/test_cancel_handling.py | 17 +- .../test_lowlevel_exception_handling.py | 2 +- tests/server/test_read_resource.py | 18 +- tests/server/test_session.py | 8 +- tests/server/test_session_race_condition.py | 22 +- tests/shared/test_progress_notifications.py | 8 +- tests/shared/test_session.py | 16 +- tests/shared/test_streamable_http.py | 79 +++----- tests/test_types.py | 4 +- 34 files changed, 572 insertions(+), 790 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4925e603db..febe4f38d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,6 +137,10 @@ select = [ ] ignore = ["PERF203", "PLC0415", "PLR0402"] +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead." + + [tool.ruff.lint.mccabe] max-complexity = 24 # Default is 10 diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py index 6b233cd072..d6cde09faf 100644 --- a/src/mcp/client/experimental/task_handlers.py +++ b/src/mcp/client/experimental/task_handlers.py @@ -242,7 +242,7 @@ def build_capability(self) -> types.ClientTasksCapability | None: def handles_request(request: types.ServerRequest) -> bool: """Check if this handler handles the given request type.""" return isinstance( - request.root, + request, types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest, ) @@ -259,7 +259,7 @@ async def handle_request( types.ClientResult | types.ErrorData ) - match responder.request.root: + match responder.request: case types.GetTaskRequest(params=params): response = await self.get_task(ctx, params) client_response = client_response_type.validate_python(response) @@ -281,7 +281,7 @@ async def handle_request( await responder.respond(client_response) case _: # pragma: no cover - raise ValueError(f"Unhandled request type: {type(responder.request.root)}") + raise ValueError(f"Unhandled request type: {type(responder.request)}") # Backwards compatibility aliases diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py index 1b38255495..2f890245c4 100644 --- a/src/mcp/client/experimental/tasks.py +++ b/src/mcp/client/experimental/tasks.py @@ -92,15 +92,13 @@ async def call_tool_as_task( _meta = types.RequestParams.Meta(**meta) return await self._session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - task=types.TaskMetadata(ttl=ttl), - _meta=_meta, - ), - ) + types.CallToolRequest( + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + task=types.TaskMetadata(ttl=ttl), + _meta=_meta, + ), ), types.CreateTaskResult, ) @@ -115,10 +113,8 @@ async def get_task(self, task_id: str) -> types.GetTaskResult: GetTaskResult containing the task status and metadata """ return await self._session.send_request( - types.ClientRequest( - types.GetTaskRequest( - params=types.GetTaskRequestParams(task_id=task_id), - ) + types.GetTaskRequest( + params=types.GetTaskRequestParams(task_id=task_id), ), types.GetTaskResult, ) @@ -142,10 +138,8 @@ async def get_task_result( The task result, validated against result_type """ return await self._session.send_request( - types.ClientRequest( - types.GetTaskPayloadRequest( - params=types.GetTaskPayloadRequestParams(task_id=task_id), - ) + types.GetTaskPayloadRequest( + params=types.GetTaskPayloadRequestParams(task_id=task_id), ), result_type, ) @@ -164,9 +158,7 @@ async def list_tasks( """ params = types.PaginatedRequestParams(cursor=cursor) if cursor else None return await self._session.send_request( - types.ClientRequest( - types.ListTasksRequest(params=params), - ), + types.ListTasksRequest(params=params), types.ListTasksResult, ) @@ -180,10 +172,8 @@ async def cancel_task(self, task_id: str) -> types.CancelTaskResult: CancelTaskResult with the updated task state """ return await self._session.send_request( - types.ClientRequest( - types.CancelTaskRequest( - params=types.CancelTaskRequestParams(task_id=task_id), - ) + types.CancelTaskRequest( + params=types.CancelTaskRequestParams(task_id=task_id), ), types.CancelTaskResult, ) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7aeee2cd8a..3f727441e0 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -122,13 +122,7 @@ def __init__( sampling_capabilities: types.SamplingCapability | None = None, experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: - super().__init__( - read_stream, - write_stream, - types.ServerRequest, - types.ServerNotification, - read_timeout_seconds=read_timeout_seconds, - ) + super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities @@ -143,6 +137,14 @@ def __init__( # Experimental: Task handlers (use defaults if not provided) self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() + @property + def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: + return types.server_request_adapter + + @property + def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]: + return types.server_notification_adapter + async def initialize(self) -> types.InitializeResult: sampling = ( (self._sampling_capabilities or types.SamplingCapability()) @@ -167,20 +169,18 @@ async def initialize(self) -> types.InitializeResult: ) result = await self.send_request( - types.ClientRequest( - types.InitializeRequest( - params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities( - sampling=sampling, - elicitation=elicitation, - experimental=None, - roots=roots, - tasks=self._task_handlers.build_capability(), - ), - client_info=self._client_info, + types.InitializeRequest( + params=types.InitializeRequestParams( + protocol_version=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( + sampling=sampling, + elicitation=elicitation, + experimental=None, + roots=roots, + tasks=self._task_handlers.build_capability(), ), - ) + client_info=self._client_info, + ), ), types.InitializeResult, ) @@ -190,7 +190,7 @@ async def initialize(self) -> types.InitializeResult: self._server_capabilities = result.capabilities - await self.send_notification(types.ClientNotification(types.InitializedNotification())) + await self.send_notification(types.InitializedNotification()) return result @@ -218,10 +218,7 @@ def experimental(self) -> ExperimentalClientFeatures: async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - return await self.send_request( - types.ClientRequest(types.PingRequest()), - types.EmptyResult, - ) + return await self.send_request(types.PingRequest(), types.EmptyResult) async def send_progress_notification( self, @@ -232,14 +229,12 @@ async def send_progress_notification( ) -> None: """Send a progress notification.""" await self.send_notification( - types.ClientNotification( - types.ProgressNotification( - params=types.ProgressNotificationParams( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - ), + types.ProgressNotification( + params=types.ProgressNotificationParams( + progress_token=progress_token, + progress=progress, + total=total, + message=message, ), ) ) @@ -247,11 +242,7 @@ async def send_progress_notification( async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: """Send a logging/setLevel request.""" return await self.send_request( # pragma: no cover - types.ClientRequest( - types.SetLevelRequest( - params=types.SetLevelRequestParams(level=level), - ) - ), + types.SetLevelRequest(params=types.SetLevelRequestParams(level=level)), types.EmptyResult, ) @@ -261,10 +252,7 @@ async def list_resources(self, *, params: types.PaginatedRequestParams | None = Args: params: Full pagination parameters including cursor and any future fields """ - return await self.send_request( - types.ClientRequest(types.ListResourcesRequest(params=params)), - types.ListResourcesResult, - ) + return await self.send_request(types.ListResourcesRequest(params=params), types.ListResourcesResult) async def list_resource_templates( self, *, params: types.PaginatedRequestParams | None = None @@ -275,28 +263,28 @@ async def list_resource_templates( params: Full pagination parameters including cursor and any future fields """ return await self.send_request( - types.ClientRequest(types.ListResourceTemplatesRequest(params=params)), + types.ListResourceTemplatesRequest(params=params), types.ListResourceTemplatesResult, ) async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( - types.ClientRequest(types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri)))), + types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri))), types.ReadResourceResult, ) async def subscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( # pragma: no cover - types.ClientRequest(types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri)))), + types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri))), types.EmptyResult, ) async def unsubscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( # pragma: no cover - types.ClientRequest(types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri)))), + types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri))), types.EmptyResult, ) @@ -316,10 +304,8 @@ async def call_tool( _meta = types.RequestParams.Meta(**meta) result = await self.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta), ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, @@ -364,17 +350,15 @@ async def list_prompts(self, *, params: types.PaginatedRequestParams | None = No params: Full pagination parameters including cursor and any future fields """ return await self.send_request( - types.ClientRequest(types.ListPromptsRequest(params=params)), + types.ListPromptsRequest(params=params), types.ListPromptsResult, ) async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( - types.ClientRequest( - types.GetPromptRequest( - params=types.GetPromptRequestParams(name=name, arguments=arguments), - ) + types.GetPromptRequest( + params=types.GetPromptRequestParams(name=name, arguments=arguments), ), types.GetPromptResult, ) @@ -391,14 +375,12 @@ async def complete( context = types.CompletionContext(arguments=context_arguments) return await self.send_request( - types.ClientRequest( - types.CompleteRequest( - params=types.CompleteRequestParams( - ref=ref, - argument=types.CompletionArgument(**argument), - context=context, - ), - ) + types.CompleteRequest( + params=types.CompleteRequestParams( + ref=ref, + argument=types.CompletionArgument(**argument), + context=context, + ), ), types.CompleteResult, ) @@ -410,7 +392,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None params: Full pagination parameters including cursor and any future fields """ result = await self.send_request( - types.ClientRequest(types.ListToolsRequest(params=params)), + types.ListToolsRequest(params=params), types.ListToolsResult, ) @@ -423,7 +405,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None async def send_roots_list_changed(self) -> None: # pragma: no cover """Send a roots/list_changed notification.""" - await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) + await self.send_notification(types.RootsListChangedNotification()) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession, Any]( @@ -440,7 +422,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques return None # Core request handling - match responder.request.root: + match responder.request: case types.CreateMessageRequest(params=params): with responder: # Check if this is a task-augmented request @@ -469,7 +451,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques case types.PingRequest(): # pragma: no cover with responder: - return await responder.respond(types.ClientResult(root=types.EmptyResult())) + return await responder.respond(types.EmptyResult()) case _: # pragma: no cover pass # Task requests handled above by _task_handlers @@ -486,7 +468,7 @@ async def _handle_incoming( async def _received_notification(self, notification: types.ServerNotification) -> None: """Handle notifications from the server.""" # Process specific notification types - match notification.root: + match notification: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) case types.ElicitCompleteNotification(params=params): diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 14d3a6aecc..28c1c261f9 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -4,7 +4,7 @@ from typing import Any from uuid import uuid4 -from pydantic import BaseModel, RootModel, ValidationError +from pydantic import BaseModel, ValidationError from starlette.requests import Request from starlette.responses import Response @@ -14,11 +14,9 @@ from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata - -class RegistrationRequest(RootModel[OAuthClientMetadata]): - # this wrapper is a no-op; it's just to separate out the types exposed to the - # provider from what we use in the HTTP handler - root: OAuthClientMetadata +# this alias is a no-op; it's just to separate out the types exposed to the +# provider from what we use in the HTTP handler +RegistrationRequest = OAuthClientMetadata class RegistrationErrorResponse(BaseModel): @@ -35,6 +33,7 @@ async def handle(self, request: Request) -> Response: # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: # Parse request body as JSON + # TODO(Marcelo): This is unnecessary. We should use `request.body()`. body = await request.json() client_metadata = OAuthClientMetadata.model_validate(body) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0d3c247c27..14f6f68720 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Annotated, Any, Literal -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, TypeAdapter, ValidationError from starlette.requests import Request from mcp.server.auth.errors import stringify_pydantic_error @@ -40,18 +40,8 @@ class RefreshTokenRequest(BaseModel): resource: str | None = Field(None, description="Resource indicator for the token") -class TokenRequest( - RootModel[ - Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, - Field(discriminator="grant_type"), - ] - ] -): - root: Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, - Field(discriminator="grant_type"), - ] +TokenRequest = Annotated[AuthorizationCodeRequest | RefreshTokenRequest, Field(discriminator="grant_type")] +token_request_adapter = TypeAdapter[TokenRequest](TokenRequest) class TokenErrorResponse(BaseModel): @@ -62,11 +52,10 @@ class TokenErrorResponse(BaseModel): error_uri: AnyHttpUrl | None = None -class TokenSuccessResponse(RootModel[OAuthToken]): - # this is just a wrapper over OAuthToken; the only reason we do this - # is to have some separation between the HTTP response type, and the - # type returned by the provider - root: OAuthToken +# this is just an alias over OAuthToken; the only reason we do this +# is to have some separation between the HTTP response type, and the +# type returned by the provider +TokenSuccessResponse = OAuthToken @dataclass @@ -107,7 +96,8 @@ async def handle(self, request: Request): try: form_data = await request.form() - token_request = TokenRequest.model_validate(dict(form_data)).root + # TODO(Marcelo): Can someone check if this `dict()` wrapper is necessary? + token_request = token_request_adapter.validate_python(dict(form_data)) except ValidationError as validation_error: # pragma: no cover return self.response( TokenErrorResponse( @@ -186,12 +176,7 @@ async def handle(self, request: Request): # Exchange authorization code for tokens tokens = await self.provider.exchange_authorization_code(client_info, auth_code) except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) + return self.response(TokenErrorResponse(error=e.error, error_description=e.error_description)) case RefreshTokenRequest(): # pragma: no cover refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) @@ -229,11 +214,6 @@ async def handle(self, request: Request): # Exchange refresh token for new tokens tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) + return self.response(TokenErrorResponse(error=e.error, error_description=e.error_description)) - return self.response(TokenSuccessResponse(root=tokens)) + return self.response(tokens) diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py index 2bccf66031..a189c3cbca 100644 --- a/src/mcp/server/experimental/session_features.py +++ b/src/mcp/server/experimental/session_features.py @@ -49,7 +49,7 @@ async def get_task(self, task_id: str) -> types.GetTaskResult: GetTaskResult containing the task status """ return await self._session.send_request( - types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id))), + types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), types.GetTaskResult, ) @@ -68,7 +68,7 @@ async def get_task_result( The task result, validated against result_type """ return await self._session.send_request( - types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id))), + types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id)), result_type, ) @@ -120,13 +120,11 @@ async def elicit_as_task( require_task_augmented_elicitation(client_caps) create_result = await self._session.send_request( - types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=types.TaskMetadata(ttl=ttl), - ) + types.ElicitRequest( + params=types.ElicitRequestFormParams( + message=message, + requested_schema=requested_schema, + task=types.TaskMetadata(ttl=ttl), ) ), types.CreateTaskResult, @@ -185,21 +183,19 @@ async def create_message_as_task( validate_tool_use_result_messages(messages) create_result = await self._session.send_request( - types.ServerRequest( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=types.TaskMetadata(ttl=ttl), - ) + types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=messages, + max_tokens=max_tokens, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + task=types.TaskMetadata(ttl=ttl), ) ), types.CreateTaskResult, diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index feb1df652f..871cefd9f5 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -39,7 +39,6 @@ Result, SamplingCapability, SamplingMessage, - ServerNotification, Task, TaskMetadata, TaskStatusNotification, @@ -156,17 +155,15 @@ async def _send_notification(self) -> None: """Send a task status notification to the client.""" task = self._ctx.task await self._session.send_notification( - ServerNotification( - TaskStatusNotification( - params=TaskStatusNotificationParams( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) + TaskStatusNotification( + params=TaskStatusNotificationParams( + task_id=task.task_id, + status=task.status, + status_message=task.status_message, + created_at=task.created_at, + last_updated_at=task.last_updated_at, + ttl=task.ttl, + poll_interval=task.poll_interval, ) ) ) diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index be2296594a..dd6b466b41 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -6,14 +6,7 @@ from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints import pydantic_core -from pydantic import ( - BaseModel, - ConfigDict, - Field, - RootModel, - WithJsonSchema, - create_model, -) +from pydantic import BaseModel, ConfigDict, Field, RootModel, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind from typing_extensions import is_typeddict @@ -495,9 +488,7 @@ class DictModel(RootModel[dict_annotation]): return DictModel -def _convert_to_content( - result: Any, -) -> Sequence[ContentBlock]: +def _convert_to_content(result: Any) -> Sequence[ContentBlock]: """Convert a result to a sequence of content objects. Note: This conversion logic comes from previous versions of FastMCP and is being diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 2c5addb6fc..49387daad7 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -141,16 +141,14 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: message=f"Task not found: {req.params.task_id}", ) ) - return ServerResult( - GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) + return GetTaskResult( + task_id=task.task_id, + status=task.status, + status_message=task.status_message, + created_at=task.created_at, + last_updated_at=task.last_updated_at, + ttl=task.ttl, + poll_interval=task.poll_interval, ) self._request_handlers[GetTaskRequest] = _default_get_task @@ -158,29 +156,29 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: # Register get_task_result handler if not already registered if GetTaskPayloadRequest not in self._request_handlers: - async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: + async def _default_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: ctx = self._server.request_context result = await support.handler.handle(req, ctx.session, ctx.request_id) - return ServerResult(result) + return result self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result # Register list_tasks handler if not already registered if ListTasksRequest not in self._request_handlers: - async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: + async def _default_list_tasks(req: ListTasksRequest) -> ListTasksResult: cursor = req.params.cursor if req.params else None tasks, next_cursor = await support.store.list_tasks(cursor) - return ServerResult(ListTasksResult(tasks=tasks, next_cursor=next_cursor)) + return ListTasksResult(tasks=tasks, next_cursor=next_cursor) self._request_handlers[ListTasksRequest] = _default_list_tasks # Register cancel_task handler if not already registered if CancelTaskRequest not in self._request_handlers: - async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult: + async def _default_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: result = await cancel_task(support.store, req.params.task_id) - return ServerResult(result) + return result self._request_handlers[CancelTaskRequest] = _default_cancel_task @@ -201,9 +199,9 @@ def decorator( logger.debug("Registering handler for ListTasksRequest") wrapper = create_call_wrapper(func, ListTasksRequest) - async def handler(req: ListTasksRequest) -> ServerResult: + async def handler(req: ListTasksRequest) -> ListTasksResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[ListTasksRequest] = handler return func @@ -226,9 +224,9 @@ def decorator( logger.debug("Registering handler for GetTaskRequest") wrapper = create_call_wrapper(func, GetTaskRequest) - async def handler(req: GetTaskRequest) -> ServerResult: + async def handler(req: GetTaskRequest) -> GetTaskResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[GetTaskRequest] = handler return func @@ -252,9 +250,9 @@ def decorator( logger.debug("Registering handler for GetTaskPayloadRequest") wrapper = create_call_wrapper(func, GetTaskPayloadRequest) - async def handler(req: GetTaskPayloadRequest) -> ServerResult: + async def handler(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[GetTaskPayloadRequest] = handler return func @@ -278,9 +276,9 @@ def decorator( logger.debug("Registering handler for CancelTaskRequest") wrapper = create_call_wrapper(func, CancelTaskRequest) - async def handler(req: CancelTaskRequest) -> ServerResult: + async def handler(req: CancelTaskRequest) -> CancelTaskResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[CancelTaskRequest] = handler return func diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 9d600a6b8e..cd92ce9d8a 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -271,10 +271,10 @@ async def handler(req: types.ListPromptsRequest): result = await wrapper(req) # Handle both old style (list[Prompt]) and new style (ListPromptsResult) if isinstance(result, types.ListPromptsResult): - return types.ServerResult(result) + return result else: # Old style returns list[Prompt] - return types.ServerResult(types.ListPromptsResult(prompts=result)) + return types.ListPromptsResult(prompts=result) self.request_handlers[types.ListPromptsRequest] = handler return func @@ -289,7 +289,7 @@ def decorator( async def handler(req: types.GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) - return types.ServerResult(prompt_get) + return prompt_get self.request_handlers[types.GetPromptRequest] = handler return func @@ -309,10 +309,10 @@ async def handler(req: types.ListResourcesRequest): result = await wrapper(req) # Handle both old style (list[Resource]) and new style (ListResourcesResult) if isinstance(result, types.ListResourcesResult): - return types.ServerResult(result) + return result else: # Old style returns list[Resource] - return types.ServerResult(types.ListResourcesResult(resources=result)) + return types.ListResourcesResult(resources=result) self.request_handlers[types.ListResourcesRequest] = handler return func @@ -325,7 +325,7 @@ def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): async def handler(_: Any): templates = await func() - return types.ServerResult(types.ListResourceTemplatesResult(resource_templates=templates)) + return types.ListResourceTemplatesResult(resource_templates=templates) self.request_handlers[types.ListResourceTemplatesRequest] = handler return func @@ -376,18 +376,12 @@ def create_content(data: str | bytes, mime_type: str | None, meta: dict[str, Any ) for content_item in contents ] - return types.ServerResult( - types.ReadResourceResult( - contents=contents_list, - ) - ) + return types.ReadResourceResult(contents=contents_list) case _: # pragma: no cover raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - return types.ServerResult( # pragma: no cover - types.ReadResourceResult( - contents=[content], - ) + return types.ReadResourceResult( # pragma: no cover + contents=[content], ) self.request_handlers[types.ReadResourceRequest] = handler @@ -401,7 +395,7 @@ def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): async def handler(req: types.SetLevelRequest): await func(req.params.level) - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() self.request_handlers[types.SetLevelRequest] = handler return func @@ -414,7 +408,7 @@ def decorator(func: Callable[[str], Awaitable[None]]): async def handler(req: types.SubscribeRequest): await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() self.request_handlers[types.SubscribeRequest] = handler return func @@ -427,7 +421,7 @@ def decorator(func: Callable[[str], Awaitable[None]]): async def handler(req: types.UnsubscribeRequest): await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() self.request_handlers[types.UnsubscribeRequest] = handler return func @@ -452,7 +446,7 @@ async def handler(req: types.ListToolsRequest): for tool in result.tools: validate_and_warn_tool_name(tool.name) self._tool_cache[tool.name] = tool - return types.ServerResult(result) + return result else: # Old style returns list[Tool] # Clear and refresh the entire tool cache @@ -460,20 +454,18 @@ async def handler(req: types.ListToolsRequest): for tool in result: validate_and_warn_tool_name(tool.name) self._tool_cache[tool.name] = tool - return types.ServerResult(types.ListToolsResult(tools=result)) + return types.ListToolsResult(tools=result) self.request_handlers[types.ListToolsRequest] = handler return func return decorator - def _make_error_result(self, error_message: str) -> types.ServerResult: - """Create a ServerResult with an error CallToolResult.""" - return types.ServerResult( - types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - is_error=True, - ) + def _make_error_result(self, error_message: str) -> types.CallToolResult: + """Create a CallToolResult with an error.""" + return types.CallToolResult( + content=[types.TextContent(type="text", text=error_message)], + is_error=True, ) async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: @@ -541,10 +533,10 @@ async def handler(req: types.CallToolRequest): unstructured_content: UnstructuredContent maybe_structured_content: StructuredContent | None if isinstance(results, types.CallToolResult): - return types.ServerResult(results) + return results elif isinstance(results, types.CreateTaskResult): # Task-augmented execution returns task info instead of result - return types.ServerResult(results) + return results elif isinstance(results, tuple) and len(results) == 2: # tool returned both structured and unstructured content unstructured_content, maybe_structured_content = cast(CombinationContent, results) @@ -572,12 +564,10 @@ async def handler(req: types.CallToolRequest): return self._make_error_result(f"Output validation error: {e.message}") # result - return types.ServerResult( - types.CallToolResult( - content=list(unstructured_content), - structured_content=maybe_structured_content, - is_error=False, - ) + return types.CallToolResult( + content=list(unstructured_content), + structured_content=maybe_structured_content, + is_error=False, ) except UrlElicitationRequiredError: # Re-raise UrlElicitationRequiredError so it can be properly handled @@ -627,12 +617,10 @@ def decorator( async def handler(req: types.CompleteRequest): completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.ServerResult( - types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, has_more=None), - ) + return types.CompleteResult( + completion=completion + if completion is not None + else types.Completion(values=[], total=None, has_more=None), ) self.request_handlers[types.CompleteRequest] = handler @@ -694,11 +682,11 @@ async def _handle_message( ): with warnings.catch_warnings(record=True) as w: match message: - case RequestResponder(request=types.ClientRequest(root=req)) as responder: + case RequestResponder() as responder: with responder: - await self._handle_request(message, req, session, lifespan_context, raise_exceptions) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) + await self._handle_request( + message, responder.request, session, lifespan_context, raise_exceptions + ) case Exception(): # pragma: no cover logger.error(f"Received exception from stream: {message}") await session.send_log_message( @@ -708,6 +696,8 @@ async def _handle_message( ) if raise_exceptions: raise message + case _: + await self._handle_notification(message) for warning in w: # pragma: no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) @@ -715,7 +705,7 @@ async def _handle_message( async def _handle_request( self, message: RequestResponder[types.ClientRequest, types.ServerResult], - req: types.ClientRequestType, + req: types.ClientRequest, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, @@ -803,4 +793,4 @@ async def _handle_notification(self, notify: Any): async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6f80615ff5..cc4973fc2f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -92,7 +92,7 @@ def __init__( init_options: InitializationOptions, stateless: bool = False, ) -> None: - super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) + super().__init__(read_stream, write_stream) self._stateless = stateless self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized @@ -104,6 +104,14 @@ def __init__( ](0) self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + @property + def _receive_request_adapter(self) -> types.TypeAdapter[types.ClientRequest]: + return types.client_request_adapter + + @property + def _receive_notification_adapter(self) -> types.TypeAdapter[types.ClientNotification]: + return types.client_notification_adapter + @property def client_params(self) -> types.InitializeRequestParams | None: return self._client_params # pragma: no cover @@ -162,29 +170,27 @@ async def _receive_loop(self) -> None: await super()._receive_loop() async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): - match responder.request.root: + match responder.request: case types.InitializeRequest(params=params): requested_version = params.protocol_version self._initialization_state = InitializationState.Initializing self._client_params = params with responder: await responder.respond( - types.ServerResult( - types.InitializeResult( - protocol_version=requested_version - if requested_version in SUPPORTED_PROTOCOL_VERSIONS - else types.LATEST_PROTOCOL_VERSION, - capabilities=self._init_options.capabilities, - server_info=types.Implementation( - name=self._init_options.server_name, - title=self._init_options.title, - description=self._init_options.description, - version=self._init_options.server_version, - website_url=self._init_options.website_url, - icons=self._init_options.icons, - ), - instructions=self._init_options.instructions, - ) + types.InitializeResult( + protocol_version=requested_version + if requested_version in SUPPORTED_PROTOCOL_VERSIONS + else types.LATEST_PROTOCOL_VERSION, + capabilities=self._init_options.capabilities, + server_info=types.Implementation( + name=self._init_options.server_name, + title=self._init_options.title, + description=self._init_options.description, + version=self._init_options.server_version, + website_url=self._init_options.website_url, + icons=self._init_options.icons, + ), + instructions=self._init_options.instructions, ) ) self._initialization_state = InitializationState.Initialized @@ -198,7 +204,7 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques async def _received_notification(self, notification: types.ClientNotification) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() - match notification.root: + match notification: case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: @@ -214,14 +220,12 @@ async def send_log_message( ) -> None: """Send a log message notification.""" await self.send_notification( - types.ServerNotification( - types.LoggingMessageNotification( - params=types.LoggingMessageNotificationParams( - level=level, - data=data, - logger=logger, - ), - ) + types.LoggingMessageNotification( + params=types.LoggingMessageNotificationParams( + level=level, + data=data, + logger=logger, + ), ), related_request_id, ) @@ -229,10 +233,8 @@ async def send_log_message( async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover """Send a resource updated notification.""" await self.send_notification( - types.ServerNotification( - types.ResourceUpdatedNotification( - params=types.ResourceUpdatedNotificationParams(uri=str(uri)), - ) + types.ResourceUpdatedNotification( + params=types.ResourceUpdatedNotificationParams(uri=str(uri)), ) ) @@ -322,21 +324,19 @@ async def create_message( validate_sampling_tools(client_caps, tools, tool_choice) validate_tool_use_result_messages(messages) - request = types.ServerRequest( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - max_tokens=max_tokens, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - ), - ) + request = types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ), ) metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) @@ -358,7 +358,7 @@ async def list_roots(self) -> types.ListRootsResult: if self._stateless: raise StatelessModeNotSupported(method="list_roots") return await self.send_request( - types.ServerRequest(types.ListRootsRequest()), + types.ListRootsRequest(), types.ListRootsResult, ) @@ -406,13 +406,11 @@ async def elicit_form( if self._stateless: raise StatelessModeNotSupported(method="elicitation") return await self.send_request( - types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - ), - ) + types.ElicitRequest( + params=types.ElicitRequestFormParams( + message=message, + requested_schema=requested_schema, + ), ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), @@ -445,14 +443,12 @@ async def elicit_url( if self._stateless: raise StatelessModeNotSupported(method="elicitation") return await self.send_request( - types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestURLParams( - message=message, - url=url, - elicitation_id=elicitation_id, - ), - ) + types.ElicitRequest( + params=types.ElicitRequestURLParams( + message=message, + url=url, + elicitation_id=elicitation_id, + ), ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), @@ -461,7 +457,7 @@ async def elicit_url( async def send_ping(self) -> types.EmptyResult: # pragma: no cover """Send a ping request.""" return await self.send_request( - types.ServerRequest(types.PingRequest()), + types.PingRequest(), types.EmptyResult, ) @@ -475,30 +471,28 @@ async def send_progress_notification( ) -> None: """Send a progress notification.""" await self.send_notification( - types.ServerNotification( - types.ProgressNotification( - params=types.ProgressNotificationParams( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - ), - ) + types.ProgressNotification( + params=types.ProgressNotificationParams( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ), ), related_request_id, ) async def send_resource_list_changed(self) -> None: # pragma: no cover """Send a resource list changed notification.""" - await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification())) + await self.send_notification(types.ResourceListChangedNotification()) async def send_tool_list_changed(self) -> None: # pragma: no cover """Send a tool list changed notification.""" - await self.send_notification(types.ServerNotification(types.ToolListChangedNotification())) + await self.send_notification(types.ToolListChangedNotification()) async def send_prompt_list_changed(self) -> None: # pragma: no cover """Send a prompt list changed notification.""" - await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) + await self.send_notification(types.PromptListChangedNotification()) async def send_elicit_complete( self, @@ -516,10 +510,8 @@ async def send_elicit_complete( related_request_id: Optional ID of the request that triggered this """ await self.send_notification( - types.ServerNotification( - types.ElicitCompleteNotification( - params=types.ElicitCompleteNotificationParams(elicitation_id=elicitation_id) - ) + types.ElicitCompleteNotification( + params=types.ElicitCompleteNotificationParams(elicitation_id=elicitation_id) ), related_request_id, ) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index be1990d618..c102200ed6 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -9,7 +9,7 @@ import anyio import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -179,8 +179,6 @@ def __init__( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], - receive_request_type: type[ReceiveRequestT], - receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: float | None = None, ) -> None: @@ -188,8 +186,6 @@ def __init__( self._write_stream = write_stream self._response_streams = {} self._request_id = 0 - self._receive_request_type = receive_request_type - self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} @@ -264,11 +260,7 @@ async def send_request( self._progress_callbacks[request_id] = progress_callback try: - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) @@ -339,26 +331,30 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er session_message = SessionMessage(message=jsonrpc_response) await self._write_stream.send(session_message) + @property + def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]: + """Each subclass must provide its own request adapter.""" + raise NotImplementedError + + @property + def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: + raise NotImplementedError + async def _receive_loop(self) -> None: - async with ( - self._read_stream, - self._write_stream, - ): + async with self._read_stream, self._write_stream: try: async for message in self._read_stream: if isinstance(message, Exception): # pragma: no cover await self._handle_incoming(message) elif isinstance(message.message, JSONRPCRequest): try: - validated_request = self._receive_request_type.model_validate( + validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) responder = RequestResponder( request_id=message.message.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, + request_meta=validated_request.params.meta if validated_request.params else None, request=validated_request, session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), @@ -369,10 +365,10 @@ async def _receive_loop(self) -> None: if not responder._completed: # type: ignore[reportPrivateUsage] await self._handle_incoming(responder) - except Exception as e: + except Exception: # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server - logging.warning(f"Failed to validate request: {e}") + logging.warning("Failed to validate request", exc_info=True) logging.debug(f"Message that failed validation: {message.message}") error_response = JSONRPCError( jsonrpc="2.0", @@ -388,34 +384,31 @@ async def _receive_loop(self) -> None: elif isinstance(message.message, JSONRPCNotification): try: - notification = self._receive_notification_type.model_validate( + notification = self._receive_notification_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.request_id + if isinstance(notification, CancelledNotification): + cancelled_id = notification.params.request_id if cancelled_id in self._in_flight: # pragma: no branch await self._in_flight[cancelled_id].cancel() else: # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): # pragma: no cover - progress_token = notification.root.params.progress_token + if isinstance(notification, ProgressNotification): # pragma: no cover + progress_token = notification.params.progress_token # If there is a progress callback for this token, # call it with the progress information if progress_token in self._progress_callbacks: callback = self._progress_callbacks[progress_token] try: await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) - except Exception as e: - logging.error( - "Progress callback raised an exception: %s", - e, + notification.params.progress, + notification.params.total, + notification.params.message, ) + except Exception: + logging.exception("Progress callback raised an exception") await self._received_notification(notification) await self._handle_incoming(notification) except Exception: # pragma: no cover diff --git a/src/mcp/types.py b/src/mcp/types.py index 4c886680aa..10b0c61fac 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, TypeAdapter +from pydantic import BaseModel, ConfigDict, Field, FileUrl, TypeAdapter from pydantic.alias_generators import to_camel LATEST_PROTOCOL_VERSION = "2025-11-25" @@ -1631,7 +1631,7 @@ class ElicitCompleteNotification( params: ElicitCompleteNotificationParams -ClientRequestType: TypeAlias = ( +ClientRequest = ( PingRequest | InitializeRequest | CompleteRequest @@ -1650,23 +1650,17 @@ class ElicitCompleteNotification( | ListTasksRequest | CancelTaskRequest ) +client_request_adapter = TypeAdapter[ClientRequest](ClientRequest) -class ClientRequest(RootModel[ClientRequestType]): - pass - - -ClientNotificationType: TypeAlias = ( +ClientNotification = ( CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification | TaskStatusNotification ) - - -class ClientNotification(RootModel[ClientNotificationType]): - pass +client_notification_adapter = TypeAdapter[ClientNotification](ClientNotification) # Type for elicitation schema - a JSON Schema dict @@ -1760,7 +1754,7 @@ class ElicitationRequiredErrorData(MCPModel): """List of URL mode elicitations that must be completed.""" -ClientResultType: TypeAlias = ( +ClientResult = ( EmptyResult | CreateMessageResult | CreateMessageResultWithTools @@ -1772,13 +1766,10 @@ class ElicitationRequiredErrorData(MCPModel): | CancelTaskResult | CreateTaskResult ) +client_result_adapter = TypeAdapter[ClientResult](ClientResult) -class ClientResult(RootModel[ClientResultType]): - pass - - -ServerRequestType: TypeAlias = ( +ServerRequest = ( PingRequest | CreateMessageRequest | ListRootsRequest @@ -1788,13 +1779,10 @@ class ClientResult(RootModel[ClientResultType]): | ListTasksRequest | CancelTaskRequest ) +server_request_adapter = TypeAdapter[ServerRequest](ServerRequest) -class ServerRequest(RootModel[ServerRequestType]): - pass - - -ServerNotificationType: TypeAlias = ( +ServerNotification = ( CancelledNotification | ProgressNotification | LoggingMessageNotification @@ -1805,13 +1793,10 @@ class ServerRequest(RootModel[ServerRequestType]): | ElicitCompleteNotification | TaskStatusNotification ) +server_notification_adapter = TypeAdapter[ServerNotification](ServerNotification) -class ServerNotification(RootModel[ServerNotificationType]): - pass - - -ServerResultType: TypeAlias = ( +ServerResult = ( EmptyResult | InitializeResult | CompleteResult @@ -1828,7 +1813,4 @@ class ServerNotification(RootModel[ServerNotificationType]): | CancelTaskResult | CreateTaskResult ) - - -class ServerResult(RootModel[ServerResultType]): - pass +server_result_adapter = TypeAdapter[ServerResult](ServerResult) diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index e05edb14dc..06d893ac68 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -19,7 +19,7 @@ from mcp import ClientSession, types from mcp.client.streamable_http import streamable_http_client from mcp.shared.session import RequestResponder -from mcp.types import ClientNotification, RootsListChangedNotification +from mcp.types import RootsListChangedNotification from tests.test_helpers import wait_for_server @@ -135,9 +135,7 @@ async def message_handler( # pragma: no cover await session.initialize() # The test server returns a 204 instead of the expected 202 - await session.send_notification( - ClientNotification(RootsListChangedNotification(method="notifications/roots/list_changed")) - ) + await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) if returned_exception: # pragma: no cover pytest.fail(f"Server encountered an exception: {returned_exception}") diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index f47299cf8b..c7bf8fafa4 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -3,6 +3,7 @@ import anyio import pytest +from pydantic import TypeAdapter from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, RequestId, SendResultT @@ -23,20 +24,23 @@ async def _send_response( ) -> None: # pragma: no cover pass + @property + def _receive_request_adapter(self) -> TypeAdapter[Any]: + return TypeAdapter(object) # pragma: no cover + + @property + def _receive_notification_adapter(self) -> TypeAdapter[Any]: + return TypeAdapter(object) # pragma: no cover + # Create streams write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) # Create the session - session = TestSession( - read_stream_receive, - write_stream_send, - object, # Request type doesn't matter for this test - object, # Notification type doesn't matter for this test - ) + session = TestSession(read_stream_receive, write_stream_send) # Create a test request - request = ClientRequest(PingRequest()) + request = PingRequest() # Patch the _write_stream.send method to raise an exception async def mock_send(*args: Any, **kwargs: Any): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 9512a0a7c4..5c1f55d238 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -12,8 +12,6 @@ from mcp.types import ( LATEST_PROTOCOL_VERSION, CallToolResult, - ClientNotification, - ClientRequest, Implementation, InitializedNotification, InitializeRequest, @@ -22,8 +20,9 @@ JSONRPCRequest, JSONRPCResponse, ServerCapabilities, - ServerResult, TextContent, + client_notification_adapter, + client_request_adapter, ) @@ -41,24 +40,22 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities( - logging=None, - resources=None, - tools=None, - experimental=None, - prompts=None, - ), - server_info=Implementation(name="mock-server", version="0.1.0"), - instructions="The server instructions.", - ) + assert isinstance(request, InitializeRequest) + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + logging=None, + resources=None, + tools=None, + experimental=None, + prompts=None, + ), + server_info=Implementation(name="mock-server", version="0.1.0"), + instructions="The server instructions.", ) async with server_to_client_send: @@ -74,7 +71,7 @@ async def mock_server(): session_notification = await client_to_server_receive.receive() jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification, JSONRPCNotification) - initialized_notification = ClientNotification.model_validate( + initialized_notification = client_notification_adapter.validate_python( jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -109,7 +106,7 @@ async def message_handler( # pragma: no cover # Check that the client sent the initialized notification assert initialized_notification - assert isinstance(initialized_notification.root, InitializedNotification) + assert isinstance(initialized_notification, InitializedNotification) @pytest.mark.anyio @@ -126,18 +123,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_client_info = request.root.params.client_info - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_client_info = request.params.client_info + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -185,18 +180,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_client_info = request.root.params.client_info - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_client_info = request.params.client_info + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -238,21 +231,19 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) # Verify client sent the latest protocol version - assert request.root.params.protocol_version == LATEST_PROTOCOL_VERSION + assert request.params.protocol_version == LATEST_PROTOCOL_VERSION # Server responds with a supported older version - result = ServerResult( - InitializeResult( - protocol_version="2024-11-05", - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version="2024-11-05", + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -295,18 +286,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) # Server responds with an unsupported version - result = ServerResult( - InitializeResult( - protocol_version="2020-01-01", # Unsupported old version - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version="2020-01-01", # Unsupported old version + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -349,18 +338,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -422,18 +409,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -503,18 +488,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -572,17 +555,15 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=expected_capabilities, - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=expected_capabilities, + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -639,17 +620,15 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) # Answer initialization request @@ -678,9 +657,7 @@ async def mock_server(): assert "_meta" in jsonrpc_request.params assert jsonrpc_request.params["_meta"] == meta - result = ServerResult( - CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False) - ) + result = CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False) # Send the tools/call result await server_to_client_send.send( diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index be35478016..7bb8066966 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -11,14 +11,13 @@ from mcp.shared.message import SessionMessage from mcp.types import ( LATEST_PROTOCOL_VERSION, - ClientRequest, Implementation, InitializeRequest, InitializeResult, JSONRPCRequest, JSONRPCResponse, ServerCapabilities, - ServerResult, + client_request_adapter, ) @@ -36,18 +35,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -108,18 +105,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -190,18 +185,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -268,18 +261,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index 3c19d82d0d..f21abf4d0f 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -23,7 +23,6 @@ CallToolResult, CancelTaskRequest, CancelTaskResult, - ClientRequest, ClientResult, CreateTaskResult, GetTaskPayloadRequest, @@ -134,13 +133,11 @@ async def run_server(app_context: AppContext): # Create a task create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, @@ -240,13 +237,11 @@ async def run_server(app_context: AppContext): # Create a task create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, @@ -343,13 +338,11 @@ async def run_server(app_context: AppContext): # Create two tasks for _ in range(2): create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, @@ -456,13 +449,11 @@ async def run_server(app_context: AppContext): # Create a task (but don't complete it) create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index db18ef3599..41cecc1295 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -30,7 +30,6 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - ClientRequest, ClientResult, CreateTaskResult, GetTaskPayloadRequest, @@ -190,14 +189,12 @@ async def run_server(app_context: AppContext): # === Step 1: Send task-augmented tool call === create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), + ), ), CreateTaskResult, ) @@ -210,7 +207,7 @@ async def run_server(app_context: AppContext): await app_context.task_done_events[task_id].wait() task_status = await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id=task_id))), + GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), GetTaskResult, ) @@ -219,7 +216,7 @@ async def run_server(app_context: AppContext): # === Step 3: Retrieve the actual result === task_result = await client_session.send_request( - ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id))), + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)), CallToolResult, ) @@ -327,14 +324,12 @@ async def run_server(app_context: AppContext): # Send task request create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), ), CreateTaskResult, ) @@ -346,8 +341,7 @@ async def run_server(app_context: AppContext): # Check that task was auto-failed task_status = await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id=task_id))), - GetTaskResult, + GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), GetTaskResult ) assert task_status.status == "failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 6b0bbfef3e..cb8b737f28 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -26,7 +26,6 @@ CancelTaskRequest, CancelTaskRequestParams, CancelTaskResult, - ClientRequest, ClientResult, ErrorData, GetTaskPayloadRequest, @@ -89,10 +88,10 @@ async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListTasksResult) - assert len(result.root.tasks) == 2 - assert result.root.tasks[0].task_id == "task-1" - assert result.root.tasks[1].task_id == "task-2" + assert isinstance(result, ListTasksResult) + assert len(result.tasks) == 2 + assert result.tasks[0].task_id == "task-1" + assert result.tasks[1].task_id == "task-2" @pytest.mark.anyio @@ -120,9 +119,9 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, GetTaskResult) - assert result.root.task_id == "test-task-123" - assert result.root.status == "working" + assert isinstance(result, GetTaskResult) + assert result.task_id == "test-task-123" + assert result.status == "working" @pytest.mark.anyio @@ -142,7 +141,7 @@ async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPaylo result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, GetTaskPayloadResult) + assert isinstance(result, GetTaskPayloadResult) @pytest.mark.anyio @@ -169,9 +168,9 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, CancelTaskResult) - assert result.root.task_id == "test-task-123" - assert result.root.status == "cancelled" + assert isinstance(result, CancelTaskResult) + assert result.task_id == "test-task-123" + assert result.status == "cancelled" @pytest.mark.anyio @@ -253,8 +252,8 @@ async def list_tools(): result = await tools_handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListToolsResult) - tools = result.root.tools + assert isinstance(result, ListToolsResult) + tools = result.tools assert tools[0].execution is not None assert tools[0].execution.task_support == TASK_FORBIDDEN @@ -330,14 +329,12 @@ async def handle_messages(): # Call tool with task metadata await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), ), CallToolResult, ) @@ -411,24 +408,14 @@ async def handle_messages(): # pragma: no cover # Call without task metadata await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), - ) - ), + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), CallToolResult, ) # Call with task metadata await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), ), CallToolResult, ) @@ -507,16 +494,13 @@ async def run_server() -> None: task = await store.create_task(TaskMetadata(ttl=60000)) # Test list_tasks (default handler) - list_result = await client_session.send_request( - ClientRequest(ListTasksRequest()), - ListTasksResult, - ) + list_result = await client_session.send_request(ListTasksRequest(), ListTasksResult) assert len(list_result.tasks) == 1 assert list_result.tasks[0].task_id == task.task_id # Test get_task (default handler - found) get_result = await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id))), + GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id)), GetTaskResult, ) assert get_result.task_id == task.task_id @@ -525,7 +509,7 @@ async def run_server() -> None: # Test get_task (default handler - not found path) with pytest.raises(McpError, match="not found"): await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task"))), + GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task")), GetTaskResult, ) @@ -538,9 +522,7 @@ async def run_server() -> None: # Test get_task_result (default handler) payload_result = await client_session.send_request( - ClientRequest( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=completed_task.task_id)) - ), + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=completed_task.task_id)), GetTaskPayloadResult, ) # The result should have the related-task metadata @@ -549,8 +531,7 @@ async def run_server() -> None: # Test cancel_task (default handler) cancel_result = await client_session.send_request( - ClientRequest(CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id))), - CancelTaskResult, + CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id)), CancelTaskResult ) assert cancel_result.task_id == task.task_id assert cancel_result.status == "cancelled" @@ -568,11 +549,7 @@ async def test_build_elicit_form_request() -> None: async with ServerSession( client_to_server_receive, server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), + InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), ) as server_session: # Test without task_id request = server_session._build_elicit_form_request( @@ -613,11 +590,7 @@ async def test_build_elicit_url_request() -> None: async with ServerSession( client_to_server_receive, server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), + InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), ) as server_session: # Test without related_task_id request = server_session._build_elicit_url_request( diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 1ebff7c92f..26b58343c3 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -26,8 +26,8 @@ def get_user_profile(user_id: str) -> str: # pragma: no cover result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest]( types.ListResourceTemplatesRequest(params=None) ) - assert isinstance(result.root, types.ListResourceTemplatesResult) - templates = result.root.resource_templates + assert isinstance(result, types.ListResourceTemplatesResult) + templates = result.resource_templates # Verify we get both templates back assert len(templates) == 2 diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 2554fbc735..44b17d3372 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -60,7 +60,7 @@ async def read_resource(uri: str) -> list[ReadResourceContents]: result: ServerResult = await handler(request) # After (fixed code): - read_result: ReadResourceResult = cast(ReadResourceResult, result.root) + read_result: ReadResourceResult = cast(ReadResourceResult, result) blob_content = read_result.contents[0] # First verify our test data actually produces different encodings diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 726843b92f..5f7caf7aca 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -77,14 +77,14 @@ async def handle_generic_notification( ) -> None: """Handle any server notification and route to appropriate handler.""" if isinstance(message, ServerNotification): # pragma: no branch - if isinstance(message.root, ProgressNotification): - self.progress_notifications.append(message.root.params) - elif isinstance(message.root, LoggingMessageNotification): - self.log_messages.append(message.root.params) - elif isinstance(message.root, ResourceListChangedNotification): - self.resource_notifications.append(message.root.params) - elif isinstance(message.root, ToolListChangedNotification): # pragma: no cover - self.tool_notifications.append(message.root.params) + if isinstance(message, ProgressNotification): + self.progress_notifications.append(message.params) + elif isinstance(message, LoggingMessageNotification): + self.log_messages.append(message.params) + elif isinstance(message, ResourceListChangedNotification): + self.resource_notifications.append(message.params) + elif isinstance(message, ToolListChangedNotification): # pragma: no cover + self.tool_notifications.append(message.params) # Common fixtures diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 38998b4b42..6bf4cddb39 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -41,8 +41,8 @@ async def handle_list_prompts() -> list[Prompt]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListPromptsResult) - assert result.root.prompts == test_prompts + assert isinstance(result, ListPromptsResult) + assert result.prompts == test_prompts @pytest.mark.anyio @@ -67,8 +67,8 @@ async def handle_list_resources() -> list[Resource]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListResourcesResult) - assert result.root.resources == test_resources + assert isinstance(result, ListResourcesResult) + assert result.resources == test_resources @pytest.mark.anyio @@ -114,8 +114,8 @@ async def handle_list_tools() -> list[Tool]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListToolsResult) - assert result.root.tools == test_tools + assert isinstance(result, ListToolsResult) + assert result.tools == test_tools @pytest.mark.anyio @@ -135,8 +135,8 @@ async def handle_list_prompts() -> list[Prompt]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListPromptsResult) - assert result.root.prompts == [] + assert isinstance(result, ListPromptsResult) + assert result.prompts == [] @pytest.mark.anyio @@ -156,8 +156,8 @@ async def handle_list_resources() -> list[Resource]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListResourcesResult) - assert result.root.resources == [] + assert isinstance(result, ListResourcesResult) + assert result.resources == [] @pytest.mark.anyio @@ -177,5 +177,5 @@ async def handle_list_tools() -> list[Tool]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListToolsResult) - assert result.root.tools == [] + assert isinstance(result, ListToolsResult) + assert result.tools == [] diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index aa8d42261e..98f34df465 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -15,8 +15,6 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, - ClientNotification, - ClientRequest, Tool, ) @@ -59,11 +57,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async def first_request(): try: await client.session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), - ) - ), + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), CallToolResult, ) pytest.fail("First request should have been cancelled") # pragma: no cover @@ -80,13 +74,8 @@ async def first_request(): # Cancel it assert first_request_id is not None await client.session.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams( - request_id=first_request_id, - reason="Testing server recovery", - ), - ) + CancelledNotification( + params=CancelledNotificationParams(request_id=first_request_id, reason="Testing server recovery"), ) ) diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 5d4c3347f6..4767ea1177 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -60,7 +60,7 @@ async def test_normal_message_handling_not_affected(): # Create a mock RequestResponder responder = Mock(spec=RequestResponder) - responder.request = types.ClientRequest(root=types.PingRequest(method="ping")) + responder.request = types.PingRequest(method="ping") responder.__enter__ = Mock(return_value=responder) responder.__exit__ = Mock(return_value=None) diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 0f62fe235a..10349846cc 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -39,10 +39,10 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: # Call the handler result = await handler(request) - assert isinstance(result.root, types.ReadResourceResult) - assert len(result.root.contents) == 1 + assert isinstance(result, types.ReadResourceResult) + assert len(result.contents) == 1 - content = result.root.contents[0] + content = result.contents[0] assert isinstance(content, types.TextResourceContents) assert content.text == "Hello World" assert content.mime_type == "text/plain" @@ -66,10 +66,10 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: # Call the handler result = await handler(request) - assert isinstance(result.root, types.ReadResourceResult) - assert len(result.root.contents) == 1 + assert isinstance(result, types.ReadResourceResult) + assert len(result.contents) == 1 - content = result.root.contents[0] + content = result.contents[0] assert isinstance(content, types.BlobResourceContents) assert content.mime_type == "application/octet-stream" @@ -97,10 +97,10 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: # Call the handler result = await handler(request) - assert isinstance(result.root, types.ReadResourceResult) - assert len(result.root.contents) == 1 + assert isinstance(result, types.ReadResourceResult) + assert len(result.contents) == 1 - content = result.root.contents[0] + content = result.contents[0] assert isinstance(content, types.TextResourceContents) assert content.text == "Hello World" assert content.mime_type == "text/plain" diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 3c1e96c126..5de9882223 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -60,7 +60,7 @@ async def run_server(): raise message if isinstance(message, ClientNotification) and isinstance( - message.root, InitializedNotification + message, InitializedNotification ): # pragma: no branch received_initialized = True return @@ -158,7 +158,7 @@ async def run_server(): raise message if isinstance(message, types.ClientNotification) and isinstance( - message.root, InitializedNotification + message, InitializedNotification ): # pragma: no branch received_initialized = True return @@ -236,11 +236,11 @@ async def run_server(): # We should receive a ping request before initialization if isinstance(message, RequestResponder) and isinstance( - message.request.root, types.PingRequest + message.request, types.PingRequest ): # pragma: no branch # Respond to the ping with message: - await message.respond(types.ServerResult(types.EmptyResult())) + await message.respond(types.EmptyResult()) return async def mock_client(): diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py index bc6145acaf..18c6b5fc6a 100644 --- a/tests/server/test_session_race_condition.py +++ b/tests/server/test_session_race_condition.py @@ -57,27 +57,25 @@ async def run_server(): # Handle tools/list request if isinstance(message, RequestResponder): - if isinstance(message.request.root, types.ListToolsRequest): # pragma: no branch + if isinstance(message.request, types.ListToolsRequest): # pragma: no branch tools_list_success = True # Respond with a tool list with message: await message.respond( - types.ServerResult( - types.ListToolsResult( - tools=[ - Tool( - name="example_tool", - description="An example tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) + types.ListToolsResult( + tools=[ + Tool( + name="example_tool", + description="An example tool", + input_schema={"type": "object", "properties": {}}, + ) + ] ) ) # Handle InitializedNotification if isinstance(message, types.ClientNotification): - if isinstance(message.root, types.InitializedNotification): # pragma: no branch + if isinstance(message, types.InitializedNotification): # pragma: no branch # Done - exit gracefully return diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 78896397b2..d65622822f 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -135,8 +135,8 @@ async def handle_client_message( raise message if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.ProgressNotification): # pragma: no branch - params = message.root.params + if isinstance(message, types.ProgressNotification): # pragma: no branch + params = message.params client_progress_updates.append( { "token": params.progress_token, @@ -332,7 +332,7 @@ async def test_progress_callback_exception_logging(): # Track logged warnings logged_errors: list[str] = [] - def mock_log_error(msg: str, *args: Any) -> None: + def mock_log_exception(msg: str, *args: Any, **kwargs: Any) -> None: logged_errors.append(msg % args if args else msg) # Create a progress callback that raises an exception @@ -368,7 +368,7 @@ async def handle_list_tools() -> list[types.Tool]: ] # Test with mocked logging - with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): + with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): async with Client(server) as client: # Call tool with a failing progress callback result = await client.call_tool( diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 77bec4aa33..89fe18ebbc 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -13,8 +13,6 @@ from mcp.types import ( CancelledNotification, CancelledNotificationParams, - ClientNotification, - ClientRequest, EmptyResult, ErrorData, JSONRPCError, @@ -73,10 +71,8 @@ async def make_request(client: Client): nonlocal ev_cancelled try: await client.session.send_request( - ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="slow_tool", arguments={}), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name="slow_tool", arguments={}), ), types.CallToolResult, ) @@ -97,11 +93,7 @@ async def make_request(client: Client): # Send cancellation notification assert request_id is not None await client.session.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams(request_id=request_id), - ) - ) + CancelledNotification(params=CancelledNotificationParams(request_id=request_id)) ) # Give cancellation time to process @@ -252,7 +244,7 @@ async def make_request(client_session: ClientSession): try: # Use a short timeout since we expect this to fail await client_session.send_request( - ClientRequest(types.PingRequest()), + types.PingRequest(), types.EmptyResult, request_read_timeout_seconds=0.5, ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 0c702dce26..ed86f9860e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1125,8 +1125,8 @@ async def message_handler( # pragma: no branch # Verify the notification is a ResourceUpdatedNotification resource_update_found = False for notif in notifications_received: - if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch - assert str(notif.root.params.uri) == "http://test_resource" + if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch + assert str(notif.params.uri) == "http://test_resource" resource_update_found = True assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" @@ -1167,10 +1167,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises( # pragma: no branch - McpError, - match="Session terminated", - ): + with pytest.raises(McpError, match="Session terminated"): # pragma: no branch await session.list_tools() @@ -1259,8 +1256,8 @@ async def message_handler( # pragma: no branch if isinstance(message, types.ServerNotification): # pragma: no branch captured_notifications.append(message) # Look for our first notification - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - if message.root.params.data == "First notification before lock": + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + if message.params.data == "First notification before lock": nonlocal first_notification_received first_notification_received = True @@ -1291,12 +1288,8 @@ async def run_tool(): on_resumption_token_update=on_resumption_token_update, ) await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams( - name="wait_for_lock_with_notification", arguments={} - ), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), ), types.CallToolResult, metadata=metadata, @@ -1313,8 +1306,8 @@ async def run_tool(): # Verify we received exactly one notification assert len(captured_notifications) == 1 # pragma: no cover - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: no cover + assert captured_notifications[0].params.data == "First notification before lock" # pragma: no cover # Clear notifications for the second phase captured_notifications = [] # pragma: no cover @@ -1334,11 +1327,7 @@ async def run_tool(): ): async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="release_lock", arguments={}), - ) - ), + types.CallToolRequest(params=types.CallToolRequestParams(name="release_lock", arguments={})), types.CallToolResult, ) metadata = ClientMessageMetadata( @@ -1346,10 +1335,8 @@ async def run_tool(): ) result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), ), types.CallToolResult, metadata=metadata, @@ -1361,8 +1348,8 @@ async def run_tool(): # We should have received the remaining notifications assert len(captured_notifications) == 1 - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "Second notification after lock" # pragma: no cover + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: no cover + assert captured_notifications[0].params.data == "Second notification after lock" # pragma: no cover @pytest.mark.anyio @@ -1905,11 +1892,7 @@ async def on_resumption_token_update(token: str) -> None: on_resumption_token_update=on_resumption_token_update, ) result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="test_tool", arguments={}), - ) - ), + types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), types.CallToolResult, metadata=metadata, ) @@ -1967,8 +1950,8 @@ async def message_handler( if isinstance(message, Exception): # pragma: no branch return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - captured_notifications.append(str(message.root.params.data)) + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + captured_notifications.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, @@ -2043,8 +2026,8 @@ async def message_handler( if isinstance(message, Exception): # pragma: no branch return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - all_notifications.append(str(message.root.params.data)) + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + all_notifications.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, @@ -2091,8 +2074,8 @@ async def message_handler( if isinstance(message, Exception): # pragma: no branch return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - notification_data.append(str(message.root.params.data)) + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + notification_data.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, @@ -2153,15 +2136,13 @@ async def on_resumption_token(token: str) -> None: # Use send_request with metadata to track resumption tokens metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), - ) + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="tool_with_multiple_stream_closes", + # retry_interval=500ms, so sleep 600ms to ensure reconnect completes + arguments={"checkpoints": 3, "sleep_time": 0.6}, + ), ), types.CallToolResult, metadata=metadata, @@ -2202,8 +2183,8 @@ async def message_handler( if isinstance(message, Exception): return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.ResourceUpdatedNotification): # pragma: no branch - received_notifications.append(str(message.root.params.uri)) + if isinstance(message, types.ResourceUpdatedNotification): # pragma: no branch + received_notifications.append(str(message.params.uri)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, diff --git a/tests/test_types.py b/tests/test_types.py index 454bac34b0..f424efdbf7 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -5,7 +5,6 @@ from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, - ClientRequest, CreateMessageRequestParams, CreateMessageResult, CreateMessageResultWithTools, @@ -21,6 +20,7 @@ ToolChoice, ToolResultContent, ToolUseContent, + client_request_adapter, jsonrpc_message_adapter, ) @@ -40,7 +40,7 @@ async def test_jsonrpc_request(): request = jsonrpc_message_adapter.validate_python(json_data) assert isinstance(request, JSONRPCRequest) - ClientRequest.model_validate(request.model_dump(by_alias=True, exclude_none=True)) + client_request_adapter.validate_python(request.model_dump(by_alias=True, exclude_none=True)) assert request.jsonrpc == "2.0" assert request.id == 1 From 197e264bfe65972f3d8bb4c1ebbe5b863c6a67f6 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 19 Jan 2026 14:19:01 +0100 Subject: [PATCH 4/5] add migration note --- docs/migration.md | 54 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/docs/migration.md b/docs/migration.md index eac51061cb..a3942ff15f 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -179,6 +179,60 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru **Note:** DNS rebinding protection is automatically enabled when `host` is `127.0.0.1`, `localhost`, or `::1`. This now happens in `sse_app()` and `streamable_http_app()` instead of the constructor. +### Replace `RootModel` by union types with `TypeAdapter` validation + +The following union types are no longer `RootModel` subclasses: + +- `ClientRequest` +- `ServerRequest` +- `ClientNotification` +- `ServerNotification` +- `ClientResult` +- `ServerResult` +- `JSONRPCMessage` + +This means you can no longer access `.root` on these types or use `model_validate()` directly on them. Instead, use the provided `TypeAdapter` instances for validation. + +**Before (v1):** + +```python +from mcp.types import ClientRequest, ServerNotification + +# Using RootModel.model_validate() +request = ClientRequest.model_validate(data) +actual_request = request.root # Accessing the wrapped value + +notification = ServerNotification.model_validate(data) +actual_notification = notification.root +``` + +**After (v2):** + +```python +from mcp.types import client_request_adapter, server_notification_adapter + +# Using TypeAdapter.validate_python() +request = client_request_adapter.validate_python(data) +# No .root access needed - request is the actual type + +notification = server_notification_adapter.validate_python(data) +# No .root access needed - notification is the actual type +``` + +**Available adapters:** + +| Union Type | Adapter | +|------------|---------| +| `ClientRequest` | `client_request_adapter` | +| `ServerRequest` | `server_request_adapter` | +| `ClientNotification` | `client_notification_adapter` | +| `ServerNotification` | `server_notification_adapter` | +| `ClientResult` | `client_result_adapter` | +| `ServerResult` | `server_result_adapter` | +| `JSONRPCMessage` | `jsonrpc_message_adapter` | + +All adapters are exported from `mcp.types`. + ### Resource URI type changed from `AnyUrl` to `str` The `uri` field on resource-related types now uses `str` instead of Pydantic's `AnyUrl`. This aligns with the [MCP specification schema](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.ts) which defines URIs as plain strings (`uri: string`) without strict URL validation. This change allows relative paths like `users/me` that were previously rejected. From 87326d68bd9e099f2102eb8e65da9bec1704d5ff Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 19 Jan 2026 14:22:57 +0100 Subject: [PATCH 5/5] Enable ruff for https://docs.astral.sh/ruff/rules/banned-api --- pyproject.toml | 19 ++++++++++--------- src/mcp/server/auth/handlers/authorize.py | 3 ++- .../server/fastmcp/utilities/func_metadata.py | 4 +++- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index febe4f38d0..2a9ad077e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,15 +125,16 @@ extend-exclude = ["README.md"] [tool.ruff.lint] select = [ - "C4", # flake8-comprehensions - "C90", # mccabe - "D212", # pydocstyle: multi-line docstring summary should start at the first line - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "PERF", # Perflint - "PL", # Pylint - "UP", # pyupgrade + "C4", # flake8-comprehensions + "C90", # mccabe + "D212", # pydocstyle: multi-line docstring summary should start at the first line + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "PERF", # Perflint + "PL", # Pylint + "UP", # pyupgrade + "TID251", # https://docs.astral.sh/ruff/rules/banned-api/ ] ignore = ["PERF203", "PLC0415", "PLR0402"] diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 3570d28c2a..dec6713b13 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from typing import Any, Literal -from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError +# TODO(Marcelo): We should drop the `RootModel`. +from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError # noqa: TID251 from starlette.datastructures import FormData, QueryParams from starlette.requests import Request from starlette.responses import RedirectResponse, Response diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index dd6b466b41..eda50cef3f 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -6,7 +6,7 @@ from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints import pydantic_core -from pydantic import BaseModel, ConfigDict, Field, RootModel, WithJsonSchema, create_model +from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind from typing_extensions import is_typeddict @@ -477,6 +477,8 @@ def _create_wrapped_model(func_name: str, annotation: Any) -> type[BaseModel]: def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]: """Create a RootModel for dict[str, T] types.""" + # TODO(Marcelo): We should not rely on RootModel for this. + from pydantic import RootModel # noqa: TID251 class DictModel(RootModel[dict_annotation]): pass