1use crate::error::{TransportError, TransportResult};
2use crate::schema::{RequestId, RpcError};
3use crate::utils::{await_timeout, current_timestamp};
4use crate::McpDispatch;
5use crate::{
6 event_store::EventStore,
7 schema::{
8 schema_utils::{
9 self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage,
10 ServerMessages,
11 },
12 JsonrpcError,
13 },
14 SessionId, StreamId,
15};
16use async_trait::async_trait;
17use futures::future::join_all;
18use std::collections::HashMap;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::AsyncWriteExt;
23use tokio::sync::oneshot::{self};
24use tokio::sync::Mutex;
25
26pub const ID_SEPARATOR: u8 = b'|';
27
28pub struct MessageDispatcher<R> {
36 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
37 writable_std: Option<Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>>,
38 writable_tx: Option<
39 tokio::sync::mpsc::Sender<(
40 String,
41 tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
42 )>,
43 >,
44 request_timeout: Duration,
45 session_id: Option<SessionId>,
47 stream_id: Option<StreamId>,
48 event_store: Option<Arc<dyn EventStore>>,
49}
50
51impl<R> MessageDispatcher<R> {
52 pub fn new(
63 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
64 writable_std: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
65 request_timeout: Duration,
66 ) -> Self {
67 Self {
68 pending_requests,
69 writable_std: Some(writable_std),
70 writable_tx: None,
71 request_timeout,
72 session_id: None,
73 stream_id: None,
74 event_store: None,
75 }
76 }
77
78 pub fn new_with_acknowledgement(
79 pending_requests: Arc<Mutex<HashMap<RequestId, oneshot::Sender<R>>>>,
80 writable_tx: tokio::sync::mpsc::Sender<(
81 String,
82 tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
83 )>,
84 request_timeout: Duration,
85 ) -> Self {
86 Self {
87 pending_requests,
88 writable_tx: Some(writable_tx),
89 writable_std: None,
90 request_timeout,
91 session_id: None,
92 stream_id: None,
93 event_store: None,
94 }
95 }
96
97 pub fn make_resumable(
100 &mut self,
101 session_id: SessionId,
102 stream_id: StreamId,
103 event_store: Arc<dyn EventStore>,
104 ) {
105 self.session_id = Some(session_id);
106 self.stream_id = Some(stream_id);
107 self.event_store = Some(event_store);
108 }
109
110 async fn store_pending_request(
111 &self,
112 request_id: RequestId,
113 ) -> tokio::sync::oneshot::Receiver<R> {
114 let (tx_response, rx_response) = oneshot::channel::<R>();
115 let mut pending_requests = self.pending_requests.lock().await;
116 pending_requests.insert(request_id.clone(), tx_response);
118 rx_response
119 }
120
121 async fn store_pending_request_for_message<M: McpMessage + RpcMessage>(
122 &self,
123 message: &M,
124 ) -> Option<tokio::sync::oneshot::Receiver<R>> {
125 if message.is_request() {
126 if let Some(request_id) = message.request_id() {
127 Some(self.store_pending_request(request_id.clone()).await)
128 } else {
129 None
130 }
131 } else {
132 None
133 }
134 }
135}
136
137#[async_trait]
139impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
140 for MessageDispatcher<ServerMessage>
141{
142 async fn send_message(
158 &self,
159 messages: ClientMessages,
160 request_timeout: Option<Duration>,
161 ) -> TransportResult<Option<ServerMessages>> {
162 match messages {
163 ClientMessages::Single(message) => {
164 let rx_response: Option<tokio::sync::oneshot::Receiver<ServerMessage>> =
165 self.store_pending_request_for_message(&message).await;
166
167 let message_payload = serde_json::to_string(&message).map_err(|_| {
169 crate::error::TransportError::JsonrpcError(RpcError::parse_error())
170 })?;
171
172 self.write_str(message_payload.as_str(), true).await?;
173
174 if let Some(rx) = rx_response {
175 match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
177 Ok(response) => Ok(Some(ServerMessages::Single(response))),
178 Err(error) => match error {
179 TransportError::ChannelClosed(_) => {
180 Err(schema_utils::SdkError::connection_closed().into())
181 }
182 _ => Err(error),
183 },
184 }
185 } else {
186 Ok(None)
187 }
188 }
189 ClientMessages::Batch(client_messages) => {
190 let (request_ids, pending_tasks): (Vec<_>, Vec<_>) = client_messages
191 .iter()
192 .filter(|message| message.is_request())
193 .map(|message| {
194 (
195 message.request_id().unwrap(), self.store_pending_request_for_message(message),
197 )
198 })
199 .unzip();
200
201 let tasks = join_all(pending_tasks).await;
203
204 let message_payload = serde_json::to_string(&client_messages).map_err(|_| {
206 crate::error::TransportError::JsonrpcError(RpcError::parse_error())
207 })?;
208 self.write_str(message_payload.as_str(), true).await?;
209
210 if request_ids.is_empty() {
212 return Ok(None);
213 }
214
215 let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| {
216 rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)))
217 });
218
219 let results: Vec<_> = join_all(timeout_wrapped_futures)
220 .await
221 .into_iter()
222 .zip(request_ids)
223 .map(|(res, request_id)| match res {
224 Ok(response) => response,
225 Err(error) => ServerMessage::Error(JsonrpcError::new(
226 RpcError::internal_error().with_message(error.to_string()),
227 request_id.to_owned(),
228 )),
229 })
230 .collect();
231
232 Ok(Some(ServerMessages::Batch(results)))
233 }
234 }
235 }
236
237 async fn send(
238 &self,
239 message: ClientMessage,
240 request_timeout: Option<Duration>,
241 ) -> TransportResult<Option<ServerMessage>> {
242 let response = self.send_message(message.into(), request_timeout).await?;
243 match response {
244 Some(r) => Ok(Some(r.as_single()?)),
245 None => Ok(None),
246 }
247 }
248
249 async fn send_batch(
250 &self,
251 message: Vec<ClientMessage>,
252 request_timeout: Option<Duration>,
253 ) -> TransportResult<Option<Vec<ServerMessage>>> {
254 let response = self.send_message(message.into(), request_timeout).await?;
255 match response {
256 Some(r) => Ok(Some(r.as_batch()?)),
257 None => Ok(None),
258 }
259 }
260
261 async fn write_str(&self, payload: &str, _skip_store: bool) -> TransportResult<()> {
265 if let Some(writable_std) = self.writable_std.as_ref() {
266 let mut writable_std = writable_std.lock().await;
267 writable_std.write_all(payload.as_bytes()).await?;
268 writable_std.write_all(b"\n").await?; writable_std.flush().await?;
270 return Ok(());
271 };
272
273 if let Some(writable_tx) = self.writable_tx.as_ref() {
274 let (resp_tx, resp_rx) = oneshot::channel();
275 writable_tx
276 .send((payload.to_string(), resp_tx))
277 .await
278 .map_err(|err| TransportError::Internal(format!("{err}")))?; return resp_rx.await?; }
281
282 Err(TransportError::Internal("Invalid dispatcher!".to_string()))
283 }
284}
285
286#[async_trait]
288impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
289 for MessageDispatcher<ClientMessage>
290{
291 async fn send_message(
307 &self,
308 messages: ServerMessages,
309 request_timeout: Option<Duration>,
310 ) -> TransportResult<Option<ClientMessages>> {
311 match messages {
312 ServerMessages::Single(message) => {
313 let rx_response: Option<tokio::sync::oneshot::Receiver<ClientMessage>> =
314 self.store_pending_request_for_message(&message).await;
315
316 let message_payload = serde_json::to_string(&message).map_err(|_| {
317 crate::error::TransportError::JsonrpcError(RpcError::parse_error())
318 })?;
319
320 self.write_str(message_payload.as_str(), false).await?;
321
322 if let Some(rx) = rx_response {
323 match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await {
324 Ok(response) => Ok(Some(ClientMessages::Single(response))),
325 Err(error) => Err(error),
326 }
327 } else {
328 Ok(None)
329 }
330 }
331 ServerMessages::Batch(server_messages) => {
332 let (request_ids, pending_tasks): (Vec<_>, Vec<_>) = server_messages
333 .iter()
334 .filter(|message| message.is_request())
335 .map(|message| {
336 (
337 message.request_id().unwrap(), self.store_pending_request_for_message(message),
339 )
340 })
341 .unzip();
342
343 let message_payload = serde_json::to_string(&server_messages).map_err(|_| {
345 crate::error::TransportError::JsonrpcError(RpcError::parse_error())
346 })?;
347
348 self.write_str(message_payload.as_str(), false).await?;
349
350 if pending_tasks.is_empty() {
352 return Ok(None);
353 }
354
355 let tasks = join_all(pending_tasks).await;
356
357 let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| {
358 rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)))
359 });
360
361 let results: Vec<_> = join_all(timeout_wrapped_futures)
362 .await
363 .into_iter()
364 .zip(request_ids)
365 .map(|(res, request_id)| match res {
366 Ok(response) => response,
367 Err(error) => ClientMessage::Error(JsonrpcError::new(
368 RpcError::internal_error().with_message(error.to_string()),
369 request_id.to_owned(),
370 )),
371 })
372 .collect();
373
374 Ok(Some(ClientMessages::Batch(results)))
375 }
376 }
377 }
378
379 async fn send(
380 &self,
381 message: ServerMessage,
382 request_timeout: Option<Duration>,
383 ) -> TransportResult<Option<ClientMessage>> {
384 let response = self.send_message(message.into(), request_timeout).await?;
385 match response {
386 Some(r) => Ok(Some(r.as_single()?)),
387 None => Ok(None),
388 }
389 }
390
391 async fn send_batch(
392 &self,
393 message: Vec<ServerMessage>,
394 request_timeout: Option<Duration>,
395 ) -> TransportResult<Option<Vec<ClientMessage>>> {
396 let response = self.send_message(message.into(), request_timeout).await?;
397 match response {
398 Some(r) => Ok(Some(r.as_batch()?)),
399 None => Ok(None),
400 }
401 }
402
403 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
407 let mut event_id = None;
408
409 if !skip_store && !payload.trim().is_empty() {
410 if let (Some(session_id), Some(stream_id), Some(event_store)) = (
411 self.session_id.as_ref(),
412 self.stream_id.as_ref(),
413 self.event_store.as_ref(),
414 ) {
415 event_id = event_store
416 .store_event(
417 session_id.clone(),
418 stream_id.clone(),
419 current_timestamp(),
420 payload.to_owned(),
421 )
422 .await
423 .map(Some)
424 .unwrap_or_else(|err| {
425 tracing::error!("{err}");
426 None
427 });
428 };
429 }
430
431 if let Some(writable_std) = self.writable_std.as_ref() {
432 let mut writable_std = writable_std.lock().await;
433 if let Some(id) = event_id {
434 writable_std.write_all(id.as_bytes()).await?;
435 writable_std.write_all(&[ID_SEPARATOR]).await?; }
437 writable_std.write_all(payload.as_bytes()).await?;
438 writable_std.write_all(b"\n").await?; writable_std.flush().await?;
440 return Ok(());
441 };
442
443 if let Some(writable_tx) = self.writable_tx.as_ref() {
444 let (resp_tx, resp_rx) = oneshot::channel();
445 writable_tx
446 .send((payload.to_string(), resp_tx))
447 .await
448 .map_err(|err| TransportError::Internal(err.to_string()))?; return resp_rx.await?; }
451
452 Err(TransportError::Internal("Invalid dispatcher!".to_string()))
453 }
454}