rust_mcp_transport/
message_dispatcher.rs

1use async_trait::async_trait;
2use rust_mcp_schema::schema_utils::{
3    ClientMessage, FromMessage, MCPMessage, MessageFromClient, MessageFromServer, ServerMessage,
4};
5use rust_mcp_schema::{JsonrpcErrorError, RequestId};
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::atomic::AtomicI64;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::io::AsyncWriteExt;
12use tokio::sync::oneshot;
13use tokio::sync::Mutex;
14
15use crate::error::TransportResult;
16use crate::utils::await_timeout;
17use crate::MCPDispatch;
18
19/// Provides a dispatcher for sending MCP messages and handling responses.
20///
21/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
22/// and response handling. It supports both client-to-server and server-to-client message flows through
23/// implementations of the `MCPDispatch` trait. The dispatcher uses a transport mechanism
24/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
25/// a configurable timeout mechanism for asynchronous responses.
26pub struct MessageDispatcher<R> {
27    pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
28    writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
29    message_id_counter: Arc<AtomicI64>,
30    timeout_msec: u64,
31}
32
33impl<R> MessageDispatcher<R> {
34    /// Creates a new `MessageDispatcher` instance with the given configuration.
35    ///
36    /// # Arguments
37    /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
38    /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
39    /// * `message_id_counter` - An atomic counter for generating unique request IDs.
40    /// * `timeout_msec` - The timeout duration in milliseconds for awaiting responses.
41    ///
42    /// # Returns
43    /// A new `MessageDispatcher` instance configured for MCP message handling.
44    pub fn new(
45        pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
46        writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
47        message_id_counter: Arc<AtomicI64>,
48        timeout_msec: u64,
49    ) -> Self {
50        Self {
51            pending_requests,
52            writable_std,
53            message_id_counter,
54            timeout_msec,
55        }
56    }
57
58    /// Determines the request ID for an outgoing MCP message.
59    ///
60    /// For requests, generates a new ID using the internal counter. For responses or errors,
61    /// uses the provided `request_id`. Notifications receive no ID.
62    ///
63    /// # Arguments
64    /// * `message` - The MCP message to evaluate.
65    /// * `request_id` - An optional existing request ID (required for responses/errors).
66    ///
67    /// # Returns
68    /// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
69    fn request_id_for_message(
70        &self,
71        message: &impl MCPMessage,
72        request_id: Option<RequestId>,
73    ) -> Option<RequestId> {
74        // we need to produce next request_id for requests
75        if message.is_request() {
76            // request_id should be None for requests
77            assert!(request_id.is_none());
78            Some(RequestId::Integer(
79                self.message_id_counter
80                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed),
81            ))
82        } else if !message.is_notification() {
83            // `request_id` must not be `None` for errors, notifications and responses
84            assert!(request_id.is_some());
85            request_id
86        } else {
87            None
88        }
89    }
90}
91
92#[async_trait]
93impl MCPDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerMessage> {
94    /// Sends a message from the client to the server and awaits a response if applicable.
95    ///
96    /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a
97    /// `ServerMessage` response if the message is a request. Notifications and responses return
98    /// `Ok(None)`.
99    ///
100    /// # Arguments
101    /// * `message` - The client message to send.
102    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
103    ///
104    /// # Returns
105    /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response,
106    /// or `None` for notifications/responses, or an error if the operation fails.
107    ///
108    /// # Errors
109    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
110    async fn send(
111        &self,
112        message: MessageFromClient,
113        request_id: Option<RequestId>,
114    ) -> TransportResult<Option<ServerMessage>> {
115        let mut writable_std = self.writable_std.lock().await;
116
117        // returns the request_id to be used to construct the message
118        // a new requestId will be returned for Requests and Notification
119        let outgoing_request_id = self.request_id_for_message(&message, request_id);
120
121        let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> = {
122            // Store the sender in the pending requests map
123            if message.is_request() {
124                if let Some(request_id) = &outgoing_request_id {
125                    let (tx_response, rx_response) = oneshot::channel::<ServerMessage>();
126                    let mut pending_requests = self.pending_requests.lock().await;
127                    // store request id in the hashmap while waiting for a matching response
128                    pending_requests.insert(request_id.clone(), tx_response);
129                    Some(rx_response)
130                } else {
131                    None
132                }
133            } else {
134                None
135            }
136        };
137
138        let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?;
139
140        //serialize the message and write it to the writable_std
141        let message_str = serde_json::to_string(&mpc_message).map_err(|_| {
142            crate::error::TransportError::JsonrpcError(JsonrpcErrorError::parse_error())
143        })?;
144
145        writable_std.write_all(message_str.as_bytes()).await?;
146        writable_std.write_all(b"\n").await?; // new line
147        writable_std.flush().await?;
148
149        if let Some(rx) = rx_response {
150            match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
151                Ok(response) => Ok(Some(response)),
152                Err(error) => Err(error),
153            }
154        } else {
155            Ok(None)
156        }
157    }
158}
159
160#[async_trait]
161impl MCPDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientMessage> {
162    /// Sends a message from the server to the client and awaits a response if applicable.
163    ///
164    /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a
165    /// `ClientMessage` response if the message is a request. Notifications and responses return
166    /// `Ok(None)`.
167    ///
168    /// # Arguments
169    /// * `message` - The server message to send.
170    /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
171    ///
172    /// # Returns
173    /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response,
174    /// or `None` for notifications/responses, or an error if the operation fails.
175    ///
176    /// # Errors
177    /// Returns a `TransportError` if serialization, writing, or timeout occurs.
178    async fn send(
179        &self,
180        message: MessageFromServer,
181        request_id: Option<RequestId>,
182    ) -> TransportResult<Option<ClientMessage>> {
183        let mut writable_std = self.writable_std.lock().await;
184
185        // returns the request_id to be used to construct the message
186        // a new requestId will be returned for Requests and Notification
187        let outgoing_request_id = self.request_id_for_message(&message, request_id);
188
189        let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> = {
190            // Store the sender in the pending requests map
191            if message.is_request() {
192                if let Some(request_id) = &outgoing_request_id {
193                    let (tx_response, rx_response) = oneshot::channel::<ClientMessage>();
194                    let mut pending_requests = self.pending_requests.lock().await;
195                    // store request id in the hashmap while waiting for a matching response
196                    pending_requests.insert(request_id.clone(), tx_response);
197                    Some(rx_response)
198                } else {
199                    None
200                }
201            } else {
202                None
203            }
204        };
205
206        let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?;
207
208        //serialize the message and write it to the writable_std
209        let message_str = serde_json::to_string(&mpc_message).map_err(|_| {
210            crate::error::TransportError::JsonrpcError(JsonrpcErrorError::parse_error())
211        })?;
212
213        writable_std.write_all(message_str.as_bytes()).await?;
214        writable_std.write_all(b"\n").await?; // new line
215        writable_std.flush().await?;
216
217        if let Some(rx) = rx_response {
218            match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
219                Ok(response) => Ok(Some(response)),
220                Err(error) => Err(error),
221            }
222        } else {
223            Ok(None)
224        }
225    }
226}