rust_mcp_transport/message_dispatcher.rs
1use async_trait::async_trait;
2use rust_mcp_schema::schema_utils::{
3 self, ClientMessage, FromMessage, McpMessage, MessageFromClient, MessageFromServer,
4 ServerMessage,
5};
6use rust_mcp_schema::{RequestId, RpcError};
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::atomic::AtomicI64;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::AsyncWriteExt;
13use tokio::sync::oneshot;
14use tokio::sync::Mutex;
15
16use crate::error::{TransportError, TransportResult};
17use crate::utils::await_timeout;
18use crate::McpDispatch;
19
20/// Provides a dispatcher for sending MCP messages and handling responses.
21///
22/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
23/// and response handling. It supports both client-to-server and server-to-client message flows through
24/// implementations of the `McpDispatch` trait. The dispatcher uses a transport mechanism
25/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
26/// a configurable timeout mechanism for asynchronous responses.
27pub struct MessageDispatcher<R> {
28 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
29 writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
30 message_id_counter: Arc<AtomicI64>,
31 request_timeout: Duration,
32}
33
34impl<R> MessageDispatcher<R> {
35 /// Creates a new `MessageDispatcher` instance with the given configuration.
36 ///
37 /// # Arguments
38 /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
39 /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
40 /// * `message_id_counter` - An atomic counter for generating unique request IDs.
41 /// * `request_timeout` - The timeout duration in milliseconds for awaiting responses.
42 ///
43 /// # Returns
44 /// A new `MessageDispatcher` instance configured for MCP message handling.
45 pub fn new(
46 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
47 writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
48 message_id_counter: Arc<AtomicI64>,
49 request_timeout: Duration,
50 ) -> Self {
51 Self {
52 pending_requests,
53 writable_std,
54 message_id_counter,
55 request_timeout,
56 }
57 }
58
59 /// Determines the request ID for an outgoing MCP message.
60 ///
61 /// For requests, generates a new ID using the internal counter. For responses or errors,
62 /// uses the provided `request_id`. Notifications receive no ID.
63 ///
64 /// # Arguments
65 /// * `message` - The MCP message to evaluate.
66 /// * `request_id` - An optional existing request ID (required for responses/errors).
67 ///
68 /// # Returns
69 /// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
70 fn request_id_for_message(
71 &self,
72 message: &impl McpMessage,
73 request_id: Option<RequestId>,
74 ) -> Option<RequestId> {
75 // we need to produce next request_id for requests
76 if message.is_request() {
77 // request_id should be None for requests
78 assert!(request_id.is_none());
79 Some(RequestId::Integer(
80 self.message_id_counter
81 .fetch_add(1, std::sync::atomic::Ordering::Relaxed),
82 ))
83 } else if !message.is_notification() {
84 // `request_id` must not be `None` for errors, notifications and responses
85 assert!(request_id.is_some());
86 request_id
87 } else {
88 None
89 }
90 }
91}
92
93#[async_trait]
94impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerMessage> {
95 /// Sends a message from the client to the server and awaits a response if applicable.
96 ///
97 /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a
98 /// `ServerMessage` response if the message is a request. Notifications and responses return
99 /// `Ok(None)`.
100 ///
101 /// # Arguments
102 /// * `message` - The client message to send.
103 /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
104 ///
105 /// # Returns
106 /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response,
107 /// or `None` for notifications/responses, or an error if the operation fails.
108 ///
109 /// # Errors
110 /// Returns a `TransportError` if serialization, writing, or timeout occurs.
111 async fn send(
112 &self,
113 message: MessageFromClient,
114 request_id: Option<RequestId>,
115 request_timeout: Option<Duration>,
116 ) -> TransportResult<Option<ServerMessage>> {
117 let mut writable_std = self.writable_std.lock().await;
118
119 // returns the request_id to be used to construct the message
120 // a new requestId will be returned for Requests and Notification
121 let outgoing_request_id = self.request_id_for_message(&message, request_id);
122
123 let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> = {
124 // Store the sender in the pending requests map
125 if message.is_request() {
126 if let Some(request_id) = &outgoing_request_id {
127 let (tx_response, rx_response) = oneshot::channel::<ServerMessage>();
128 let mut pending_requests = self.pending_requests.lock().await;
129 // store request id in the hashmap while waiting for a matching response
130 pending_requests.insert(request_id.clone(), tx_response);
131 Some(rx_response)
132 } else {
133 None
134 }
135 } else {
136 None
137 }
138 };
139
140 let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?;
141
142 //serialize the message and write it to the writable_std
143 let message_str = serde_json::to_string(&mpc_message)
144 .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
145
146 writable_std.write_all(message_str.as_bytes()).await?;
147 writable_std.write_all(b"\n").await?; // new line
148 writable_std.flush().await?;
149
150 if let Some(rx) = rx_response {
151 // Wait for the response with timeout
152 match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
153 Ok(response) => Ok(Some(response)),
154 Err(error) => match error {
155 TransportError::OneshotRecvError(_) => {
156 Err(schema_utils::SdkError::connection_closed().into())
157 }
158 _ => Err(error),
159 },
160 }
161 } else {
162 Ok(None)
163 }
164 }
165}
166
167#[async_trait]
168impl McpDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientMessage> {
169 /// Sends a message from the server to the client and awaits a response if applicable.
170 ///
171 /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a
172 /// `ClientMessage` response if the message is a request. Notifications and responses return
173 /// `Ok(None)`.
174 ///
175 /// # Arguments
176 /// * `message` - The server message to send.
177 /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
178 ///
179 /// # Returns
180 /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response,
181 /// or `None` for notifications/responses, or an error if the operation fails.
182 ///
183 /// # Errors
184 /// Returns a `TransportError` if serialization, writing, or timeout occurs.
185 async fn send(
186 &self,
187 message: MessageFromServer,
188 request_id: Option<RequestId>,
189 request_timeout: Option<Duration>,
190 ) -> TransportResult<Option<ClientMessage>> {
191 let mut writable_std = self.writable_std.lock().await;
192
193 // returns the request_id to be used to construct the message
194 // a new requestId will be returned for Requests and Notification
195 let outgoing_request_id = self.request_id_for_message(&message, request_id);
196
197 let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> = {
198 // Store the sender in the pending requests map
199 if message.is_request() {
200 if let Some(request_id) = &outgoing_request_id {
201 let (tx_response, rx_response) = oneshot::channel::<ClientMessage>();
202 let mut pending_requests = self.pending_requests.lock().await;
203 // store request id in the hashmap while waiting for a matching response
204 pending_requests.insert(request_id.clone(), tx_response);
205 Some(rx_response)
206 } else {
207 None
208 }
209 } else {
210 None
211 }
212 };
213
214 let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?;
215
216 //serialize the message and write it to the writable_std
217 let message_str = serde_json::to_string(&mpc_message)
218 .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
219
220 writable_std.write_all(message_str.as_bytes()).await?;
221 writable_std.write_all(b"\n").await?; // new line
222 writable_std.flush().await?;
223
224 if let Some(rx) = rx_response {
225 match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
226 Ok(response) => Ok(Some(response)),
227 Err(error) => Err(error),
228 }
229 } else {
230 Ok(None)
231 }
232 }
233}