1use crate::event_store::EventStore;
2use crate::schema::schema_utils::{
3 ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages,
4};
5use crate::schema::RequestId;
6use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncWriteExt, DuplexStream};
13use tokio::sync::oneshot::Sender;
14use tokio::sync::{oneshot, Mutex};
15use tokio::task::JoinHandle;
16use tokio::time::{self, Interval};
17
18use crate::error::{TransportError, TransportResult};
19use crate::mcp_stream::MCPStream;
20use crate::message_dispatcher::MessageDispatcher;
21use crate::transport::Transport;
22use crate::utils::{endpoint_with_session_id, CancellationTokenSource};
23use crate::{IoStream, McpDispatch, SessionId, StreamId, TransportDispatcher, TransportOptions};
24
25pub struct SseTransport<R>
26where
27 R: Clone + Send + Sync + DeserializeOwned + 'static,
28{
29 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
30 is_shut_down: Mutex<bool>,
31 read_write_streams: Mutex<Option<(DuplexStream, DuplexStream)>>,
32 receiver_tx: Mutex<DuplexStream>, options: Arc<TransportOptions>,
34 message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
35 error_stream: tokio::sync::RwLock<Option<IoStream>>,
36 pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
37 session_id: Option<SessionId>,
39 stream_id: Option<StreamId>,
40 event_store: Option<Arc<dyn EventStore>>,
41}
42
43impl<R> SseTransport<R>
45where
46 R: Clone + Send + Sync + DeserializeOwned + 'static,
47{
48 pub fn new(
61 read_rx: DuplexStream,
62 write_tx: DuplexStream,
63 receiver_tx: DuplexStream,
64 options: Arc<TransportOptions>,
65 ) -> TransportResult<Self> {
66 Ok(Self {
67 read_write_streams: Mutex::new(Some((read_rx, write_tx))),
68 options,
69 shutdown_source: tokio::sync::RwLock::new(None),
70 is_shut_down: Mutex::new(false),
71 receiver_tx: Mutex::new(receiver_tx),
72 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
73 error_stream: tokio::sync::RwLock::new(None),
74 pending_requests: Arc::new(Mutex::new(HashMap::new())),
75 session_id: None,
76 stream_id: None,
77 event_store: None,
78 })
79 }
80
81 pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String {
82 endpoint_with_session_id(endpoint, session_id)
83 }
84
85 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
86 let mut lock = self.message_sender.write().await;
87 *lock = Some(sender);
88 }
89
90 pub(crate) async fn set_error_stream(
91 &self,
92 error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
93 ) {
94 let mut lock = self.error_stream.write().await;
95 *lock = Some(IoStream::Writable(error_stream));
96 }
97
98 pub fn make_resumable(
101 &mut self,
102 session_id: SessionId,
103 stream_id: StreamId,
104 event_store: Arc<dyn EventStore>,
105 ) {
106 self.session_id = Some(session_id);
107 self.stream_id = Some(stream_id);
108 self.event_store = Some(event_store);
109 }
110}
111
112#[async_trait]
113impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
114 for SseTransport<ClientMessage>
115{
116 async fn send_message(
117 &self,
118 message: ServerMessages,
119 request_timeout: Option<Duration>,
120 ) -> TransportResult<Option<ClientMessages>> {
121 let sender = self.message_sender.read().await;
122 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
123
124 sender.send_message(message, request_timeout).await
125 }
126
127 async fn send(
128 &self,
129 message: ServerMessage,
130 request_timeout: Option<Duration>,
131 ) -> TransportResult<Option<ClientMessage>> {
132 let sender = self.message_sender.read().await;
133 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
134 sender.send(message, request_timeout).await
135 }
136
137 async fn send_batch(
138 &self,
139 message: Vec<ServerMessage>,
140 request_timeout: Option<Duration>,
141 ) -> TransportResult<Option<Vec<ClientMessage>>> {
142 let sender = self.message_sender.read().await;
143 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
144 sender.send_batch(message, request_timeout).await
145 }
146
147 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
148 let sender = self.message_sender.read().await;
149 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
150 sender.write_str(payload, skip_store).await
151 }
152}
153
154#[async_trait] impl Transport<ClientMessages, MessageFromServer, ClientMessage, ServerMessages, ServerMessage>
156 for SseTransport<ClientMessage>
157{
158 async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<ClientMessages>>
169 where
170 MessageDispatcher<ClientMessage>:
171 McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>,
172 {
173 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
175 let mut lock = self.shutdown_source.write().await;
176 *lock = Some(cancellation_source);
177
178 let mut lock = self.read_write_streams.lock().await;
179 let (read_rx, write_tx) = lock.take().ok_or_else(|| {
180 TransportError::Internal(
181 "SSE streams already taken or transport not initialized".to_string(),
182 )
183 })?;
184
185 let (stream, mut sender, error_stream) = MCPStream::create::<ClientMessages, ClientMessage>(
186 Box::pin(read_rx),
187 Mutex::new(Box::pin(write_tx)),
188 IoStream::Writable(Box::pin(tokio::io::stderr())),
189 self.pending_requests.clone(),
190 self.options.timeout,
191 cancellation_token,
192 );
193
194 if let (Some(session_id), Some(stream_id), Some(event_store)) = (
195 self.session_id.as_ref(),
196 self.stream_id.as_ref(),
197 self.event_store.as_ref(),
198 ) {
199 sender.make_resumable(
200 session_id.to_owned(),
201 stream_id.to_owned(),
202 event_store.clone(),
203 );
204 }
205
206 self.set_message_sender(sender).await;
207
208 if let IoStream::Writable(error_stream) = error_stream {
209 self.set_error_stream(error_stream).await;
210 }
211
212 Ok(stream)
213 }
214
215 async fn is_shut_down(&self) -> bool {
220 let result = self.is_shut_down.lock().await;
221 *result
222 }
223
224 fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<ClientMessage>>>> {
225 self.message_sender.clone() as _
226 }
227
228 fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
229 &self.error_stream as _
230 }
231
232 async fn consume_string_payload(&self, payload: &str) -> TransportResult<()> {
233 let mut transmit = self.receiver_tx.lock().await;
234 transmit
235 .write_all(format!("{payload}\n").as_bytes())
236 .await?;
237 transmit.flush().await?;
238 Ok(())
239 }
240
241 async fn shut_down(&self) -> TransportResult<()> {
248 let mut cancellation_lock = self.shutdown_source.write().await;
250 if let Some(source) = cancellation_lock.as_ref() {
251 source.cancel()?;
252 }
253 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
257 *is_shut_down_lock = true;
258 Ok(())
259 }
260
261 async fn keep_alive(
262 &self,
263 interval: Duration,
264 disconnect_tx: oneshot::Sender<()>,
265 ) -> TransportResult<JoinHandle<()>> {
266 let sender = self.message_sender();
267
268 let handle = tokio::spawn(async move {
269 let mut interval: Interval = time::interval(interval);
270 interval.tick().await; loop {
272 interval.tick().await;
273 let sender = sender.read().await;
274 if let Some(sender) = sender.as_ref() {
275 match sender.write_str("\n", true).await {
276 Ok(_) => {}
277 Err(TransportError::Io(error)) => {
278 if error.kind() == std::io::ErrorKind::BrokenPipe {
279 let _ = disconnect_tx.send(());
280 break;
281 }
282 }
283 _ => {}
284 }
285 }
286 }
287 });
288 Ok(handle)
289 }
290 async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<ClientMessage>> {
291 let mut pending_requests = self.pending_requests.lock().await;
292 pending_requests.remove(request_id)
293 }
294}
295
296impl
297 TransportDispatcher<
298 ClientMessages,
299 MessageFromServer,
300 ClientMessage,
301 ServerMessages,
302 ServerMessage,
303 > for SseTransport<ClientMessage>
304{
305}