rust_mcp_transport/
client_sse.rs

1use crate::error::{TransportError, TransportResult};
2use crate::mcp_stream::MCPStream;
3use crate::message_dispatcher::MessageDispatcher;
4use crate::transport::Transport;
5use crate::utils::{
6    extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel,
7};
8use crate::{IoStream, McpDispatch, TransportOptions};
9use async_trait::async_trait;
10use bytes::Bytes;
11use futures::Stream;
12use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
13use reqwest::Client;
14use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
15use rust_mcp_schema::RequestId;
16use std::cmp::Ordering;
17use std::collections::HashMap;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::io::{BufReader, BufWriter};
22use tokio::sync::{mpsc, oneshot, Mutex};
23
24const DEFAULT_CHANNEL_CAPACITY: usize = 64;
25const DEFAULT_MAX_RETRY: usize = 5;
26const DEFAULT_RETRY_TIME_SECONDS: u64 = 3;
27const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5;
28
29/// Configuration options for the Client SSE Transport
30///
31/// Defines settings for request timeouts, retry behavior, and custom HTTP headers.
32pub struct ClientSseTransportOptions {
33    pub request_timeout: Duration,
34    pub retry_delay: Option<Duration>,
35    pub max_retries: Option<usize>,
36    pub custom_headers: Option<HashMap<String, String>>,
37}
38
39/// Provides default values for ClientSseTransportOptions
40impl Default for ClientSseTransportOptions {
41    fn default() -> Self {
42        Self {
43            request_timeout: TransportOptions::default().timeout,
44            retry_delay: None,
45            max_retries: None,
46            custom_headers: None,
47        }
48    }
49}
50
51/// Client-side Server-Sent Events (SSE) transport implementation
52///
53/// Manages SSE connections, HTTP POST requests, and message streaming for client-server communication.
54pub struct ClientSseTransport {
55    /// Optional cancellation token source for shutting down the transport
56    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
57    /// Flag indicating if the transport is shut down
58    is_shut_down: Mutex<bool>,
59    /// Timeout duration for MCP messages
60    request_timeout: Duration,
61    /// HTTP client for making requests
62    client: Client,
63    /// URL for the SSE endpoint
64    sse_url: String,
65    /// Base URL extracted from the server URL
66    base_url: String,
67    /// Delay between retry attempts
68    retry_delay: Duration,
69    /// Maximum number of retry attempts
70    max_retries: usize,
71    /// Optional custom HTTP headers
72    custom_headers: Option<HeaderMap>,
73    sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
74    post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
75}
76
77impl ClientSseTransport {
78    /// Creates a new ClientSseTransport instance
79    ///
80    /// Initializes the transport with the provided server URL and options.
81    ///
82    /// # Arguments
83    /// * `server_url` - The URL of the SSE server
84    /// * `options` - Configuration options for the transport
85    ///
86    /// # Returns
87    /// * `TransportResult<Self>` - The initialized transport or an error
88    pub fn new(server_url: &str, options: ClientSseTransportOptions) -> TransportResult<Self> {
89        let client = Client::new();
90
91        //TODO: error handling
92        let base_url = extract_origin(server_url).unwrap();
93
94        let headers = match &options.custom_headers {
95            Some(h) => Some(Self::validate_headers(h)?),
96            None => None,
97        };
98
99        Ok(Self {
100            client,
101            base_url,
102            sse_url: server_url.to_string(),
103            max_retries: options.max_retries.unwrap_or(DEFAULT_MAX_RETRY),
104            retry_delay: options
105                .retry_delay
106                .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)),
107            shutdown_source: tokio::sync::RwLock::new(None),
108            is_shut_down: Mutex::new(false),
109            request_timeout: options.request_timeout,
110            custom_headers: headers,
111            sse_task: tokio::sync::RwLock::new(None),
112            post_task: tokio::sync::RwLock::new(None),
113        })
114    }
115
116    /// Validates and converts a HashMap of headers into a HeaderMap
117    ///
118    /// # Arguments
119    /// * `headers` - The HashMap of header names and values
120    ///
121    /// # Returns
122    /// * `TransportResult<HeaderMap>` - The validated HeaderMap or an error
123    fn validate_headers(headers: &HashMap<String, String>) -> TransportResult<HeaderMap> {
124        let mut header_map = HeaderMap::new();
125
126        for (key, value) in headers {
127            let header_name = key.parse::<HeaderName>().map_err(|e| {
128                TransportError::InvalidOptions(format!("Invalid header name: {}", e))
129            })?;
130            let header_value = HeaderValue::from_str(value).map_err(|e| {
131                TransportError::InvalidOptions(format!("Invalid header value: {}", e))
132            })?;
133            header_map.insert(header_name, header_value);
134        }
135
136        Ok(header_map)
137    }
138
139    /// Validates the message endpoint URL
140    ///
141    /// Ensures the endpoint is either relative to the base URL or matches the base URL's origin.
142    ///
143    /// # Arguments
144    /// * `endpoint` - The endpoint URL to validate
145    ///
146    /// # Returns
147    /// * `TransportResult<String>` - The validated endpoint URL or an error
148    pub fn validate_message_endpoint(&self, endpoint: String) -> TransportResult<String> {
149        if endpoint.starts_with("/") {
150            return Ok(format!("{}{}", self.base_url, endpoint));
151        }
152        if let Some(endpoint_origin) = extract_origin(&endpoint) {
153            if endpoint_origin.cmp(&self.base_url) != Ordering::Equal {
154                return Err(TransportError::InvalidOptions(format!(
155                    "Endpoint origin does not match connection origin. expected: {} , received: {}",
156                    self.base_url, endpoint_origin
157                )));
158            }
159            return Ok(endpoint);
160        }
161        Ok(endpoint)
162    }
163}
164
165#[async_trait]
166impl<R, S> Transport<R, S> for ClientSseTransport
167where
168    R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
169    S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
170{
171    /// Starts the transport, initializing SSE and POST tasks
172    ///
173    /// Sets up the SSE stream, POST request handler, and message streams for communication.
174    ///
175    /// # Returns
176    /// * `TransportResult<(Pin<Box<dyn Stream<Item = R> + Send>>, MessageDispatcher<R>, IoStream)>`
177    ///   - The message stream, dispatcher, and error stream
178    async fn start(
179        &self,
180    ) -> TransportResult<(
181        Pin<Box<dyn Stream<Item = R> + Send>>,
182        MessageDispatcher<R>,
183        IoStream,
184    )>
185    where
186        MessageDispatcher<R>: McpDispatch<R, S>,
187    {
188        // Create CancellationTokenSource and token
189        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
190        let mut lock = self.shutdown_source.write().await;
191        *lock = Some(cancellation_source);
192
193        let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
194            Arc::new(Mutex::new(HashMap::new()));
195
196        let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
197        let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
198
199        // Create oneshot channel for signaling SSE endpoint event message
200        let (endpoint_event_tx, endpoint_event_rx) = oneshot::channel::<Option<String>>();
201        let endpoint_event_tx = Some(endpoint_event_tx);
202
203        let sse_client = self.client.clone();
204        let sse_url = self.sse_url.clone();
205
206        let max_retries = self.max_retries;
207        let retry_delay = self.retry_delay;
208
209        let read_stream = SseStream {
210            sse_client,
211            sse_url,
212            max_retries,
213            retry_delay,
214            read_tx,
215        };
216
217        // Spawn task to handle SSE stream with reconnection
218        let cancellation_token_sse = cancellation_token.clone();
219        let sse_task_handle = tokio::spawn(async move {
220            read_stream
221                .run(endpoint_event_tx, cancellation_token_sse)
222                .await;
223        });
224        let mut sse_task_lock = self.sse_task.write().await;
225        *sse_task_lock = Some(sse_task_handle);
226
227        // Await the first SSE message, expected to receive messages endpoint from he server
228        let err =
229            || std::io::Error::other("Failed to receive 'messages' endpoint from the server.");
230        let post_url = endpoint_event_rx
231            .await
232            .map_err(|_| err())?
233            .ok_or_else(err)?;
234
235        let post_url = self.validate_message_endpoint(post_url)?;
236
237        let client_clone = self.client.clone();
238
239        let custom_headers = self.custom_headers.clone();
240
241        let cancellation_token_post = cancellation_token.clone();
242        // Spawn task to handle POST requests from writable stream
243        let post_task_handle = tokio::spawn(async move {
244            loop {
245                tokio::select! {
246
247                _ = cancellation_token_post.cancelled() =>
248                {
249                        break;
250                },
251
252                data = write_rx.recv() => {
253                    match data{
254                      Some(data) => {
255                        // trim the trailing \n before making a request
256                        let body = String::from_utf8_lossy(&data).trim().to_string();
257                          if let Err(e) = http_post(&client_clone, &post_url, body, &custom_headers).await {
258                            eprintln!("Failed to POST message: {:?}", e);
259                      }
260                    },
261                    None => break, // Exit if channel is closed
262                    }
263                   }
264                }
265            }
266        });
267        let mut post_task_lock = self.post_task.write().await;
268        *post_task_lock = Some(post_task_handle);
269
270        // Create writable stream
271        let writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>> =
272            Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx })));
273
274        // Create readable stream
275        let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
276            Box::pin(BufReader::new(ReadableChannel {
277                read_rx,
278                buffer: Bytes::new(),
279            }));
280
281        let (stream, sender, error_stream) = MCPStream::create(
282            readable,
283            writable,
284            IoStream::Writable(Box::pin(tokio::io::stderr())),
285            pending_requests,
286            self.request_timeout,
287            cancellation_token,
288        );
289
290        Ok((stream, sender, error_stream))
291    }
292
293    /// Checks if the transport has been shut down
294    ///
295    /// # Returns
296    /// * `bool` - True if the transport is shut down, false otherwise
297    async fn is_shut_down(&self) -> bool {
298        let result = self.is_shut_down.lock().await;
299        *result
300    }
301
302    // Shuts down the transport, terminating any subprocess and signaling closure.
303    ///
304    /// Sends a shutdown signal via the watch channel and kills the subprocess if present.
305    ///
306    /// # Returns
307    /// A `TransportResult` indicating success or failure.
308    ///
309    /// # Errors
310    /// Returns a `TransportError` if the shutdown signal fails or the process cannot be killed.
311    async fn shut_down(&self) -> TransportResult<()> {
312        // Trigger cancellation
313        let mut cancellation_lock = self.shutdown_source.write().await;
314        if let Some(source) = cancellation_lock.as_ref() {
315            source.cancel()?;
316        }
317        *cancellation_lock = None; // Clear cancellation_source
318
319        // Mark as shut down
320        let mut is_shut_down_lock = self.is_shut_down.lock().await;
321        *is_shut_down_lock = true;
322
323        // Get task handles
324        let sse_task = self.sse_task.write().await.take();
325        let post_task = self.post_task.write().await.take();
326
327        // Wait for tasks to complete with a timeout
328        let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS);
329        let shutdown_future = async {
330            if let Some(post_handle) = post_task {
331                let _ = post_handle.await;
332            }
333            if let Some(sse_handle) = sse_task {
334                let _ = sse_handle.await;
335            }
336            Ok::<(), TransportError>(())
337        };
338
339        tokio::select! {
340            result = shutdown_future => {
341                result // result of task completion
342            }
343            _ = tokio::time::sleep(timeout) => {
344                tracing::warn!("Shutdown timed out after {:?}", timeout);
345                Err(TransportError::ShutdownTimeout)
346            }
347        }
348    }
349}