diff --git a/docs/migration.md b/docs/migration.md index eac51061c..a3942ff15 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -179,6 +179,60 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru **Note:** DNS rebinding protection is automatically enabled when `host` is `127.0.0.1`, `localhost`, or `::1`. This now happens in `sse_app()` and `streamable_http_app()` instead of the constructor. +### Replace `RootModel` by union types with `TypeAdapter` validation + +The following union types are no longer `RootModel` subclasses: + +- `ClientRequest` +- `ServerRequest` +- `ClientNotification` +- `ServerNotification` +- `ClientResult` +- `ServerResult` +- `JSONRPCMessage` + +This means you can no longer access `.root` on these types or use `model_validate()` directly on them. Instead, use the provided `TypeAdapter` instances for validation. + +**Before (v1):** + +```python +from mcp.types import ClientRequest, ServerNotification + +# Using RootModel.model_validate() +request = ClientRequest.model_validate(data) +actual_request = request.root # Accessing the wrapped value + +notification = ServerNotification.model_validate(data) +actual_notification = notification.root +``` + +**After (v2):** + +```python +from mcp.types import client_request_adapter, server_notification_adapter + +# Using TypeAdapter.validate_python() +request = client_request_adapter.validate_python(data) +# No .root access needed - request is the actual type + +notification = server_notification_adapter.validate_python(data) +# No .root access needed - notification is the actual type +``` + +**Available adapters:** + +| Union Type | Adapter | +|------------|---------| +| `ClientRequest` | `client_request_adapter` | +| `ServerRequest` | `server_request_adapter` | +| `ClientNotification` | `client_notification_adapter` | +| `ServerNotification` | `server_notification_adapter` | +| `ClientResult` | `client_result_adapter` | +| `ServerResult` | `server_result_adapter` | +| `JSONRPCMessage` | `jsonrpc_message_adapter` | + +All adapters are exported from `mcp.types`. + ### Resource URI type changed from `AnyUrl` to `str` The `uri` field on resource-related types now uses `str` instead of Pydantic's `AnyUrl`. This aligns with the [MCP specification schema](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.ts) which defines URIs as plain strings (`uri: string`) without strict URL validation. This change allows relative paths like `users/me` that were previously rejected. diff --git a/pyproject.toml b/pyproject.toml index 4925e603d..2a9ad077e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,18 +125,23 @@ extend-exclude = ["README.md"] [tool.ruff.lint] select = [ - "C4", # flake8-comprehensions - "C90", # mccabe - "D212", # pydocstyle: multi-line docstring summary should start at the first line - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "PERF", # Perflint - "PL", # Pylint - "UP", # pyupgrade + "C4", # flake8-comprehensions + "C90", # mccabe + "D212", # pydocstyle: multi-line docstring summary should start at the first line + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "PERF", # Perflint + "PL", # Pylint + "UP", # pyupgrade + "TID251", # https://docs.astral.sh/ruff/rules/banned-api/ ] ignore = ["PERF203", "PLC0415", "PLR0402"] +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead." + + [tool.ruff.lint.mccabe] max-complexity = 24 # Default is 10 diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py index 6b233cd07..d6cde09fa 100644 --- a/src/mcp/client/experimental/task_handlers.py +++ b/src/mcp/client/experimental/task_handlers.py @@ -242,7 +242,7 @@ def build_capability(self) -> types.ClientTasksCapability | None: def handles_request(request: types.ServerRequest) -> bool: """Check if this handler handles the given request type.""" return isinstance( - request.root, + request, types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest, ) @@ -259,7 +259,7 @@ async def handle_request( types.ClientResult | types.ErrorData ) - match responder.request.root: + match responder.request: case types.GetTaskRequest(params=params): response = await self.get_task(ctx, params) client_response = client_response_type.validate_python(response) @@ -281,7 +281,7 @@ async def handle_request( await responder.respond(client_response) case _: # pragma: no cover - raise ValueError(f"Unhandled request type: {type(responder.request.root)}") + raise ValueError(f"Unhandled request type: {type(responder.request)}") # Backwards compatibility aliases diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py index 1b3825549..2f890245c 100644 --- a/src/mcp/client/experimental/tasks.py +++ b/src/mcp/client/experimental/tasks.py @@ -92,15 +92,13 @@ async def call_tool_as_task( _meta = types.RequestParams.Meta(**meta) return await self._session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - task=types.TaskMetadata(ttl=ttl), - _meta=_meta, - ), - ) + types.CallToolRequest( + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + task=types.TaskMetadata(ttl=ttl), + _meta=_meta, + ), ), types.CreateTaskResult, ) @@ -115,10 +113,8 @@ async def get_task(self, task_id: str) -> types.GetTaskResult: GetTaskResult containing the task status and metadata """ return await self._session.send_request( - types.ClientRequest( - types.GetTaskRequest( - params=types.GetTaskRequestParams(task_id=task_id), - ) + types.GetTaskRequest( + params=types.GetTaskRequestParams(task_id=task_id), ), types.GetTaskResult, ) @@ -142,10 +138,8 @@ async def get_task_result( The task result, validated against result_type """ return await self._session.send_request( - types.ClientRequest( - types.GetTaskPayloadRequest( - params=types.GetTaskPayloadRequestParams(task_id=task_id), - ) + types.GetTaskPayloadRequest( + params=types.GetTaskPayloadRequestParams(task_id=task_id), ), result_type, ) @@ -164,9 +158,7 @@ async def list_tasks( """ params = types.PaginatedRequestParams(cursor=cursor) if cursor else None return await self._session.send_request( - types.ClientRequest( - types.ListTasksRequest(params=params), - ), + types.ListTasksRequest(params=params), types.ListTasksResult, ) @@ -180,10 +172,8 @@ async def cancel_task(self, task_id: str) -> types.CancelTaskResult: CancelTaskResult with the updated task state """ return await self._session.send_request( - types.ClientRequest( - types.CancelTaskRequest( - params=types.CancelTaskRequestParams(task_id=task_id), - ) + types.CancelTaskRequest( + params=types.CancelTaskRequestParams(task_id=task_id), ), types.CancelTaskResult, ) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7aeee2cd8..3f727441e 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -122,13 +122,7 @@ def __init__( sampling_capabilities: types.SamplingCapability | None = None, experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: - super().__init__( - read_stream, - write_stream, - types.ServerRequest, - types.ServerNotification, - read_timeout_seconds=read_timeout_seconds, - ) + super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities @@ -143,6 +137,14 @@ def __init__( # Experimental: Task handlers (use defaults if not provided) self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() + @property + def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: + return types.server_request_adapter + + @property + def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]: + return types.server_notification_adapter + async def initialize(self) -> types.InitializeResult: sampling = ( (self._sampling_capabilities or types.SamplingCapability()) @@ -167,20 +169,18 @@ async def initialize(self) -> types.InitializeResult: ) result = await self.send_request( - types.ClientRequest( - types.InitializeRequest( - params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities( - sampling=sampling, - elicitation=elicitation, - experimental=None, - roots=roots, - tasks=self._task_handlers.build_capability(), - ), - client_info=self._client_info, + types.InitializeRequest( + params=types.InitializeRequestParams( + protocol_version=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( + sampling=sampling, + elicitation=elicitation, + experimental=None, + roots=roots, + tasks=self._task_handlers.build_capability(), ), - ) + client_info=self._client_info, + ), ), types.InitializeResult, ) @@ -190,7 +190,7 @@ async def initialize(self) -> types.InitializeResult: self._server_capabilities = result.capabilities - await self.send_notification(types.ClientNotification(types.InitializedNotification())) + await self.send_notification(types.InitializedNotification()) return result @@ -218,10 +218,7 @@ def experimental(self) -> ExperimentalClientFeatures: async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - return await self.send_request( - types.ClientRequest(types.PingRequest()), - types.EmptyResult, - ) + return await self.send_request(types.PingRequest(), types.EmptyResult) async def send_progress_notification( self, @@ -232,14 +229,12 @@ async def send_progress_notification( ) -> None: """Send a progress notification.""" await self.send_notification( - types.ClientNotification( - types.ProgressNotification( - params=types.ProgressNotificationParams( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - ), + types.ProgressNotification( + params=types.ProgressNotificationParams( + progress_token=progress_token, + progress=progress, + total=total, + message=message, ), ) ) @@ -247,11 +242,7 @@ async def send_progress_notification( async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: """Send a logging/setLevel request.""" return await self.send_request( # pragma: no cover - types.ClientRequest( - types.SetLevelRequest( - params=types.SetLevelRequestParams(level=level), - ) - ), + types.SetLevelRequest(params=types.SetLevelRequestParams(level=level)), types.EmptyResult, ) @@ -261,10 +252,7 @@ async def list_resources(self, *, params: types.PaginatedRequestParams | None = Args: params: Full pagination parameters including cursor and any future fields """ - return await self.send_request( - types.ClientRequest(types.ListResourcesRequest(params=params)), - types.ListResourcesResult, - ) + return await self.send_request(types.ListResourcesRequest(params=params), types.ListResourcesResult) async def list_resource_templates( self, *, params: types.PaginatedRequestParams | None = None @@ -275,28 +263,28 @@ async def list_resource_templates( params: Full pagination parameters including cursor and any future fields """ return await self.send_request( - types.ClientRequest(types.ListResourceTemplatesRequest(params=params)), + types.ListResourceTemplatesRequest(params=params), types.ListResourceTemplatesResult, ) async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( - types.ClientRequest(types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri)))), + types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri))), types.ReadResourceResult, ) async def subscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( # pragma: no cover - types.ClientRequest(types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri)))), + types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri))), types.EmptyResult, ) async def unsubscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( # pragma: no cover - types.ClientRequest(types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri)))), + types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri))), types.EmptyResult, ) @@ -316,10 +304,8 @@ async def call_tool( _meta = types.RequestParams.Meta(**meta) result = await self.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta), ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, @@ -364,17 +350,15 @@ async def list_prompts(self, *, params: types.PaginatedRequestParams | None = No params: Full pagination parameters including cursor and any future fields """ return await self.send_request( - types.ClientRequest(types.ListPromptsRequest(params=params)), + types.ListPromptsRequest(params=params), types.ListPromptsResult, ) async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( - types.ClientRequest( - types.GetPromptRequest( - params=types.GetPromptRequestParams(name=name, arguments=arguments), - ) + types.GetPromptRequest( + params=types.GetPromptRequestParams(name=name, arguments=arguments), ), types.GetPromptResult, ) @@ -391,14 +375,12 @@ async def complete( context = types.CompletionContext(arguments=context_arguments) return await self.send_request( - types.ClientRequest( - types.CompleteRequest( - params=types.CompleteRequestParams( - ref=ref, - argument=types.CompletionArgument(**argument), - context=context, - ), - ) + types.CompleteRequest( + params=types.CompleteRequestParams( + ref=ref, + argument=types.CompletionArgument(**argument), + context=context, + ), ), types.CompleteResult, ) @@ -410,7 +392,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None params: Full pagination parameters including cursor and any future fields """ result = await self.send_request( - types.ClientRequest(types.ListToolsRequest(params=params)), + types.ListToolsRequest(params=params), types.ListToolsResult, ) @@ -423,7 +405,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None async def send_roots_list_changed(self) -> None: # pragma: no cover """Send a roots/list_changed notification.""" - await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) + await self.send_notification(types.RootsListChangedNotification()) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession, Any]( @@ -440,7 +422,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques return None # Core request handling - match responder.request.root: + match responder.request: case types.CreateMessageRequest(params=params): with responder: # Check if this is a task-augmented request @@ -469,7 +451,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques case types.PingRequest(): # pragma: no cover with responder: - return await responder.respond(types.ClientResult(root=types.EmptyResult())) + return await responder.respond(types.EmptyResult()) case _: # pragma: no cover pass # Task requests handled above by _task_handlers @@ -486,7 +468,7 @@ async def _handle_incoming( async def _received_notification(self, notification: types.ServerNotification) -> None: """Handle notifications from the server.""" # Process specific notification types - match notification.root: + match notification: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) case types.ElicitCompleteNotification(params=params): diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 3570d28c2..dec6713b1 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from typing import Any, Literal -from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError +# TODO(Marcelo): We should drop the `RootModel`. +from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError # noqa: TID251 from starlette.datastructures import FormData, QueryParams from starlette.requests import Request from starlette.responses import RedirectResponse, Response diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 14d3a6aec..28c1c261f 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -4,7 +4,7 @@ from typing import Any from uuid import uuid4 -from pydantic import BaseModel, RootModel, ValidationError +from pydantic import BaseModel, ValidationError from starlette.requests import Request from starlette.responses import Response @@ -14,11 +14,9 @@ from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata - -class RegistrationRequest(RootModel[OAuthClientMetadata]): - # this wrapper is a no-op; it's just to separate out the types exposed to the - # provider from what we use in the HTTP handler - root: OAuthClientMetadata +# this alias is a no-op; it's just to separate out the types exposed to the +# provider from what we use in the HTTP handler +RegistrationRequest = OAuthClientMetadata class RegistrationErrorResponse(BaseModel): @@ -35,6 +33,7 @@ async def handle(self, request: Request) -> Response: # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: # Parse request body as JSON + # TODO(Marcelo): This is unnecessary. We should use `request.body()`. body = await request.json() client_metadata = OAuthClientMetadata.model_validate(body) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0d3c247c2..14f6f6872 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Annotated, Any, Literal -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, TypeAdapter, ValidationError from starlette.requests import Request from mcp.server.auth.errors import stringify_pydantic_error @@ -40,18 +40,8 @@ class RefreshTokenRequest(BaseModel): resource: str | None = Field(None, description="Resource indicator for the token") -class TokenRequest( - RootModel[ - Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, - Field(discriminator="grant_type"), - ] - ] -): - root: Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, - Field(discriminator="grant_type"), - ] +TokenRequest = Annotated[AuthorizationCodeRequest | RefreshTokenRequest, Field(discriminator="grant_type")] +token_request_adapter = TypeAdapter[TokenRequest](TokenRequest) class TokenErrorResponse(BaseModel): @@ -62,11 +52,10 @@ class TokenErrorResponse(BaseModel): error_uri: AnyHttpUrl | None = None -class TokenSuccessResponse(RootModel[OAuthToken]): - # this is just a wrapper over OAuthToken; the only reason we do this - # is to have some separation between the HTTP response type, and the - # type returned by the provider - root: OAuthToken +# this is just an alias over OAuthToken; the only reason we do this +# is to have some separation between the HTTP response type, and the +# type returned by the provider +TokenSuccessResponse = OAuthToken @dataclass @@ -107,7 +96,8 @@ async def handle(self, request: Request): try: form_data = await request.form() - token_request = TokenRequest.model_validate(dict(form_data)).root + # TODO(Marcelo): Can someone check if this `dict()` wrapper is necessary? + token_request = token_request_adapter.validate_python(dict(form_data)) except ValidationError as validation_error: # pragma: no cover return self.response( TokenErrorResponse( @@ -186,12 +176,7 @@ async def handle(self, request: Request): # Exchange authorization code for tokens tokens = await self.provider.exchange_authorization_code(client_info, auth_code) except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) + return self.response(TokenErrorResponse(error=e.error, error_description=e.error_description)) case RefreshTokenRequest(): # pragma: no cover refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) @@ -229,11 +214,6 @@ async def handle(self, request: Request): # Exchange refresh token for new tokens tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) + return self.response(TokenErrorResponse(error=e.error, error_description=e.error_description)) - return self.response(TokenSuccessResponse(root=tokens)) + return self.response(tokens) diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py index 2bccf6603..a189c3cbc 100644 --- a/src/mcp/server/experimental/session_features.py +++ b/src/mcp/server/experimental/session_features.py @@ -49,7 +49,7 @@ async def get_task(self, task_id: str) -> types.GetTaskResult: GetTaskResult containing the task status """ return await self._session.send_request( - types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id))), + types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), types.GetTaskResult, ) @@ -68,7 +68,7 @@ async def get_task_result( The task result, validated against result_type """ return await self._session.send_request( - types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id))), + types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id)), result_type, ) @@ -120,13 +120,11 @@ async def elicit_as_task( require_task_augmented_elicitation(client_caps) create_result = await self._session.send_request( - types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=types.TaskMetadata(ttl=ttl), - ) + types.ElicitRequest( + params=types.ElicitRequestFormParams( + message=message, + requested_schema=requested_schema, + task=types.TaskMetadata(ttl=ttl), ) ), types.CreateTaskResult, @@ -185,21 +183,19 @@ async def create_message_as_task( validate_tool_use_result_messages(messages) create_result = await self._session.send_request( - types.ServerRequest( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=types.TaskMetadata(ttl=ttl), - ) + types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=messages, + max_tokens=max_tokens, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + task=types.TaskMetadata(ttl=ttl), ) ), types.CreateTaskResult, diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index feb1df652..871cefd9f 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -39,7 +39,6 @@ Result, SamplingCapability, SamplingMessage, - ServerNotification, Task, TaskMetadata, TaskStatusNotification, @@ -156,17 +155,15 @@ async def _send_notification(self) -> None: """Send a task status notification to the client.""" task = self._ctx.task await self._session.send_notification( - ServerNotification( - TaskStatusNotification( - params=TaskStatusNotificationParams( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) + TaskStatusNotification( + params=TaskStatusNotificationParams( + task_id=task.task_id, + status=task.status, + status_message=task.status_message, + created_at=task.created_at, + last_updated_at=task.last_updated_at, + ttl=task.ttl, + poll_interval=task.poll_interval, ) ) ) diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index be2296594..eda50cef3 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -6,14 +6,7 @@ from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints import pydantic_core -from pydantic import ( - BaseModel, - ConfigDict, - Field, - RootModel, - WithJsonSchema, - create_model, -) +from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind from typing_extensions import is_typeddict @@ -484,6 +477,8 @@ def _create_wrapped_model(func_name: str, annotation: Any) -> type[BaseModel]: def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]: """Create a RootModel for dict[str, T] types.""" + # TODO(Marcelo): We should not rely on RootModel for this. + from pydantic import RootModel # noqa: TID251 class DictModel(RootModel[dict_annotation]): pass @@ -495,9 +490,7 @@ class DictModel(RootModel[dict_annotation]): return DictModel -def _convert_to_content( - result: Any, -) -> Sequence[ContentBlock]: +def _convert_to_content(result: Any) -> Sequence[ContentBlock]: """Convert a result to a sequence of content objects. Note: This conversion logic comes from previous versions of FastMCP and is being diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 2c5addb6f..49387daad 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -141,16 +141,14 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: message=f"Task not found: {req.params.task_id}", ) ) - return ServerResult( - GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) + return GetTaskResult( + task_id=task.task_id, + status=task.status, + status_message=task.status_message, + created_at=task.created_at, + last_updated_at=task.last_updated_at, + ttl=task.ttl, + poll_interval=task.poll_interval, ) self._request_handlers[GetTaskRequest] = _default_get_task @@ -158,29 +156,29 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: # Register get_task_result handler if not already registered if GetTaskPayloadRequest not in self._request_handlers: - async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: + async def _default_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: ctx = self._server.request_context result = await support.handler.handle(req, ctx.session, ctx.request_id) - return ServerResult(result) + return result self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result # Register list_tasks handler if not already registered if ListTasksRequest not in self._request_handlers: - async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: + async def _default_list_tasks(req: ListTasksRequest) -> ListTasksResult: cursor = req.params.cursor if req.params else None tasks, next_cursor = await support.store.list_tasks(cursor) - return ServerResult(ListTasksResult(tasks=tasks, next_cursor=next_cursor)) + return ListTasksResult(tasks=tasks, next_cursor=next_cursor) self._request_handlers[ListTasksRequest] = _default_list_tasks # Register cancel_task handler if not already registered if CancelTaskRequest not in self._request_handlers: - async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult: + async def _default_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: result = await cancel_task(support.store, req.params.task_id) - return ServerResult(result) + return result self._request_handlers[CancelTaskRequest] = _default_cancel_task @@ -201,9 +199,9 @@ def decorator( logger.debug("Registering handler for ListTasksRequest") wrapper = create_call_wrapper(func, ListTasksRequest) - async def handler(req: ListTasksRequest) -> ServerResult: + async def handler(req: ListTasksRequest) -> ListTasksResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[ListTasksRequest] = handler return func @@ -226,9 +224,9 @@ def decorator( logger.debug("Registering handler for GetTaskRequest") wrapper = create_call_wrapper(func, GetTaskRequest) - async def handler(req: GetTaskRequest) -> ServerResult: + async def handler(req: GetTaskRequest) -> GetTaskResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[GetTaskRequest] = handler return func @@ -252,9 +250,9 @@ def decorator( logger.debug("Registering handler for GetTaskPayloadRequest") wrapper = create_call_wrapper(func, GetTaskPayloadRequest) - async def handler(req: GetTaskPayloadRequest) -> ServerResult: + async def handler(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[GetTaskPayloadRequest] = handler return func @@ -278,9 +276,9 @@ def decorator( logger.debug("Registering handler for CancelTaskRequest") wrapper = create_call_wrapper(func, CancelTaskRequest) - async def handler(req: CancelTaskRequest) -> ServerResult: + async def handler(req: CancelTaskRequest) -> CancelTaskResult: result = await wrapper(req) - return ServerResult(result) + return result self._request_handlers[CancelTaskRequest] = handler return func diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 9d600a6b8..cd92ce9d8 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -271,10 +271,10 @@ async def handler(req: types.ListPromptsRequest): result = await wrapper(req) # Handle both old style (list[Prompt]) and new style (ListPromptsResult) if isinstance(result, types.ListPromptsResult): - return types.ServerResult(result) + return result else: # Old style returns list[Prompt] - return types.ServerResult(types.ListPromptsResult(prompts=result)) + return types.ListPromptsResult(prompts=result) self.request_handlers[types.ListPromptsRequest] = handler return func @@ -289,7 +289,7 @@ def decorator( async def handler(req: types.GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) - return types.ServerResult(prompt_get) + return prompt_get self.request_handlers[types.GetPromptRequest] = handler return func @@ -309,10 +309,10 @@ async def handler(req: types.ListResourcesRequest): result = await wrapper(req) # Handle both old style (list[Resource]) and new style (ListResourcesResult) if isinstance(result, types.ListResourcesResult): - return types.ServerResult(result) + return result else: # Old style returns list[Resource] - return types.ServerResult(types.ListResourcesResult(resources=result)) + return types.ListResourcesResult(resources=result) self.request_handlers[types.ListResourcesRequest] = handler return func @@ -325,7 +325,7 @@ def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): async def handler(_: Any): templates = await func() - return types.ServerResult(types.ListResourceTemplatesResult(resource_templates=templates)) + return types.ListResourceTemplatesResult(resource_templates=templates) self.request_handlers[types.ListResourceTemplatesRequest] = handler return func @@ -376,18 +376,12 @@ def create_content(data: str | bytes, mime_type: str | None, meta: dict[str, Any ) for content_item in contents ] - return types.ServerResult( - types.ReadResourceResult( - contents=contents_list, - ) - ) + return types.ReadResourceResult(contents=contents_list) case _: # pragma: no cover raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - return types.ServerResult( # pragma: no cover - types.ReadResourceResult( - contents=[content], - ) + return types.ReadResourceResult( # pragma: no cover + contents=[content], ) self.request_handlers[types.ReadResourceRequest] = handler @@ -401,7 +395,7 @@ def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): async def handler(req: types.SetLevelRequest): await func(req.params.level) - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() self.request_handlers[types.SetLevelRequest] = handler return func @@ -414,7 +408,7 @@ def decorator(func: Callable[[str], Awaitable[None]]): async def handler(req: types.SubscribeRequest): await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() self.request_handlers[types.SubscribeRequest] = handler return func @@ -427,7 +421,7 @@ def decorator(func: Callable[[str], Awaitable[None]]): async def handler(req: types.UnsubscribeRequest): await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() self.request_handlers[types.UnsubscribeRequest] = handler return func @@ -452,7 +446,7 @@ async def handler(req: types.ListToolsRequest): for tool in result.tools: validate_and_warn_tool_name(tool.name) self._tool_cache[tool.name] = tool - return types.ServerResult(result) + return result else: # Old style returns list[Tool] # Clear and refresh the entire tool cache @@ -460,20 +454,18 @@ async def handler(req: types.ListToolsRequest): for tool in result: validate_and_warn_tool_name(tool.name) self._tool_cache[tool.name] = tool - return types.ServerResult(types.ListToolsResult(tools=result)) + return types.ListToolsResult(tools=result) self.request_handlers[types.ListToolsRequest] = handler return func return decorator - def _make_error_result(self, error_message: str) -> types.ServerResult: - """Create a ServerResult with an error CallToolResult.""" - return types.ServerResult( - types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - is_error=True, - ) + def _make_error_result(self, error_message: str) -> types.CallToolResult: + """Create a CallToolResult with an error.""" + return types.CallToolResult( + content=[types.TextContent(type="text", text=error_message)], + is_error=True, ) async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: @@ -541,10 +533,10 @@ async def handler(req: types.CallToolRequest): unstructured_content: UnstructuredContent maybe_structured_content: StructuredContent | None if isinstance(results, types.CallToolResult): - return types.ServerResult(results) + return results elif isinstance(results, types.CreateTaskResult): # Task-augmented execution returns task info instead of result - return types.ServerResult(results) + return results elif isinstance(results, tuple) and len(results) == 2: # tool returned both structured and unstructured content unstructured_content, maybe_structured_content = cast(CombinationContent, results) @@ -572,12 +564,10 @@ async def handler(req: types.CallToolRequest): return self._make_error_result(f"Output validation error: {e.message}") # result - return types.ServerResult( - types.CallToolResult( - content=list(unstructured_content), - structured_content=maybe_structured_content, - is_error=False, - ) + return types.CallToolResult( + content=list(unstructured_content), + structured_content=maybe_structured_content, + is_error=False, ) except UrlElicitationRequiredError: # Re-raise UrlElicitationRequiredError so it can be properly handled @@ -627,12 +617,10 @@ def decorator( async def handler(req: types.CompleteRequest): completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.ServerResult( - types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, has_more=None), - ) + return types.CompleteResult( + completion=completion + if completion is not None + else types.Completion(values=[], total=None, has_more=None), ) self.request_handlers[types.CompleteRequest] = handler @@ -694,11 +682,11 @@ async def _handle_message( ): with warnings.catch_warnings(record=True) as w: match message: - case RequestResponder(request=types.ClientRequest(root=req)) as responder: + case RequestResponder() as responder: with responder: - await self._handle_request(message, req, session, lifespan_context, raise_exceptions) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) + await self._handle_request( + message, responder.request, session, lifespan_context, raise_exceptions + ) case Exception(): # pragma: no cover logger.error(f"Received exception from stream: {message}") await session.send_log_message( @@ -708,6 +696,8 @@ async def _handle_message( ) if raise_exceptions: raise message + case _: + await self._handle_notification(message) for warning in w: # pragma: no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) @@ -715,7 +705,7 @@ async def _handle_message( async def _handle_request( self, message: RequestResponder[types.ClientRequest, types.ServerResult], - req: types.ClientRequestType, + req: types.ClientRequest, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, @@ -803,4 +793,4 @@ async def _handle_notification(self, notify: Any): async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.ServerResult(types.EmptyResult()) + return types.EmptyResult() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6f80615ff..cc4973fc2 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -92,7 +92,7 @@ def __init__( init_options: InitializationOptions, stateless: bool = False, ) -> None: - super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) + super().__init__(read_stream, write_stream) self._stateless = stateless self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized @@ -104,6 +104,14 @@ def __init__( ](0) self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + @property + def _receive_request_adapter(self) -> types.TypeAdapter[types.ClientRequest]: + return types.client_request_adapter + + @property + def _receive_notification_adapter(self) -> types.TypeAdapter[types.ClientNotification]: + return types.client_notification_adapter + @property def client_params(self) -> types.InitializeRequestParams | None: return self._client_params # pragma: no cover @@ -162,29 +170,27 @@ async def _receive_loop(self) -> None: await super()._receive_loop() async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): - match responder.request.root: + match responder.request: case types.InitializeRequest(params=params): requested_version = params.protocol_version self._initialization_state = InitializationState.Initializing self._client_params = params with responder: await responder.respond( - types.ServerResult( - types.InitializeResult( - protocol_version=requested_version - if requested_version in SUPPORTED_PROTOCOL_VERSIONS - else types.LATEST_PROTOCOL_VERSION, - capabilities=self._init_options.capabilities, - server_info=types.Implementation( - name=self._init_options.server_name, - title=self._init_options.title, - description=self._init_options.description, - version=self._init_options.server_version, - website_url=self._init_options.website_url, - icons=self._init_options.icons, - ), - instructions=self._init_options.instructions, - ) + types.InitializeResult( + protocol_version=requested_version + if requested_version in SUPPORTED_PROTOCOL_VERSIONS + else types.LATEST_PROTOCOL_VERSION, + capabilities=self._init_options.capabilities, + server_info=types.Implementation( + name=self._init_options.server_name, + title=self._init_options.title, + description=self._init_options.description, + version=self._init_options.server_version, + website_url=self._init_options.website_url, + icons=self._init_options.icons, + ), + instructions=self._init_options.instructions, ) ) self._initialization_state = InitializationState.Initialized @@ -198,7 +204,7 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques async def _received_notification(self, notification: types.ClientNotification) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() - match notification.root: + match notification: case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: @@ -214,14 +220,12 @@ async def send_log_message( ) -> None: """Send a log message notification.""" await self.send_notification( - types.ServerNotification( - types.LoggingMessageNotification( - params=types.LoggingMessageNotificationParams( - level=level, - data=data, - logger=logger, - ), - ) + types.LoggingMessageNotification( + params=types.LoggingMessageNotificationParams( + level=level, + data=data, + logger=logger, + ), ), related_request_id, ) @@ -229,10 +233,8 @@ async def send_log_message( async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover """Send a resource updated notification.""" await self.send_notification( - types.ServerNotification( - types.ResourceUpdatedNotification( - params=types.ResourceUpdatedNotificationParams(uri=str(uri)), - ) + types.ResourceUpdatedNotification( + params=types.ResourceUpdatedNotificationParams(uri=str(uri)), ) ) @@ -322,21 +324,19 @@ async def create_message( validate_sampling_tools(client_caps, tools, tool_choice) validate_tool_use_result_messages(messages) - request = types.ServerRequest( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - max_tokens=max_tokens, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - ), - ) + request = types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ), ) metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) @@ -358,7 +358,7 @@ async def list_roots(self) -> types.ListRootsResult: if self._stateless: raise StatelessModeNotSupported(method="list_roots") return await self.send_request( - types.ServerRequest(types.ListRootsRequest()), + types.ListRootsRequest(), types.ListRootsResult, ) @@ -406,13 +406,11 @@ async def elicit_form( if self._stateless: raise StatelessModeNotSupported(method="elicitation") return await self.send_request( - types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - ), - ) + types.ElicitRequest( + params=types.ElicitRequestFormParams( + message=message, + requested_schema=requested_schema, + ), ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), @@ -445,14 +443,12 @@ async def elicit_url( if self._stateless: raise StatelessModeNotSupported(method="elicitation") return await self.send_request( - types.ServerRequest( - types.ElicitRequest( - params=types.ElicitRequestURLParams( - message=message, - url=url, - elicitation_id=elicitation_id, - ), - ) + types.ElicitRequest( + params=types.ElicitRequestURLParams( + message=message, + url=url, + elicitation_id=elicitation_id, + ), ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), @@ -461,7 +457,7 @@ async def elicit_url( async def send_ping(self) -> types.EmptyResult: # pragma: no cover """Send a ping request.""" return await self.send_request( - types.ServerRequest(types.PingRequest()), + types.PingRequest(), types.EmptyResult, ) @@ -475,30 +471,28 @@ async def send_progress_notification( ) -> None: """Send a progress notification.""" await self.send_notification( - types.ServerNotification( - types.ProgressNotification( - params=types.ProgressNotificationParams( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - ), - ) + types.ProgressNotification( + params=types.ProgressNotificationParams( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ), ), related_request_id, ) async def send_resource_list_changed(self) -> None: # pragma: no cover """Send a resource list changed notification.""" - await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification())) + await self.send_notification(types.ResourceListChangedNotification()) async def send_tool_list_changed(self) -> None: # pragma: no cover """Send a tool list changed notification.""" - await self.send_notification(types.ServerNotification(types.ToolListChangedNotification())) + await self.send_notification(types.ToolListChangedNotification()) async def send_prompt_list_changed(self) -> None: # pragma: no cover """Send a prompt list changed notification.""" - await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) + await self.send_notification(types.PromptListChangedNotification()) async def send_elicit_complete( self, @@ -516,10 +510,8 @@ async def send_elicit_complete( related_request_id: Optional ID of the request that triggered this """ await self.send_notification( - types.ServerNotification( - types.ElicitCompleteNotification( - params=types.ElicitCompleteNotificationParams(elicitation_id=elicitation_id) - ) + types.ElicitCompleteNotification( + params=types.ElicitCompleteNotificationParams(elicitation_id=elicitation_id) ), related_request_id, ) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index be1990d61..c102200ed 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -9,7 +9,7 @@ import anyio import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -179,8 +179,6 @@ def __init__( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], - receive_request_type: type[ReceiveRequestT], - receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: float | None = None, ) -> None: @@ -188,8 +186,6 @@ def __init__( self._write_stream = write_stream self._response_streams = {} self._request_id = 0 - self._receive_request_type = receive_request_type - self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} @@ -264,11 +260,7 @@ async def send_request( self._progress_callbacks[request_id] = progress_callback try: - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) @@ -339,26 +331,30 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er session_message = SessionMessage(message=jsonrpc_response) await self._write_stream.send(session_message) + @property + def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]: + """Each subclass must provide its own request adapter.""" + raise NotImplementedError + + @property + def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: + raise NotImplementedError + async def _receive_loop(self) -> None: - async with ( - self._read_stream, - self._write_stream, - ): + async with self._read_stream, self._write_stream: try: async for message in self._read_stream: if isinstance(message, Exception): # pragma: no cover await self._handle_incoming(message) elif isinstance(message.message, JSONRPCRequest): try: - validated_request = self._receive_request_type.model_validate( + validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) responder = RequestResponder( request_id=message.message.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, + request_meta=validated_request.params.meta if validated_request.params else None, request=validated_request, session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), @@ -369,10 +365,10 @@ async def _receive_loop(self) -> None: if not responder._completed: # type: ignore[reportPrivateUsage] await self._handle_incoming(responder) - except Exception as e: + except Exception: # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server - logging.warning(f"Failed to validate request: {e}") + logging.warning("Failed to validate request", exc_info=True) logging.debug(f"Message that failed validation: {message.message}") error_response = JSONRPCError( jsonrpc="2.0", @@ -388,34 +384,31 @@ async def _receive_loop(self) -> None: elif isinstance(message.message, JSONRPCNotification): try: - notification = self._receive_notification_type.model_validate( + notification = self._receive_notification_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.request_id + if isinstance(notification, CancelledNotification): + cancelled_id = notification.params.request_id if cancelled_id in self._in_flight: # pragma: no branch await self._in_flight[cancelled_id].cancel() else: # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): # pragma: no cover - progress_token = notification.root.params.progress_token + if isinstance(notification, ProgressNotification): # pragma: no cover + progress_token = notification.params.progress_token # If there is a progress callback for this token, # call it with the progress information if progress_token in self._progress_callbacks: callback = self._progress_callbacks[progress_token] try: await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) - except Exception as e: - logging.error( - "Progress callback raised an exception: %s", - e, + notification.params.progress, + notification.params.total, + notification.params.message, ) + except Exception: + logging.exception("Progress callback raised an exception") await self._received_notification(notification) await self._handle_incoming(notification) except Exception: # pragma: no cover diff --git a/src/mcp/types.py b/src/mcp/types.py index 4c886680a..10b0c61fa 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, TypeAdapter +from pydantic import BaseModel, ConfigDict, Field, FileUrl, TypeAdapter from pydantic.alias_generators import to_camel LATEST_PROTOCOL_VERSION = "2025-11-25" @@ -1631,7 +1631,7 @@ class ElicitCompleteNotification( params: ElicitCompleteNotificationParams -ClientRequestType: TypeAlias = ( +ClientRequest = ( PingRequest | InitializeRequest | CompleteRequest @@ -1650,23 +1650,17 @@ class ElicitCompleteNotification( | ListTasksRequest | CancelTaskRequest ) +client_request_adapter = TypeAdapter[ClientRequest](ClientRequest) -class ClientRequest(RootModel[ClientRequestType]): - pass - - -ClientNotificationType: TypeAlias = ( +ClientNotification = ( CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification | TaskStatusNotification ) - - -class ClientNotification(RootModel[ClientNotificationType]): - pass +client_notification_adapter = TypeAdapter[ClientNotification](ClientNotification) # Type for elicitation schema - a JSON Schema dict @@ -1760,7 +1754,7 @@ class ElicitationRequiredErrorData(MCPModel): """List of URL mode elicitations that must be completed.""" -ClientResultType: TypeAlias = ( +ClientResult = ( EmptyResult | CreateMessageResult | CreateMessageResultWithTools @@ -1772,13 +1766,10 @@ class ElicitationRequiredErrorData(MCPModel): | CancelTaskResult | CreateTaskResult ) +client_result_adapter = TypeAdapter[ClientResult](ClientResult) -class ClientResult(RootModel[ClientResultType]): - pass - - -ServerRequestType: TypeAlias = ( +ServerRequest = ( PingRequest | CreateMessageRequest | ListRootsRequest @@ -1788,13 +1779,10 @@ class ClientResult(RootModel[ClientResultType]): | ListTasksRequest | CancelTaskRequest ) +server_request_adapter = TypeAdapter[ServerRequest](ServerRequest) -class ServerRequest(RootModel[ServerRequestType]): - pass - - -ServerNotificationType: TypeAlias = ( +ServerNotification = ( CancelledNotification | ProgressNotification | LoggingMessageNotification @@ -1805,13 +1793,10 @@ class ServerRequest(RootModel[ServerRequestType]): | ElicitCompleteNotification | TaskStatusNotification ) +server_notification_adapter = TypeAdapter[ServerNotification](ServerNotification) -class ServerNotification(RootModel[ServerNotificationType]): - pass - - -ServerResultType: TypeAlias = ( +ServerResult = ( EmptyResult | InitializeResult | CompleteResult @@ -1828,7 +1813,4 @@ class ServerNotification(RootModel[ServerNotificationType]): | CancelTaskResult | CreateTaskResult ) - - -class ServerResult(RootModel[ServerResultType]): - pass +server_result_adapter = TypeAdapter[ServerResult](ServerResult) diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index e05edb14d..06d893ac6 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -19,7 +19,7 @@ from mcp import ClientSession, types from mcp.client.streamable_http import streamable_http_client from mcp.shared.session import RequestResponder -from mcp.types import ClientNotification, RootsListChangedNotification +from mcp.types import RootsListChangedNotification from tests.test_helpers import wait_for_server @@ -135,9 +135,7 @@ async def message_handler( # pragma: no cover await session.initialize() # The test server returns a 204 instead of the expected 202 - await session.send_notification( - ClientNotification(RootsListChangedNotification(method="notifications/roots/list_changed")) - ) + await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) if returned_exception: # pragma: no cover pytest.fail(f"Server encountered an exception: {returned_exception}") diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index f47299cf8..c7bf8fafa 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -3,6 +3,7 @@ import anyio import pytest +from pydantic import TypeAdapter from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, RequestId, SendResultT @@ -23,20 +24,23 @@ async def _send_response( ) -> None: # pragma: no cover pass + @property + def _receive_request_adapter(self) -> TypeAdapter[Any]: + return TypeAdapter(object) # pragma: no cover + + @property + def _receive_notification_adapter(self) -> TypeAdapter[Any]: + return TypeAdapter(object) # pragma: no cover + # Create streams write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) # Create the session - session = TestSession( - read_stream_receive, - write_stream_send, - object, # Request type doesn't matter for this test - object, # Notification type doesn't matter for this test - ) + session = TestSession(read_stream_receive, write_stream_send) # Create a test request - request = ClientRequest(PingRequest()) + request = PingRequest() # Patch the _write_stream.send method to raise an exception async def mock_send(*args: Any, **kwargs: Any): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 9512a0a7c..5c1f55d23 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -12,8 +12,6 @@ from mcp.types import ( LATEST_PROTOCOL_VERSION, CallToolResult, - ClientNotification, - ClientRequest, Implementation, InitializedNotification, InitializeRequest, @@ -22,8 +20,9 @@ JSONRPCRequest, JSONRPCResponse, ServerCapabilities, - ServerResult, TextContent, + client_notification_adapter, + client_request_adapter, ) @@ -41,24 +40,22 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities( - logging=None, - resources=None, - tools=None, - experimental=None, - prompts=None, - ), - server_info=Implementation(name="mock-server", version="0.1.0"), - instructions="The server instructions.", - ) + assert isinstance(request, InitializeRequest) + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + logging=None, + resources=None, + tools=None, + experimental=None, + prompts=None, + ), + server_info=Implementation(name="mock-server", version="0.1.0"), + instructions="The server instructions.", ) async with server_to_client_send: @@ -74,7 +71,7 @@ async def mock_server(): session_notification = await client_to_server_receive.receive() jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification, JSONRPCNotification) - initialized_notification = ClientNotification.model_validate( + initialized_notification = client_notification_adapter.validate_python( jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -109,7 +106,7 @@ async def message_handler( # pragma: no cover # Check that the client sent the initialized notification assert initialized_notification - assert isinstance(initialized_notification.root, InitializedNotification) + assert isinstance(initialized_notification, InitializedNotification) @pytest.mark.anyio @@ -126,18 +123,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_client_info = request.root.params.client_info - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_client_info = request.params.client_info + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -185,18 +180,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_client_info = request.root.params.client_info - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_client_info = request.params.client_info + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -238,21 +231,19 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) # Verify client sent the latest protocol version - assert request.root.params.protocol_version == LATEST_PROTOCOL_VERSION + assert request.params.protocol_version == LATEST_PROTOCOL_VERSION # Server responds with a supported older version - result = ServerResult( - InitializeResult( - protocol_version="2024-11-05", - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version="2024-11-05", + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -295,18 +286,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) # Server responds with an unsupported version - result = ServerResult( - InitializeResult( - protocol_version="2020-01-01", # Unsupported old version - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version="2020-01-01", # Unsupported old version + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -349,18 +338,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -422,18 +409,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -503,18 +488,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -572,17 +555,15 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=expected_capabilities, - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=expected_capabilities, + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -639,17 +620,15 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) + assert isinstance(request, InitializeRequest) - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) # Answer initialization request @@ -678,9 +657,7 @@ async def mock_server(): assert "_meta" in jsonrpc_request.params assert jsonrpc_request.params["_meta"] == meta - result = ServerResult( - CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False) - ) + result = CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False) # Send the tools/call result await server_to_client_send.send( diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index be3547801..7bb806696 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -11,14 +11,13 @@ from mcp.shared.message import SessionMessage from mcp.types import ( LATEST_PROTOCOL_VERSION, - ClientRequest, Implementation, InitializeRequest, InitializeResult, JSONRPCRequest, JSONRPCResponse, ServerCapabilities, - ServerResult, + client_request_adapter, ) @@ -36,18 +35,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -108,18 +105,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -190,18 +185,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: @@ -268,18 +261,16 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) - request = ClientRequest.model_validate( + request = client_request_adapter.validate_python( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) - assert isinstance(request.root, InitializeRequest) - received_capabilities = request.root.params.capabilities - - result = ServerResult( - InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) + assert isinstance(request, InitializeRequest) + received_capabilities = request.params.capabilities + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), ) async with server_to_client_send: diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index 3c19d82d0..f21abf4d0 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -23,7 +23,6 @@ CallToolResult, CancelTaskRequest, CancelTaskResult, - ClientRequest, ClientResult, CreateTaskResult, GetTaskPayloadRequest, @@ -134,13 +133,11 @@ async def run_server(app_context: AppContext): # Create a task create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, @@ -240,13 +237,11 @@ async def run_server(app_context: AppContext): # Create a task create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, @@ -343,13 +338,11 @@ async def run_server(app_context: AppContext): # Create two tasks for _ in range(2): create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, @@ -456,13 +449,11 @@ async def run_server(app_context: AppContext): # Create a task (but don't complete it) create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), ) ), CreateTaskResult, diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index db18ef359..41cecc129 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -30,7 +30,6 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - ClientRequest, ClientResult, CreateTaskResult, GetTaskPayloadRequest, @@ -190,14 +189,12 @@ async def run_server(app_context: AppContext): # === Step 1: Send task-augmented tool call === create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), + ), ), CreateTaskResult, ) @@ -210,7 +207,7 @@ async def run_server(app_context: AppContext): await app_context.task_done_events[task_id].wait() task_status = await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id=task_id))), + GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), GetTaskResult, ) @@ -219,7 +216,7 @@ async def run_server(app_context: AppContext): # === Step 3: Retrieve the actual result === task_result = await client_session.send_request( - ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id))), + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)), CallToolResult, ) @@ -327,14 +324,12 @@ async def run_server(app_context: AppContext): # Send task request create_result = await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), ), CreateTaskResult, ) @@ -346,8 +341,7 @@ async def run_server(app_context: AppContext): # Check that task was auto-failed task_status = await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id=task_id))), - GetTaskResult, + GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), GetTaskResult ) assert task_status.status == "failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 6b0bbfef3..cb8b737f2 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -26,7 +26,6 @@ CancelTaskRequest, CancelTaskRequestParams, CancelTaskResult, - ClientRequest, ClientResult, ErrorData, GetTaskPayloadRequest, @@ -89,10 +88,10 @@ async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListTasksResult) - assert len(result.root.tasks) == 2 - assert result.root.tasks[0].task_id == "task-1" - assert result.root.tasks[1].task_id == "task-2" + assert isinstance(result, ListTasksResult) + assert len(result.tasks) == 2 + assert result.tasks[0].task_id == "task-1" + assert result.tasks[1].task_id == "task-2" @pytest.mark.anyio @@ -120,9 +119,9 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, GetTaskResult) - assert result.root.task_id == "test-task-123" - assert result.root.status == "working" + assert isinstance(result, GetTaskResult) + assert result.task_id == "test-task-123" + assert result.status == "working" @pytest.mark.anyio @@ -142,7 +141,7 @@ async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPaylo result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, GetTaskPayloadResult) + assert isinstance(result, GetTaskPayloadResult) @pytest.mark.anyio @@ -169,9 +168,9 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, CancelTaskResult) - assert result.root.task_id == "test-task-123" - assert result.root.status == "cancelled" + assert isinstance(result, CancelTaskResult) + assert result.task_id == "test-task-123" + assert result.status == "cancelled" @pytest.mark.anyio @@ -253,8 +252,8 @@ async def list_tools(): result = await tools_handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListToolsResult) - tools = result.root.tools + assert isinstance(result, ListToolsResult) + tools = result.tools assert tools[0].execution is not None assert tools[0].execution.task_support == TASK_FORBIDDEN @@ -330,14 +329,12 @@ async def handle_messages(): # Call tool with task metadata await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), ), CallToolResult, ) @@ -411,24 +408,14 @@ async def handle_messages(): # pragma: no cover # Call without task metadata await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), - ) - ), + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), CallToolResult, ) # Call with task metadata await client_session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ) + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), ), CallToolResult, ) @@ -507,16 +494,13 @@ async def run_server() -> None: task = await store.create_task(TaskMetadata(ttl=60000)) # Test list_tasks (default handler) - list_result = await client_session.send_request( - ClientRequest(ListTasksRequest()), - ListTasksResult, - ) + list_result = await client_session.send_request(ListTasksRequest(), ListTasksResult) assert len(list_result.tasks) == 1 assert list_result.tasks[0].task_id == task.task_id # Test get_task (default handler - found) get_result = await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id))), + GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id)), GetTaskResult, ) assert get_result.task_id == task.task_id @@ -525,7 +509,7 @@ async def run_server() -> None: # Test get_task (default handler - not found path) with pytest.raises(McpError, match="not found"): await client_session.send_request( - ClientRequest(GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task"))), + GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task")), GetTaskResult, ) @@ -538,9 +522,7 @@ async def run_server() -> None: # Test get_task_result (default handler) payload_result = await client_session.send_request( - ClientRequest( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=completed_task.task_id)) - ), + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=completed_task.task_id)), GetTaskPayloadResult, ) # The result should have the related-task metadata @@ -549,8 +531,7 @@ async def run_server() -> None: # Test cancel_task (default handler) cancel_result = await client_session.send_request( - ClientRequest(CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id))), - CancelTaskResult, + CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id)), CancelTaskResult ) assert cancel_result.task_id == task.task_id assert cancel_result.status == "cancelled" @@ -568,11 +549,7 @@ async def test_build_elicit_form_request() -> None: async with ServerSession( client_to_server_receive, server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), + InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), ) as server_session: # Test without task_id request = server_session._build_elicit_form_request( @@ -613,11 +590,7 @@ async def test_build_elicit_url_request() -> None: async with ServerSession( client_to_server_receive, server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), + InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), ) as server_session: # Test without related_task_id request = server_session._build_elicit_url_request( diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 1ebff7c92..26b58343c 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -26,8 +26,8 @@ def get_user_profile(user_id: str) -> str: # pragma: no cover result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest]( types.ListResourceTemplatesRequest(params=None) ) - assert isinstance(result.root, types.ListResourceTemplatesResult) - templates = result.root.resource_templates + assert isinstance(result, types.ListResourceTemplatesResult) + templates = result.resource_templates # Verify we get both templates back assert len(templates) == 2 diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 2554fbc73..44b17d337 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -60,7 +60,7 @@ async def read_resource(uri: str) -> list[ReadResourceContents]: result: ServerResult = await handler(request) # After (fixed code): - read_result: ReadResourceResult = cast(ReadResourceResult, result.root) + read_result: ReadResourceResult = cast(ReadResourceResult, result) blob_content = read_result.contents[0] # First verify our test data actually produces different encodings diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 726843b92..5f7caf7ac 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -77,14 +77,14 @@ async def handle_generic_notification( ) -> None: """Handle any server notification and route to appropriate handler.""" if isinstance(message, ServerNotification): # pragma: no branch - if isinstance(message.root, ProgressNotification): - self.progress_notifications.append(message.root.params) - elif isinstance(message.root, LoggingMessageNotification): - self.log_messages.append(message.root.params) - elif isinstance(message.root, ResourceListChangedNotification): - self.resource_notifications.append(message.root.params) - elif isinstance(message.root, ToolListChangedNotification): # pragma: no cover - self.tool_notifications.append(message.root.params) + if isinstance(message, ProgressNotification): + self.progress_notifications.append(message.params) + elif isinstance(message, LoggingMessageNotification): + self.log_messages.append(message.params) + elif isinstance(message, ResourceListChangedNotification): + self.resource_notifications.append(message.params) + elif isinstance(message, ToolListChangedNotification): # pragma: no cover + self.tool_notifications.append(message.params) # Common fixtures diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 38998b4b4..6bf4cddb3 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -41,8 +41,8 @@ async def handle_list_prompts() -> list[Prompt]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListPromptsResult) - assert result.root.prompts == test_prompts + assert isinstance(result, ListPromptsResult) + assert result.prompts == test_prompts @pytest.mark.anyio @@ -67,8 +67,8 @@ async def handle_list_resources() -> list[Resource]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListResourcesResult) - assert result.root.resources == test_resources + assert isinstance(result, ListResourcesResult) + assert result.resources == test_resources @pytest.mark.anyio @@ -114,8 +114,8 @@ async def handle_list_tools() -> list[Tool]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListToolsResult) - assert result.root.tools == test_tools + assert isinstance(result, ListToolsResult) + assert result.tools == test_tools @pytest.mark.anyio @@ -135,8 +135,8 @@ async def handle_list_prompts() -> list[Prompt]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListPromptsResult) - assert result.root.prompts == [] + assert isinstance(result, ListPromptsResult) + assert result.prompts == [] @pytest.mark.anyio @@ -156,8 +156,8 @@ async def handle_list_resources() -> list[Resource]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListResourcesResult) - assert result.root.resources == [] + assert isinstance(result, ListResourcesResult) + assert result.resources == [] @pytest.mark.anyio @@ -177,5 +177,5 @@ async def handle_list_tools() -> list[Tool]: result = await handler(request) assert isinstance(result, ServerResult) - assert isinstance(result.root, ListToolsResult) - assert result.root.tools == [] + assert isinstance(result, ListToolsResult) + assert result.tools == [] diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index aa8d42261..98f34df46 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -15,8 +15,6 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, - ClientNotification, - ClientRequest, Tool, ) @@ -59,11 +57,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async def first_request(): try: await client.session.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), - ) - ), + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), CallToolResult, ) pytest.fail("First request should have been cancelled") # pragma: no cover @@ -80,13 +74,8 @@ async def first_request(): # Cancel it assert first_request_id is not None await client.session.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams( - request_id=first_request_id, - reason="Testing server recovery", - ), - ) + CancelledNotification( + params=CancelledNotificationParams(request_id=first_request_id, reason="Testing server recovery"), ) ) diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 5d4c3347f..4767ea117 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -60,7 +60,7 @@ async def test_normal_message_handling_not_affected(): # Create a mock RequestResponder responder = Mock(spec=RequestResponder) - responder.request = types.ClientRequest(root=types.PingRequest(method="ping")) + responder.request = types.PingRequest(method="ping") responder.__enter__ = Mock(return_value=responder) responder.__exit__ = Mock(return_value=None) diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 0f62fe235..10349846c 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -39,10 +39,10 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: # Call the handler result = await handler(request) - assert isinstance(result.root, types.ReadResourceResult) - assert len(result.root.contents) == 1 + assert isinstance(result, types.ReadResourceResult) + assert len(result.contents) == 1 - content = result.root.contents[0] + content = result.contents[0] assert isinstance(content, types.TextResourceContents) assert content.text == "Hello World" assert content.mime_type == "text/plain" @@ -66,10 +66,10 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: # Call the handler result = await handler(request) - assert isinstance(result.root, types.ReadResourceResult) - assert len(result.root.contents) == 1 + assert isinstance(result, types.ReadResourceResult) + assert len(result.contents) == 1 - content = result.root.contents[0] + content = result.contents[0] assert isinstance(content, types.BlobResourceContents) assert content.mime_type == "application/octet-stream" @@ -97,10 +97,10 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: # Call the handler result = await handler(request) - assert isinstance(result.root, types.ReadResourceResult) - assert len(result.root.contents) == 1 + assert isinstance(result, types.ReadResourceResult) + assert len(result.contents) == 1 - content = result.root.contents[0] + content = result.contents[0] assert isinstance(content, types.TextResourceContents) assert content.text == "Hello World" assert content.mime_type == "text/plain" diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 3c1e96c12..5de988222 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -60,7 +60,7 @@ async def run_server(): raise message if isinstance(message, ClientNotification) and isinstance( - message.root, InitializedNotification + message, InitializedNotification ): # pragma: no branch received_initialized = True return @@ -158,7 +158,7 @@ async def run_server(): raise message if isinstance(message, types.ClientNotification) and isinstance( - message.root, InitializedNotification + message, InitializedNotification ): # pragma: no branch received_initialized = True return @@ -236,11 +236,11 @@ async def run_server(): # We should receive a ping request before initialization if isinstance(message, RequestResponder) and isinstance( - message.request.root, types.PingRequest + message.request, types.PingRequest ): # pragma: no branch # Respond to the ping with message: - await message.respond(types.ServerResult(types.EmptyResult())) + await message.respond(types.EmptyResult()) return async def mock_client(): diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py index bc6145aca..18c6b5fc6 100644 --- a/tests/server/test_session_race_condition.py +++ b/tests/server/test_session_race_condition.py @@ -57,27 +57,25 @@ async def run_server(): # Handle tools/list request if isinstance(message, RequestResponder): - if isinstance(message.request.root, types.ListToolsRequest): # pragma: no branch + if isinstance(message.request, types.ListToolsRequest): # pragma: no branch tools_list_success = True # Respond with a tool list with message: await message.respond( - types.ServerResult( - types.ListToolsResult( - tools=[ - Tool( - name="example_tool", - description="An example tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) + types.ListToolsResult( + tools=[ + Tool( + name="example_tool", + description="An example tool", + input_schema={"type": "object", "properties": {}}, + ) + ] ) ) # Handle InitializedNotification if isinstance(message, types.ClientNotification): - if isinstance(message.root, types.InitializedNotification): # pragma: no branch + if isinstance(message, types.InitializedNotification): # pragma: no branch # Done - exit gracefully return diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 78896397b..d65622822 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -135,8 +135,8 @@ async def handle_client_message( raise message if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.ProgressNotification): # pragma: no branch - params = message.root.params + if isinstance(message, types.ProgressNotification): # pragma: no branch + params = message.params client_progress_updates.append( { "token": params.progress_token, @@ -332,7 +332,7 @@ async def test_progress_callback_exception_logging(): # Track logged warnings logged_errors: list[str] = [] - def mock_log_error(msg: str, *args: Any) -> None: + def mock_log_exception(msg: str, *args: Any, **kwargs: Any) -> None: logged_errors.append(msg % args if args else msg) # Create a progress callback that raises an exception @@ -368,7 +368,7 @@ async def handle_list_tools() -> list[types.Tool]: ] # Test with mocked logging - with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): + with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): async with Client(server) as client: # Call tool with a failing progress callback result = await client.call_tool( diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 77bec4aa3..89fe18ebb 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -13,8 +13,6 @@ from mcp.types import ( CancelledNotification, CancelledNotificationParams, - ClientNotification, - ClientRequest, EmptyResult, ErrorData, JSONRPCError, @@ -73,10 +71,8 @@ async def make_request(client: Client): nonlocal ev_cancelled try: await client.session.send_request( - ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="slow_tool", arguments={}), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name="slow_tool", arguments={}), ), types.CallToolResult, ) @@ -97,11 +93,7 @@ async def make_request(client: Client): # Send cancellation notification assert request_id is not None await client.session.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams(request_id=request_id), - ) - ) + CancelledNotification(params=CancelledNotificationParams(request_id=request_id)) ) # Give cancellation time to process @@ -252,7 +244,7 @@ async def make_request(client_session: ClientSession): try: # Use a short timeout since we expect this to fail await client_session.send_request( - ClientRequest(types.PingRequest()), + types.PingRequest(), types.EmptyResult, request_read_timeout_seconds=0.5, ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 0c702dce2..ed86f9860 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1125,8 +1125,8 @@ async def message_handler( # pragma: no branch # Verify the notification is a ResourceUpdatedNotification resource_update_found = False for notif in notifications_received: - if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch - assert str(notif.root.params.uri) == "http://test_resource" + if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch + assert str(notif.params.uri) == "http://test_resource" resource_update_found = True assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" @@ -1167,10 +1167,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises( # pragma: no branch - McpError, - match="Session terminated", - ): + with pytest.raises(McpError, match="Session terminated"): # pragma: no branch await session.list_tools() @@ -1259,8 +1256,8 @@ async def message_handler( # pragma: no branch if isinstance(message, types.ServerNotification): # pragma: no branch captured_notifications.append(message) # Look for our first notification - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - if message.root.params.data == "First notification before lock": + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + if message.params.data == "First notification before lock": nonlocal first_notification_received first_notification_received = True @@ -1291,12 +1288,8 @@ async def run_tool(): on_resumption_token_update=on_resumption_token_update, ) await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams( - name="wait_for_lock_with_notification", arguments={} - ), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), ), types.CallToolResult, metadata=metadata, @@ -1313,8 +1306,8 @@ async def run_tool(): # Verify we received exactly one notification assert len(captured_notifications) == 1 # pragma: no cover - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: no cover + assert captured_notifications[0].params.data == "First notification before lock" # pragma: no cover # Clear notifications for the second phase captured_notifications = [] # pragma: no cover @@ -1334,11 +1327,7 @@ async def run_tool(): ): async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="release_lock", arguments={}), - ) - ), + types.CallToolRequest(params=types.CallToolRequestParams(name="release_lock", arguments={})), types.CallToolResult, ) metadata = ClientMessageMetadata( @@ -1346,10 +1335,8 @@ async def run_tool(): ) result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), - ) + types.CallToolRequest( + params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), ), types.CallToolResult, metadata=metadata, @@ -1361,8 +1348,8 @@ async def run_tool(): # We should have received the remaining notifications assert len(captured_notifications) == 1 - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "Second notification after lock" # pragma: no cover + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: no cover + assert captured_notifications[0].params.data == "Second notification after lock" # pragma: no cover @pytest.mark.anyio @@ -1905,11 +1892,7 @@ async def on_resumption_token_update(token: str) -> None: on_resumption_token_update=on_resumption_token_update, ) result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="test_tool", arguments={}), - ) - ), + types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), types.CallToolResult, metadata=metadata, ) @@ -1967,8 +1950,8 @@ async def message_handler( if isinstance(message, Exception): # pragma: no branch return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - captured_notifications.append(str(message.root.params.data)) + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + captured_notifications.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, @@ -2043,8 +2026,8 @@ async def message_handler( if isinstance(message, Exception): # pragma: no branch return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - all_notifications.append(str(message.root.params.data)) + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + all_notifications.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, @@ -2091,8 +2074,8 @@ async def message_handler( if isinstance(message, Exception): # pragma: no branch return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch - notification_data.append(str(message.root.params.data)) + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + notification_data.append(str(message.params.data)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, @@ -2153,15 +2136,13 @@ async def on_resumption_token(token: str) -> None: # Use send_request with metadata to track resumption tokens metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), - ) + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="tool_with_multiple_stream_closes", + # retry_interval=500ms, so sleep 600ms to ensure reconnect completes + arguments={"checkpoints": 3, "sleep_time": 0.6}, + ), ), types.CallToolResult, metadata=metadata, @@ -2202,8 +2183,8 @@ async def message_handler( if isinstance(message, Exception): return # pragma: no cover if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message.root, types.ResourceUpdatedNotification): # pragma: no branch - received_notifications.append(str(message.root.params.uri)) + if isinstance(message, types.ResourceUpdatedNotification): # pragma: no branch + received_notifications.append(str(message.params.uri)) async with streamable_http_client(f"{server_url}/mcp") as ( read_stream, diff --git a/tests/test_types.py b/tests/test_types.py index 454bac34b..f424efdbf 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -5,7 +5,6 @@ from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, - ClientRequest, CreateMessageRequestParams, CreateMessageResult, CreateMessageResultWithTools, @@ -21,6 +20,7 @@ ToolChoice, ToolResultContent, ToolUseContent, + client_request_adapter, jsonrpc_message_adapter, ) @@ -40,7 +40,7 @@ async def test_jsonrpc_request(): request = jsonrpc_message_adapter.validate_python(json_data) assert isinstance(request, JSONRPCRequest) - ClientRequest.model_validate(request.model_dump(by_alias=True, exclude_none=True)) + client_request_adapter.validate_python(request.model_dump(by_alias=True, exclude_none=True)) assert request.jsonrpc == "2.0" assert request.id == 1