1use crate::error::{TransportError, TransportResult};
2use crate::mcp_stream::MCPStream;
3use crate::message_dispatcher::MessageDispatcher;
4use crate::transport::Transport;
5use crate::utils::{
6 extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel,
7};
8use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
9use async_trait::async_trait;
10use bytes::Bytes;
11use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
12use reqwest::Client;
13use tokio::sync::oneshot::Sender;
14use tokio::task::JoinHandle;
15
16use crate::schema::{
17 schema_utils::{
18 ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage,
19 ServerMessages,
20 },
21 RequestId,
22};
23use std::cmp::Ordering;
24use std::collections::HashMap;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::time::Duration;
28use tokio::io::{BufReader, BufWriter};
29use tokio::sync::{mpsc, oneshot, Mutex};
30
31const DEFAULT_CHANNEL_CAPACITY: usize = 64;
32const DEFAULT_MAX_RETRY: usize = 5;
33const DEFAULT_RETRY_TIME_SECONDS: u64 = 1;
34const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5;
35
36pub struct ClientSseTransportOptions {
40 pub request_timeout: Duration,
41 pub retry_delay: Option<Duration>,
42 pub max_retries: Option<usize>,
43 pub custom_headers: Option<HashMap<String, String>>,
44}
45
46impl Default for ClientSseTransportOptions {
48 fn default() -> Self {
49 Self {
50 request_timeout: TransportOptions::default().timeout,
51 retry_delay: None,
52 max_retries: None,
53 custom_headers: None,
54 }
55 }
56}
57
58pub struct ClientSseTransport<R>
62where
63 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
64{
65 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
67 is_shut_down: Mutex<bool>,
69 request_timeout: Duration,
71 client: Client,
73 sse_url: String,
75 base_url: String,
77 retry_delay: Duration,
79 max_retries: usize,
81 custom_headers: Option<HeaderMap>,
83 sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
84 post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
85 message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
86 error_stream: tokio::sync::RwLock<Option<IoStream>>,
87 pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
88}
89
90impl<R> ClientSseTransport<R>
91where
92 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
93{
94 pub fn new(server_url: &str, options: ClientSseTransportOptions) -> TransportResult<Self> {
105 let client = Client::new();
106
107 let base_url = match extract_origin(server_url) {
108 Some(url) => url,
109 None => {
110 let message = format!("Failed to extract origin from server URL: {server_url}");
111 tracing::error!(message);
112 return Err(TransportError::Configuration { message });
113 }
114 };
115
116 let headers = match &options.custom_headers {
117 Some(h) => Some(Self::validate_headers(h)?),
118 None => None,
119 };
120
121 Ok(Self {
122 client,
123 base_url,
124 sse_url: server_url.to_string(),
125 max_retries: options.max_retries.unwrap_or(DEFAULT_MAX_RETRY),
126 retry_delay: options
127 .retry_delay
128 .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)),
129 shutdown_source: tokio::sync::RwLock::new(None),
130 is_shut_down: Mutex::new(false),
131 request_timeout: options.request_timeout,
132 custom_headers: headers,
133 sse_task: tokio::sync::RwLock::new(None),
134 post_task: tokio::sync::RwLock::new(None),
135 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
136 error_stream: tokio::sync::RwLock::new(None),
137 pending_requests: Arc::new(Mutex::new(HashMap::new())),
138 })
139 }
140
141 fn validate_headers(headers: &HashMap<String, String>) -> TransportResult<HeaderMap> {
149 let mut header_map = HeaderMap::new();
150
151 for (key, value) in headers {
152 let header_name =
153 key.parse::<HeaderName>()
154 .map_err(|e| TransportError::Configuration {
155 message: format!("Invalid header name: {e}"),
156 })?;
157 let header_value =
158 HeaderValue::from_str(value).map_err(|e| TransportError::Configuration {
159 message: format!("Invalid header value: {e}"),
160 })?;
161 header_map.insert(header_name, header_value);
162 }
163
164 Ok(header_map)
165 }
166
167 pub fn validate_message_endpoint(&self, endpoint: String) -> TransportResult<String> {
177 if endpoint.starts_with("/") {
178 return Ok(format!("{}{}", self.base_url, endpoint));
179 }
180 if let Some(endpoint_origin) = extract_origin(&endpoint) {
181 if endpoint_origin.cmp(&self.base_url) != Ordering::Equal {
182 return Err(TransportError::Configuration {
183 message: format!(
184 "Endpoint origin does not match connection origin. expected: {} , received: {}",
185 self.base_url, endpoint_origin
186 ),
187 });
188 }
189 return Ok(endpoint);
190 }
191 Ok(endpoint)
192 }
193
194 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
195 let mut lock = self.message_sender.write().await;
196 *lock = Some(sender);
197 }
198
199 pub(crate) async fn set_error_stream(
200 &self,
201 error_stream: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>,
202 ) {
203 let mut lock = self.error_stream.write().await;
204 *lock = Some(IoStream::Readable(error_stream));
205 }
206}
207
208#[async_trait]
209impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for ClientSseTransport<M>
210where
211 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
212 S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
213 M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
214 OR: Clone + Send + Sync + serde::Serialize + 'static,
215 OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
216{
217 async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
225 where
226 MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
227 {
228 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
230 let mut lock = self.shutdown_source.write().await;
231 *lock = Some(cancellation_source);
232
233 let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
234 let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
235
236 let (endpoint_event_tx, endpoint_event_rx) = oneshot::channel::<Option<String>>();
238 let endpoint_event_tx = Some(endpoint_event_tx);
239
240 let sse_client = self.client.clone();
241 let sse_url = self.sse_url.clone();
242
243 let max_retries = self.max_retries;
244 let retry_delay = self.retry_delay;
245
246 let custom_headers = self.custom_headers.clone();
247
248 let read_stream = SseStream {
249 sse_client,
250 sse_url,
251 max_retries,
252 retry_delay,
253 read_tx,
254 };
255
256 let cancellation_token_sse = cancellation_token.clone();
258 let sse_task_handle = tokio::spawn(async move {
259 read_stream
260 .run(endpoint_event_tx, cancellation_token_sse, &custom_headers)
261 .await;
262 });
263 let mut sse_task_lock = self.sse_task.write().await;
264 *sse_task_lock = Some(sse_task_handle);
265
266 let err =
268 || std::io::Error::other("Failed to receive 'messages' endpoint from the server.");
269 let post_url = endpoint_event_rx
270 .await
271 .map_err(|_| err())?
272 .ok_or_else(err)?;
273
274 let post_url = self.validate_message_endpoint(post_url)?;
275
276 let client_clone = self.client.clone();
277
278 let custom_headers = self.custom_headers.clone();
279
280 let cancellation_token_post = cancellation_token.clone();
281 let post_task_handle = tokio::spawn(async move {
283 loop {
284 tokio::select! {
285
286 _ = cancellation_token_post.cancelled() =>
287 {
288 break;
289 },
290
291 data = write_rx.recv() => {
292 match data{
293 Some(data) => {
294 let body = String::from_utf8_lossy(&data).trim().to_string();
296 if let Err(e) = http_post(&client_clone, &post_url, body,None, custom_headers.as_ref()).await {
297 tracing::error!("Failed to POST message: {e}");
298 }
299 },
300 None => break, }
302 }
303 }
304 }
305 });
306 let mut post_task_lock = self.post_task.write().await;
307 *post_task_lock = Some(post_task_handle);
308
309 let writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>> =
311 Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx })));
312
313 let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
315 Box::pin(BufReader::new(ReadableChannel {
316 read_rx,
317 buffer: Bytes::new(),
318 }));
319
320 let (stream, sender, error_stream) = MCPStream::create(
321 readable,
322 writable,
323 IoStream::Writable(Box::pin(tokio::io::stderr())),
324 self.pending_requests.clone(),
325 self.request_timeout,
326 cancellation_token,
327 );
328
329 self.set_message_sender(sender).await;
330
331 if let IoStream::Readable(error_stream) = error_stream {
332 self.set_error_stream(error_stream).await;
333 }
334
335 Ok(stream)
336 }
337
338 fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
339 self.message_sender.clone() as _
340 }
341
342 fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
343 &self.error_stream as _
344 }
345
346 async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
347 Err(TransportError::Internal(
348 "Invalid invocation of consume_string_payload() function for ClientSseTransport"
349 .to_string(),
350 ))
351 }
352
353 async fn keep_alive(
354 &self,
355 _: Duration,
356 _: oneshot::Sender<()>,
357 ) -> TransportResult<JoinHandle<()>> {
358 Err(TransportError::Internal(
359 "Invalid invocation of keep_alive() function for ClientSseTransport".to_string(),
360 ))
361 }
362
363 async fn is_shut_down(&self) -> bool {
368 let result = self.is_shut_down.lock().await;
369 *result
370 }
371
372 async fn shut_down(&self) -> TransportResult<()> {
382 let mut cancellation_lock = self.shutdown_source.write().await;
384 if let Some(source) = cancellation_lock.as_ref() {
385 source.cancel()?;
386 }
387 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
391 *is_shut_down_lock = true;
392
393 let sse_task = self.sse_task.write().await.take();
395 let post_task = self.post_task.write().await.take();
396
397 let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS);
399 let shutdown_future = async {
400 if let Some(post_handle) = post_task {
401 let _ = post_handle.await;
402 }
403 if let Some(sse_handle) = sse_task {
404 let _ = sse_handle.await;
405 }
406 Ok::<(), TransportError>(())
407 };
408
409 tokio::select! {
410 result = shutdown_future => {
411 result }
413 _ = tokio::time::sleep(timeout) => {
414 tracing::warn!("Shutdown timed out after {:?}", timeout);
415 Err(TransportError::ShutdownTimeout)
416 }
417 }
418 }
419
420 async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
421 let mut pending_requests = self.pending_requests.lock().await;
422 pending_requests.remove(request_id)
423 }
424}
425
426#[async_trait]
427impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
428 for ClientSseTransport<ServerMessage>
429{
430 async fn send_message(
431 &self,
432 message: ClientMessages,
433 request_timeout: Option<Duration>,
434 ) -> TransportResult<Option<ServerMessages>> {
435 let sender = self.message_sender.read().await;
436 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
437 sender.send_message(message, request_timeout).await
438 }
439
440 async fn send(
441 &self,
442 message: ClientMessage,
443 request_timeout: Option<Duration>,
444 ) -> TransportResult<Option<ServerMessage>> {
445 let sender = self.message_sender.read().await;
446 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
447 sender.send(message, request_timeout).await
448 }
449
450 async fn send_batch(
451 &self,
452 message: Vec<ClientMessage>,
453 request_timeout: Option<Duration>,
454 ) -> TransportResult<Option<Vec<ServerMessage>>> {
455 let sender = self.message_sender.read().await;
456 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
457 sender.send_batch(message, request_timeout).await
458 }
459
460 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
461 let sender = self.message_sender.read().await;
462 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
463 sender.write_str(payload, skip_store).await
464 }
465}
466
467impl
468 TransportDispatcher<
469 ServerMessages,
470 MessageFromClient,
471 ServerMessage,
472 ClientMessages,
473 ClientMessage,
474 > for ClientSseTransport<ServerMessage>
475{
476}