rust_mcp_transport/
sse.rs

1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
4use rust_mcp_schema::RequestId;
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::io::DuplexStream;
9use tokio::sync::Mutex;
10
11use crate::error::{TransportError, TransportResult};
12use crate::mcp_stream::MCPStream;
13use crate::message_dispatcher::MessageDispatcher;
14use crate::transport::Transport;
15use crate::utils::{endpoint_with_session_id, CancellationTokenSource};
16use crate::{IoStream, McpDispatch, SessionId, TransportOptions};
17
18pub struct SseTransport {
19    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
20    is_shut_down: Mutex<bool>,
21    read_write_streams: Mutex<Option<(DuplexStream, DuplexStream)>>,
22    options: Arc<TransportOptions>,
23}
24
25/// Server-Sent Events (SSE) transport implementation
26impl SseTransport {
27    /// Creates a new SseTransport instance
28    ///
29    /// Initializes the transport with provided read and write duplex streams and options.
30    ///
31    /// # Arguments
32    /// * `read_rx` - Duplex stream for receiving messages
33    /// * `write_tx` - Duplex stream for sending messages
34    /// * `options` - Shared transport configuration options
35    ///
36    /// # Returns
37    /// * `TransportResult<Self>` - The initialized transport or an error
38    pub fn new(
39        read_rx: DuplexStream,
40        write_tx: DuplexStream,
41        options: Arc<TransportOptions>,
42    ) -> TransportResult<Self> {
43        Ok(Self {
44            read_write_streams: Mutex::new(Some((read_rx, write_tx))),
45            options,
46            shutdown_source: tokio::sync::RwLock::new(None),
47            is_shut_down: Mutex::new(false),
48        })
49    }
50
51    pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String {
52        endpoint_with_session_id(endpoint, session_id)
53    }
54}
55
56#[async_trait]
57impl<R, S> Transport<R, S> for SseTransport
58where
59    R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
60    S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
61{
62    /// Starts the transport, initializing streams and message dispatcher
63    ///
64    /// Sets up the MCP stream and dispatcher using the provided duplex streams.
65    ///
66    /// # Returns
67    /// * `TransportResult<(Pin<Box<dyn Stream<Item = R> + Send>>, MessageDispatcher<R>, IoStream)>`
68    ///   - The message stream, dispatcher, and error stream
69    ///
70    /// # Errors
71    /// * Returns `TransportError` if streams are already taken or not initialized
72    async fn start(
73        &self,
74    ) -> TransportResult<(
75        Pin<Box<dyn Stream<Item = R> + Send>>,
76        MessageDispatcher<R>,
77        IoStream,
78    )>
79    where
80        MessageDispatcher<R>: McpDispatch<R, S>,
81    {
82        // Create CancellationTokenSource and token
83        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
84        let mut lock = self.shutdown_source.write().await;
85        *lock = Some(cancellation_source);
86
87        let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
88            Arc::new(Mutex::new(HashMap::new()));
89
90        let mut lock = self.read_write_streams.lock().await;
91        let (read_rx, write_tx) = lock.take().ok_or_else(|| {
92            TransportError::FromString(
93                "SSE streams already taken or transport not initialized".to_string(),
94            )
95        })?;
96
97        let (stream, sender, error_stream) = MCPStream::create(
98            Box::pin(read_rx),
99            Mutex::new(Box::pin(write_tx)),
100            IoStream::Writable(Box::pin(tokio::io::stderr())),
101            pending_requests,
102            self.options.timeout,
103            cancellation_token,
104        );
105
106        Ok((stream, sender, error_stream))
107    }
108
109    /// Checks if the transport has been shut down
110    ///
111    /// # Returns
112    /// * `bool` - True if the transport is shut down, false otherwise
113    async fn is_shut_down(&self) -> bool {
114        let result = self.is_shut_down.lock().await;
115        *result
116    }
117
118    /// Shuts down the transport, terminating tasks and signaling closure
119    ///
120    /// Cancels any running tasks and clears the cancellation source.
121    ///
122    /// # Returns
123    /// * `TransportResult<()>` - Ok if shutdown is successful, Err if cancellation fails
124    async fn shut_down(&self) -> TransportResult<()> {
125        // Trigger cancellation
126        let mut cancellation_lock = self.shutdown_source.write().await;
127        if let Some(source) = cancellation_lock.as_ref() {
128            source.cancel()?;
129        }
130        *cancellation_lock = None; // Clear cancellation_source
131
132        // Mark as shut down
133        let mut is_shut_down_lock = self.is_shut_down.lock().await;
134        *is_shut_down_lock = true;
135        Ok(())
136    }
137}