1use bytes::Bytes;
14use futures::StreamExt;
15use reqwest::{Client as HttpClient, header};
16use std::collections::HashMap;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::{Mutex, RwLock, mpsc};
22use tracing::{debug, error, info, warn};
23
24use turbomcp_protocol::MessageId;
25use turbomcp_transport_traits::{
26 LimitsConfig, TlsConfig, TlsVersion, Transport, TransportCapabilities, TransportError,
27 TransportEventEmitter, TransportMessage, TransportMetrics, TransportResult, TransportState,
28 TransportType, validate_request_size, validate_response_size,
29};
30
31#[derive(Clone, Debug)]
33pub enum RetryPolicy {
34 Fixed {
36 interval: Duration,
38 max_attempts: Option<u32>,
40 },
41 Exponential {
43 base: Duration,
45 max_delay: Duration,
47 max_attempts: Option<u32>,
49 },
50 Never,
52}
53
54impl Default for RetryPolicy {
55 fn default() -> Self {
56 Self::Exponential {
57 base: Duration::from_secs(1),
58 max_delay: Duration::from_secs(60),
59 max_attempts: Some(10),
60 }
61 }
62}
63
64impl RetryPolicy {
65 pub(crate) fn delay(&self, attempt: u32) -> Option<Duration> {
66 match self {
67 Self::Fixed {
68 interval,
69 max_attempts,
70 } => {
71 if let Some(max) = max_attempts
72 && attempt >= *max
73 {
74 return None;
75 }
76 Some(*interval)
77 }
78 Self::Exponential {
79 base,
80 max_delay,
81 max_attempts,
82 } => {
83 if let Some(max) = max_attempts
84 && attempt >= *max
85 {
86 return None;
87 }
88 let base_delay = base.as_millis() as u64 * 2u64.pow(attempt);
89 let max_delay_ms = max_delay.as_millis() as u64;
90 let capped = base_delay.min(max_delay_ms);
91 let jitter_range = capped / 4;
95 let jitter_offset = if jitter_range > 0 {
96 fastrand::u64(0..jitter_range * 2)
97 } else {
98 0
99 };
100 let final_delay = capped
101 .saturating_sub(jitter_range)
102 .saturating_add(jitter_offset);
103 Some(Duration::from_millis(final_delay))
104 }
105 Self::Never => None,
106 }
107 }
108}
109
110#[derive(Clone, Debug)]
112pub struct StreamableHttpClientConfig {
113 pub base_url: String,
115
116 pub endpoint_path: String,
118
119 pub timeout: Duration,
121
122 pub retry_policy: RetryPolicy,
124
125 pub auth_token: Option<String>,
127
128 pub headers: HashMap<String, String>,
130
131 pub user_agent: Option<String>,
142
143 pub protocol_version: String,
145
146 pub limits: LimitsConfig,
148
149 pub tls: TlsConfig,
151
152 pub sse_read_timeout: Duration,
159}
160
161impl Default for StreamableHttpClientConfig {
162 fn default() -> Self {
163 Self {
164 base_url: "http://localhost:8080".to_string(),
165 endpoint_path: "/mcp".to_string(),
166 timeout: Duration::from_secs(30),
167 retry_policy: RetryPolicy::default(),
168 auth_token: None,
169 headers: HashMap::new(),
170 user_agent: Some(format!("TurboMCP-Client/{}", env!("CARGO_PKG_VERSION"))),
171 protocol_version: "2025-11-25".to_string(),
172 limits: LimitsConfig::default(),
173 tls: TlsConfig::default(),
174 sse_read_timeout: Duration::from_secs(300),
175 }
176 }
177}
178
179pub struct StreamableHttpClientTransport {
181 config: StreamableHttpClientConfig,
182 http_client: HttpClient,
183 state: Arc<RwLock<TransportState>>,
184 capabilities: TransportCapabilities,
185 metrics: Arc<RwLock<TransportMetrics>>,
186 _event_emitter: TransportEventEmitter,
187
188 message_endpoint: Arc<RwLock<Option<String>>>,
194
195 session_id: Arc<RwLock<Option<String>>>,
197
198 last_event_id: Arc<RwLock<Option<String>>>,
200
201 sse_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
203 sse_sender: mpsc::Sender<TransportMessage>,
204
205 response_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
207 response_sender: mpsc::Sender<TransportMessage>,
208
209 sse_task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
211}
212
213impl std::fmt::Debug for StreamableHttpClientTransport {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 f.debug_struct("StreamableHttpClientTransport")
216 .field("base_url", &self.config.base_url)
217 .field("endpoint", &self.config.endpoint_path)
218 .finish()
219 }
220}
221
222impl StreamableHttpClientTransport {
223 pub fn new(config: StreamableHttpClientConfig) -> TransportResult<Self> {
230 let (sse_tx, sse_rx) = mpsc::channel(1000);
231 let (response_tx, response_rx) = mpsc::channel(100);
232 let (event_emitter, _) = TransportEventEmitter::new();
233
234 if config.tls.is_insecure() {
236 warn!(
237 "Certificate validation is disabled. This is insecure and should only be used \
238 for testing or in secure mTLS mesh environments. \
239 See https://turbomcp.org/docs/security/tls#certificate-validation"
240 );
241 }
242
243 let mut client_builder = HttpClient::builder()
248 .use_rustls_tls()
249 .timeout(config.timeout);
250
251 if config.auth_token.is_some() {
256 client_builder =
257 client_builder.redirect(reqwest::redirect::Policy::custom(|attempt| {
258 if attempt.previous().len() >= 10 {
259 return attempt.error("too many redirects");
260 }
261 let prev_origin = attempt.previous().last().map(reqwest::Url::origin);
262 if prev_origin.as_ref() == Some(&attempt.url().origin()) {
263 attempt.follow()
264 } else {
265 attempt.stop()
268 }
269 }));
270 }
271
272 if let Some(ref user_agent) = config.user_agent {
274 client_builder = client_builder.user_agent(user_agent);
275 }
276
277 client_builder = match config.tls.min_version {
279 TlsVersion::Tls13 => client_builder.min_tls_version(reqwest::tls::Version::TLS_1_3),
280 };
281
282 if !config.tls.validate_certificates {
284 const INSECURE_TLS_ENV_VAR: &str = "TURBOMCP_ALLOW_INSECURE_TLS";
287
288 if std::env::var(INSECURE_TLS_ENV_VAR).is_err() {
289 error!(
290 "SECURITY: Certificate validation disabled but {} not set. \
291 Overriding to validate_certificates=true for safety. \
292 Set {}=1 to allow insecure TLS.",
293 INSECURE_TLS_ENV_VAR, INSECURE_TLS_ENV_VAR
294 );
295 } else {
298 warn!(
299 "SECURITY WARNING: TLS certificate validation is DISABLED. \
300 This configuration is INSECURE and should ONLY be used: \
301 (1) In development/testing environments, or \
302 (2) In secure mTLS mesh where validation happens elsewhere. \
303 NEVER use in production connecting to untrusted servers."
304 );
305
306 client_builder = client_builder.danger_accept_invalid_certs(true);
307 }
308 }
309
310 if let Some(ca_certs) = &config.tls.custom_ca_certs {
312 let mut loaded = 0usize;
313 let total = ca_certs.len();
314 for cert_bytes in ca_certs {
315 if let Ok(cert) = reqwest::Certificate::from_pem(cert_bytes) {
317 client_builder = client_builder.add_root_certificate(cert);
318 loaded += 1;
319 } else if let Ok(cert) = reqwest::Certificate::from_der(cert_bytes) {
320 client_builder = client_builder.add_root_certificate(cert);
321 loaded += 1;
322 } else {
323 warn!(
324 "Failed to parse custom CA certificate ({}/{}), skipping",
325 loaded + 1,
326 total
327 );
328 }
329 }
330 if loaded == 0 && total > 0 {
331 error!("All {} custom CA certificates failed to parse", total);
332 }
334 if loaded > 0 {
335 info!("Loaded {}/{} custom CA certificates", loaded, total);
336 }
337 }
338
339 let http_client = client_builder.build().map_err(|e| {
340 TransportError::ConfigurationError(format!(
341 "Failed to build HTTP client (likely bad TLS configuration): {e}"
342 ))
343 })?;
344
345 Ok(Self {
346 config,
347 http_client,
348 state: Arc::new(RwLock::new(TransportState::Disconnected)),
349 capabilities: TransportCapabilities {
350 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
351 supports_compression: false,
352 supports_streaming: true,
353 supports_bidirectional: true,
354 supports_multiplexing: false,
355 compression_algorithms: Vec::new(),
356 custom: HashMap::new(),
357 },
358 metrics: Arc::new(RwLock::new(TransportMetrics::default())),
359 _event_emitter: event_emitter,
360 message_endpoint: Arc::new(RwLock::new(None)),
361 session_id: Arc::new(RwLock::new(None)),
362 last_event_id: Arc::new(RwLock::new(None)),
363 sse_receiver: Arc::new(Mutex::new(sse_rx)),
364 sse_sender: sse_tx,
365 response_receiver: Arc::new(Mutex::new(response_rx)),
366 response_sender: response_tx,
367 sse_task_handle: Arc::new(Mutex::new(None)),
368 })
369 }
370
371 fn get_endpoint_url(&self) -> String {
373 format!("{}{}", self.config.base_url, self.config.endpoint_path)
374 }
375
376 async fn get_message_endpoint_url(&self) -> String {
378 let discovered = self.message_endpoint.read().await;
379 if let Some(endpoint) = discovered.as_ref() {
380 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
381 endpoint.clone()
382 } else if endpoint.starts_with('/') {
383 format!("{}{}", self.config.base_url, endpoint)
384 } else {
385 format!("{}/{}", self.config.base_url, endpoint)
386 }
387 } else {
388 self.get_endpoint_url()
389 }
390 }
391
392 async fn build_headers(&self, accept: &str) -> header::HeaderMap {
394 let mut headers = header::HeaderMap::new();
395
396 if let Ok(accept_value) = header::HeaderValue::from_str(accept) {
398 headers.insert(header::ACCEPT, accept_value);
399 }
400
401 if let Ok(protocol_value) = header::HeaderValue::from_str(&self.config.protocol_version) {
402 headers.insert("MCP-Protocol-Version", protocol_value);
403 }
404
405 if let Some(session_id) = self.session_id.read().await.as_ref()
406 && let Ok(session_value) = header::HeaderValue::from_str(session_id)
407 {
408 headers.insert("Mcp-Session-Id", session_value);
409 }
410
411 if let Some(last_event_id) = self.last_event_id.read().await.as_ref()
412 && let Ok(event_value) = header::HeaderValue::from_str(last_event_id)
413 {
414 headers.insert("Last-Event-ID", event_value);
415 }
416
417 if let Some(token) = &self.config.auth_token
418 && let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
419 {
420 headers.insert(header::AUTHORIZATION, auth_value);
421 }
422
423 for (key, value) in &self.config.headers {
424 if let (Ok(k), Ok(v)) = (
425 header::HeaderName::from_bytes(key.as_bytes()),
426 header::HeaderValue::from_str(value),
427 ) {
428 headers.insert(k, v);
429 }
430 }
431
432 headers
433 }
434
435 async fn start_sse_connection(&self) -> TransportResult<()> {
437 if self.session_id.read().await.is_none() {
438 debug!("Deferring SSE connection until server provides a session ID");
439 return Ok(());
440 }
441
442 let mut task_handle = self.sse_task_handle.lock().await;
443 if let Some(handle) = task_handle.as_ref()
444 && !handle.is_finished()
445 {
446 debug!("SSE connection task already running");
447 return Ok(());
448 }
449
450 info!("Starting SSE connection to {}", self.get_endpoint_url());
451
452 let endpoint_url = self.get_endpoint_url();
453 let config = self.config.clone();
454 let http_client = self.http_client.clone();
455 let state = Arc::clone(&self.state);
456 let sse_sender = self.sse_sender.clone();
457 let session_id = Arc::clone(&self.session_id);
458 let last_event_id = Arc::clone(&self.last_event_id);
459 let message_endpoint = Arc::clone(&self.message_endpoint);
460
461 let task = tokio::spawn(async move {
462 Self::sse_connection_task(
463 endpoint_url,
464 config,
465 http_client,
466 state,
467 sse_sender,
468 session_id,
469 last_event_id,
470 message_endpoint,
471 )
472 .await;
473 });
474
475 *task_handle = Some(task);
476
477 Ok(())
478 }
479
480 #[allow(clippy::too_many_arguments)]
482 async fn sse_connection_task(
483 endpoint_url: String,
484 config: StreamableHttpClientConfig,
485 http_client: HttpClient,
486 state: Arc<RwLock<TransportState>>,
487 sse_sender: mpsc::Sender<TransportMessage>,
488 session_id: Arc<RwLock<Option<String>>>,
489 last_event_id: Arc<RwLock<Option<String>>>,
490 message_endpoint: Arc<RwLock<Option<String>>>,
491 ) {
492 let mut attempt = 0u32;
493
494 loop {
495 if let Some(delay) = config.retry_policy.delay(attempt) {
497 if attempt > 0 {
498 warn!("Reconnecting in {:?} (attempt {})", delay, attempt + 1);
499 tokio::time::sleep(delay).await;
500 }
501 } else {
502 error!("Max retry attempts reached, giving up");
503 *state.write().await = TransportState::Disconnected;
504 break;
505 }
506
507 let mut headers = header::HeaderMap::new();
509 headers.insert(
510 header::ACCEPT,
511 header::HeaderValue::from_static("text/event-stream"),
512 );
513
514 if let Ok(protocol_value) = header::HeaderValue::from_str(&config.protocol_version) {
515 headers.insert("MCP-Protocol-Version", protocol_value);
516 }
517
518 if let Some(sid) = session_id.read().await.as_ref()
519 && let Ok(session_value) = header::HeaderValue::from_str(sid)
520 {
521 headers.insert("Mcp-Session-Id", session_value);
522 }
523
524 if let Some(last_id) = last_event_id.read().await.as_ref()
525 && let Ok(event_value) = header::HeaderValue::from_str(last_id)
526 {
527 headers.insert("Last-Event-ID", event_value);
528 }
529
530 if let Some(token) = &config.auth_token
531 && let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
532 {
533 headers.insert(header::AUTHORIZATION, auth_value);
534 }
535
536 match http_client.get(&endpoint_url).headers(headers).send().await {
538 Ok(response) => {
539 if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
540 info!(
541 "Server returned HTTP 405 for GET {}. Continuing without standalone SSE polling.",
542 endpoint_url
543 );
544 break;
545 }
546
547 if !response.status().is_success() {
548 error!("SSE connection failed: {}", response.status());
549 attempt += 1;
550 continue;
551 }
552
553 if let Some(sid) = response
555 .headers()
556 .get("Mcp-Session-Id")
557 .and_then(|v| v.to_str().ok())
558 {
559 *session_id.write().await = Some(sid.to_string());
560 info!("Received session ID: {}", sid);
561 }
562
563 info!("SSE connection established");
564 *state.write().await = TransportState::Connected;
565 attempt = 0; let mut stream = response.bytes_stream();
569 let mut buffer = String::new();
570 let read_timeout = config.sse_read_timeout;
571 let buffer_cap = config
575 .limits
576 .enforce_on_streams
577 .then_some(config.limits.max_response_size)
578 .flatten();
579
580 'sse_loop: loop {
581 let chunk_result =
582 match tokio::time::timeout(read_timeout, stream.next()).await {
583 Ok(Some(r)) => r,
584 Ok(None) => break,
585 Err(_) => {
586 warn!(
587 "SSE read idle for {:?}; closing stream to reconnect",
588 read_timeout
589 );
590 break;
591 }
592 };
593 match chunk_result {
594 Ok(chunk) => {
595 let chunk_str = String::from_utf8_lossy(&chunk);
596 buffer.push_str(&chunk_str);
597
598 while let Some(pos) = buffer.find("\n\n") {
600 let event_str = buffer[..pos].to_string();
601 buffer = buffer[pos + 2..].to_string();
602
603 if let Err(e) = Self::process_sse_event(
604 &event_str,
605 &sse_sender,
606 &last_event_id,
607 &message_endpoint,
608 )
609 .await
610 {
611 warn!("Failed to process SSE event: {}", e);
612 }
613 }
614
615 if let Some(cap) = buffer_cap
616 && buffer.len() > cap
617 {
618 error!(
619 "SSE event buffer exceeded {} bytes without an event \
620 boundary; closing stream to avoid OOM",
621 cap
622 );
623 break 'sse_loop;
624 }
625 }
626 Err(e) => {
627 error!("Error reading SSE stream: {}", e);
628 break;
629 }
630 }
631 }
632
633 warn!("SSE stream ended");
634 *state.write().await = TransportState::Disconnected;
635 }
636 Err(e) => {
637 error!("Failed to connect: {}", e);
638 attempt += 1;
639 }
640 }
641 }
642 }
643
644 async fn process_sse_event(
646 event_str: &str,
647 sse_sender: &mpsc::Sender<TransportMessage>,
648 last_event_id: &Arc<RwLock<Option<String>>>,
649 message_endpoint: &Arc<RwLock<Option<String>>>,
650 ) -> TransportResult<()> {
651 let lines: Vec<&str> = event_str.lines().collect();
652 let mut event_type: Option<String> = None;
653 let mut event_data: Vec<String> = Vec::new();
654 let mut event_id: Option<String> = None;
655
656 for line in lines {
657 if line.is_empty() {
658 continue;
659 }
660
661 if let Some(colon_pos) = line.find(':') {
662 let field = &line[..colon_pos];
663 let value = line[colon_pos + 1..].trim_start();
664
665 match field {
666 "event" => event_type = Some(value.to_string()),
667 "data" => event_data.push(value.to_string()),
668 "id" => event_id = Some(value.to_string()),
669 _ => {}
670 }
671 }
672 }
673
674 if let Some(id) = event_id {
676 *last_event_id.write().await = Some(id);
677 }
678
679 if event_data.is_empty() {
680 return Ok(());
681 }
682
683 let data_str = event_data.join("\n");
684
685 match event_type.as_deref() {
687 Some("endpoint") => {
688 let endpoint_uri = if data_str.trim().starts_with('{') {
696 let endpoint_json: serde_json::Value = serde_json::from_str(&data_str)
698 .map_err(|e| {
699 TransportError::SerializationFailed(format!(
700 "Invalid endpoint JSON: {}",
701 e
702 ))
703 })?;
704 endpoint_json["uri"]
705 .as_str()
706 .ok_or_else(|| {
707 TransportError::SerializationFailed(
708 "Endpoint event missing 'uri' field".to_string(),
709 )
710 })?
711 .to_string()
712 } else {
713 data_str.clone()
715 };
716
717 info!("Discovered message endpoint: {}", endpoint_uri);
718 *message_endpoint.write().await = Some(endpoint_uri);
719 Ok(())
720 }
721 Some("message") | None => {
722 if data_str.trim().is_empty() {
725 debug!("Skipping empty SSE event");
726 return Ok(());
727 }
728
729 let json_value: serde_json::Value =
731 serde_json::from_str(&data_str).map_err(|e| {
732 TransportError::SerializationFailed(format!("Invalid JSON: {}", e))
733 })?;
734
735 let message = TransportMessage::new(
736 MessageId::from("sse-message".to_string()),
737 Bytes::from(
738 serde_json::to_vec(&json_value)
739 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
740 ),
741 );
742
743 sse_sender
744 .send(message)
745 .await
746 .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
747
748 debug!("Received SSE message");
749 Ok(())
750 }
751 Some(other) => {
752 debug!("Ignoring unknown event type: {}", other);
753 Ok(())
754 }
755 }
756 }
757
758 async fn process_post_sse_event(
760 event_str: &str,
761 response_sender: &mpsc::Sender<TransportMessage>,
762 last_event_id: &Arc<RwLock<Option<String>>>,
763 ) -> TransportResult<()> {
764 let lines: Vec<&str> = event_str.lines().collect();
765 let mut event_data: Vec<String> = Vec::new();
766 let mut event_id: Option<String> = None;
767
768 for line in lines {
769 if line.is_empty() {
770 continue;
771 }
772
773 if let Some(colon_pos) = line.find(':') {
774 let field = &line[..colon_pos];
775 let value = line[colon_pos + 1..].trim_start();
776
777 match field {
778 "data" => event_data.push(value.to_string()),
779 "id" => event_id = Some(value.to_string()),
780 "event" => {
781 }
784 _ => {}
785 }
786 }
787 }
788
789 if let Some(id) = event_id {
791 *last_event_id.write().await = Some(id);
792 }
793
794 if event_data.is_empty() {
795 return Ok(());
796 }
797
798 let data_str = event_data.join("\n");
799 if data_str.trim().is_empty() {
800 debug!("Skipping empty POST SSE event");
801 return Ok(());
802 }
803
804 let json_value: serde_json::Value = serde_json::from_str(&data_str).map_err(|e| {
806 TransportError::SerializationFailed(format!("Invalid JSON in POST SSE: {}", e))
807 })?;
808
809 let message = TransportMessage::new(
810 MessageId::from("post-sse-response".to_string()),
811 Bytes::from(
812 serde_json::to_vec(&json_value)
813 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
814 ),
815 );
816
817 response_sender
818 .send(message.clone())
819 .await
820 .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
821
822 debug!(
823 "Queued message from POST SSE stream: {}",
824 String::from_utf8_lossy(&message.payload)
825 );
826 Ok(())
827 }
828
829 pub async fn recv_async(&self) -> TransportResult<TransportMessage> {
838 let mut response_receiver = self.response_receiver.lock().await;
839 let mut sse_receiver = self.sse_receiver.lock().await;
840 let message = tokio::select! {
841 biased;
842 msg = response_receiver.recv() => msg.ok_or_else(|| {
845 TransportError::ConnectionLost("Response channel disconnected".to_string())
846 })?,
847 msg = sse_receiver.recv() => msg.ok_or_else(|| {
848 TransportError::ConnectionLost("SSE channel disconnected".to_string())
849 })?,
850 };
851 let mut metrics = self.metrics.write().await;
852 metrics.messages_received += 1;
853 metrics.bytes_received += message.payload.len() as u64;
854 Ok(message)
855 }
856}
857
858impl Transport for StreamableHttpClientTransport {
859 fn send(
860 &self,
861 message: TransportMessage,
862 ) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
863 Box::pin(async move {
864 debug!("Sending message via HTTP POST");
865
866 validate_request_size(message.payload.len(), &self.config.limits)?;
868
869 let url = self.get_message_endpoint_url().await;
871
872 let headers = self
874 .build_headers("application/json, text/event-stream")
875 .await;
876
877 let response = self
879 .http_client
880 .post(&url)
881 .headers(headers)
882 .header(header::CONTENT_TYPE, "application/json")
883 .body(message.payload.to_vec())
884 .send()
885 .await
886 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
887
888 if !response.status().is_success() {
889 return Err(TransportError::ConnectionFailed(format!(
890 "POST failed: {}",
891 response.status()
892 )));
893 }
894
895 if let Some(session_id) = response
897 .headers()
898 .get("Mcp-Session-Id")
899 .and_then(|v| v.to_str().ok())
900 {
901 *self.session_id.write().await = Some(session_id.to_string());
902 self.start_sse_connection().await?;
903 }
904
905 if response.status() == reqwest::StatusCode::ACCEPTED {
907 debug!("Received HTTP 202 Accepted (no response body expected)");
908 {
910 let mut metrics = self.metrics.write().await;
911 metrics.messages_sent += 1;
912 metrics.bytes_sent += message.payload.len() as u64;
913 }
914 return Ok(());
915 }
916
917 let content_type = response
919 .headers()
920 .get(header::CONTENT_TYPE)
921 .and_then(|v| v.to_str().ok())
922 .unwrap_or("");
923
924 if content_type.contains("application/json") {
925 debug!("Received JSON response from POST");
927
928 let response_bytes = response
929 .bytes()
930 .await
931 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
932
933 validate_response_size(response_bytes.len(), &self.config.limits)?;
935
936 let response_message = TransportMessage::new(
937 MessageId::from("http-response".to_string()),
938 response_bytes,
939 );
940
941 self.response_sender
943 .send(response_message)
944 .await
945 .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
946
947 debug!("JSON response queued successfully");
948 } else if content_type.contains("text/event-stream") {
949 debug!("Received SSE stream response from POST, processing events");
952
953 let response_sender = self.response_sender.clone();
954 let last_event_id = Arc::clone(&self.last_event_id);
955
956 let mut stream = response.bytes_stream();
958 let mut buffer = String::new();
959 let buffer_cap = self
962 .config
963 .limits
964 .enforce_on_streams
965 .then_some(self.config.limits.max_response_size)
966 .flatten();
967
968 'post_sse_loop: while let Some(chunk_result) = stream.next().await {
969 match chunk_result {
970 Ok(chunk) => {
971 let chunk_str = String::from_utf8_lossy(&chunk);
972 buffer.push_str(&chunk_str);
973
974 while let Some(pos) = buffer.find("\n\n") {
976 let event_str = buffer[..pos].to_string();
977 buffer = buffer[pos + 2..].to_string();
978
979 if let Err(e) = Self::process_post_sse_event(
980 &event_str,
981 &response_sender,
982 &last_event_id,
983 )
984 .await
985 {
986 warn!("Failed to process POST SSE event: {}", e);
987 }
988 }
989
990 if let Some(cap) = buffer_cap
991 && buffer.len() > cap
992 {
993 error!(
994 "POST SSE event buffer exceeded {} bytes without an event \
995 boundary; closing stream to avoid OOM",
996 cap
997 );
998 break 'post_sse_loop;
999 }
1000 }
1001 Err(e) => {
1002 warn!("Error reading POST SSE stream: {}", e);
1003 break;
1004 }
1005 }
1006 }
1007 debug!("POST SSE stream processing completed");
1008 }
1009
1010 {
1012 let mut metrics = self.metrics.write().await;
1013 metrics.messages_sent += 1;
1014 metrics.bytes_sent += message.payload.len() as u64;
1015 }
1016
1017 debug!("Message sent successfully");
1018 Ok(())
1019 })
1020 }
1021
1022 fn receive(
1029 &self,
1030 ) -> Pin<Box<dyn Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_>> {
1031 Box::pin(async move {
1032 {
1035 let mut response_receiver = self.response_receiver.lock().await;
1036 match response_receiver.try_recv() {
1037 Ok(message) => {
1038 debug!("Received queued JSON response");
1039 {
1041 let mut metrics = self.metrics.write().await;
1042 metrics.messages_received += 1;
1043 metrics.bytes_received += message.payload.len() as u64;
1044 }
1045 return Ok(Some(message));
1046 }
1047 Err(mpsc::error::TryRecvError::Empty) => {
1048 }
1050 Err(mpsc::error::TryRecvError::Disconnected) => {
1051 return Err(TransportError::ConnectionLost(
1052 "Response channel disconnected".to_string(),
1053 ));
1054 }
1055 }
1056 }
1057
1058 let mut sse_receiver = self.sse_receiver.lock().await;
1060 match sse_receiver.try_recv() {
1061 Ok(message) => {
1062 debug!("Received SSE message");
1063 {
1065 let mut metrics = self.metrics.write().await;
1066 metrics.messages_received += 1;
1067 metrics.bytes_received += message.payload.len() as u64;
1068 }
1069 Ok(Some(message))
1070 }
1071 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
1072 Err(mpsc::error::TryRecvError::Disconnected) => Err(
1073 TransportError::ConnectionLost("SSE channel disconnected".to_string()),
1074 ),
1075 }
1076 })
1077 }
1078
1079 fn capabilities(&self) -> &TransportCapabilities {
1080 &self.capabilities
1081 }
1082
1083 fn state(&self) -> Pin<Box<dyn Future<Output = TransportState> + Send + '_>> {
1084 Box::pin(async move { self.state.read().await.clone() })
1085 }
1086
1087 fn transport_type(&self) -> TransportType {
1088 TransportType::Http
1089 }
1090
1091 fn metrics(&self) -> Pin<Box<dyn Future<Output = TransportMetrics> + Send + '_>> {
1092 Box::pin(async move { self.metrics.read().await.clone() })
1093 }
1094
1095 fn connect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
1096 Box::pin(async move {
1097 info!("Connecting to {}", self.get_endpoint_url());
1098
1099 *self.state.write().await = TransportState::Connecting;
1100
1101 self.start_sse_connection().await?;
1103
1104 *self.state.write().await = TransportState::Connected;
1105
1106 info!("Connected successfully");
1107 Ok(())
1108 })
1109 }
1110
1111 fn disconnect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
1112 Box::pin(async move {
1113 info!("Disconnecting");
1114
1115 *self.state.write().await = TransportState::Disconnecting;
1116
1117 if let Some(handle) = self.sse_task_handle.lock().await.take() {
1119 handle.abort();
1120 }
1121
1122 if let Some(session_id) = self.session_id.read().await.as_ref() {
1124 let url = self.get_endpoint_url();
1125 let mut headers = header::HeaderMap::new();
1126 if let Ok(session_value) = header::HeaderValue::from_str(session_id) {
1127 headers.insert("Mcp-Session-Id", session_value);
1128 }
1129
1130 let _ = self.http_client.delete(&url).headers(headers).send().await;
1131 }
1132
1133 *self.state.write().await = TransportState::Disconnected;
1134
1135 info!("Disconnected");
1136 Ok(())
1137 })
1138 }
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143 use super::*;
1144
1145 #[test]
1146 fn test_retry_policy_fixed() {
1147 let policy = RetryPolicy::Fixed {
1148 interval: Duration::from_secs(5),
1149 max_attempts: Some(3),
1150 };
1151
1152 assert_eq!(policy.delay(0), Some(Duration::from_secs(5)));
1153 assert_eq!(policy.delay(1), Some(Duration::from_secs(5)));
1154 assert_eq!(policy.delay(2), Some(Duration::from_secs(5)));
1155 assert_eq!(policy.delay(3), None);
1156 }
1157
1158 #[test]
1159 fn test_retry_policy_exponential() {
1160 let policy = RetryPolicy::Exponential {
1161 base: Duration::from_secs(1),
1162 max_delay: Duration::from_secs(60),
1163 max_attempts: None,
1164 };
1165
1166 let delay0 = policy.delay(0).unwrap();
1169 assert!(delay0 >= Duration::from_millis(750) && delay0 <= Duration::from_millis(1250));
1170
1171 let delay1 = policy.delay(1).unwrap();
1172 assert!(delay1 >= Duration::from_millis(1500) && delay1 <= Duration::from_millis(2500));
1173
1174 let delay2 = policy.delay(2).unwrap();
1175 assert!(delay2 >= Duration::from_millis(3000) && delay2 <= Duration::from_millis(5000));
1176
1177 let delay3 = policy.delay(3).unwrap();
1178 assert!(delay3 >= Duration::from_millis(6000) && delay3 <= Duration::from_millis(10000));
1179
1180 let delay10 = policy.delay(10).unwrap();
1181 assert!(delay10 >= Duration::from_millis(45000) && delay10 <= Duration::from_millis(75000));
1183 }
1184
1185 #[tokio::test]
1186 async fn test_client_creation() {
1187 let config = StreamableHttpClientConfig::default();
1188 let client = StreamableHttpClientTransport::new(config).expect("default config builds");
1189
1190 assert_eq!(client.transport_type(), TransportType::Http);
1191 assert!(client.capabilities().supports_streaming);
1192 assert!(client.capabilities().supports_bidirectional);
1193 }
1194
1195 #[tokio::test]
1196 async fn test_endpoint_event_json_parsing() {
1197 use std::sync::Arc;
1201 use tokio::sync::RwLock;
1202
1203 let message_endpoint = Arc::new(RwLock::new(None::<String>));
1204
1205 let event_data = [r#"{"uri":"http://127.0.0.1:8080/mcp"}"#.to_string()];
1207 let data_str = event_data.join("\n");
1208
1209 let endpoint_uri = if data_str.trim().starts_with('{') {
1211 let endpoint_json: serde_json::Value =
1212 serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
1213 endpoint_json["uri"]
1214 .as_str()
1215 .expect("Missing uri field")
1216 .to_string()
1217 } else {
1218 data_str.clone()
1219 };
1220
1221 *message_endpoint.write().await = Some(endpoint_uri.clone());
1222
1223 let stored = message_endpoint.read().await;
1225 assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
1226 assert!(stored.as_ref().unwrap().starts_with("http://"));
1227
1228 assert!(stored.as_ref().unwrap().parse::<url::Url>().is_ok());
1230 }
1231
1232 #[tokio::test]
1233 async fn test_endpoint_event_plain_string_parsing() {
1234 use std::sync::Arc;
1237 use tokio::sync::RwLock;
1238
1239 let message_endpoint = Arc::new(RwLock::new(None::<String>));
1240
1241 let event_data = ["http://127.0.0.1:8080/mcp".to_string()];
1243 let data_str = event_data.join("\n");
1244
1245 let endpoint_uri = if data_str.trim().starts_with('{') {
1247 let endpoint_json: serde_json::Value =
1248 serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
1249 endpoint_json["uri"]
1250 .as_str()
1251 .expect("Missing uri field")
1252 .to_string()
1253 } else {
1254 data_str.clone()
1255 };
1256
1257 *message_endpoint.write().await = Some(endpoint_uri.clone());
1258
1259 let stored = message_endpoint.read().await;
1261 assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
1262 assert!(stored.as_ref().unwrap().starts_with("http://"));
1263 }
1264
1265 #[tokio::test]
1266 async fn test_post_sse_whitespace_data_event_is_ignored() {
1267 let (tx, mut rx) = mpsc::channel(1);
1268 let last_event_id = Arc::new(RwLock::new(None));
1269
1270 StreamableHttpClientTransport::process_post_sse_event(
1271 "id: primer-1\nevent: message\ndata: \n",
1272 &tx,
1273 &last_event_id,
1274 )
1275 .await
1276 .expect("whitespace POST SSE event should be ignored");
1277
1278 assert_eq!(last_event_id.read().await.as_deref(), Some("primer-1"));
1279 assert!(rx.try_recv().is_err());
1280 }
1281
1282 #[tokio::test]
1283 async fn test_post_sse_json_event_is_queued() {
1284 let (tx, mut rx) = mpsc::channel(1);
1285 let last_event_id = Arc::new(RwLock::new(None));
1286
1287 StreamableHttpClientTransport::process_post_sse_event(
1288 "id: msg-1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\n",
1289 &tx,
1290 &last_event_id,
1291 )
1292 .await
1293 .expect("valid POST SSE event should be queued");
1294
1295 assert_eq!(last_event_id.read().await.as_deref(), Some("msg-1"));
1296 let message = rx.try_recv().expect("queued message");
1297 let value: serde_json::Value =
1298 serde_json::from_slice(&message.payload).expect("valid queued JSON");
1299 assert_eq!(value["jsonrpc"], "2.0");
1300 }
1301}