rust_mcp_transport/
message_dispatcher.rs

1use async_trait::async_trait;
2use rust_mcp_schema::schema_utils::{
3    self, ClientMessage, FromMessage, McpMessage, MessageFromClient, MessageFromServer,
4    ServerMessage,
5};
6use rust_mcp_schema::{RequestId, RpcError};
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::atomic::AtomicI64;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::AsyncWriteExt;
13use tokio::sync::oneshot;
14use tokio::sync::Mutex;
15
16use crate::error::{TransportError, TransportResult};
17use crate::utils::await_timeout;
18use crate::McpDispatch;
19
20/// Provides a dispatcher for sending MCP messages and handling responses.
21///
22/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
23/// and response handling. It supports both client-to-server and server-to-client message flows through
24/// implementations of the `McpDispatch` trait. The dispatcher uses a transport mechanism
25/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
26/// a configurable timeout mechanism for asynchronous responses.
27pub struct MessageDispatcher<R> {
28    pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
29    writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
30    message_id_counter: Arc<AtomicI64>,
31    timeout_msec: u64,
32}
33
34impl<R> MessageDispatcher<R> {
35    /// Creates a new `MessageDispatcher` instance with the given configuration.
36    ///
37    /// # Arguments
38    /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
39    /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
40    /// * `message_id_counter` - An atomic counter for generating unique request IDs.
41    /// * `timeout_msec` - The timeout duration in milliseconds for awaiting responses.
42    ///
43    /// # Returns
44    /// A new `MessageDispatcher` instance configured for MCP message handling.
45    pub fn new(
46        pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
47        writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
48        message_id_counter: Arc<AtomicI64>,
49        timeout_msec: u64,
50    ) -> Self {
51        Self {
52            pending_requests,
53            writable_std,
54            message_id_counter,
55            timeout_msec,
56        }
57    }
58
59    /// Determines the request ID for an outgoing MCP message.
60    ///
61    /// For requests, generates a new ID using the internal counter. For responses or errors,
62    /// uses the provided `request_id`. Notifications receive no ID.
63    ///
64    /// # Arguments
65    /// * `message` - The MCP message to evaluate.
66    /// * `request_id` - An optional existing request ID (required for responses/errors).
67    ///
68    /// # Returns
69    /// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
70    fn request_id_for_message(
71        &self,
72        message: &impl McpMessage,
73        request_id: Option<RequestId>,
74    ) -> Option<RequestId> {
75        // we need to produce next request_id for requests
76        if message.is_request() {
77            // request_id should be None for requests
78            assert!(request_id.is_none());
79            Some(RequestId::Integer(
80                self.message_id_counter
81                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed),
82            ))
83        } else if !message.is_notification() {
84            // `request_id` must not be `None` for errors, notifications and responses
85            assert!(request_id.is_some());
86            request_id
87        } else {
88            None
89        }
90    }
91}
92
93#[async_trait]
94impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerMessage> {
95    /// Sends a message from the client to the server and awaits a response if applicable.
96    ///
97    /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a
98    /// `ServerMessage` response if the message is a request. Notifications and responses return
99    /// `Ok(None)`.
100    ///
101    /// # Arguments
102    /// * `message` - The client message to send.
103    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
104    ///
105    /// # Returns
106    /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response,
107    /// or `None` for notifications/responses, or an error if the operation fails.
108    ///
109    /// # Errors
110    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
111    async fn send(
112        &self,
113        message: MessageFromClient,
114        request_id: Option<RequestId>,
115    ) -> TransportResult<Option<ServerMessage>> {
116        let mut writable_std = self.writable_std.lock().await;
117
118        // returns the request_id to be used to construct the message
119        // a new requestId will be returned for Requests and Notification
120        let outgoing_request_id = self.request_id_for_message(&message, request_id);
121
122        let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> = {
123            // Store the sender in the pending requests map
124            if message.is_request() {
125                if let Some(request_id) = &outgoing_request_id {
126                    let (tx_response, rx_response) = oneshot::channel::<ServerMessage>();
127                    let mut pending_requests = self.pending_requests.lock().await;
128                    // store request id in the hashmap while waiting for a matching response
129                    pending_requests.insert(request_id.clone(), tx_response);
130                    Some(rx_response)
131                } else {
132                    None
133                }
134            } else {
135                None
136            }
137        };
138
139        let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?;
140
141        //serialize the message and write it to the writable_std
142        let message_str = serde_json::to_string(&mpc_message)
143            .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
144
145        writable_std.write_all(message_str.as_bytes()).await?;
146        writable_std.write_all(b"\n").await?; // new line
147        writable_std.flush().await?;
148
149        if let Some(rx) = rx_response {
150            // Wait for the response with timeout
151            match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
152                Ok(response) => Ok(Some(response)),
153                Err(error) => match error {
154                    TransportError::OneshotRecvError(_) => {
155                        Err(schema_utils::SdkError::connection_closed().into())
156                    }
157                    _ => Err(error),
158                },
159            }
160        } else {
161            Ok(None)
162        }
163    }
164}
165
166#[async_trait]
167impl McpDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientMessage> {
168    /// Sends a message from the server to the client and awaits a response if applicable.
169    ///
170    /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a
171    /// `ClientMessage` response if the message is a request. Notifications and responses return
172    /// `Ok(None)`.
173    ///
174    /// # Arguments
175    /// * `message` - The server message to send.
176    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
177    ///
178    /// # Returns
179    /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response,
180    /// or `None` for notifications/responses, or an error if the operation fails.
181    ///
182    /// # Errors
183    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
184    async fn send(
185        &self,
186        message: MessageFromServer,
187        request_id: Option<RequestId>,
188    ) -> TransportResult<Option<ClientMessage>> {
189        let mut writable_std = self.writable_std.lock().await;
190
191        // returns the request_id to be used to construct the message
192        // a new requestId will be returned for Requests and Notification
193        let outgoing_request_id = self.request_id_for_message(&message, request_id);
194
195        let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> = {
196            // Store the sender in the pending requests map
197            if message.is_request() {
198                if let Some(request_id) = &outgoing_request_id {
199                    let (tx_response, rx_response) = oneshot::channel::<ClientMessage>();
200                    let mut pending_requests = self.pending_requests.lock().await;
201                    // store request id in the hashmap while waiting for a matching response
202                    pending_requests.insert(request_id.clone(), tx_response);
203                    Some(rx_response)
204                } else {
205                    None
206                }
207            } else {
208                None
209            }
210        };
211
212        let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?;
213
214        //serialize the message and write it to the writable_std
215        let message_str = serde_json::to_string(&mpc_message)
216            .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
217
218        writable_std.write_all(message_str.as_bytes()).await?;
219        writable_std.write_all(b"\n").await?; // new line
220        writable_std.flush().await?;
221
222        if let Some(rx) = rx_response {
223            match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
224                Ok(response) => Ok(Some(response)),
225                Err(error) => Err(error),
226            }
227        } else {
228            Ok(None)
229        }
230    }
231}