rust_mcp_transport/
message_dispatcher.rs

1use async_trait::async_trait;
2use rust_mcp_schema::schema_utils::{
3    ClientMessage, FromMessage, MCPMessage, MessageFromClient, MessageFromServer, ServerMessage,
4};
5use rust_mcp_schema::{RequestId, RpcError};
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::atomic::AtomicI64;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::io::AsyncWriteExt;
12use tokio::sync::oneshot;
13use tokio::sync::Mutex;
14
15use crate::error::TransportResult;
16use crate::utils::await_timeout;
17use crate::McpDispatch;
18
19/// Provides a dispatcher for sending MCP messages and handling responses.
20///
21/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
22/// and response handling. It supports both client-to-server and server-to-client message flows through
23/// implementations of the `McpDispatch` trait. The dispatcher uses a transport mechanism
24/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
25/// a configurable timeout mechanism for asynchronous responses.
26pub struct MessageDispatcher<R> {
27    pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
28    writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
29    message_id_counter: Arc<AtomicI64>,
30    timeout_msec: u64,
31}
32
33impl<R> MessageDispatcher<R> {
34    /// Creates a new `MessageDispatcher` instance with the given configuration.
35    ///
36    /// # Arguments
37    /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
38    /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
39    /// * `message_id_counter` - An atomic counter for generating unique request IDs.
40    /// * `timeout_msec` - The timeout duration in milliseconds for awaiting responses.
41    ///
42    /// # Returns
43    /// A new `MessageDispatcher` instance configured for MCP message handling.
44    pub fn new(
45        pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
46        writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
47        message_id_counter: Arc<AtomicI64>,
48        timeout_msec: u64,
49    ) -> Self {
50        Self {
51            pending_requests,
52            writable_std,
53            message_id_counter,
54            timeout_msec,
55        }
56    }
57
58    /// Determines the request ID for an outgoing MCP message.
59    ///
60    /// For requests, generates a new ID using the internal counter. For responses or errors,
61    /// uses the provided `request_id`. Notifications receive no ID.
62    ///
63    /// # Arguments
64    /// * `message` - The MCP message to evaluate.
65    /// * `request_id` - An optional existing request ID (required for responses/errors).
66    ///
67    /// # Returns
68    /// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
69    fn request_id_for_message(
70        &self,
71        message: &impl MCPMessage,
72        request_id: Option<RequestId>,
73    ) -> Option<RequestId> {
74        // we need to produce next request_id for requests
75        if message.is_request() {
76            // request_id should be None for requests
77            assert!(request_id.is_none());
78            Some(RequestId::Integer(
79                self.message_id_counter
80                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed),
81            ))
82        } else if !message.is_notification() {
83            // `request_id` must not be `None` for errors, notifications and responses
84            assert!(request_id.is_some());
85            request_id
86        } else {
87            None
88        }
89    }
90}
91
92#[async_trait]
93impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerMessage> {
94    /// Sends a message from the client to the server and awaits a response if applicable.
95    ///
96    /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a
97    /// `ServerMessage` response if the message is a request. Notifications and responses return
98    /// `Ok(None)`.
99    ///
100    /// # Arguments
101    /// * `message` - The client message to send.
102    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
103    ///
104    /// # Returns
105    /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response,
106    /// or `None` for notifications/responses, or an error if the operation fails.
107    ///
108    /// # Errors
109    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
110    async fn send(
111        &self,
112        message: MessageFromClient,
113        request_id: Option<RequestId>,
114    ) -> TransportResult<Option<ServerMessage>> {
115        let mut writable_std = self.writable_std.lock().await;
116
117        // returns the request_id to be used to construct the message
118        // a new requestId will be returned for Requests and Notification
119        let outgoing_request_id = self.request_id_for_message(&message, request_id);
120
121        let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> = {
122            // Store the sender in the pending requests map
123            if message.is_request() {
124                if let Some(request_id) = &outgoing_request_id {
125                    let (tx_response, rx_response) = oneshot::channel::<ServerMessage>();
126                    let mut pending_requests = self.pending_requests.lock().await;
127                    // store request id in the hashmap while waiting for a matching response
128                    pending_requests.insert(request_id.clone(), tx_response);
129                    Some(rx_response)
130                } else {
131                    None
132                }
133            } else {
134                None
135            }
136        };
137
138        let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?;
139
140        //serialize the message and write it to the writable_std
141        let message_str = serde_json::to_string(&mpc_message)
142            .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
143
144        writable_std.write_all(message_str.as_bytes()).await?;
145        writable_std.write_all(b"\n").await?; // new line
146        writable_std.flush().await?;
147
148        if let Some(rx) = rx_response {
149            match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
150                Ok(response) => Ok(Some(response)),
151                Err(error) => Err(error),
152            }
153        } else {
154            Ok(None)
155        }
156    }
157}
158
159#[async_trait]
160impl McpDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientMessage> {
161    /// Sends a message from the server to the client and awaits a response if applicable.
162    ///
163    /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a
164    /// `ClientMessage` response if the message is a request. Notifications and responses return
165    /// `Ok(None)`.
166    ///
167    /// # Arguments
168    /// * `message` - The server message to send.
169    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
170    ///
171    /// # Returns
172    /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response,
173    /// or `None` for notifications/responses, or an error if the operation fails.
174    ///
175    /// # Errors
176    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
177    async fn send(
178        &self,
179        message: MessageFromServer,
180        request_id: Option<RequestId>,
181    ) -> TransportResult<Option<ClientMessage>> {
182        let mut writable_std = self.writable_std.lock().await;
183
184        // returns the request_id to be used to construct the message
185        // a new requestId will be returned for Requests and Notification
186        let outgoing_request_id = self.request_id_for_message(&message, request_id);
187
188        let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> = {
189            // Store the sender in the pending requests map
190            if message.is_request() {
191                if let Some(request_id) = &outgoing_request_id {
192                    let (tx_response, rx_response) = oneshot::channel::<ClientMessage>();
193                    let mut pending_requests = self.pending_requests.lock().await;
194                    // store request id in the hashmap while waiting for a matching response
195                    pending_requests.insert(request_id.clone(), tx_response);
196                    Some(rx_response)
197                } else {
198                    None
199                }
200            } else {
201                None
202            }
203        };
204
205        let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?;
206
207        //serialize the message and write it to the writable_std
208        let message_str = serde_json::to_string(&mpc_message)
209            .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
210
211        writable_std.write_all(message_str.as_bytes()).await?;
212        writable_std.write_all(b"\n").await?; // new line
213        writable_std.flush().await?;
214
215        if let Some(rx) = rx_response {
216            match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
217                Ok(response) => Ok(Some(response)),
218                Err(error) => Err(error),
219            }
220        } else {
221            Ok(None)
222        }
223    }
224}