rust_mcp_transport/message_dispatcher.rs
1use async_trait::async_trait;
2use rust_mcp_schema::schema_utils::{
3 ClientMessage, FromMessage, MCPMessage, MessageFromClient, MessageFromServer, ServerMessage,
4};
5use rust_mcp_schema::{JsonrpcErrorError, RequestId};
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::atomic::AtomicI64;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::io::AsyncWriteExt;
12use tokio::sync::oneshot;
13use tokio::sync::Mutex;
14
15use crate::error::TransportResult;
16use crate::utils::await_timeout;
17use crate::MCPDispatch;
18
19/// Provides a dispatcher for sending MCP messages and handling responses.
20///
21/// `MessageDispatcher` facilitates MCP communication by managing message sending, request tracking,
22/// and response handling. It supports both client-to-server and server-to-client message flows through
23/// implementations of the `MCPDispatch` trait. The dispatcher uses a transport mechanism
24/// (e.g., stdin/stdout) to serialize and send messages, and it tracks pending requests with
25/// a configurable timeout mechanism for asynchronous responses.
26pub struct MessageDispatcher<R> {
27 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
28 writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
29 message_id_counter: Arc<AtomicI64>,
30 timeout_msec: u64,
31}
32
33impl<R> MessageDispatcher<R> {
34 /// Creates a new `MessageDispatcher` instance with the given configuration.
35 ///
36 /// # Arguments
37 /// * `pending_requests` - A thread-safe map for storing pending request IDs and their response channels.
38 /// * `writable_std` - A mutex-protected, pinned writer (e.g., stdout) for sending serialized messages.
39 /// * `message_id_counter` - An atomic counter for generating unique request IDs.
40 /// * `timeout_msec` - The timeout duration in milliseconds for awaiting responses.
41 ///
42 /// # Returns
43 /// A new `MessageDispatcher` instance configured for MCP message handling.
44 pub fn new(
45 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
46 writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
47 message_id_counter: Arc<AtomicI64>,
48 timeout_msec: u64,
49 ) -> Self {
50 Self {
51 pending_requests,
52 writable_std,
53 message_id_counter,
54 timeout_msec,
55 }
56 }
57
58 /// Determines the request ID for an outgoing MCP message.
59 ///
60 /// For requests, generates a new ID using the internal counter. For responses or errors,
61 /// uses the provided `request_id`. Notifications receive no ID.
62 ///
63 /// # Arguments
64 /// * `message` - The MCP message to evaluate.
65 /// * `request_id` - An optional existing request ID (required for responses/errors).
66 ///
67 /// # Returns
68 /// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
69 fn request_id_for_message(
70 &self,
71 message: &impl MCPMessage,
72 request_id: Option<RequestId>,
73 ) -> Option<RequestId> {
74 // we need to produce next request_id for requests
75 if message.is_request() {
76 // request_id should be None for requests
77 assert!(request_id.is_none());
78 Some(RequestId::Integer(
79 self.message_id_counter
80 .fetch_add(1, std::sync::atomic::Ordering::Relaxed),
81 ))
82 } else if !message.is_notification() {
83 // `request_id` must not be `None` for errors, notifications and responses
84 assert!(request_id.is_some());
85 request_id
86 } else {
87 None
88 }
89 }
90}
91
92#[async_trait]
93impl MCPDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerMessage> {
94 /// Sends a message from the client to the server and awaits a response if applicable.
95 ///
96 /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a
97 /// `ServerMessage` response if the message is a request. Notifications and responses return
98 /// `Ok(None)`.
99 ///
100 /// # Arguments
101 /// * `message` - The client message to send.
102 /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
103 ///
104 /// # Returns
105 /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response,
106 /// or `None` for notifications/responses, or an error if the operation fails.
107 ///
108 /// # Errors
109 /// Returns a `TransportError` if serialization, writing, or timeout occurs.
110 async fn send(
111 &self,
112 message: MessageFromClient,
113 request_id: Option<RequestId>,
114 ) -> TransportResult<Option<ServerMessage>> {
115 let mut writable_std = self.writable_std.lock().await;
116
117 // returns the request_id to be used to construct the message
118 // a new requestId will be returned for Requests and Notification
119 let outgoing_request_id = self.request_id_for_message(&message, request_id);
120
121 let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> = {
122 // Store the sender in the pending requests map
123 if message.is_request() {
124 if let Some(request_id) = &outgoing_request_id {
125 let (tx_response, rx_response) = oneshot::channel::<ServerMessage>();
126 let mut pending_requests = self.pending_requests.lock().await;
127 // store request id in the hashmap while waiting for a matching response
128 pending_requests.insert(request_id.clone(), tx_response);
129 Some(rx_response)
130 } else {
131 None
132 }
133 } else {
134 None
135 }
136 };
137
138 let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?;
139
140 //serialize the message and write it to the writable_std
141 let message_str = serde_json::to_string(&mpc_message).map_err(|_| {
142 crate::error::TransportError::JsonrpcError(JsonrpcErrorError::parse_error())
143 })?;
144
145 writable_std.write_all(message_str.as_bytes()).await?;
146 writable_std.write_all(b"\n").await?; // new line
147 writable_std.flush().await?;
148
149 if let Some(rx) = rx_response {
150 match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
151 Ok(response) => Ok(Some(response)),
152 Err(error) => Err(error),
153 }
154 } else {
155 Ok(None)
156 }
157 }
158}
159
160#[async_trait]
161impl MCPDispatch<ClientMessage, MessageFromServer> for MessageDispatcher<ClientMessage> {
162 /// Sends a message from the server to the client and awaits a response if applicable.
163 ///
164 /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a
165 /// `ClientMessage` response if the message is a request. Notifications and responses return
166 /// `Ok(None)`.
167 ///
168 /// # Arguments
169 /// * `message` - The server message to send.
170 /// * `request_id` - An optional request ID (used for responses/errors, None for requests).
171 ///
172 /// # Returns
173 /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response,
174 /// or `None` for notifications/responses, or an error if the operation fails.
175 ///
176 /// # Errors
177 /// Returns a `TransportError` if serialization, writing, or timeout occurs.
178 async fn send(
179 &self,
180 message: MessageFromServer,
181 request_id: Option<RequestId>,
182 ) -> TransportResult<Option<ClientMessage>> {
183 let mut writable_std = self.writable_std.lock().await;
184
185 // returns the request_id to be used to construct the message
186 // a new requestId will be returned for Requests and Notification
187 let outgoing_request_id = self.request_id_for_message(&message, request_id);
188
189 let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> = {
190 // Store the sender in the pending requests map
191 if message.is_request() {
192 if let Some(request_id) = &outgoing_request_id {
193 let (tx_response, rx_response) = oneshot::channel::<ClientMessage>();
194 let mut pending_requests = self.pending_requests.lock().await;
195 // store request id in the hashmap while waiting for a matching response
196 pending_requests.insert(request_id.clone(), tx_response);
197 Some(rx_response)
198 } else {
199 None
200 }
201 } else {
202 None
203 }
204 };
205
206 let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?;
207
208 //serialize the message and write it to the writable_std
209 let message_str = serde_json::to_string(&mpc_message).map_err(|_| {
210 crate::error::TransportError::JsonrpcError(JsonrpcErrorError::parse_error())
211 })?;
212
213 writable_std.write_all(message_str.as_bytes()).await?;
214 writable_std.write_all(b"\n").await?; // new line
215 writable_std.flush().await?;
216
217 if let Some(rx) = rx_response {
218 match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
219 Ok(response) => Ok(Some(response)),
220 Err(error) => Err(error),
221 }
222 } else {
223 Ok(None)
224 }
225 }
226}