rust_mcp_transport/
message_dispatcher.rs

1use async_trait::async_trait;
2use rust_mcp_schema::schema_utils::{
3    self, ClientMessage, FromMessage, McpMessage, MessageFromClient, MessageFromServer,
4    ServerMessage,
5};
6use rust_mcp_schema::{RequestId, RpcError};
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::atomic::AtomicI64;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::AsyncWriteExt;
13use tokio::sync::oneshot;
14use tokio::sync::Mutex;
15
16use crate::error::{TransportError, TransportResult};
17use crate::utils::await_timeout;
18use crate::McpDispatch;
19
20/// Provides a dispatcher for sending MCP messages and handling responses.
21///
22/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
23/// and response handling. It supports both client-to-server and server-to-client message flows through
24/// implementations of the `McpDispatch` trait. The dispatcher uses a transport mechanism
25/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
26/// a configurable timeout mechanism for asynchronous responses.
27pub struct MessageDispatcher<R> {
28    pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
29    writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
30    message_id_counter: Arc<AtomicI64>,
31    request_timeout: Duration,
32}
33
34impl<R> MessageDispatcher<R> {
35    /// Creates a new `MessageDispatcher` instance with the given configuration.
36    ///
37    /// # Arguments
38    /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
39    /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
40    /// * `message_id_counter` - An atomic counter for generating unique request IDs.
41    /// * `request_timeout` - The timeout duration in milliseconds for awaiting responses.
42    ///
43    /// # Returns
44    /// A new `MessageDispatcher` instance configured for MCP message handling.
45    pub fn new(
46        pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
47        writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
48        message_id_counter: Arc<AtomicI64>,
49        request_timeout: Duration,
50    ) -> Self {
51        Self {
52            pending_requests,
53            writable_std,
54            message_id_counter,
55            request_timeout,
56        }
57    }
58
59    /// Determines the request ID for an outgoing MCP message.
60    ///
61    /// For requests, generates a new ID using the internal counter. For responses or errors,
62    /// uses the provided `request_id`. Notifications receive no ID.
63    ///
64    /// # Arguments
65    /// * `message` - The MCP message to evaluate.
66    /// * `request_id` - An optional existing request ID (required for responses/errors).
67    ///
68    /// # Returns
69    /// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
70    fn request_id_for_message(
71        &self,
72        message: &impl McpMessage,
73        request_id: Option<RequestId>,
74    ) -> Option<RequestId> {
75        // we need to produce next request_id for requests
76        if message.is_request() {
77            // request_id should be None for requests
78            assert!(request_id.is_none());
79            Some(RequestId::Integer(
80                self.message_id_counter
81                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed),
82            ))
83        } else if !message.is_notification() {
84            // `request_id` must not be `None` for errors, notifications and responses
85            assert!(request_id.is_some());
86            request_id
87        } else {
88            None
89        }
90    }
91}
92
93#[async_trait]
94impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerMessage> {
95    /// Sends a message from the client to the server and awaits a response if applicable.
96    ///
97    /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a
98    /// `ServerMessage` response if the message is a request. Notifications and responses return
99    /// `Ok(None)`.
100    ///
101    /// # Arguments
102    /// * `message` - The client message to send.
103    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
104    ///
105    /// # Returns
106    /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response,
107    /// or `None` for notifications/responses, or an error if the operation fails.
108    ///
109    /// # Errors
110    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
111    async fn send(
112        &self,
113        message: MessageFromClient,
114        request_id: Option<RequestId>,
115        request_timeout: Option<Duration>,
116    ) -> TransportResult<Option<ServerMessage>> {
117        let mut writable_std = self.writable_std.lock().await;
118
119        // returns the request_id to be used to construct the message
120        // a new requestId will be returned for Requests and Notification
121        let outgoing_request_id = self.request_id_for_message(&message, request_id);
122
123        let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> = {
124            // Store the sender in the pending requests map
125            if message.is_request() {
126                if let Some(request_id) = &outgoing_request_id {
127                    let (tx_response, rx_response) = oneshot::channel::<ServerMessage>();
128                    let mut pending_requests = self.pending_requests.lock().await;
129                    // store request id in the hashmap while waiting for a matching response
130                    pending_requests.insert(request_id.clone(), tx_response);
131                    Some(rx_response)
132                } else {
133                    None
134                }
135            } else {
136                None
137            }
138        };
139
140        let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?;
141
142        //serialize the message and write it to the writable_std
143        let message_str = serde_json::to_string(&mpc_message)
144            .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
145
146        writable_std.write_all(message_str.as_bytes()).await?;
147        writable_std.write_all(b"\n").await?; // new line
148        writable_std.flush().await?;
149
150        if let Some(rx) = rx_response {
151            // Wait for the response with timeout
152            match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
153                Ok(response) => Ok(Some(response)),
154                Err(error) => match error {
155                    TransportError::OneshotRecvError(_) => {
156                        Err(schema_utils::SdkError::connection_closed().into())
157                    }
158                    _ => Err(error),
159                },
160            }
161        } else {
162            Ok(None)
163        }
164    }
165}
166
167#[async_trait]
168impl McpDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientMessage> {
169    /// Sends a message from the server to the client and awaits a response if applicable.
170    ///
171    /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a
172    /// `ClientMessage` response if the message is a request. Notifications and responses return
173    /// `Ok(None)`.
174    ///
175    /// # Arguments
176    /// * `message` - The server message to send.
177    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
178    ///
179    /// # Returns
180    /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response,
181    /// or `None` for notifications/responses, or an error if the operation fails.
182    ///
183    /// # Errors
184    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
185    async fn send(
186        &self,
187        message: MessageFromServer,
188        request_id: Option<RequestId>,
189        request_timeout: Option<Duration>,
190    ) -> TransportResult<Option<ClientMessage>> {
191        let mut writable_std = self.writable_std.lock().await;
192
193        // returns the request_id to be used to construct the message
194        // a new requestId will be returned for Requests and Notification
195        let outgoing_request_id = self.request_id_for_message(&message, request_id);
196
197        let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> = {
198            // Store the sender in the pending requests map
199            if message.is_request() {
200                if let Some(request_id) = &outgoing_request_id {
201                    let (tx_response, rx_response) = oneshot::channel::<ClientMessage>();
202                    let mut pending_requests = self.pending_requests.lock().await;
203                    // store request id in the hashmap while waiting for a matching response
204                    pending_requests.insert(request_id.clone(), tx_response);
205                    Some(rx_response)
206                } else {
207                    None
208                }
209            } else {
210                None
211            }
212        };
213
214        let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?;
215
216        //serialize the message and write it to the writable_std
217        let message_str = serde_json::to_string(&mpc_message)
218            .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
219
220        writable_std.write_all(message_str.as_bytes()).await?;
221        writable_std.write_all(b"\n").await?; // new line
222        writable_std.flush().await?;
223
224        if let Some(rx) = rx_response {
225            match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
226                Ok(response) => Ok(Some(response)),
227                Err(error) => Err(error),
228            }
229        } else {
230            Ok(None)
231        }
232    }
233}