rust_mcp_transport/
message_dispatcher.rs

1use crate::error::{TransportError, TransportResult};
2use crate::schema::{RequestId, RpcError};
3use crate::utils::{await_timeout, current_timestamp};
4use crate::McpDispatch;
5use crate::{
6    event_store::EventStore,
7    schema::{
8        schema_utils::{
9            self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage,
10            ServerMessages,
11        },
12        JsonrpcError,
13    },
14    SessionId, StreamId,
15};
16use async_trait::async_trait;
17use futures::future::join_all;
18use std::collections::HashMap;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::AsyncWriteExt;
23use tokio::sync::oneshot::{self};
24use tokio::sync::Mutex;
25
26pub const ID_SEPARATOR: u8 = b'|';
27
28/// Provides a dispatcher for sending MCP messages and handling responses.
29///
30/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
31/// and response handling. It supports both client-to-server and server-to-client message flows through
32/// implementations of the `McpDispatch` trait. The dispatcher uses a transport mechanism
33/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
34/// a configurable timeout mechanism for asynchronous responses.
35pub struct MessageDispatcher<R> {
36    pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
37    writable_std: Option<Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>>,
38    writable_tx: Option<
39        tokio::sync::mpsc::Sender<(
40            String,
41            tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
42        )>,
43    >,
44    request_timeout: Duration,
45    // resumability support
46    session_id: Option<SessionId>,
47    stream_id: Option<StreamId>,
48    event_store: Option<Arc<dyn EventStore>>,
49}
50
51impl<R> MessageDispatcher<R> {
52    /// Creates a new `MessageDispatcher` instance with the given configuration.
53    ///
54    /// # Arguments
55    /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
56    /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
57    /// * `message_id_counter` - An atomic counter for generating unique request IDs.
58    /// * `request_timeout` - The timeout duration in milliseconds for awaiting responses.
59    ///
60    /// # Returns
61    /// A new `MessageDispatcher` instance configured for MCP message handling.
62    pub fn new(
63        pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
64        writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
65        request_timeout: Duration,
66    ) -> Self {
67        Self {
68            pending_requests,
69            writable_std: Some(writable_std),
70            writable_tx: None,
71            request_timeout,
72            session_id: None,
73            stream_id: None,
74            event_store: None,
75        }
76    }
77
78    pub fn new_with_acknowledgement(
79        pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
80        writable_tx: tokio::sync::mpsc::Sender<(
81            String,
82            tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
83        )>,
84        request_timeout: Duration,
85    ) -> Self {
86        Self {
87            pending_requests,
88            writable_tx: Some(writable_tx),
89            writable_std: None,
90            request_timeout,
91            session_id: None,
92            stream_id: None,
93            event_store: None,
94        }
95    }
96
97    /// Supports resumability for streamable HTTP transports by setting the session ID,
98    /// stream ID, and event store.
99    pub fn make_resumable(
100        &mut self,
101        session_id: SessionId,
102        stream_id: StreamId,
103        event_store: Arc<dyn EventStore>,
104    ) {
105        self.session_id = Some(session_id);
106        self.stream_id = Some(stream_id);
107        self.event_store = Some(event_store);
108    }
109
110    async fn store_pending_request(
111        &self,
112        request_id: RequestId,
113    ) -> tokio::sync::oneshot::Receiver<R> {
114        let (tx_response, rx_response) = oneshot::channel::<R>();
115        let mut pending_requests = self.pending_requests.lock().await;
116        // store request id in the hashmap while waiting for a matching response
117        pending_requests.insert(request_id.clone(), tx_response);
118        rx_response
119    }
120
121    async fn store_pending_request_for_message<M: McpMessage + RpcMessage>(
122        &self,
123        message: &M,
124    ) -> Option<tokio::sync::oneshot::Receiver<R>> {
125        if message.is_request() {
126            if let Some(request_id) = message.request_id() {
127                Some(self.store_pending_request(request_id.clone()).await)
128            } else {
129                None
130            }
131        } else {
132            None
133        }
134    }
135}
136
137// Client side dispatcher
138#[async_trait]
139impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
140    for MessageDispatcher<ServerMessage>
141{
142    /// Sends a message from the client to the server and awaits a response if applicable.
143    ///
144    /// Serializes the `ClientMessages` to JSON, writes it to the transport, and waits for a
145    /// `ServerMessages` response if the message is a request. Notifications and responses return
146    /// `Ok(None)`.
147    ///
148    /// # Arguments
149    /// * `messages` - The client message to send, coulld be a single message or batch.
150    ///
151    /// # Returns
152    /// A `TransportResult` containing `Some(ServerMessages)` for requests with a response,
153    /// or `None` for notifications/responses, or an error if the operation fails.
154    ///
155    /// # Errors
156    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
157    async fn send_message(
158        &self,
159        messages: ClientMessages,
160        request_timeout: Option<Duration>,
161    ) -> TransportResult<Option<ServerMessages>> {
162        match messages {
163            ClientMessages::Single(message) => {
164                let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> =
165                    self.store_pending_request_for_message(&message).await;
166
167                //serialize the message and write it to the writable_std
168                let message_payload = serde_json::to_string(&message).map_err(|_| {
169                    crate::error::TransportError::JsonrpcError(RpcError::parse_error())
170                })?;
171
172                self.write_str(message_payload.as_str(), true).await?;
173
174                if let Some(rx) = rx_response {
175                    // Wait for the response with timeout
176                    match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
177                        Ok(response) => Ok(Some(ServerMessages::Single(response))),
178                        Err(error) => match error {
179                            TransportError::ChannelClosed(_) => {
180                                Err(schema_utils::SdkError::connection_closed().into())
181                            }
182                            _ => Err(error),
183                        },
184                    }
185                } else {
186                    Ok(None)
187                }
188            }
189            ClientMessages::Batch(client_messages) => {
190                let (request_ids, pending_tasks): (Vec<_>, Vec<_>) = client_messages
191                    .iter()
192                    .filter(|message| message.is_request())
193                    .map(|message| {
194                        (
195                            message.request_id().unwrap(), // guaranteed to have request_id
196                            self.store_pending_request_for_message(message),
197                        )
198                    })
199                    .unzip();
200
201                // Ensure all request IDs are stored before sending the request
202                let tasks = join_all(pending_tasks).await;
203
204                // send the batch messages to the server
205                let message_payload = serde_json::to_string(&client_messages).map_err(|_| {
206                    crate::error::TransportError::JsonrpcError(RpcError::parse_error())
207                })?;
208                self.write_str(message_payload.as_str(), true).await?;
209
210                // no request in the batch, no need to wait for the result
211                if request_ids.is_empty() {
212                    return Ok(None);
213                }
214
215                let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| {
216                    rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)))
217                });
218
219                let results: Vec<_> = join_all(timeout_wrapped_futures)
220                    .await
221                    .into_iter()
222                    .zip(request_ids)
223                    .map(|(res, request_id)| match res {
224                        Ok(response) => response,
225                        Err(error) => ServerMessage::Error(JsonrpcError::new(
226                            RpcError::internal_error().with_message(error.to_string()),
227                            request_id.to_owned(),
228                        )),
229                    })
230                    .collect();
231
232                Ok(Some(ServerMessages::Batch(results)))
233            }
234        }
235    }
236
237    async fn send(
238        &self,
239        message: ClientMessage,
240        request_timeout: Option<Duration>,
241    ) -> TransportResult<Option<ServerMessage>> {
242        let response = self.send_message(message.into(), request_timeout).await?;
243        match response {
244            Some(r) => Ok(Some(r.as_single()?)),
245            None => Ok(None),
246        }
247    }
248
249    async fn send_batch(
250        &self,
251        message: Vec<ClientMessage>,
252        request_timeout: Option<Duration>,
253    ) -> TransportResult<Option<Vec<ServerMessage>>> {
254        let response = self.send_message(message.into(), request_timeout).await?;
255        match response {
256            Some(r) => Ok(Some(r.as_batch()?)),
257            None => Ok(None),
258        }
259    }
260
261    /// Writes a string payload to the underlying asynchronous writable stream,
262    /// appending a newline character and flushing the stream afterward.
263    ///
264    async fn write_str(&self, payload: &str, _skip_store: bool) -> TransportResult<()> {
265        if let Some(writable_std) = self.writable_std.as_ref() {
266            let mut writable_std = writable_std.lock().await;
267            writable_std.write_all(payload.as_bytes()).await?;
268            writable_std.write_all(b"\n").await?; // new line
269            writable_std.flush().await?;
270            return Ok(());
271        };
272
273        if let Some(writable_tx) = self.writable_tx.as_ref() {
274            let (resp_tx, resp_rx) = oneshot::channel();
275            writable_tx
276                .send((payload.to_string(), resp_tx))
277                .await
278                .map_err(|err| TransportError::Internal(format!("{err}")))?; // Send fails if channel closed
279            return resp_rx.await?; // Await the POST result; propagates the error if POST failed
280        }
281
282        Err(TransportError::Internal("Invalid dispatcher!".to_string()))
283    }
284}
285
286// Server side dispatcher, Sends S and Returns R
287#[async_trait]
288impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
289    for MessageDispatcher<ClientMessage>
290{
291    /// Sends a message from the server to the client and awaits a response if applicable.
292    ///
293    /// Serializes the `ServerMessages` to JSON, writes it to the transport, and waits for a
294    /// `ClientMessages` response if the message is a request. Notifications and responses return
295    /// `Ok(None)`.
296    ///
297    /// # Arguments
298    /// * `messages` - The client message to send, coulld be a single message or batch.
299    ///
300    /// # Returns
301    /// A `TransportResult` containing `Some(ClientMessages)` for requests with a response,
302    /// or `None` for notifications/responses, or an error if the operation fails.
303    ///
304    /// # Errors
305    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
306    async fn send_message(
307        &self,
308        messages: ServerMessages,
309        request_timeout: Option<Duration>,
310    ) -> TransportResult<Option<ClientMessages>> {
311        match messages {
312            ServerMessages::Single(message) => {
313                let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> =
314                    self.store_pending_request_for_message(&message).await;
315
316                let message_payload = serde_json::to_string(&message).map_err(|_| {
317                    crate::error::TransportError::JsonrpcError(RpcError::parse_error())
318                })?;
319
320                self.write_str(message_payload.as_str(), false).await?;
321
322                if let Some(rx) = rx_response {
323                    match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
324                        Ok(response) => Ok(Some(ClientMessages::Single(response))),
325                        Err(error) => Err(error),
326                    }
327                } else {
328                    Ok(None)
329                }
330            }
331            ServerMessages::Batch(server_messages) => {
332                let (request_ids, pending_tasks): (Vec<_>, Vec<_>) = server_messages
333                    .iter()
334                    .filter(|message| message.is_request())
335                    .map(|message| {
336                        (
337                            message.request_id().unwrap(), // guaranteed to have request_id
338                            self.store_pending_request_for_message(message),
339                        )
340                    })
341                    .unzip();
342
343                // send the batch messages to the client
344                let message_payload = serde_json::to_string(&server_messages).map_err(|_| {
345                    crate::error::TransportError::JsonrpcError(RpcError::parse_error())
346                })?;
347
348                self.write_str(message_payload.as_str(), false).await?;
349
350                // no request in the batch, no need to wait for the result
351                if pending_tasks.is_empty() {
352                    return Ok(None);
353                }
354
355                let tasks = join_all(pending_tasks).await;
356
357                let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| {
358                    rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)))
359                });
360
361                let results: Vec<_> = join_all(timeout_wrapped_futures)
362                    .await
363                    .into_iter()
364                    .zip(request_ids)
365                    .map(|(res, request_id)| match res {
366                        Ok(response) => response,
367                        Err(error) => ClientMessage::Error(JsonrpcError::new(
368                            RpcError::internal_error().with_message(error.to_string()),
369                            request_id.to_owned(),
370                        )),
371                    })
372                    .collect();
373
374                Ok(Some(ClientMessages::Batch(results)))
375            }
376        }
377    }
378
379    async fn send(
380        &self,
381        message: ServerMessage,
382        request_timeout: Option<Duration>,
383    ) -> TransportResult<Option<ClientMessage>> {
384        let response = self.send_message(message.into(), request_timeout).await?;
385        match response {
386            Some(r) => Ok(Some(r.as_single()?)),
387            None => Ok(None),
388        }
389    }
390
391    async fn send_batch(
392        &self,
393        message: Vec<ServerMessage>,
394        request_timeout: Option<Duration>,
395    ) -> TransportResult<Option<Vec<ClientMessage>>> {
396        let response = self.send_message(message.into(), request_timeout).await?;
397        match response {
398            Some(r) => Ok(Some(r.as_batch()?)),
399            None => Ok(None),
400        }
401    }
402
403    /// Writes a string payload to the underlying asynchronous writable stream,
404    /// appending a newline character and flushing the stream afterward.
405    ///
406    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
407        let mut event_id = None;
408
409        if !skip_store && !payload.trim().is_empty() {
410            if let (Some(session_id), Some(stream_id), Some(event_store)) = (
411                self.session_id.as_ref(),
412                self.stream_id.as_ref(),
413                self.event_store.as_ref(),
414            ) {
415                event_id = event_store
416                    .store_event(
417                        session_id.clone(),
418                        stream_id.clone(),
419                        current_timestamp(),
420                        payload.to_owned(),
421                    )
422                    .await
423                    .map(Some)
424                    .unwrap_or_else(|err| {
425                        tracing::error!("{err}");
426                        None
427                    });
428            };
429        }
430
431        if let Some(writable_std) = self.writable_std.as_ref() {
432            let mut writable_std = writable_std.lock().await;
433            if let Some(id) = event_id {
434                writable_std.write_all(id.as_bytes()).await?;
435                writable_std.write_all(&[ID_SEPARATOR]).await?; // separate id from message
436            }
437            writable_std.write_all(payload.as_bytes()).await?;
438            writable_std.write_all(b"\n").await?; // new line
439            writable_std.flush().await?;
440            return Ok(());
441        };
442
443        if let Some(writable_tx) = self.writable_tx.as_ref() {
444            let (resp_tx, resp_rx) = oneshot::channel();
445            writable_tx
446                .send((payload.to_string(), resp_tx))
447                .await
448                .map_err(|err| TransportError::Internal(err.to_string()))?; // Send fails if channel closed
449            return resp_rx.await?; // Await the POST result; propagates the error if POST failed
450        }
451
452        Err(TransportError::Internal("Invalid dispatcher!".to_string()))
453    }
454}