X Tutup
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared._context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.session import (
BaseSession,
ProgressFnT,
RequestResponder,
request_methods_for_union,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types._types import RequestParamsMeta

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
KNOWN_SERVER_REQUEST_METHODS = request_methods_for_union(types.ServerRequest)

logger = logging.getLogger("client")

Expand Down Expand Up @@ -141,6 +147,10 @@ def __init__(
def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]:
return types.server_request_adapter

@property
def _known_request_methods(self) -> frozenset[str]:
return KNOWN_SERVER_REQUEST_METHODS

@property
def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]:
return types.server_notification_adapter
Expand Down
7 changes: 7 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
from mcp.shared.session import (
BaseSession,
RequestResponder,
request_methods_for_union,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

Expand All @@ -63,6 +64,8 @@ class InitializationState(Enum):
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
)

KNOWN_CLIENT_REQUEST_METHODS = request_methods_for_union(types.ClientRequest)


class ServerSession(
BaseSession[
Expand Down Expand Up @@ -100,6 +103,10 @@ def __init__(
def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]:
return types.client_request_adapter

@property
def _known_request_methods(self) -> frozenset[str]:
return KNOWN_CLIENT_REQUEST_METHODS

@property
def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]:
return types.client_notification_adapter
Expand Down
27 changes: 25 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Callable
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar
from typing import Any, Generic, Protocol, TypeVar, get_args

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand All @@ -17,6 +17,7 @@
from mcp.types import (
CONNECTION_CLOSED,
INVALID_PARAMS,
METHOD_NOT_FOUND,
REQUEST_TIMEOUT,
CancelledNotification,
ClientNotification,
Expand Down Expand Up @@ -45,6 +46,16 @@
RequestId = str | int


def request_methods_for_union(request_union: Any) -> frozenset[str]:
methods: set[str] = set()
for request_type in get_args(request_union):
field = getattr(request_type, "model_fields", {}).get("method")
default = getattr(field, "default", None)
if isinstance(default, str):
methods.add(default)
return frozenset(methods)


class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""

Expand Down Expand Up @@ -326,6 +337,10 @@ def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]:
"""Each subclass must provide its own request adapter."""
raise NotImplementedError

@property
def _known_request_methods(self) -> frozenset[str]:
return frozenset()

@property
def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError
Expand Down Expand Up @@ -360,10 +375,18 @@ async def _receive_loop(self) -> None:
# response instead of crashing the server
logging.warning("Failed to validate request", exc_info=True)
logging.debug(f"Message that failed validation: {message.message}")
if message.message.method not in self._known_request_methods:
error = ErrorData(code=METHOD_NOT_FOUND, message="Method not found")
else:
error = ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
)
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.id,
error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""),
error=error,
)
session_message = SessionMessage(message=error_response)
await self._write_stream.send(session_message)
Expand Down
101 changes: 101 additions & 0 deletions tests/issues/test_1561_invalid_method_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Test for issue #1561: unknown methods should return METHOD_NOT_FOUND."""

import anyio
import pytest
from pydantic import BaseModel

from mcp import types
from mcp.client.session import KNOWN_SERVER_REQUEST_METHODS, ClientSession
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, request_methods_for_union
from mcp.types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCRequest, ServerCapabilities


@pytest.mark.anyio
async def test_invalid_method_returns_method_not_found() -> None:
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10)

async with read_send_stream, read_receive_stream, write_send_stream, write_receive_stream:
async with ServerSession(
read_stream=read_receive_stream,
write_stream=write_send_stream,
init_options=InitializationOptions(
server_name="test_server",
server_version="1.0.0",
capabilities=ServerCapabilities(),
),
):
await read_send_stream.send(
SessionMessage(
message=JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="invalid/method",
params={},
)
)
)

await anyio.sleep(0.1)

response_message = write_receive_stream.receive_nowait()
response = response_message.message

assert isinstance(response, JSONRPCError)
assert response.id == 1
assert response.error.code == METHOD_NOT_FOUND
assert response.error.message == "Method not found"


class MissingDefaultMethodRequest(BaseModel):
jsonrpc: str = "2.0"
id: int = 1
method: str


def test_request_methods_for_union_ignores_non_literal_defaults() -> None:
methods = request_methods_for_union(types.ServerRequest | MissingDefaultMethodRequest)
assert methods == KNOWN_SERVER_REQUEST_METHODS


@pytest.mark.anyio
async def test_client_session_known_request_methods_match_server_request_union() -> None:
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10)

async with read_send_stream, read_receive_stream, write_send_stream, write_receive_stream:
session = ClientSession(read_stream=read_receive_stream, write_stream=write_send_stream)
assert session._known_request_methods == KNOWN_SERVER_REQUEST_METHODS


class DummyBaseSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
@property
def _receive_request_adapter(self):
return types.server_request_adapter

@property
def _receive_notification_adapter(self):
return types.server_notification_adapter


@pytest.mark.anyio
async def test_base_session_known_request_methods_default_to_empty() -> None:
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10)

async with read_send_stream, read_receive_stream, write_send_stream, write_receive_stream:
session = DummyBaseSession(read_stream=read_receive_stream, write_stream=write_send_stream)
assert session._known_request_methods == frozenset()
assert session._receive_request_adapter is types.server_request_adapter
assert session._receive_notification_adapter is types.server_notification_adapter
Loading
X Tutup