Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .server.session import ServerSession
from .server.stdio import stdio_server
from .shared.exceptions import McpError, UrlElicitationRequiredError
from .shared.session import MessageMiddleware
from .types import (
CallToolRequest,
ClientCapabilities,
Expand All @@ -23,6 +24,8 @@
InitializeRequest,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListPromptsRequest,
Expand Down Expand Up @@ -87,8 +90,11 @@
"InitializeResult",
"InitializedNotification",
"JSONRPCError",
"JSONRPCMessage",
"JSONRPCNotification",
"JSONRPCRequest",
"JSONRPCResponse",
"MessageMiddleware",
"ListPromptsRequest",
"ListPromptsResult",
"ListResourcesRequest",
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.session import BaseSession, MessageMiddleware, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
Expand Down Expand Up @@ -123,13 +123,17 @@ def __init__(
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
send_middleware: list["MessageMiddleware"] | None = None,
receive_middleware: list["MessageMiddleware"] | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
send_middleware=send_middleware,
receive_middleware=receive_middleware,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
Expand Down
13 changes: 12 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
MessageMiddleware,
RequestResponder,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
Expand Down Expand Up @@ -91,8 +92,18 @@ def __init__(
write_stream: MemoryObjectSendStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
*,
send_middleware: list["MessageMiddleware"] | None = None,
receive_middleware: list["MessageMiddleware"] | None = None,
) -> None:
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
super().__init__(
read_stream,
write_stream,
types.ClientRequest,
types.ClientNotification,
send_middleware=send_middleware,
receive_middleware=receive_middleware,
)
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
)
Expand Down
53 changes: 47 additions & 6 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import logging
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
Expand Down Expand Up @@ -43,6 +44,10 @@

RequestId = str | int

# Middleware type for transforming messages before sending or after receiving.
# Can be sync (returns JSONRPCMessage) or async (returns Awaitable[JSONRPCMessage]).
MessageMiddleware = Callable[[JSONRPCMessage], JSONRPCMessage | Awaitable[JSONRPCMessage]]


class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""
Expand Down Expand Up @@ -190,6 +195,9 @@ def __init__(
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
*,
send_middleware: list[MessageMiddleware] | None = None,
receive_middleware: list[MessageMiddleware] | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
Expand All @@ -202,6 +210,24 @@ def __init__(
self._progress_callbacks = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()
# Pre-compute whether each middleware is async to avoid checking on every message
self._send_middleware: list[tuple[MessageMiddleware, bool]] = [
(m, inspect.iscoroutinefunction(m)) for m in (send_middleware or [])
]
self._receive_middleware: list[tuple[MessageMiddleware, bool]] = [
(m, inspect.iscoroutinefunction(m)) for m in (receive_middleware or [])
]

async def _apply_middleware(
self, message: JSONRPCMessage, middleware_list: list[tuple[MessageMiddleware, bool]]
) -> JSONRPCMessage:
"""Apply a list of middleware functions to a message."""
for middleware, is_async in middleware_list:
result = middleware(message)
if is_async:
result = await result # type: ignore[misc]
message = result # type: ignore[assignment]
return message # type: ignore[return-value]

def add_response_router(self, router: ResponseRouter) -> None:
"""
Expand Down Expand Up @@ -278,7 +304,9 @@ async def send_request(
**request_data,
)

await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
message = JSONRPCMessage(jsonrpc_request)
message = await self._apply_middleware(message, self._send_middleware)
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))

# request read timeout takes precedence over session read timeout
timeout = None
Expand Down Expand Up @@ -328,24 +356,30 @@ async def send_notification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
message = JSONRPCMessage(jsonrpc_notification)
message = await self._apply_middleware(message, self._send_middleware)
session_message = SessionMessage( # pragma: no cover
message=JSONRPCMessage(jsonrpc_notification),
message=message,
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
await self._write_stream.send(session_message)

async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
message = JSONRPCMessage(jsonrpc_error)
message = await self._apply_middleware(message, self._send_middleware)
session_message = SessionMessage(message=message)
await self._write_stream.send(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
message = JSONRPCMessage(jsonrpc_response)
message = await self._apply_middleware(message, self._send_middleware)
session_message = SessionMessage(message=message)
await self._write_stream.send(session_message)

async def _receive_loop(self) -> None:
Expand All @@ -357,7 +391,14 @@ async def _receive_loop(self) -> None:
async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
continue

# Apply receive middleware to transform the message
if self._receive_middleware:
transformed_msg = await self._apply_middleware(message.message, self._receive_middleware)
message = SessionMessage(message=transformed_msg, metadata=message.metadata) # noqa: PLW2901

if isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
Expand Down
Loading