rust_mcp_transport/
client_streamable_http.rs

1use crate::error::TransportError;
2use crate::mcp_stream::MCPStream;
3
4use crate::schema::{
5    schema_utils::{
6        ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage,
7        ServerMessages,
8    },
9    RequestId,
10};
11use crate::utils::{
12    http_delete, http_post, CancellationTokenSource, ReadableChannel, StreamableHttpStream,
13    WritableChannel,
14};
15use crate::{error::TransportResult, IoStream, McpDispatch, MessageDispatcher, Transport};
16use crate::{SessionId, TransportDispatcher, TransportOptions};
17use async_trait::async_trait;
18use bytes::Bytes;
19use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
20use reqwest::Client;
21use std::collections::HashMap;
22use std::pin::Pin;
23use std::{sync::Arc, time::Duration};
24use tokio::io::{BufReader, BufWriter};
25use tokio::sync::oneshot::Sender;
26use tokio::sync::{mpsc, oneshot, Mutex};
27use tokio::task::JoinHandle;
28
29const DEFAULT_CHANNEL_CAPACITY: usize = 64;
30const DEFAULT_MAX_RETRY: usize = 5;
31const DEFAULT_RETRY_TIME_SECONDS: u64 = 1;
32const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5;
33
34pub struct StreamableTransportOptions {
35    pub mcp_url: String,
36    pub request_options: RequestOptions,
37}
38
39impl StreamableTransportOptions {
40    pub async fn terminate_session(&self, session_id: Option<&SessionId>) {
41        let client = Client::new();
42        match http_delete(&client, &self.mcp_url, session_id, None).await {
43            Ok(_) => {}
44            Err(TransportError::Http(status_code)) => {
45                tracing::info!("Session termination failed with status code {status_code}",);
46            }
47            Err(error) => {
48                tracing::info!("Session termination failed with error :{error}");
49            }
50        };
51    }
52}
53
54pub struct RequestOptions {
55    pub request_timeout: Duration,
56    pub retry_delay: Option<Duration>,
57    pub max_retries: Option<usize>,
58    pub custom_headers: Option<HashMap<String, String>>,
59}
60
61impl Default for RequestOptions {
62    fn default() -> Self {
63        Self {
64            request_timeout: TransportOptions::default().timeout,
65            retry_delay: None,
66            max_retries: None,
67            custom_headers: None,
68        }
69    }
70}
71
72pub struct ClientStreamableTransport<R>
73where
74    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
75{
76    /// Optional cancellation token source for shutting down the transport
77    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
78    /// Flag indicating if the transport is shut down
79    is_shut_down: Mutex<bool>,
80    /// Timeout duration for MCP messages
81    request_timeout: Duration,
82    /// HTTP client for making requests
83    client: Client,
84    /// URL for the SSE endpoint
85    mcp_server_url: String,
86    /// Delay between retry attempts
87    retry_delay: Duration,
88    /// Maximum number of retry attempts
89    max_retries: usize,
90    /// Optional custom HTTP headers
91    custom_headers: Option<HeaderMap>,
92    sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
93    post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
94    message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
95    error_stream: tokio::sync::RwLock<Option<IoStream>>,
96    pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
97    session_id: Arc<tokio::sync::RwLock<Option<SessionId>>>,
98    standalone: bool,
99}
100
101impl<R> ClientStreamableTransport<R>
102where
103    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
104{
105    pub fn new(
106        options: &StreamableTransportOptions,
107        session_id: Option<SessionId>,
108        standalone: bool,
109    ) -> TransportResult<Self> {
110        let client = Client::new();
111
112        let headers = match &options.request_options.custom_headers {
113            Some(h) => Some(Self::validate_headers(h)?),
114            None => None,
115        };
116
117        let mcp_server_url = options.mcp_url.to_owned();
118        Ok(Self {
119            shutdown_source: tokio::sync::RwLock::new(None),
120            is_shut_down: Mutex::new(false),
121            request_timeout: options.request_options.request_timeout,
122            client,
123            mcp_server_url,
124            retry_delay: options
125                .request_options
126                .retry_delay
127                .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)),
128            max_retries: options
129                .request_options
130                .max_retries
131                .unwrap_or(DEFAULT_MAX_RETRY),
132            sse_task: tokio::sync::RwLock::new(None),
133            post_task: tokio::sync::RwLock::new(None),
134            custom_headers: headers,
135            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
136            error_stream: tokio::sync::RwLock::new(None),
137            pending_requests: Arc::new(Mutex::new(HashMap::new())),
138            session_id: Arc::new(tokio::sync::RwLock::new(session_id)),
139            standalone,
140        })
141    }
142
143    fn validate_headers(headers: &HashMap<String, String>) -> TransportResult<HeaderMap> {
144        let mut header_map = HeaderMap::new();
145        for (key, value) in headers {
146            let header_name =
147                key.parse::<HeaderName>()
148                    .map_err(|e| TransportError::Configuration {
149                        message: format!("Invalid header name: {e}"),
150                    })?;
151            let header_value =
152                HeaderValue::from_str(value).map_err(|e| TransportError::Configuration {
153                    message: format!("Invalid header value: {e}"),
154                })?;
155            header_map.insert(header_name, header_value);
156        }
157        Ok(header_map)
158    }
159
160    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
161        let mut lock = self.message_sender.write().await;
162        *lock = Some(sender);
163    }
164
165    pub(crate) async fn set_error_stream(
166        &self,
167        error_stream: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>,
168    ) {
169        let mut lock = self.error_stream.write().await;
170        *lock = Some(IoStream::Readable(error_stream));
171    }
172}
173
174#[async_trait]
175impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for ClientStreamableTransport<M>
176where
177    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
178    S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
179    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
180    OR: Clone + Send + Sync + serde::Serialize + 'static,
181    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
182{
183    async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
184    where
185        MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
186    {
187        if self.standalone {
188            // Create CancellationTokenSource and token
189            let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
190            let mut lock = self.shutdown_source.write().await;
191            *lock = Some(cancellation_source);
192
193            let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
194            let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
195
196            let max_retries = self.max_retries;
197            let retry_delay = self.retry_delay;
198
199            let post_url = self.mcp_server_url.clone();
200            let custom_headers = self.custom_headers.clone();
201            let cancellation_token_post = cancellation_token.clone();
202            let cancellation_token_sse = cancellation_token.clone();
203
204            let session_id_clone = self.session_id.clone();
205
206            let mut streamable_http = StreamableHttpStream {
207                client: self.client.clone(),
208                mcp_url: post_url,
209                max_retries,
210                retry_delay,
211                read_tx,
212                session_id: session_id_clone, //Arc<RwLock<Option<String>>>
213            };
214
215            let session_id = self.session_id.read().await.to_owned();
216
217            let sse_response = streamable_http
218                .make_standalone_stream_connection(&cancellation_token_sse, &custom_headers, None)
219                .await?;
220
221            let sse_task_handle = tokio::spawn(async move {
222                if let Err(error) = streamable_http
223                    .run_standalone(&cancellation_token_sse, &custom_headers, sse_response)
224                    .await
225                {
226                    if !matches!(error, TransportError::Cancelled(_)) {
227                        tracing::warn!("{error}");
228                    }
229                }
230            });
231
232            let mut sse_task_lock = self.sse_task.write().await;
233            *sse_task_lock = Some(sse_task_handle);
234
235            let post_url = self.mcp_server_url.clone();
236            let client = self.client.clone();
237            let custom_headers = self.custom_headers.clone();
238
239            // Initiate a task to process POST requests from messages received via the writable stream.
240            let post_task_handle = tokio::spawn(async move {
241                loop {
242                    tokio::select! {
243                    _ = cancellation_token_post.cancelled() =>
244                    {
245                            break;
246                    },
247                    data = write_rx.recv() => {
248                        match data{
249                          Some(data) => {
250                              // trim the trailing \n before making a request
251                              let payload = String::from_utf8_lossy(&data).trim().to_string();
252
253                             if let Err(e) = http_post(
254                                  &client,
255                                  &post_url,
256                                  payload.to_string(),
257                                  session_id.as_ref(),
258                                  custom_headers.as_ref(),
259                              )
260                              .await{
261                                tracing::error!("Failed to POST message: {e}")
262                          }
263                        },
264                        None => break, // Exit if channel is closed
265                        }
266                       }
267                    }
268                }
269            });
270            let mut post_task_lock = self.post_task.write().await;
271            *post_task_lock = Some(post_task_handle);
272
273            // Create writable stream
274            let writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>> =
275                Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx })));
276
277            // Create readable stream
278            let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
279                Box::pin(BufReader::new(ReadableChannel {
280                    read_rx,
281                    buffer: Bytes::new(),
282                }));
283
284            let (stream, sender, error_stream) = MCPStream::create(
285                readable,
286                writable,
287                IoStream::Writable(Box::pin(tokio::io::stderr())),
288                self.pending_requests.clone(),
289                self.request_timeout,
290                cancellation_token,
291            );
292
293            self.set_message_sender(sender).await;
294
295            if let IoStream::Readable(error_stream) = error_stream {
296                self.set_error_stream(error_stream).await;
297            }
298            Ok(stream)
299        } else {
300            // Create CancellationTokenSource and token
301            let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
302            let mut lock = self.shutdown_source.write().await;
303            *lock = Some(cancellation_source);
304
305            // let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
306            let (write_tx, mut write_rx): (
307                tokio::sync::mpsc::Sender<(
308                    String,
309                    tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
310                )>,
311                tokio::sync::mpsc::Receiver<(
312                    String,
313                    tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
314                )>,
315            ) = tokio::sync::mpsc::channel(DEFAULT_CHANNEL_CAPACITY); // Buffer size as needed
316            let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
317
318            let max_retries = self.max_retries;
319            let retry_delay = self.retry_delay;
320
321            let post_url = self.mcp_server_url.clone();
322            let custom_headers = self.custom_headers.clone();
323            let cancellation_token_post = cancellation_token.clone();
324            let cancellation_token_sse = cancellation_token.clone();
325
326            let session_id_clone = self.session_id.clone();
327
328            let mut streamable_http = StreamableHttpStream {
329                client: self.client.clone(),
330                mcp_url: post_url,
331                max_retries,
332                retry_delay,
333                read_tx,
334                session_id: session_id_clone, //Arc<RwLock<Option<String>>>
335            };
336
337            // Initiate a task to process POST requests from messages received via the writable stream.
338            let post_task_handle = tokio::spawn(async move {
339                loop {
340                    tokio::select! {
341                    _ = cancellation_token_post.cancelled() =>
342                    {
343                            break;
344                    },
345                    data = write_rx.recv() => {
346                        match data{
347                          Some((data, ack_tx)) => {
348                            // trim the trailing \n before making a request
349                            let payload = data.trim().to_string();
350                            let result = streamable_http.run(payload, &cancellation_token_sse, &custom_headers).await;
351                            let _ = ack_tx.send(result);// Ignore error if receiver dropped
352                        },
353                        None => break, // Exit if channel is closed
354                        }
355                       }
356                    }
357                }
358            });
359            let mut post_task_lock = self.post_task.write().await;
360            *post_task_lock = Some(post_task_handle);
361
362            // Create readable stream
363            let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
364                Box::pin(BufReader::new(ReadableChannel {
365                    read_rx,
366                    buffer: Bytes::new(),
367                }));
368
369            let (stream, sender, error_stream) = MCPStream::create_with_ack(
370                readable,
371                write_tx,
372                IoStream::Writable(Box::pin(tokio::io::stderr())),
373                self.pending_requests.clone(),
374                self.request_timeout,
375                cancellation_token,
376            );
377
378            self.set_message_sender(sender).await;
379
380            if let IoStream::Readable(error_stream) = error_stream {
381                self.set_error_stream(error_stream).await;
382            }
383
384            Ok(stream)
385        }
386    }
387
388    fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
389        self.message_sender.clone() as _
390    }
391
392    fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
393        &self.error_stream as _
394    }
395    async fn shut_down(&self) -> TransportResult<()> {
396        // Trigger cancellation
397        let mut cancellation_lock = self.shutdown_source.write().await;
398        if let Some(source) = cancellation_lock.as_ref() {
399            source.cancel()?;
400        }
401        *cancellation_lock = None; // Clear cancellation_source
402
403        // Mark as shut down
404        let mut is_shut_down_lock = self.is_shut_down.lock().await;
405        *is_shut_down_lock = true;
406
407        // Get task handle
408        let post_task = self.post_task.write().await.take();
409
410        // // Wait for tasks to complete with a timeout
411        let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS);
412        let shutdown_future = async {
413            if let Some(post_handle) = post_task {
414                let _ = post_handle.await;
415            }
416            Ok::<(), TransportError>(())
417        };
418
419        tokio::select! {
420            result = shutdown_future => {
421                result // result of task completion
422            }
423            _ = tokio::time::sleep(timeout) => {
424                tracing::warn!("Shutdown timed out after {:?}", timeout);
425                Err(TransportError::ShutdownTimeout)
426            }
427        }
428    }
429    async fn is_shut_down(&self) -> bool {
430        let result = self.is_shut_down.lock().await;
431        *result
432    }
433    async fn consume_string_payload(&self, _: &str) -> TransportResult<()> {
434        Err(TransportError::Internal(
435            "Invalid invocation of consume_string_payload() function for ClientStreamableTransport"
436                .to_string(),
437        ))
438    }
439
440    async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
441        let mut pending_requests = self.pending_requests.lock().await;
442        pending_requests.remove(request_id)
443    }
444
445    async fn keep_alive(
446        &self,
447        _: Duration,
448        _: oneshot::Sender<()>,
449    ) -> TransportResult<JoinHandle<()>> {
450        Err(TransportError::Internal(
451            "Invalid invocation of keep_alive() function for ClientStreamableTransport".to_string(),
452        ))
453    }
454
455    async fn session_id(&self) -> Option<SessionId> {
456        let guard = self.session_id.read().await;
457        guard.clone()
458    }
459}
460
461#[async_trait]
462impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
463    for ClientStreamableTransport<ServerMessage>
464{
465    async fn send_message(
466        &self,
467        message: ClientMessages,
468        request_timeout: Option<Duration>,
469    ) -> TransportResult<Option<ServerMessages>> {
470        let sender = self.message_sender.read().await;
471
472        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
473
474        sender.send_message(message, request_timeout).await
475    }
476
477    async fn send(
478        &self,
479        message: ClientMessage,
480        request_timeout: Option<Duration>,
481    ) -> TransportResult<Option<ServerMessage>> {
482        let sender = self.message_sender.read().await;
483
484        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
485
486        sender.send(message, request_timeout).await
487    }
488
489    async fn send_batch(
490        &self,
491        message: Vec<ClientMessage>,
492        request_timeout: Option<Duration>,
493    ) -> TransportResult<Option<Vec<ServerMessage>>> {
494        let sender = self.message_sender.read().await;
495        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
496        sender.send_batch(message, request_timeout).await
497    }
498
499    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
500        let sender = self.message_sender.read().await;
501        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
502        sender.write_str(payload, skip_store).await
503    }
504}
505
506impl
507    TransportDispatcher<
508        ServerMessages,
509        MessageFromClient,
510        ServerMessage,
511        ClientMessages,
512        ClientMessage,
513    > for ClientStreamableTransport<ServerMessage>
514{
515}