rust_mcp_transport/
client_sse.rs

1use crate::error::{TransportError, TransportResult};
2use crate::mcp_stream::MCPStream;
3use crate::message_dispatcher::MessageDispatcher;
4use crate::transport::Transport;
5use crate::utils::{
6    extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel,
7};
8use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
9use async_trait::async_trait;
10use bytes::Bytes;
11use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
12use reqwest::Client;
13use tokio::sync::oneshot::Sender;
14use tokio::task::JoinHandle;
15
16use crate::schema::{
17    schema_utils::{
18        ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage,
19        ServerMessages,
20    },
21    RequestId,
22};
23use std::cmp::Ordering;
24use std::collections::HashMap;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::time::Duration;
28use tokio::io::{BufReader, BufWriter};
29use tokio::sync::{mpsc, oneshot, Mutex};
30
31const DEFAULT_CHANNEL_CAPACITY: usize = 64;
32const DEFAULT_MAX_RETRY: usize = 5;
33const DEFAULT_RETRY_TIME_SECONDS: u64 = 1;
34const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5;
35
36/// Configuration options for the Client SSE Transport
37///
38/// Defines settings for request timeouts, retry behavior, and custom HTTP headers.
39pub struct ClientSseTransportOptions {
40    pub request_timeout: Duration,
41    pub retry_delay: Option<Duration>,
42    pub max_retries: Option<usize>,
43    pub custom_headers: Option<HashMap<String, String>>,
44}
45
46/// Provides default values for ClientSseTransportOptions
47impl Default for ClientSseTransportOptions {
48    fn default() -> Self {
49        Self {
50            request_timeout: TransportOptions::default().timeout,
51            retry_delay: None,
52            max_retries: None,
53            custom_headers: None,
54        }
55    }
56}
57
58/// Client-side Server-Sent Events (SSE) transport implementation
59///
60/// Manages SSE connections, HTTP POST requests, and message streaming for client-server communication.
61pub struct ClientSseTransport<R>
62where
63    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
64{
65    /// Optional cancellation token source for shutting down the transport
66    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
67    /// Flag indicating if the transport is shut down
68    is_shut_down: Mutex<bool>,
69    /// Timeout duration for MCP messages
70    request_timeout: Duration,
71    /// HTTP client for making requests
72    client: Client,
73    /// URL for the SSE endpoint
74    sse_url: String,
75    /// Base URL extracted from the server URL
76    base_url: String,
77    /// Delay between retry attempts
78    retry_delay: Duration,
79    /// Maximum number of retry attempts
80    max_retries: usize,
81    /// Optional custom HTTP headers
82    custom_headers: Option<HeaderMap>,
83    sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
84    post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
85    message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
86    error_stream: tokio::sync::RwLock<Option<IoStream>>,
87    pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
88}
89
90impl<R> ClientSseTransport<R>
91where
92    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
93{
94    /// Creates a new ClientSseTransport instance
95    ///
96    /// Initializes the transport with the provided server URL and options.
97    ///
98    /// # Arguments
99    /// * `server_url` - The URL of the SSE server
100    /// * `options` - Configuration options for the transport
101    ///
102    /// # Returns
103    /// * `TransportResult<Self>` - The initialized transport or an error
104    pub fn new(server_url: &str, options: ClientSseTransportOptions) -> TransportResult<Self> {
105        let client = Client::new();
106
107        let base_url = match extract_origin(server_url) {
108            Some(url) => url,
109            None => {
110                let message = format!("Failed to extract origin from server URL: {server_url}");
111                tracing::error!(message);
112                return Err(TransportError::Configuration { message });
113            }
114        };
115
116        let headers = match &options.custom_headers {
117            Some(h) => Some(Self::validate_headers(h)?),
118            None => None,
119        };
120
121        Ok(Self {
122            client,
123            base_url,
124            sse_url: server_url.to_string(),
125            max_retries: options.max_retries.unwrap_or(DEFAULT_MAX_RETRY),
126            retry_delay: options
127                .retry_delay
128                .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)),
129            shutdown_source: tokio::sync::RwLock::new(None),
130            is_shut_down: Mutex::new(false),
131            request_timeout: options.request_timeout,
132            custom_headers: headers,
133            sse_task: tokio::sync::RwLock::new(None),
134            post_task: tokio::sync::RwLock::new(None),
135            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
136            error_stream: tokio::sync::RwLock::new(None),
137            pending_requests: Arc::new(Mutex::new(HashMap::new())),
138        })
139    }
140
141    /// Validates and converts a HashMap of headers into a HeaderMap
142    ///
143    /// # Arguments
144    /// * `headers` - The HashMap of header names and values
145    ///
146    /// # Returns
147    /// * `TransportResult<HeaderMap>` - The validated HeaderMap or an error
148    fn validate_headers(headers: &HashMap<String, String>) -> TransportResult<HeaderMap> {
149        let mut header_map = HeaderMap::new();
150
151        for (key, value) in headers {
152            let header_name =
153                key.parse::<HeaderName>()
154                    .map_err(|e| TransportError::Configuration {
155                        message: format!("Invalid header name: {e}"),
156                    })?;
157            let header_value =
158                HeaderValue::from_str(value).map_err(|e| TransportError::Configuration {
159                    message: format!("Invalid header value: {e}"),
160                })?;
161            header_map.insert(header_name, header_value);
162        }
163
164        Ok(header_map)
165    }
166
167    /// Validates the message endpoint URL
168    ///
169    /// Ensures the endpoint is either relative to the base URL or matches the base URL's origin.
170    ///
171    /// # Arguments
172    /// * `endpoint` - The endpoint URL to validate
173    ///
174    /// # Returns
175    /// * `TransportResult<String>` - The validated endpoint URL or an error
176    pub fn validate_message_endpoint(&self, endpoint: String) -> TransportResult<String> {
177        if endpoint.starts_with("/") {
178            return Ok(format!("{}{}", self.base_url, endpoint));
179        }
180        if let Some(endpoint_origin) = extract_origin(&endpoint) {
181            if endpoint_origin.cmp(&self.base_url) != Ordering::Equal {
182                return Err(TransportError::Configuration {
183                    message: format!(
184                    "Endpoint origin does not match connection origin. expected: {} , received: {}",
185                    self.base_url, endpoint_origin
186                ),
187                });
188            }
189            return Ok(endpoint);
190        }
191        Ok(endpoint)
192    }
193
194    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
195        let mut lock = self.message_sender.write().await;
196        *lock = Some(sender);
197    }
198
199    pub(crate) async fn set_error_stream(
200        &self,
201        error_stream: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>,
202    ) {
203        let mut lock = self.error_stream.write().await;
204        *lock = Some(IoStream::Readable(error_stream));
205    }
206}
207
208#[async_trait]
209impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for ClientSseTransport<M>
210where
211    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
212    S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
213    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
214    OR: Clone + Send + Sync + serde::Serialize + 'static,
215    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
216{
217    /// Starts the transport, initializing SSE and POST tasks
218    ///
219    /// Sets up the SSE stream, POST request handler, and message streams for communication.
220    ///
221    /// # Returns
222    /// * `TransportResult<(Pin<Box<dyn Stream<Item = R> + Send>>, MessageDispatcher<R>, IoStream)>`
223    ///   - The message stream, dispatcher, and error stream
224    async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
225    where
226        MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
227    {
228        // Create CancellationTokenSource and token
229        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
230        let mut lock = self.shutdown_source.write().await;
231        *lock = Some(cancellation_source);
232
233        let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
234        let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
235
236        // Create oneshot channel for signaling SSE endpoint event message
237        let (endpoint_event_tx, endpoint_event_rx) = oneshot::channel::<Option<String>>();
238        let endpoint_event_tx = Some(endpoint_event_tx);
239
240        let sse_client = self.client.clone();
241        let sse_url = self.sse_url.clone();
242
243        let max_retries = self.max_retries;
244        let retry_delay = self.retry_delay;
245
246        let custom_headers = self.custom_headers.clone();
247
248        let read_stream = SseStream {
249            sse_client,
250            sse_url,
251            max_retries,
252            retry_delay,
253            read_tx,
254        };
255
256        // Spawn task to handle SSE stream with reconnection
257        let cancellation_token_sse = cancellation_token.clone();
258        let sse_task_handle = tokio::spawn(async move {
259            read_stream
260                .run(endpoint_event_tx, cancellation_token_sse, &custom_headers)
261                .await;
262        });
263        let mut sse_task_lock = self.sse_task.write().await;
264        *sse_task_lock = Some(sse_task_handle);
265
266        // Await the first SSE message, expected to receive messages endpoint from he server
267        let err =
268            || std::io::Error::other("Failed to receive 'messages' endpoint from the server.");
269        let post_url = endpoint_event_rx
270            .await
271            .map_err(|_| err())?
272            .ok_or_else(err)?;
273
274        let post_url = self.validate_message_endpoint(post_url)?;
275
276        let client_clone = self.client.clone();
277
278        let custom_headers = self.custom_headers.clone();
279
280        let cancellation_token_post = cancellation_token.clone();
281        // Spawn task to handle POST requests from writable stream
282        let post_task_handle = tokio::spawn(async move {
283            loop {
284                tokio::select! {
285
286                _ = cancellation_token_post.cancelled() =>
287                {
288                        break;
289                },
290
291                data = write_rx.recv() => {
292                    match data{
293                      Some(data) => {
294                        // trim the trailing \n before making a request
295                        let body = String::from_utf8_lossy(&data).trim().to_string();
296                          if let Err(e) = http_post(&client_clone, &post_url, body,None, custom_headers.as_ref()).await {
297                            tracing::error!("Failed to POST message: {e}");
298                      }
299                    },
300                    None => break, // Exit if channel is closed
301                    }
302                   }
303                }
304            }
305        });
306        let mut post_task_lock = self.post_task.write().await;
307        *post_task_lock = Some(post_task_handle);
308
309        // Create writable stream
310        let writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>> =
311            Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx })));
312
313        // Create readable stream
314        let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
315            Box::pin(BufReader::new(ReadableChannel {
316                read_rx,
317                buffer: Bytes::new(),
318            }));
319
320        let (stream, sender, error_stream) = MCPStream::create(
321            readable,
322            writable,
323            IoStream::Writable(Box::pin(tokio::io::stderr())),
324            self.pending_requests.clone(),
325            self.request_timeout,
326            cancellation_token,
327        );
328
329        self.set_message_sender(sender).await;
330
331        if let IoStream::Readable(error_stream) = error_stream {
332            self.set_error_stream(error_stream).await;
333        }
334
335        Ok(stream)
336    }
337
338    fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
339        self.message_sender.clone() as _
340    }
341
342    fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
343        &self.error_stream as _
344    }
345
346    async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
347        Err(TransportError::Internal(
348            "Invalid invocation of consume_string_payload() function for ClientSseTransport"
349                .to_string(),
350        ))
351    }
352
353    async fn keep_alive(
354        &self,
355        _: Duration,
356        _: oneshot::Sender<()>,
357    ) -> TransportResult<JoinHandle<()>> {
358        Err(TransportError::Internal(
359            "Invalid invocation of keep_alive() function for ClientSseTransport".to_string(),
360        ))
361    }
362
363    /// Checks if the transport has been shut down
364    ///
365    /// # Returns
366    /// * `bool` - True if the transport is shut down, false otherwise
367    async fn is_shut_down(&self) -> bool {
368        let result = self.is_shut_down.lock().await;
369        *result
370    }
371
372    // Shuts down the transport, terminating any subprocess and signaling closure.
373    ///
374    /// Sends a shutdown signal via the watch channel and kills the subprocess if present.
375    ///
376    /// # Returns
377    /// A `TransportResult` indicating success or failure.
378    ///
379    /// # Errors
380    /// Returns a `TransportError` if the shutdown signal fails or the process cannot be killed.
381    async fn shut_down(&self) -> TransportResult<()> {
382        // Trigger cancellation
383        let mut cancellation_lock = self.shutdown_source.write().await;
384        if let Some(source) = cancellation_lock.as_ref() {
385            source.cancel()?;
386        }
387        *cancellation_lock = None; // Clear cancellation_source
388
389        // Mark as shut down
390        let mut is_shut_down_lock = self.is_shut_down.lock().await;
391        *is_shut_down_lock = true;
392
393        // Get task handles
394        let sse_task = self.sse_task.write().await.take();
395        let post_task = self.post_task.write().await.take();
396
397        // Wait for tasks to complete with a timeout
398        let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS);
399        let shutdown_future = async {
400            if let Some(post_handle) = post_task {
401                let _ = post_handle.await;
402            }
403            if let Some(sse_handle) = sse_task {
404                let _ = sse_handle.await;
405            }
406            Ok::<(), TransportError>(())
407        };
408
409        tokio::select! {
410            result = shutdown_future => {
411                result // result of task completion
412            }
413            _ = tokio::time::sleep(timeout) => {
414                tracing::warn!("Shutdown timed out after {:?}", timeout);
415                Err(TransportError::ShutdownTimeout)
416            }
417        }
418    }
419
420    async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
421        let mut pending_requests = self.pending_requests.lock().await;
422        pending_requests.remove(request_id)
423    }
424}
425
426#[async_trait]
427impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
428    for ClientSseTransport<ServerMessage>
429{
430    async fn send_message(
431        &self,
432        message: ClientMessages,
433        request_timeout: Option<Duration>,
434    ) -> TransportResult<Option<ServerMessages>> {
435        let sender = self.message_sender.read().await;
436        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
437        sender.send_message(message, request_timeout).await
438    }
439
440    async fn send(
441        &self,
442        message: ClientMessage,
443        request_timeout: Option<Duration>,
444    ) -> TransportResult<Option<ServerMessage>> {
445        let sender = self.message_sender.read().await;
446        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
447        sender.send(message, request_timeout).await
448    }
449
450    async fn send_batch(
451        &self,
452        message: Vec<ClientMessage>,
453        request_timeout: Option<Duration>,
454    ) -> TransportResult<Option<Vec<ServerMessage>>> {
455        let sender = self.message_sender.read().await;
456        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
457        sender.send_batch(message, request_timeout).await
458    }
459
460    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
461        let sender = self.message_sender.read().await;
462        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
463        sender.write_str(payload, skip_store).await
464    }
465}
466
467impl
468    TransportDispatcher<
469        ServerMessages,
470        MessageFromClient,
471        ServerMessage,
472        ClientMessages,
473        ClientMessage,
474    > for ClientSseTransport<ServerMessage>
475{
476}