1use bytes::Bytes;
14use futures::StreamExt;
15use reqwest::{Client as HttpClient, header};
16use std::collections::HashMap;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::{Mutex, RwLock, mpsc};
22use tracing::{debug, error, info, warn};
23
24use turbomcp_protocol::MessageId;
25use turbomcp_transport_traits::{
26 LimitsConfig, TlsConfig, TlsVersion, Transport, TransportCapabilities, TransportError,
27 TransportEventEmitter, TransportMessage, TransportMetrics, TransportResult, TransportState,
28 TransportType, validate_request_size, validate_response_size,
29};
30
31#[derive(Clone, Debug)]
33pub enum RetryPolicy {
34 Fixed {
36 interval: Duration,
38 max_attempts: Option<u32>,
40 },
41 Exponential {
43 base: Duration,
45 max_delay: Duration,
47 max_attempts: Option<u32>,
49 },
50 Never,
52}
53
54impl Default for RetryPolicy {
55 fn default() -> Self {
56 Self::Exponential {
57 base: Duration::from_secs(1),
58 max_delay: Duration::from_secs(60),
59 max_attempts: Some(10),
60 }
61 }
62}
63
64impl RetryPolicy {
65 pub(crate) fn delay(&self, attempt: u32) -> Option<Duration> {
66 match self {
67 Self::Fixed {
68 interval,
69 max_attempts,
70 } => {
71 if let Some(max) = max_attempts
72 && attempt >= *max
73 {
74 return None;
75 }
76 Some(*interval)
77 }
78 Self::Exponential {
79 base,
80 max_delay,
81 max_attempts,
82 } => {
83 if let Some(max) = max_attempts
84 && attempt >= *max
85 {
86 return None;
87 }
88 let base_delay = base.as_millis() as u64 * 2u64.pow(attempt);
89 let max_delay_ms = max_delay.as_millis() as u64;
90 let capped = base_delay.min(max_delay_ms);
91 let jitter_range = capped / 4;
95 let jitter_offset = if jitter_range > 0 {
96 fastrand::u64(0..jitter_range * 2)
97 } else {
98 0
99 };
100 let final_delay = capped
101 .saturating_sub(jitter_range)
102 .saturating_add(jitter_offset);
103 Some(Duration::from_millis(final_delay))
104 }
105 Self::Never => None,
106 }
107 }
108}
109
110#[derive(Clone, Debug)]
112pub struct StreamableHttpClientConfig {
113 pub base_url: String,
115
116 pub endpoint_path: String,
118
119 pub timeout: Duration,
121
122 pub retry_policy: RetryPolicy,
124
125 pub auth_token: Option<String>,
127
128 pub headers: HashMap<String, String>,
130
131 pub user_agent: Option<String>,
142
143 pub protocol_version: String,
145
146 pub limits: LimitsConfig,
148
149 pub tls: TlsConfig,
151
152 pub sse_read_timeout: Duration,
159}
160
161impl Default for StreamableHttpClientConfig {
162 fn default() -> Self {
163 Self {
164 base_url: "http://localhost:8080".to_string(),
165 endpoint_path: "/mcp".to_string(),
166 timeout: Duration::from_secs(30),
167 retry_policy: RetryPolicy::default(),
168 auth_token: None,
169 headers: HashMap::new(),
170 user_agent: Some(format!("TurboMCP-Client/{}", env!("CARGO_PKG_VERSION"))),
171 protocol_version: "2025-11-25".to_string(),
172 limits: LimitsConfig::default(),
173 tls: TlsConfig::default(),
174 sse_read_timeout: Duration::from_secs(300),
175 }
176 }
177}
178
179pub struct StreamableHttpClientTransport {
181 config: StreamableHttpClientConfig,
182 http_client: HttpClient,
183 state: Arc<RwLock<TransportState>>,
184 capabilities: TransportCapabilities,
185 metrics: Arc<RwLock<TransportMetrics>>,
186 _event_emitter: TransportEventEmitter,
187
188 message_endpoint: Arc<RwLock<Option<String>>>,
190
191 session_id: Arc<RwLock<Option<String>>>,
193
194 last_event_id: Arc<RwLock<Option<String>>>,
196
197 sse_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
199 sse_sender: mpsc::Sender<TransportMessage>,
200
201 response_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
203 response_sender: mpsc::Sender<TransportMessage>,
204
205 sse_task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
207
208 endpoint_ready_tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
214 endpoint_ready_rx: Arc<Mutex<Option<tokio::sync::oneshot::Receiver<()>>>>,
215}
216
217impl std::fmt::Debug for StreamableHttpClientTransport {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("StreamableHttpClientTransport")
220 .field("base_url", &self.config.base_url)
221 .field("endpoint", &self.config.endpoint_path)
222 .finish()
223 }
224}
225
226impl StreamableHttpClientTransport {
227 pub fn new(config: StreamableHttpClientConfig) -> TransportResult<Self> {
234 let (sse_tx, sse_rx) = mpsc::channel(1000);
235 let (response_tx, response_rx) = mpsc::channel(100);
236 let (event_emitter, _) = TransportEventEmitter::new();
237 let (endpoint_ready_tx, endpoint_ready_rx) = tokio::sync::oneshot::channel();
238
239 if config.tls.is_insecure() {
241 warn!(
242 "Certificate validation is disabled. This is insecure and should only be used \
243 for testing or in secure mTLS mesh environments. \
244 See https://turbomcp.org/docs/security/tls#certificate-validation"
245 );
246 }
247
248 let mut client_builder = HttpClient::builder()
253 .use_rustls_tls()
254 .timeout(config.timeout);
255
256 if config.auth_token.is_some() {
261 client_builder =
262 client_builder.redirect(reqwest::redirect::Policy::custom(|attempt| {
263 if attempt.previous().len() >= 10 {
264 return attempt.error("too many redirects");
265 }
266 let prev_origin = attempt.previous().last().map(reqwest::Url::origin);
267 if prev_origin.as_ref() == Some(&attempt.url().origin()) {
268 attempt.follow()
269 } else {
270 attempt.stop()
273 }
274 }));
275 }
276
277 if let Some(ref user_agent) = config.user_agent {
279 client_builder = client_builder.user_agent(user_agent);
280 }
281
282 client_builder = match config.tls.min_version {
284 TlsVersion::Tls13 => client_builder.min_tls_version(reqwest::tls::Version::TLS_1_3),
285 };
286
287 if !config.tls.validate_certificates {
289 const INSECURE_TLS_ENV_VAR: &str = "TURBOMCP_ALLOW_INSECURE_TLS";
292
293 if std::env::var(INSECURE_TLS_ENV_VAR).is_err() {
294 error!(
295 "SECURITY: Certificate validation disabled but {} not set. \
296 Overriding to validate_certificates=true for safety. \
297 Set {}=1 to allow insecure TLS.",
298 INSECURE_TLS_ENV_VAR, INSECURE_TLS_ENV_VAR
299 );
300 } else {
303 warn!(
304 "SECURITY WARNING: TLS certificate validation is DISABLED. \
305 This configuration is INSECURE and should ONLY be used: \
306 (1) In development/testing environments, or \
307 (2) In secure mTLS mesh where validation happens elsewhere. \
308 NEVER use in production connecting to untrusted servers."
309 );
310
311 client_builder = client_builder.danger_accept_invalid_certs(true);
312 }
313 }
314
315 if let Some(ca_certs) = &config.tls.custom_ca_certs {
317 let mut loaded = 0usize;
318 let total = ca_certs.len();
319 for cert_bytes in ca_certs {
320 if let Ok(cert) = reqwest::Certificate::from_pem(cert_bytes) {
322 client_builder = client_builder.add_root_certificate(cert);
323 loaded += 1;
324 } else if let Ok(cert) = reqwest::Certificate::from_der(cert_bytes) {
325 client_builder = client_builder.add_root_certificate(cert);
326 loaded += 1;
327 } else {
328 warn!(
329 "Failed to parse custom CA certificate ({}/{}), skipping",
330 loaded + 1,
331 total
332 );
333 }
334 }
335 if loaded == 0 && total > 0 {
336 error!("All {} custom CA certificates failed to parse", total);
337 }
339 if loaded > 0 {
340 info!("Loaded {}/{} custom CA certificates", loaded, total);
341 }
342 }
343
344 let http_client = client_builder.build().map_err(|e| {
345 TransportError::ConfigurationError(format!(
346 "Failed to build HTTP client (likely bad TLS configuration): {e}"
347 ))
348 })?;
349
350 Ok(Self {
351 config,
352 http_client,
353 state: Arc::new(RwLock::new(TransportState::Disconnected)),
354 capabilities: TransportCapabilities {
355 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
356 supports_compression: false,
357 supports_streaming: true,
358 supports_bidirectional: true,
359 supports_multiplexing: false,
360 compression_algorithms: Vec::new(),
361 custom: HashMap::new(),
362 },
363 metrics: Arc::new(RwLock::new(TransportMetrics::default())),
364 _event_emitter: event_emitter,
365 message_endpoint: Arc::new(RwLock::new(None)),
366 session_id: Arc::new(RwLock::new(None)),
367 last_event_id: Arc::new(RwLock::new(None)),
368 sse_receiver: Arc::new(Mutex::new(sse_rx)),
369 sse_sender: sse_tx,
370 response_receiver: Arc::new(Mutex::new(response_rx)),
371 response_sender: response_tx,
372 sse_task_handle: Arc::new(Mutex::new(None)),
373 endpoint_ready_tx: Arc::new(Mutex::new(Some(endpoint_ready_tx))),
374 endpoint_ready_rx: Arc::new(Mutex::new(Some(endpoint_ready_rx))),
375 })
376 }
377
378 fn get_endpoint_url(&self) -> String {
380 format!("{}{}", self.config.base_url, self.config.endpoint_path)
381 }
382
383 async fn get_message_endpoint_url(&self) -> String {
385 let discovered = self.message_endpoint.read().await;
386 if let Some(endpoint) = discovered.as_ref() {
387 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
388 endpoint.clone()
389 } else if endpoint.starts_with('/') {
390 format!("{}{}", self.config.base_url, endpoint)
391 } else {
392 format!("{}/{}", self.config.base_url, endpoint)
393 }
394 } else {
395 self.get_endpoint_url()
396 }
397 }
398
399 async fn build_headers(&self, accept: &str) -> header::HeaderMap {
401 let mut headers = header::HeaderMap::new();
402
403 if let Ok(accept_value) = header::HeaderValue::from_str(accept) {
405 headers.insert(header::ACCEPT, accept_value);
406 }
407
408 if let Ok(protocol_value) = header::HeaderValue::from_str(&self.config.protocol_version) {
409 headers.insert("MCP-Protocol-Version", protocol_value);
410 }
411
412 if let Some(session_id) = self.session_id.read().await.as_ref()
413 && let Ok(session_value) = header::HeaderValue::from_str(session_id)
414 {
415 headers.insert("Mcp-Session-Id", session_value);
416 }
417
418 if let Some(last_event_id) = self.last_event_id.read().await.as_ref()
419 && let Ok(event_value) = header::HeaderValue::from_str(last_event_id)
420 {
421 headers.insert("Last-Event-ID", event_value);
422 }
423
424 if let Some(token) = &self.config.auth_token
425 && let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
426 {
427 headers.insert(header::AUTHORIZATION, auth_value);
428 }
429
430 for (key, value) in &self.config.headers {
431 if let (Ok(k), Ok(v)) = (
432 header::HeaderName::from_bytes(key.as_bytes()),
433 header::HeaderValue::from_str(value),
434 ) {
435 headers.insert(k, v);
436 }
437 }
438
439 headers
440 }
441
442 async fn start_sse_connection(&self) -> TransportResult<()> {
444 info!("Starting SSE connection to {}", self.get_endpoint_url());
445
446 let endpoint_url = self.get_endpoint_url();
447 let config = self.config.clone();
448 let http_client = self.http_client.clone();
449 let state = Arc::clone(&self.state);
450 let sse_sender = self.sse_sender.clone();
451 let session_id = Arc::clone(&self.session_id);
452 let last_event_id = Arc::clone(&self.last_event_id);
453 let message_endpoint = Arc::clone(&self.message_endpoint);
454 let endpoint_ready_tx = self.endpoint_ready_tx.lock().await.take();
457
458 let task = tokio::spawn(async move {
459 Self::sse_connection_task(
460 endpoint_url,
461 config,
462 http_client,
463 state,
464 sse_sender,
465 session_id,
466 last_event_id,
467 message_endpoint,
468 endpoint_ready_tx,
469 )
470 .await;
471 });
472
473 *self.sse_task_handle.lock().await = Some(task);
474
475 Ok(())
476 }
477
478 #[allow(clippy::too_many_arguments)]
480 async fn sse_connection_task(
481 endpoint_url: String,
482 config: StreamableHttpClientConfig,
483 http_client: HttpClient,
484 state: Arc<RwLock<TransportState>>,
485 sse_sender: mpsc::Sender<TransportMessage>,
486 session_id: Arc<RwLock<Option<String>>>,
487 last_event_id: Arc<RwLock<Option<String>>>,
488 message_endpoint: Arc<RwLock<Option<String>>>,
489 mut endpoint_ready_tx: Option<tokio::sync::oneshot::Sender<()>>,
490 ) {
491 let mut attempt = 0u32;
492
493 loop {
494 if let Some(delay) = config.retry_policy.delay(attempt) {
496 if attempt > 0 {
497 warn!("Reconnecting in {:?} (attempt {})", delay, attempt + 1);
498 tokio::time::sleep(delay).await;
499 }
500 } else {
501 error!("Max retry attempts reached, giving up");
502 *state.write().await = TransportState::Disconnected;
503 break;
504 }
505
506 let mut headers = header::HeaderMap::new();
508 headers.insert(
509 header::ACCEPT,
510 header::HeaderValue::from_static("text/event-stream"),
511 );
512
513 if let Ok(protocol_value) = header::HeaderValue::from_str(&config.protocol_version) {
514 headers.insert("MCP-Protocol-Version", protocol_value);
515 }
516
517 if let Some(sid) = session_id.read().await.as_ref()
518 && let Ok(session_value) = header::HeaderValue::from_str(sid)
519 {
520 headers.insert("Mcp-Session-Id", session_value);
521 }
522
523 if let Some(last_id) = last_event_id.read().await.as_ref()
524 && let Ok(event_value) = header::HeaderValue::from_str(last_id)
525 {
526 headers.insert("Last-Event-ID", event_value);
527 }
528
529 if let Some(token) = &config.auth_token
530 && let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
531 {
532 headers.insert(header::AUTHORIZATION, auth_value);
533 }
534
535 match http_client.get(&endpoint_url).headers(headers).send().await {
537 Ok(response) => {
538 if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
539 info!(
540 "Server returned HTTP 405 for GET {}. Continuing without standalone SSE polling.",
541 endpoint_url
542 );
543 break;
544 }
545
546 if !response.status().is_success() {
547 error!("SSE connection failed: {}", response.status());
548 attempt += 1;
549 continue;
550 }
551
552 if let Some(sid) = response
554 .headers()
555 .get("Mcp-Session-Id")
556 .and_then(|v| v.to_str().ok())
557 {
558 *session_id.write().await = Some(sid.to_string());
559 info!("Received session ID: {}", sid);
560 }
561
562 info!("SSE connection established");
563 *state.write().await = TransportState::Connected;
564 attempt = 0; let mut stream = response.bytes_stream();
568 let mut buffer = String::new();
569 let read_timeout = config.sse_read_timeout;
570 let buffer_cap = config
574 .limits
575 .enforce_on_streams
576 .then_some(config.limits.max_response_size)
577 .flatten();
578
579 'sse_loop: loop {
580 let chunk_result =
581 match tokio::time::timeout(read_timeout, stream.next()).await {
582 Ok(Some(r)) => r,
583 Ok(None) => break,
584 Err(_) => {
585 warn!(
586 "SSE read idle for {:?}; closing stream to reconnect",
587 read_timeout
588 );
589 break;
590 }
591 };
592 match chunk_result {
593 Ok(chunk) => {
594 let chunk_str = String::from_utf8_lossy(&chunk);
595 buffer.push_str(&chunk_str);
596
597 while let Some(pos) = buffer.find("\n\n") {
599 let event_str = buffer[..pos].to_string();
600 buffer = buffer[pos + 2..].to_string();
601
602 if let Err(e) = Self::process_sse_event(
603 &event_str,
604 &sse_sender,
605 &last_event_id,
606 &message_endpoint,
607 &mut endpoint_ready_tx,
608 )
609 .await
610 {
611 warn!("Failed to process SSE event: {}", e);
612 }
613 }
614
615 if let Some(cap) = buffer_cap
616 && buffer.len() > cap
617 {
618 error!(
619 "SSE event buffer exceeded {} bytes without an event \
620 boundary; closing stream to avoid OOM",
621 cap
622 );
623 break 'sse_loop;
624 }
625 }
626 Err(e) => {
627 error!("Error reading SSE stream: {}", e);
628 break;
629 }
630 }
631 }
632
633 warn!("SSE stream ended");
634 *state.write().await = TransportState::Disconnected;
635 }
636 Err(e) => {
637 error!("Failed to connect: {}", e);
638 attempt += 1;
639 }
640 }
641 }
642 }
643
644 async fn process_sse_event(
650 event_str: &str,
651 sse_sender: &mpsc::Sender<TransportMessage>,
652 last_event_id: &Arc<RwLock<Option<String>>>,
653 message_endpoint: &Arc<RwLock<Option<String>>>,
654 endpoint_ready_tx: &mut Option<tokio::sync::oneshot::Sender<()>>,
655 ) -> TransportResult<()> {
656 let lines: Vec<&str> = event_str.lines().collect();
657 let mut event_type: Option<String> = None;
658 let mut event_data: Vec<String> = Vec::new();
659 let mut event_id: Option<String> = None;
660
661 for line in lines {
662 if line.is_empty() {
663 continue;
664 }
665
666 if let Some(colon_pos) = line.find(':') {
667 let field = &line[..colon_pos];
668 let value = line[colon_pos + 1..].trim_start();
669
670 match field {
671 "event" => event_type = Some(value.to_string()),
672 "data" => event_data.push(value.to_string()),
673 "id" => event_id = Some(value.to_string()),
674 _ => {}
675 }
676 }
677 }
678
679 if let Some(id) = event_id {
681 *last_event_id.write().await = Some(id);
682 }
683
684 if event_data.is_empty() {
685 return Ok(());
686 }
687
688 let data_str = event_data.join("\n");
689
690 match event_type.as_deref() {
692 Some("endpoint") => {
693 let endpoint_uri = if data_str.trim().starts_with('{') {
698 let endpoint_json: serde_json::Value = serde_json::from_str(&data_str)
700 .map_err(|e| {
701 TransportError::SerializationFailed(format!(
702 "Invalid endpoint JSON: {}",
703 e
704 ))
705 })?;
706 endpoint_json["uri"]
707 .as_str()
708 .ok_or_else(|| {
709 TransportError::SerializationFailed(
710 "Endpoint event missing 'uri' field".to_string(),
711 )
712 })?
713 .to_string()
714 } else {
715 data_str.clone()
717 };
718
719 info!("Discovered message endpoint: {}", endpoint_uri);
720 *message_endpoint.write().await = Some(endpoint_uri);
721 if let Some(tx) = endpoint_ready_tx.take() {
725 let _ = tx.send(());
726 }
727 Ok(())
728 }
729 Some("message") | None => {
730 if data_str.trim().is_empty() {
733 debug!("Skipping empty SSE event");
734 return Ok(());
735 }
736
737 let json_value: serde_json::Value =
739 serde_json::from_str(&data_str).map_err(|e| {
740 TransportError::SerializationFailed(format!("Invalid JSON: {}", e))
741 })?;
742
743 let message = TransportMessage::new(
744 MessageId::from("sse-message".to_string()),
745 Bytes::from(
746 serde_json::to_vec(&json_value)
747 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
748 ),
749 );
750
751 sse_sender
752 .send(message)
753 .await
754 .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
755
756 debug!("Received SSE message");
757 Ok(())
758 }
759 Some(other) => {
760 debug!("Ignoring unknown event type: {}", other);
761 Ok(())
762 }
763 }
764 }
765
766 async fn process_post_sse_event(
768 event_str: &str,
769 response_sender: &mpsc::Sender<TransportMessage>,
770 last_event_id: &Arc<RwLock<Option<String>>>,
771 ) -> TransportResult<()> {
772 let lines: Vec<&str> = event_str.lines().collect();
773 let mut event_data: Vec<String> = Vec::new();
774 let mut event_id: Option<String> = None;
775
776 for line in lines {
777 if line.is_empty() {
778 continue;
779 }
780
781 if let Some(colon_pos) = line.find(':') {
782 let field = &line[..colon_pos];
783 let value = line[colon_pos + 1..].trim_start();
784
785 match field {
786 "data" => event_data.push(value.to_string()),
787 "id" => event_id = Some(value.to_string()),
788 "event" => {
789 }
792 _ => {}
793 }
794 }
795 }
796
797 if let Some(id) = event_id {
799 *last_event_id.write().await = Some(id);
800 }
801
802 if event_data.is_empty() {
803 return Ok(());
804 }
805
806 let data_str = event_data.join("\n");
807
808 let json_value: serde_json::Value = serde_json::from_str(&data_str).map_err(|e| {
810 TransportError::SerializationFailed(format!("Invalid JSON in POST SSE: {}", e))
811 })?;
812
813 let message = TransportMessage::new(
814 MessageId::from("post-sse-response".to_string()),
815 Bytes::from(
816 serde_json::to_vec(&json_value)
817 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
818 ),
819 );
820
821 response_sender
822 .send(message.clone())
823 .await
824 .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
825
826 debug!(
827 "Queued message from POST SSE stream: {}",
828 String::from_utf8_lossy(&message.payload)
829 );
830 Ok(())
831 }
832
833 pub async fn recv_async(&self) -> TransportResult<TransportMessage> {
842 let mut response_receiver = self.response_receiver.lock().await;
843 let mut sse_receiver = self.sse_receiver.lock().await;
844 let message = tokio::select! {
845 biased;
846 msg = response_receiver.recv() => msg.ok_or_else(|| {
849 TransportError::ConnectionLost("Response channel disconnected".to_string())
850 })?,
851 msg = sse_receiver.recv() => msg.ok_or_else(|| {
852 TransportError::ConnectionLost("SSE channel disconnected".to_string())
853 })?,
854 };
855 let mut metrics = self.metrics.write().await;
856 metrics.messages_received += 1;
857 metrics.bytes_received += message.payload.len() as u64;
858 Ok(message)
859 }
860}
861
862impl Transport for StreamableHttpClientTransport {
863 fn send(
864 &self,
865 message: TransportMessage,
866 ) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
867 Box::pin(async move {
868 debug!("Sending message via HTTP POST");
869
870 validate_request_size(message.payload.len(), &self.config.limits)?;
872
873 let url = self.get_message_endpoint_url().await;
875
876 let headers = self
878 .build_headers("application/json, text/event-stream")
879 .await;
880
881 let response = self
883 .http_client
884 .post(&url)
885 .headers(headers)
886 .header(header::CONTENT_TYPE, "application/json")
887 .body(message.payload.to_vec())
888 .send()
889 .await
890 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
891
892 if !response.status().is_success() {
893 return Err(TransportError::ConnectionFailed(format!(
894 "POST failed: {}",
895 response.status()
896 )));
897 }
898
899 if let Some(session_id) = response
901 .headers()
902 .get("Mcp-Session-Id")
903 .and_then(|v| v.to_str().ok())
904 {
905 *self.session_id.write().await = Some(session_id.to_string());
906 }
907
908 if response.status() == reqwest::StatusCode::ACCEPTED {
910 debug!("Received HTTP 202 Accepted (no response body expected)");
911 {
913 let mut metrics = self.metrics.write().await;
914 metrics.messages_sent += 1;
915 metrics.bytes_sent += message.payload.len() as u64;
916 }
917 return Ok(());
918 }
919
920 let content_type = response
922 .headers()
923 .get(header::CONTENT_TYPE)
924 .and_then(|v| v.to_str().ok())
925 .unwrap_or("");
926
927 if content_type.contains("application/json") {
928 debug!("Received JSON response from POST");
930
931 let response_bytes = response
932 .bytes()
933 .await
934 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
935
936 validate_response_size(response_bytes.len(), &self.config.limits)?;
938
939 let response_message = TransportMessage::new(
940 MessageId::from("http-response".to_string()),
941 response_bytes,
942 );
943
944 self.response_sender
946 .send(response_message)
947 .await
948 .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
949
950 debug!("JSON response queued successfully");
951 } else if content_type.contains("text/event-stream") {
952 debug!("Received SSE stream response from POST, processing events");
955
956 let response_sender = self.response_sender.clone();
957 let last_event_id = Arc::clone(&self.last_event_id);
958
959 let mut stream = response.bytes_stream();
961 let mut buffer = String::new();
962 let buffer_cap = self
965 .config
966 .limits
967 .enforce_on_streams
968 .then_some(self.config.limits.max_response_size)
969 .flatten();
970
971 'post_sse_loop: while let Some(chunk_result) = stream.next().await {
972 match chunk_result {
973 Ok(chunk) => {
974 let chunk_str = String::from_utf8_lossy(&chunk);
975 buffer.push_str(&chunk_str);
976
977 while let Some(pos) = buffer.find("\n\n") {
979 let event_str = buffer[..pos].to_string();
980 buffer = buffer[pos + 2..].to_string();
981
982 if let Err(e) = Self::process_post_sse_event(
983 &event_str,
984 &response_sender,
985 &last_event_id,
986 )
987 .await
988 {
989 warn!("Failed to process POST SSE event: {}", e);
990 }
991 }
992
993 if let Some(cap) = buffer_cap
994 && buffer.len() > cap
995 {
996 error!(
997 "POST SSE event buffer exceeded {} bytes without an event \
998 boundary; closing stream to avoid OOM",
999 cap
1000 );
1001 break 'post_sse_loop;
1002 }
1003 }
1004 Err(e) => {
1005 warn!("Error reading POST SSE stream: {}", e);
1006 break;
1007 }
1008 }
1009 }
1010 debug!("POST SSE stream processing completed");
1011 }
1012
1013 {
1015 let mut metrics = self.metrics.write().await;
1016 metrics.messages_sent += 1;
1017 metrics.bytes_sent += message.payload.len() as u64;
1018 }
1019
1020 debug!("Message sent successfully");
1021 Ok(())
1022 })
1023 }
1024
1025 fn receive(
1032 &self,
1033 ) -> Pin<Box<dyn Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_>> {
1034 Box::pin(async move {
1035 {
1038 let mut response_receiver = self.response_receiver.lock().await;
1039 match response_receiver.try_recv() {
1040 Ok(message) => {
1041 debug!("Received queued JSON response");
1042 {
1044 let mut metrics = self.metrics.write().await;
1045 metrics.messages_received += 1;
1046 metrics.bytes_received += message.payload.len() as u64;
1047 }
1048 return Ok(Some(message));
1049 }
1050 Err(mpsc::error::TryRecvError::Empty) => {
1051 }
1053 Err(mpsc::error::TryRecvError::Disconnected) => {
1054 return Err(TransportError::ConnectionLost(
1055 "Response channel disconnected".to_string(),
1056 ));
1057 }
1058 }
1059 }
1060
1061 let mut sse_receiver = self.sse_receiver.lock().await;
1063 match sse_receiver.try_recv() {
1064 Ok(message) => {
1065 debug!("Received SSE message");
1066 {
1068 let mut metrics = self.metrics.write().await;
1069 metrics.messages_received += 1;
1070 metrics.bytes_received += message.payload.len() as u64;
1071 }
1072 Ok(Some(message))
1073 }
1074 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
1075 Err(mpsc::error::TryRecvError::Disconnected) => Err(
1076 TransportError::ConnectionLost("SSE channel disconnected".to_string()),
1077 ),
1078 }
1079 })
1080 }
1081
1082 fn capabilities(&self) -> &TransportCapabilities {
1083 &self.capabilities
1084 }
1085
1086 fn state(&self) -> Pin<Box<dyn Future<Output = TransportState> + Send + '_>> {
1087 Box::pin(async move { self.state.read().await.clone() })
1088 }
1089
1090 fn transport_type(&self) -> TransportType {
1091 TransportType::Http
1092 }
1093
1094 fn metrics(&self) -> Pin<Box<dyn Future<Output = TransportMetrics> + Send + '_>> {
1095 Box::pin(async move { self.metrics.read().await.clone() })
1096 }
1097
1098 fn connect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
1099 Box::pin(async move {
1100 info!("Connecting to {}", self.get_endpoint_url());
1101
1102 *self.state.write().await = TransportState::Connecting;
1103
1104 self.start_sse_connection().await?;
1106
1107 let rx = self.endpoint_ready_rx.lock().await.take();
1113 if let Some(rx) = rx {
1114 match tokio::time::timeout(self.config.timeout, rx).await {
1115 Ok(_) => {
1116 }
1120 Err(_) => {
1121 return Err(TransportError::ConnectionFailed(format!(
1122 "SSE endpoint discovery timed out after {:?}",
1123 self.config.timeout
1124 )));
1125 }
1126 }
1127 }
1128
1129 *self.state.write().await = TransportState::Connected;
1130
1131 info!("Connected successfully");
1132 Ok(())
1133 })
1134 }
1135
1136 fn disconnect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
1137 Box::pin(async move {
1138 info!("Disconnecting");
1139
1140 *self.state.write().await = TransportState::Disconnecting;
1141
1142 if let Some(handle) = self.sse_task_handle.lock().await.take() {
1144 handle.abort();
1145 }
1146
1147 if let Some(session_id) = self.session_id.read().await.as_ref() {
1149 let url = self.get_endpoint_url();
1150 let mut headers = header::HeaderMap::new();
1151 if let Ok(session_value) = header::HeaderValue::from_str(session_id) {
1152 headers.insert("Mcp-Session-Id", session_value);
1153 }
1154
1155 let _ = self.http_client.delete(&url).headers(headers).send().await;
1156 }
1157
1158 *self.state.write().await = TransportState::Disconnected;
1159
1160 info!("Disconnected");
1161 Ok(())
1162 })
1163 }
1164}
1165
1166#[cfg(test)]
1167mod tests {
1168 use super::*;
1169
1170 #[test]
1171 fn test_retry_policy_fixed() {
1172 let policy = RetryPolicy::Fixed {
1173 interval: Duration::from_secs(5),
1174 max_attempts: Some(3),
1175 };
1176
1177 assert_eq!(policy.delay(0), Some(Duration::from_secs(5)));
1178 assert_eq!(policy.delay(1), Some(Duration::from_secs(5)));
1179 assert_eq!(policy.delay(2), Some(Duration::from_secs(5)));
1180 assert_eq!(policy.delay(3), None);
1181 }
1182
1183 #[test]
1184 fn test_retry_policy_exponential() {
1185 let policy = RetryPolicy::Exponential {
1186 base: Duration::from_secs(1),
1187 max_delay: Duration::from_secs(60),
1188 max_attempts: None,
1189 };
1190
1191 let delay0 = policy.delay(0).unwrap();
1194 assert!(delay0 >= Duration::from_millis(750) && delay0 <= Duration::from_millis(1250));
1195
1196 let delay1 = policy.delay(1).unwrap();
1197 assert!(delay1 >= Duration::from_millis(1500) && delay1 <= Duration::from_millis(2500));
1198
1199 let delay2 = policy.delay(2).unwrap();
1200 assert!(delay2 >= Duration::from_millis(3000) && delay2 <= Duration::from_millis(5000));
1201
1202 let delay3 = policy.delay(3).unwrap();
1203 assert!(delay3 >= Duration::from_millis(6000) && delay3 <= Duration::from_millis(10000));
1204
1205 let delay10 = policy.delay(10).unwrap();
1206 assert!(delay10 >= Duration::from_millis(45000) && delay10 <= Duration::from_millis(75000));
1208 }
1209
1210 #[tokio::test]
1211 async fn test_client_creation() {
1212 let config = StreamableHttpClientConfig::default();
1213 let client = StreamableHttpClientTransport::new(config).expect("default config builds");
1214
1215 assert_eq!(client.transport_type(), TransportType::Http);
1216 assert!(client.capabilities().supports_streaming);
1217 assert!(client.capabilities().supports_bidirectional);
1218 }
1219
1220 #[tokio::test]
1221 async fn test_endpoint_event_json_parsing() {
1222 use std::sync::Arc;
1226 use tokio::sync::RwLock;
1227
1228 let message_endpoint = Arc::new(RwLock::new(None::<String>));
1229
1230 let event_data = [r#"{"uri":"http://127.0.0.1:8080/mcp"}"#.to_string()];
1232 let data_str = event_data.join("\n");
1233
1234 let endpoint_uri = if data_str.trim().starts_with('{') {
1236 let endpoint_json: serde_json::Value =
1237 serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
1238 endpoint_json["uri"]
1239 .as_str()
1240 .expect("Missing uri field")
1241 .to_string()
1242 } else {
1243 data_str.clone()
1244 };
1245
1246 *message_endpoint.write().await = Some(endpoint_uri.clone());
1247
1248 let stored = message_endpoint.read().await;
1250 assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
1251 assert!(stored.as_ref().unwrap().starts_with("http://"));
1252
1253 assert!(stored.as_ref().unwrap().parse::<url::Url>().is_ok());
1255 }
1256
1257 #[tokio::test]
1258 async fn test_endpoint_event_plain_string_parsing() {
1259 use std::sync::Arc;
1262 use tokio::sync::RwLock;
1263
1264 let message_endpoint = Arc::new(RwLock::new(None::<String>));
1265
1266 let event_data = ["http://127.0.0.1:8080/mcp".to_string()];
1268 let data_str = event_data.join("\n");
1269
1270 let endpoint_uri = if data_str.trim().starts_with('{') {
1272 let endpoint_json: serde_json::Value =
1273 serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
1274 endpoint_json["uri"]
1275 .as_str()
1276 .expect("Missing uri field")
1277 .to_string()
1278 } else {
1279 data_str.clone()
1280 };
1281
1282 *message_endpoint.write().await = Some(endpoint_uri.clone());
1283
1284 let stored = message_endpoint.read().await;
1286 assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
1287 assert!(stored.as_ref().unwrap().starts_with("http://"));
1288 }
1289}