rust_mcp_transport/message_dispatcher.rs
1use async_trait::async_trait;
2use rust_mcp_schema::schema_utils::{
3 self, ClientMessage, FromMessage, McpMessage, MessageFromClient, MessageFromServer,
4 ServerMessage,
5};
6use rust_mcp_schema::{RequestId, RpcError};
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::atomic::AtomicI64;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::AsyncWriteExt;
13use tokio::sync::oneshot;
14use tokio::sync::Mutex;
15
16use crate::error::{TransportError, TransportResult};
17use crate::utils::await_timeout;
18use crate::McpDispatch;
19
20/// Provides a dispatcher for sending MCP messages and handling responses.
21///
22/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
23/// and response handling. It supports both client-to-server and server-to-client message flows through
24/// implementations of the `McpDispatch` trait. The dispatcher uses a transport mechanism
25/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
26/// a configurable timeout mechanism for asynchronous responses.
27pub struct MessageDispatcher<R> {
28 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
29 writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
30 message_id_counter: Arc<AtomicI64>,
31 timeout_msec: u64,
32}
33
34impl<R> MessageDispatcher<R> {
35 /// Creates a new `MessageDispatcher` instance with the given configuration.
36 ///
37 /// # Arguments
38 /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
39 /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
40 /// * `message_id_counter` - An atomic counter for generating unique request IDs.
41 /// * `timeout_msec` - The timeout duration in milliseconds for awaiting responses.
42 ///
43 /// # Returns
44 /// A new `MessageDispatcher` instance configured for MCP message handling.
45 pub fn new(
46 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
47 writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
48 message_id_counter: Arc<AtomicI64>,
49 timeout_msec: u64,
50 ) -> Self {
51 Self {
52 pending_requests,
53 writable_std,
54 message_id_counter,
55 timeout_msec,
56 }
57 }
58
59 /// Determines the request ID for an outgoing MCP message.
60 ///
61 /// For requests, generates a new ID using the internal counter. For responses or errors,
62 /// uses the provided `request_id`. Notifications receive no ID.
63 ///
64 /// # Arguments
65 /// * `message` - The MCP message to evaluate.
66 /// * `request_id` - An optional existing request ID (required for responses/errors).
67 ///
68 /// # Returns
69 /// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
70 fn request_id_for_message(
71 &self,
72 message: &impl McpMessage,
73 request_id: Option<RequestId>,
74 ) -> Option<RequestId> {
75 // we need to produce next request_id for requests
76 if message.is_request() {
77 // request_id should be None for requests
78 assert!(request_id.is_none());
79 Some(RequestId::Integer(
80 self.message_id_counter
81 .fetch_add(1, std::sync::atomic::Ordering::Relaxed),
82 ))
83 } else if !message.is_notification() {
84 // `request_id` must not be `None` for errors, notifications and responses
85 assert!(request_id.is_some());
86 request_id
87 } else {
88 None
89 }
90 }
91}
92
93#[async_trait]
94impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerMessage> {
95 /// Sends a message from the client to the server and awaits a response if applicable.
96 ///
97 /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a
98 /// `ServerMessage` response if the message is a request. Notifications and responses return
99 /// `Ok(None)`.
100 ///
101 /// # Arguments
102 /// * `message` - The client message to send.
103 /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
104 ///
105 /// # Returns
106 /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response,
107 /// or `None` for notifications/responses, or an error if the operation fails.
108 ///
109 /// # Errors
110 /// Returns a `TransportError` if serialization, writing, or timeout occurs.
111 async fn send(
112 &self,
113 message: MessageFromClient,
114 request_id: Option<RequestId>,
115 ) -> TransportResult<Option<ServerMessage>> {
116 let mut writable_std = self.writable_std.lock().await;
117
118 // returns the request_id to be used to construct the message
119 // a new requestId will be returned for Requests and Notification
120 let outgoing_request_id = self.request_id_for_message(&message, request_id);
121
122 let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> = {
123 // Store the sender in the pending requests map
124 if message.is_request() {
125 if let Some(request_id) = &outgoing_request_id {
126 let (tx_response, rx_response) = oneshot::channel::<ServerMessage>();
127 let mut pending_requests = self.pending_requests.lock().await;
128 // store request id in the hashmap while waiting for a matching response
129 pending_requests.insert(request_id.clone(), tx_response);
130 Some(rx_response)
131 } else {
132 None
133 }
134 } else {
135 None
136 }
137 };
138
139 let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?;
140
141 //serialize the message and write it to the writable_std
142 let message_str = serde_json::to_string(&mpc_message)
143 .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
144
145 writable_std.write_all(message_str.as_bytes()).await?;
146 writable_std.write_all(b"\n").await?; // new line
147 writable_std.flush().await?;
148
149 if let Some(rx) = rx_response {
150 // Wait for the response with timeout
151 match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
152 Ok(response) => Ok(Some(response)),
153 Err(error) => match error {
154 TransportError::OneshotRecvError(_) => {
155 Err(schema_utils::SdkError::connection_closed().into())
156 }
157 _ => Err(error),
158 },
159 }
160 } else {
161 Ok(None)
162 }
163 }
164}
165
166#[async_trait]
167impl McpDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientMessage> {
168 /// Sends a message from the server to the client and awaits a response if applicable.
169 ///
170 /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a
171 /// `ClientMessage` response if the message is a request. Notifications and responses return
172 /// `Ok(None)`.
173 ///
174 /// # Arguments
175 /// * `message` - The server message to send.
176 /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
177 ///
178 /// # Returns
179 /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response,
180 /// or `None` for notifications/responses, or an error if the operation fails.
181 ///
182 /// # Errors
183 /// Returns a `TransportError` if serialization, writing, or timeout occurs.
184 async fn send(
185 &self,
186 message: MessageFromServer,
187 request_id: Option<RequestId>,
188 ) -> TransportResult<Option<ClientMessage>> {
189 let mut writable_std = self.writable_std.lock().await;
190
191 // returns the request_id to be used to construct the message
192 // a new requestId will be returned for Requests and Notification
193 let outgoing_request_id = self.request_id_for_message(&message, request_id);
194
195 let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> = {
196 // Store the sender in the pending requests map
197 if message.is_request() {
198 if let Some(request_id) = &outgoing_request_id {
199 let (tx_response, rx_response) = oneshot::channel::<ClientMessage>();
200 let mut pending_requests = self.pending_requests.lock().await;
201 // store request id in the hashmap while waiting for a matching response
202 pending_requests.insert(request_id.clone(), tx_response);
203 Some(rx_response)
204 } else {
205 None
206 }
207 } else {
208 None
209 }
210 };
211
212 let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?;
213
214 //serialize the message and write it to the writable_std
215 let message_str = serde_json::to_string(&mpc_message)
216 .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?;
217
218 writable_std.write_all(message_str.as_bytes()).await?;
219 writable_std.write_all(b"\n").await?; // new line
220 writable_std.flush().await?;
221
222 if let Some(rx) = rx_response {
223 match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
224 Ok(response) => Ok(Some(response)),
225 Err(error) => Err(error),
226 }
227 } else {
228 Ok(None)
229 }
230 }
231}