1use crate::error::{TransportError, TransportResult};
2use crate::mcp_stream::MCPStream;
3use crate::message_dispatcher::MessageDispatcher;
4use crate::transport::Transport;
5use crate::utils::{
6 extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel,
7};
8use crate::{IoStream, McpDispatch, TransportOptions};
9use async_trait::async_trait;
10use bytes::Bytes;
11use futures::Stream;
12use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
13use reqwest::Client;
14use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
15use rust_mcp_schema::RequestId;
16use std::cmp::Ordering;
17use std::collections::HashMap;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::io::{BufReader, BufWriter};
22use tokio::sync::{mpsc, oneshot, Mutex};
23
24const DEFAULT_CHANNEL_CAPACITY: usize = 64;
25const DEFAULT_MAX_RETRY: usize = 5;
26const DEFAULT_RETRY_TIME_SECONDS: u64 = 3;
27const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5;
28
29pub struct ClientSseTransportOptions {
33 pub request_timeout: Duration,
34 pub retry_delay: Option<Duration>,
35 pub max_retries: Option<usize>,
36 pub custom_headers: Option<HashMap<String, String>>,
37}
38
39impl Default for ClientSseTransportOptions {
41 fn default() -> Self {
42 Self {
43 request_timeout: TransportOptions::default().timeout,
44 retry_delay: None,
45 max_retries: None,
46 custom_headers: None,
47 }
48 }
49}
50
51pub struct ClientSseTransport {
55 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
57 is_shut_down: Mutex<bool>,
59 request_timeout: Duration,
61 client: Client,
63 sse_url: String,
65 base_url: String,
67 retry_delay: Duration,
69 max_retries: usize,
71 custom_headers: Option<HeaderMap>,
73 sse_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
74 post_task: tokio::sync::RwLock<Option<tokio::task::JoinHandle<()>>>,
75}
76
77impl ClientSseTransport {
78 pub fn new(server_url: &str, options: ClientSseTransportOptions) -> TransportResult<Self> {
89 let client = Client::new();
90
91 let base_url = extract_origin(server_url).unwrap();
93
94 let headers = match &options.custom_headers {
95 Some(h) => Some(Self::validate_headers(h)?),
96 None => None,
97 };
98
99 Ok(Self {
100 client,
101 base_url,
102 sse_url: server_url.to_string(),
103 max_retries: options.max_retries.unwrap_or(DEFAULT_MAX_RETRY),
104 retry_delay: options
105 .retry_delay
106 .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)),
107 shutdown_source: tokio::sync::RwLock::new(None),
108 is_shut_down: Mutex::new(false),
109 request_timeout: options.request_timeout,
110 custom_headers: headers,
111 sse_task: tokio::sync::RwLock::new(None),
112 post_task: tokio::sync::RwLock::new(None),
113 })
114 }
115
116 fn validate_headers(headers: &HashMap<String, String>) -> TransportResult<HeaderMap> {
124 let mut header_map = HeaderMap::new();
125
126 for (key, value) in headers {
127 let header_name = key.parse::<HeaderName>().map_err(|e| {
128 TransportError::InvalidOptions(format!("Invalid header name: {}", e))
129 })?;
130 let header_value = HeaderValue::from_str(value).map_err(|e| {
131 TransportError::InvalidOptions(format!("Invalid header value: {}", e))
132 })?;
133 header_map.insert(header_name, header_value);
134 }
135
136 Ok(header_map)
137 }
138
139 pub fn validate_message_endpoint(&self, endpoint: String) -> TransportResult<String> {
149 if endpoint.starts_with("/") {
150 return Ok(format!("{}{}", self.base_url, endpoint));
151 }
152 if let Some(endpoint_origin) = extract_origin(&endpoint) {
153 if endpoint_origin.cmp(&self.base_url) != Ordering::Equal {
154 return Err(TransportError::InvalidOptions(format!(
155 "Endpoint origin does not match connection origin. expected: {} , received: {}",
156 self.base_url, endpoint_origin
157 )));
158 }
159 return Ok(endpoint);
160 }
161 Ok(endpoint)
162 }
163}
164
165#[async_trait]
166impl<R, S> Transport<R, S> for ClientSseTransport
167where
168 R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
169 S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
170{
171 async fn start(
179 &self,
180 ) -> TransportResult<(
181 Pin<Box<dyn Stream<Item = R> + Send>>,
182 MessageDispatcher<R>,
183 IoStream,
184 )>
185 where
186 MessageDispatcher<R>: McpDispatch<R, S>,
187 {
188 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
190 let mut lock = self.shutdown_source.write().await;
191 *lock = Some(cancellation_source);
192
193 let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
194 Arc::new(Mutex::new(HashMap::new()));
195
196 let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
197 let (read_tx, read_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_CAPACITY);
198
199 let (endpoint_event_tx, endpoint_event_rx) = oneshot::channel::<Option<String>>();
201 let endpoint_event_tx = Some(endpoint_event_tx);
202
203 let sse_client = self.client.clone();
204 let sse_url = self.sse_url.clone();
205
206 let max_retries = self.max_retries;
207 let retry_delay = self.retry_delay;
208
209 let read_stream = SseStream {
210 sse_client,
211 sse_url,
212 max_retries,
213 retry_delay,
214 read_tx,
215 };
216
217 let cancellation_token_sse = cancellation_token.clone();
219 let sse_task_handle = tokio::spawn(async move {
220 read_stream
221 .run(endpoint_event_tx, cancellation_token_sse)
222 .await;
223 });
224 let mut sse_task_lock = self.sse_task.write().await;
225 *sse_task_lock = Some(sse_task_handle);
226
227 let err =
229 || std::io::Error::other("Failed to receive 'messages' endpoint from the server.");
230 let post_url = endpoint_event_rx
231 .await
232 .map_err(|_| err())?
233 .ok_or_else(err)?;
234
235 let post_url = self.validate_message_endpoint(post_url)?;
236
237 let client_clone = self.client.clone();
238
239 let custom_headers = self.custom_headers.clone();
240
241 let cancellation_token_post = cancellation_token.clone();
242 let post_task_handle = tokio::spawn(async move {
244 loop {
245 tokio::select! {
246
247 _ = cancellation_token_post.cancelled() =>
248 {
249 break;
250 },
251
252 data = write_rx.recv() => {
253 match data{
254 Some(data) => {
255 let body = String::from_utf8_lossy(&data).trim().to_string();
257 if let Err(e) = http_post(&client_clone, &post_url, body, &custom_headers).await {
258 eprintln!("Failed to POST message: {:?}", e);
259 }
260 },
261 None => break, }
263 }
264 }
265 }
266 });
267 let mut post_task_lock = self.post_task.write().await;
268 *post_task_lock = Some(post_task_handle);
269
270 let writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>> =
272 Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx })));
273
274 let readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>> =
276 Box::pin(BufReader::new(ReadableChannel {
277 read_rx,
278 buffer: Bytes::new(),
279 }));
280
281 let (stream, sender, error_stream) = MCPStream::create(
282 readable,
283 writable,
284 IoStream::Writable(Box::pin(tokio::io::stderr())),
285 pending_requests,
286 self.request_timeout,
287 cancellation_token,
288 );
289
290 Ok((stream, sender, error_stream))
291 }
292
293 async fn is_shut_down(&self) -> bool {
298 let result = self.is_shut_down.lock().await;
299 *result
300 }
301
302 async fn shut_down(&self) -> TransportResult<()> {
312 let mut cancellation_lock = self.shutdown_source.write().await;
314 if let Some(source) = cancellation_lock.as_ref() {
315 source.cancel()?;
316 }
317 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
321 *is_shut_down_lock = true;
322
323 let sse_task = self.sse_task.write().await.take();
325 let post_task = self.post_task.write().await.take();
326
327 let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS);
329 let shutdown_future = async {
330 if let Some(post_handle) = post_task {
331 let _ = post_handle.await;
332 }
333 if let Some(sse_handle) = sse_task {
334 let _ = sse_handle.await;
335 }
336 Ok::<(), TransportError>(())
337 };
338
339 tokio::select! {
340 result = shutdown_future => {
341 result }
343 _ = tokio::time::sleep(timeout) => {
344 tracing::warn!("Shutdown timed out after {:?}", timeout);
345 Err(TransportError::ShutdownTimeout)
346 }
347 }
348 }
349}