rust_mcp_transport/
sse.rs1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
4use rust_mcp_schema::RequestId;
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::io::DuplexStream;
9use tokio::sync::Mutex;
10
11use crate::error::{TransportError, TransportResult};
12use crate::mcp_stream::MCPStream;
13use crate::message_dispatcher::MessageDispatcher;
14use crate::transport::Transport;
15use crate::utils::{endpoint_with_session_id, CancellationTokenSource};
16use crate::{IoStream, McpDispatch, SessionId, TransportOptions};
17
18pub struct SseTransport {
19 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
20 is_shut_down: Mutex<bool>,
21 read_write_streams: Mutex<Option<(DuplexStream, DuplexStream)>>,
22 options: Arc<TransportOptions>,
23}
24
25impl SseTransport {
27 pub fn new(
39 read_rx: DuplexStream,
40 write_tx: DuplexStream,
41 options: Arc<TransportOptions>,
42 ) -> TransportResult<Self> {
43 Ok(Self {
44 read_write_streams: Mutex::new(Some((read_rx, write_tx))),
45 options,
46 shutdown_source: tokio::sync::RwLock::new(None),
47 is_shut_down: Mutex::new(false),
48 })
49 }
50
51 pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String {
52 endpoint_with_session_id(endpoint, session_id)
53 }
54}
55
56#[async_trait]
57impl<R, S> Transport<R, S> for SseTransport
58where
59 R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
60 S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
61{
62 async fn start(
73 &self,
74 ) -> TransportResult<(
75 Pin<Box<dyn Stream<Item = R> + Send>>,
76 MessageDispatcher<R>,
77 IoStream,
78 )>
79 where
80 MessageDispatcher<R>: McpDispatch<R, S>,
81 {
82 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
84 let mut lock = self.shutdown_source.write().await;
85 *lock = Some(cancellation_source);
86
87 let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
88 Arc::new(Mutex::new(HashMap::new()));
89
90 let mut lock = self.read_write_streams.lock().await;
91 let (read_rx, write_tx) = lock.take().ok_or_else(|| {
92 TransportError::FromString(
93 "SSE streams already taken or transport not initialized".to_string(),
94 )
95 })?;
96
97 let (stream, sender, error_stream) = MCPStream::create(
98 Box::pin(read_rx),
99 Mutex::new(Box::pin(write_tx)),
100 IoStream::Writable(Box::pin(tokio::io::stderr())),
101 pending_requests,
102 self.options.timeout,
103 cancellation_token,
104 );
105
106 Ok((stream, sender, error_stream))
107 }
108
109 async fn is_shut_down(&self) -> bool {
114 let result = self.is_shut_down.lock().await;
115 *result
116 }
117
118 async fn shut_down(&self) -> TransportResult<()> {
125 let mut cancellation_lock = self.shutdown_source.write().await;
127 if let Some(source) = cancellation_lock.as_ref() {
128 source.cancel()?;
129 }
130 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
134 *is_shut_down_lock = true;
135 Ok(())
136 }
137}