1use crate::error::TransportError;
2use crate::mcp_stream::MCPStream;
3
4use crate::schema::{
5 schema_utils::{
6 ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage,
7 ServerMessages,
8 },
9 RequestId,
10};
11use crate::utils::{
12 http_delete, http_post, CancellationTokenSource, ReadableChannel, StreamableHttpStream,
13 WritableChannel,
14};
15use crate::{error::TransportResult, IoStream, McpDispatch, MessageDispatcher, Transport};
16use crate::{SessionId, TransportDispatcher, TransportOptions};
17use async_trait::async_trait;
18use bytes::Bytes;
19use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
20use reqwest::Client;
21use std::collections::HashMap;
22use std::pin::Pin;
23use std::{sync::Arc, time::Duration};
24use tokio::io::{BufReader, BufWriter};
25use tokio::sync::oneshot::Sender;
26use tokio::sync::{mpsc, oneshot, Mutex};
27use tokio::task::JoinHandle;
28
29const DEFAULT_CHANNEL_CAPACITY: usize = 64;
30const DEFAULT_MAX_RETRY: usize = 5;
31const DEFAULT_RETRY_TIME_SECONDS: u64 = 1;
32const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5;
33
34pub struct StreamableTransportOptions {
35 pub mcp_url: String,
36 pub request_options: RequestOptions,
37}
38
39impl StreamableTransportOptions {
40 pub async fn terminate_session(&self, session_id: Option<&SessionId>) {
41 let client = Client::new();
42 match http_delete(&client, &self.mcp_url, session_id, None).await {
43 Ok(_) => {}
44 Err(TransportError::Http(status_code)) => {
45 tracing::info!("Session termination failed with status code {status_code}",);
46 }
47 Err(error) => {
48 tracing::info!("Session termination failed with error :{error}");
49 }
50 };
51 }
52}
53
54pub struct RequestOptions {
55 pub request_timeout: Duration,
56 pub retry_delay: Option<Duration>,
57 pub max_retries: Option<usize>,
58 pub custom_headers: Option<HashMap<String, String>>,
59}
60
61impl Default for RequestOptions {
62 fn default() -> Self {
63 Self {
64 request_timeout: TransportOptions::default().timeout,
65 retry_delay: None,
66 max_retries: None,
67 custom_headers: None,
68 }
69 }
70}
71
72pub struct ClientStreamableTransport<R>
73where
74 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
75{
76 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
78 is_shut_down: Mutex<bool>,
80 request_timeout: Duration,
82 client: Client,
84 mcp_server_url: String,
86 retry_delay: Duration,
88 max_retries: usize,
90 custom_headers: Option<HeaderMap>,
92 sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
93 post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
94 message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
95 error_stream: tokio::sync::RwLock<Option<IoStream>>,
96 pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
97 session_id: Arc<tokio::sync::RwLock<Option<SessionId>>>,
98 standalone: bool,
99}
100
101impl<R> ClientStreamableTransport<R>
102where
103 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
104{
105 pub fn new(
106 options: &StreamableTransportOptions,
107 session_id: Option<SessionId>,
108 standalone: bool,
109 ) -> TransportResult<Self> {
110 let client = Client::new();
111
112 let headers = match &options.request_options.custom_headers {
113 Some(h) => Some(Self::validate_headers(h)?),
114 None => None,
115 };
116
117 let mcp_server_url = options.mcp_url.to_owned();
118 Ok(Self {
119 shutdown_source: tokio::sync::RwLock::new(None),
120 is_shut_down: Mutex::new(false),
121 request_timeout: options.request_options.request_timeout,
122 client,
123 mcp_server_url,
124 retry_delay: options
125 .request_options
126 .retry_delay
127 .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)),
128 max_retries: options
129 .request_options
130 .max_retries
131 .unwrap_or(DEFAULT_MAX_RETRY),
132 sse_task: tokio::sync::RwLock::new(None),
133 post_task: tokio::sync::RwLock::new(None),
134 custom_headers: headers,
135 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
136 error_stream: tokio::sync::RwLock::new(None),
137 pending_requests: Arc::new(Mutex::new(HashMap::new())),
138 session_id: Arc::new(tokio::sync::RwLock::new(session_id)),
139 standalone,
140 })
141 }
142
143 fn validate_headers(headers: &HashMap<String, String>) -> TransportResult<HeaderMap> {
144 let mut header_map = HeaderMap::new();
145 for (key, value) in headers {
146 let header_name =
147 key.parse::<HeaderName>()
148 .map_err(|e| TransportError::Configuration {
149 message: format!("Invalid header name: {e}"),
150 })?;
151 let header_value =
152 HeaderValue::from_str(value).map_err(|e| TransportError::Configuration {
153 message: format!("Invalid header value: {e}"),
154 })?;
155 header_map.insert(header_name, header_value);
156 }
157 Ok(header_map)
158 }
159
160 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
161 let mut lock = self.message_sender.write().await;
162 *lock = Some(sender);
163 }
164
165 pub(crate) async fn set_error_stream(
166 &self,
167 error_stream: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>,
168 ) {
169 let mut lock = self.error_stream.write().await;
170 *lock = Some(IoStream::Readable(error_stream));
171 }
172}
173
174#[async_trait]
175impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for ClientStreamableTransport<M>
176where
177 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
178 S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
179 M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
180 OR: Clone + Send + Sync + serde::Serialize + 'static,
181 OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
182{
183 async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
184 where
185 MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
186 {
187 if self.standalone {
188 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
190 let mut lock = self.shutdown_source.write().await;
191 *lock = Some(cancellation_source);
192
193 let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
194 let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
195
196 let max_retries = self.max_retries;
197 let retry_delay = self.retry_delay;
198
199 let post_url = self.mcp_server_url.clone();
200 let custom_headers = self.custom_headers.clone();
201 let cancellation_token_post = cancellation_token.clone();
202 let cancellation_token_sse = cancellation_token.clone();
203
204 let session_id_clone = self.session_id.clone();
205
206 let mut streamable_http = StreamableHttpStream {
207 client: self.client.clone(),
208 mcp_url: post_url,
209 max_retries,
210 retry_delay,
211 read_tx,
212 session_id: session_id_clone, };
214
215 let session_id = self.session_id.read().await.to_owned();
216
217 let sse_response = streamable_http
218 .make_standalone_stream_connection(&cancellation_token_sse, &custom_headers, None)
219 .await?;
220
221 let sse_task_handle = tokio::spawn(async move {
222 if let Err(error) = streamable_http
223 .run_standalone(&cancellation_token_sse, &custom_headers, sse_response)
224 .await
225 {
226 if !matches!(error, TransportError::Cancelled(_)) {
227 tracing::warn!("{error}");
228 }
229 }
230 });
231
232 let mut sse_task_lock = self.sse_task.write().await;
233 *sse_task_lock = Some(sse_task_handle);
234
235 let post_url = self.mcp_server_url.clone();
236 let client = self.client.clone();
237 let custom_headers = self.custom_headers.clone();
238
239 let post_task_handle = tokio::spawn(async move {
241 loop {
242 tokio::select! {
243 _ = cancellation_token_post.cancelled() =>
244 {
245 break;
246 },
247 data = write_rx.recv() => {
248 match data{
249 Some(data) => {
250 let payload = String::from_utf8_lossy(&data).trim().to_string();
252
253 if let Err(e) = http_post(
254 &client,
255 &post_url,
256 payload.to_string(),
257 session_id.as_ref(),
258 custom_headers.as_ref(),
259 )
260 .await{
261 tracing::error!("Failed to POST message: {e}")
262 }
263 },
264 None => break, }
266 }
267 }
268 }
269 });
270 let mut post_task_lock = self.post_task.write().await;
271 *post_task_lock = Some(post_task_handle);
272
273 let writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>> =
275 Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx })));
276
277 let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
279 Box::pin(BufReader::new(ReadableChannel {
280 read_rx,
281 buffer: Bytes::new(),
282 }));
283
284 let (stream, sender, error_stream) = MCPStream::create(
285 readable,
286 writable,
287 IoStream::Writable(Box::pin(tokio::io::stderr())),
288 self.pending_requests.clone(),
289 self.request_timeout,
290 cancellation_token,
291 );
292
293 self.set_message_sender(sender).await;
294
295 if let IoStream::Readable(error_stream) = error_stream {
296 self.set_error_stream(error_stream).await;
297 }
298 Ok(stream)
299 } else {
300 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
302 let mut lock = self.shutdown_source.write().await;
303 *lock = Some(cancellation_source);
304
305 let (write_tx, mut write_rx): (
307 tokio::sync::mpsc::Sender<(
308 String,
309 tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
310 )>,
311 tokio::sync::mpsc::Receiver<(
312 String,
313 tokio::sync::oneshot::Sender<crate::error::TransportResult<()>>,
314 )>,
315 ) = tokio::sync::mpsc::channel(DEFAULT_CHANNEL_CAPACITY); let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
317
318 let max_retries = self.max_retries;
319 let retry_delay = self.retry_delay;
320
321 let post_url = self.mcp_server_url.clone();
322 let custom_headers = self.custom_headers.clone();
323 let cancellation_token_post = cancellation_token.clone();
324 let cancellation_token_sse = cancellation_token.clone();
325
326 let session_id_clone = self.session_id.clone();
327
328 let mut streamable_http = StreamableHttpStream {
329 client: self.client.clone(),
330 mcp_url: post_url,
331 max_retries,
332 retry_delay,
333 read_tx,
334 session_id: session_id_clone, };
336
337 let post_task_handle = tokio::spawn(async move {
339 loop {
340 tokio::select! {
341 _ = cancellation_token_post.cancelled() =>
342 {
343 break;
344 },
345 data = write_rx.recv() => {
346 match data{
347 Some((data, ack_tx)) => {
348 let payload = data.trim().to_string();
350 let result = streamable_http.run(payload, &cancellation_token_sse, &custom_headers).await;
351 let _ = ack_tx.send(result);},
353 None => break, }
355 }
356 }
357 }
358 });
359 let mut post_task_lock = self.post_task.write().await;
360 *post_task_lock = Some(post_task_handle);
361
362 let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
364 Box::pin(BufReader::new(ReadableChannel {
365 read_rx,
366 buffer: Bytes::new(),
367 }));
368
369 let (stream, sender, error_stream) = MCPStream::create_with_ack(
370 readable,
371 write_tx,
372 IoStream::Writable(Box::pin(tokio::io::stderr())),
373 self.pending_requests.clone(),
374 self.request_timeout,
375 cancellation_token,
376 );
377
378 self.set_message_sender(sender).await;
379
380 if let IoStream::Readable(error_stream) = error_stream {
381 self.set_error_stream(error_stream).await;
382 }
383
384 Ok(stream)
385 }
386 }
387
388 fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
389 self.message_sender.clone() as _
390 }
391
392 fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
393 &self.error_stream as _
394 }
395 async fn shut_down(&self) -> TransportResult<()> {
396 let mut cancellation_lock = self.shutdown_source.write().await;
398 if let Some(source) = cancellation_lock.as_ref() {
399 source.cancel()?;
400 }
401 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
405 *is_shut_down_lock = true;
406
407 let post_task = self.post_task.write().await.take();
409
410 let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS);
412 let shutdown_future = async {
413 if let Some(post_handle) = post_task {
414 let _ = post_handle.await;
415 }
416 Ok::<(), TransportError>(())
417 };
418
419 tokio::select! {
420 result = shutdown_future => {
421 result }
423 _ = tokio::time::sleep(timeout) => {
424 tracing::warn!("Shutdown timed out after {:?}", timeout);
425 Err(TransportError::ShutdownTimeout)
426 }
427 }
428 }
429 async fn is_shut_down(&self) -> bool {
430 let result = self.is_shut_down.lock().await;
431 *result
432 }
433 async fn consume_string_payload(&self, _: &str) -> TransportResult<()> {
434 Err(TransportError::Internal(
435 "Invalid invocation of consume_string_payload() function for ClientStreamableTransport"
436 .to_string(),
437 ))
438 }
439
440 async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
441 let mut pending_requests = self.pending_requests.lock().await;
442 pending_requests.remove(request_id)
443 }
444
445 async fn keep_alive(
446 &self,
447 _: Duration,
448 _: oneshot::Sender<()>,
449 ) -> TransportResult<JoinHandle<()>> {
450 Err(TransportError::Internal(
451 "Invalid invocation of keep_alive() function for ClientStreamableTransport".to_string(),
452 ))
453 }
454
455 async fn session_id(&self) -> Option<SessionId> {
456 let guard = self.session_id.read().await;
457 guard.clone()
458 }
459}
460
461#[async_trait]
462impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
463 for ClientStreamableTransport<ServerMessage>
464{
465 async fn send_message(
466 &self,
467 message: ClientMessages,
468 request_timeout: Option<Duration>,
469 ) -> TransportResult<Option<ServerMessages>> {
470 let sender = self.message_sender.read().await;
471
472 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
473
474 sender.send_message(message, request_timeout).await
475 }
476
477 async fn send(
478 &self,
479 message: ClientMessage,
480 request_timeout: Option<Duration>,
481 ) -> TransportResult<Option<ServerMessage>> {
482 let sender = self.message_sender.read().await;
483
484 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
485
486 sender.send(message, request_timeout).await
487 }
488
489 async fn send_batch(
490 &self,
491 message: Vec<ClientMessage>,
492 request_timeout: Option<Duration>,
493 ) -> TransportResult<Option<Vec<ServerMessage>>> {
494 let sender = self.message_sender.read().await;
495 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
496 sender.send_batch(message, request_timeout).await
497 }
498
499 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
500 let sender = self.message_sender.read().await;
501 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
502 sender.write_str(payload, skip_store).await
503 }
504}
505
506impl
507 TransportDispatcher<
508 ServerMessages,
509 MessageFromClient,
510 ServerMessage,
511 ClientMessages,
512 ClientMessage,
513 > for ClientStreamableTransport<ServerMessage>
514{
515}