rust_mcp_transport/
sse.rs

1use crate::event_store::EventStore;
2use crate::schema::schema_utils::{
3    ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages,
4};
5use crate::schema::RequestId;
6use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncWriteExt, DuplexStream};
13use tokio::sync::oneshot::Sender;
14use tokio::sync::{oneshot, Mutex};
15use tokio::task::JoinHandle;
16use tokio::time::{self, Interval};
17
18use crate::error::{TransportError, TransportResult};
19use crate::mcp_stream::MCPStream;
20use crate::message_dispatcher::MessageDispatcher;
21use crate::transport::Transport;
22use crate::utils::{endpoint_with_session_id, CancellationTokenSource};
23use crate::{IoStream, McpDispatch, SessionId, StreamId, TransportDispatcher, TransportOptions};
24
25pub struct SseTransport<R>
26where
27    R: Clone + Send + Sync + DeserializeOwned + 'static,
28{
29    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
30    is_shut_down: Mutex<bool>,
31    read_write_streams: Mutex<Option<(DuplexStream, DuplexStream)>>,
32    receiver_tx: Mutex<DuplexStream>, // receiving string payload
33    options: Arc<TransportOptions>,
34    message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
35    error_stream: tokio::sync::RwLock<Option<IoStream>>,
36    pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
37    // resumability support
38    session_id: Option<SessionId>,
39    stream_id: Option<StreamId>,
40    event_store: Option<Arc<dyn EventStore>>,
41}
42
43/// Server-Sent Events (SSE) transport implementation
44impl<R> SseTransport<R>
45where
46    R: Clone + Send + Sync + DeserializeOwned + 'static,
47{
48    /// Creates a new SseTransport instance
49    ///
50    /// Initializes the transport with provided read and write duplex streams and options.
51    ///
52    /// # Arguments
53    /// * `read_rx` - Duplex stream for receiving messages
54    /// * `write_tx` - Duplex stream for sending messages
55    /// * `receiver_tx` - Duplex stream for receiving string payload
56    /// * `options` - Shared transport configuration options
57    ///
58    /// # Returns
59    /// * `TransportResult<Self>` - The initialized transport or an error
60    pub fn new(
61        read_rx: DuplexStream,
62        write_tx: DuplexStream,
63        receiver_tx: DuplexStream,
64        options: Arc<TransportOptions>,
65    ) -> TransportResult<Self> {
66        Ok(Self {
67            read_write_streams: Mutex::new(Some((read_rx, write_tx))),
68            options,
69            shutdown_source: tokio::sync::RwLock::new(None),
70            is_shut_down: Mutex::new(false),
71            receiver_tx: Mutex::new(receiver_tx),
72            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
73            error_stream: tokio::sync::RwLock::new(None),
74            pending_requests: Arc::new(Mutex::new(HashMap::new())),
75            session_id: None,
76            stream_id: None,
77            event_store: None,
78        })
79    }
80
81    pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String {
82        endpoint_with_session_id(endpoint, session_id)
83    }
84
85    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
86        let mut lock = self.message_sender.write().await;
87        *lock = Some(sender);
88    }
89
90    pub(crate) async fn set_error_stream(
91        &self,
92        error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
93    ) {
94        let mut lock = self.error_stream.write().await;
95        *lock = Some(IoStream::Writable(error_stream));
96    }
97
98    /// Supports resumability for streamable HTTP transports by setting the session ID,
99    /// stream ID, and event store.
100    pub fn make_resumable(
101        &mut self,
102        session_id: SessionId,
103        stream_id: StreamId,
104        event_store: Arc<dyn EventStore>,
105    ) {
106        self.session_id = Some(session_id);
107        self.stream_id = Some(stream_id);
108        self.event_store = Some(event_store);
109    }
110}
111
112#[async_trait]
113impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
114    for SseTransport<ClientMessage>
115{
116    async fn send_message(
117        &self,
118        message: ServerMessages,
119        request_timeout: Option<Duration>,
120    ) -> TransportResult<Option<ClientMessages>> {
121        let sender = self.message_sender.read().await;
122        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
123
124        sender.send_message(message, request_timeout).await
125    }
126
127    async fn send(
128        &self,
129        message: ServerMessage,
130        request_timeout: Option<Duration>,
131    ) -> TransportResult<Option<ClientMessage>> {
132        let sender = self.message_sender.read().await;
133        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
134        sender.send(message, request_timeout).await
135    }
136
137    async fn send_batch(
138        &self,
139        message: Vec<ServerMessage>,
140        request_timeout: Option<Duration>,
141    ) -> TransportResult<Option<Vec<ClientMessage>>> {
142        let sender = self.message_sender.read().await;
143        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
144        sender.send_batch(message, request_timeout).await
145    }
146
147    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
148        let sender = self.message_sender.read().await;
149        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
150        sender.write_str(payload, skip_store).await
151    }
152}
153
154#[async_trait] //RSMX
155impl Transport<ClientMessages, MessageFromServer, ClientMessage, ServerMessages, ServerMessage>
156    for SseTransport<ClientMessage>
157{
158    /// Starts the transport, initializing streams and message dispatcher
159    ///
160    /// Sets up the MCP stream and dispatcher using the provided duplex streams.
161    ///
162    /// # Returns
163    /// * `TransportResult<(Pin<Box<dyn Stream<Item = R> + Send>>, MessageDispatcher<R>, IoStream)>`
164    ///   - The message stream, dispatcher, and error stream
165    ///
166    /// # Errors
167    /// * Returns `TransportError` if streams are already taken or not initialized
168    async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<ClientMessages>>
169    where
170        MessageDispatcher<ClientMessage>:
171            McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>,
172    {
173        // Create CancellationTokenSource and token
174        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
175        let mut lock = self.shutdown_source.write().await;
176        *lock = Some(cancellation_source);
177
178        let mut lock = self.read_write_streams.lock().await;
179        let (read_rx, write_tx) = lock.take().ok_or_else(|| {
180            TransportError::Internal(
181                "SSE streams already taken or transport not initialized".to_string(),
182            )
183        })?;
184
185        let (stream, mut sender, error_stream) = MCPStream::create::<ClientMessages, ClientMessage>(
186            Box::pin(read_rx),
187            Mutex::new(Box::pin(write_tx)),
188            IoStream::Writable(Box::pin(tokio::io::stderr())),
189            self.pending_requests.clone(),
190            self.options.timeout,
191            cancellation_token,
192        );
193
194        if let (Some(session_id), Some(stream_id), Some(event_store)) = (
195            self.session_id.as_ref(),
196            self.stream_id.as_ref(),
197            self.event_store.as_ref(),
198        ) {
199            sender.make_resumable(
200                session_id.to_owned(),
201                stream_id.to_owned(),
202                event_store.clone(),
203            );
204        }
205
206        self.set_message_sender(sender).await;
207
208        if let IoStream::Writable(error_stream) = error_stream {
209            self.set_error_stream(error_stream).await;
210        }
211
212        Ok(stream)
213    }
214
215    /// Checks if the transport has been shut down
216    ///
217    /// # Returns
218    /// * `bool` - True if the transport is shut down, false otherwise
219    async fn is_shut_down(&self) -> bool {
220        let result = self.is_shut_down.lock().await;
221        *result
222    }
223
224    fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>> {
225        self.message_sender.clone() as _
226    }
227
228    fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
229        &self.error_stream as _
230    }
231
232    async fn consume_string_payload(&self, payload: &str) -> TransportResult<()> {
233        let mut transmit = self.receiver_tx.lock().await;
234        transmit
235            .write_all(format!("{payload}\n").as_bytes())
236            .await?;
237        transmit.flush().await?;
238        Ok(())
239    }
240
241    /// Shuts down the transport, terminating tasks and signaling closure
242    ///
243    /// Cancels any running tasks and clears the cancellation source.
244    ///
245    /// # Returns
246    /// * `TransportResult<()>` - Ok if shutdown is successful, Err if cancellation fails
247    async fn shut_down(&self) -> TransportResult<()> {
248        // Trigger cancellation
249        let mut cancellation_lock = self.shutdown_source.write().await;
250        if let Some(source) = cancellation_lock.as_ref() {
251            source.cancel()?;
252        }
253        *cancellation_lock = None; // Clear cancellation_source
254
255        // Mark as shut down
256        let mut is_shut_down_lock = self.is_shut_down.lock().await;
257        *is_shut_down_lock = true;
258        Ok(())
259    }
260
261    async fn keep_alive(
262        &self,
263        interval: Duration,
264        disconnect_tx: oneshot::Sender<()>,
265    ) -> TransportResult<JoinHandle<()>> {
266        let sender = self.message_sender();
267
268        let handle = tokio::spawn(async move {
269            let mut interval: Interval = time::interval(interval);
270            interval.tick().await; // Skip the first immediate tick
271            loop {
272                interval.tick().await;
273                let sender = sender.read().await;
274                if let Some(sender) = sender.as_ref() {
275                    match sender.write_str("\n", true).await {
276                        Ok(_) => {}
277                        Err(TransportError::Io(error)) => {
278                            if error.kind() == std::io::ErrorKind::BrokenPipe {
279                                let _ = disconnect_tx.send(());
280                                break;
281                            }
282                        }
283                        _ => {}
284                    }
285                }
286            }
287        });
288        Ok(handle)
289    }
290    async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<ClientMessage>> {
291        let mut pending_requests = self.pending_requests.lock().await;
292        pending_requests.remove(request_id)
293    }
294}
295
296impl
297    TransportDispatcher<
298        ClientMessages,
299        MessageFromServer,
300        ClientMessage,
301        ServerMessages,
302        ServerMessage,
303    > for SseTransport<ClientMessage>
304{
305}