Skip to main content

turbomcp_http/
transport.rs

1//! MCP 2025-11-25 Compliant Streamable HTTP Client - Standard Implementation
2//!
3//! This client provides **strict MCP 2025-11-25 specification compliance** with:
4//! - Single MCP endpoint for all communication
5//! - Accept header negotiation (application/json, text/event-stream)
6//! - Handles SSE responses from POST requests
7//! - Backward-compatible handling for legacy SSE "endpoint" events
8//! - Auto-reconnect with exponential backoff
9//! - Last-Event-ID resumability
10//! - Session management with Mcp-Session-Id
11//! - Protocol version headers
12
13use bytes::Bytes;
14use futures::StreamExt;
15use reqwest::{Client as HttpClient, header};
16use std::collections::HashMap;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::{Mutex, RwLock, mpsc};
22use tracing::{debug, error, info, warn};
23
24use turbomcp_protocol::MessageId;
25use turbomcp_transport_traits::{
26    LimitsConfig, TlsConfig, TlsVersion, Transport, TransportCapabilities, TransportError,
27    TransportEventEmitter, TransportMessage, TransportMetrics, TransportResult, TransportState,
28    TransportType, validate_request_size, validate_response_size,
29};
30
31/// Retry policy for auto-reconnect
32#[derive(Clone, Debug)]
33pub enum RetryPolicy {
34    /// Fixed interval between retries
35    Fixed {
36        /// Time interval between retry attempts
37        interval: Duration,
38        /// Maximum number of retry attempts (None for unlimited)
39        max_attempts: Option<u32>,
40    },
41    /// Exponential backoff
42    Exponential {
43        /// Base delay for exponential backoff calculation
44        base: Duration,
45        /// Maximum delay between retry attempts
46        max_delay: Duration,
47        /// Maximum number of retry attempts (None for unlimited)
48        max_attempts: Option<u32>,
49    },
50    /// Never retry
51    Never,
52}
53
54impl Default for RetryPolicy {
55    fn default() -> Self {
56        Self::Exponential {
57            base: Duration::from_secs(1),
58            max_delay: Duration::from_secs(60),
59            max_attempts: Some(10),
60        }
61    }
62}
63
64impl RetryPolicy {
65    pub(crate) fn delay(&self, attempt: u32) -> Option<Duration> {
66        match self {
67            Self::Fixed {
68                interval,
69                max_attempts,
70            } => {
71                if let Some(max) = max_attempts
72                    && attempt >= *max
73                {
74                    return None;
75                }
76                Some(*interval)
77            }
78            Self::Exponential {
79                base,
80                max_delay,
81                max_attempts,
82            } => {
83                if let Some(max) = max_attempts
84                    && attempt >= *max
85                {
86                    return None;
87                }
88                let base_delay = base.as_millis() as u64 * 2u64.pow(attempt);
89                let max_delay_ms = max_delay.as_millis() as u64;
90                let capped = base_delay.min(max_delay_ms);
91                // Add ±25% jitter to prevent thundering herd. Sourced per-instance
92                // from `fastrand` so concurrent clients on the same attempt number
93                // do not produce identical delays.
94                let jitter_range = capped / 4;
95                let jitter_offset = if jitter_range > 0 {
96                    fastrand::u64(0..jitter_range * 2)
97                } else {
98                    0
99                };
100                let final_delay = capped
101                    .saturating_sub(jitter_range)
102                    .saturating_add(jitter_offset);
103                Some(Duration::from_millis(final_delay))
104            }
105            Self::Never => None,
106        }
107    }
108}
109
110/// Streamable HTTP client configuration
111#[derive(Clone, Debug)]
112pub struct StreamableHttpClientConfig {
113    /// Base URL (e.g., <https://api.example.com>)
114    pub base_url: String,
115
116    /// MCP endpoint path (e.g., "/mcp")
117    pub endpoint_path: String,
118
119    /// Request timeout
120    pub timeout: Duration,
121
122    /// Auto-reconnect policy
123    pub retry_policy: RetryPolicy,
124
125    /// Authentication token
126    pub auth_token: Option<String>,
127
128    /// Custom headers
129    pub headers: HashMap<String, String>,
130
131    /// User agent string (set to None to disable User-Agent header)
132    ///
133    /// Default: `TurboMCP-Client/{version}`
134    ///
135    /// # Security Note
136    ///
137    /// The User-Agent header can expose client version information. Consider:
138    /// - Setting to `None` to disable User-Agent header entirely
139    /// - Using a generic string like "MCP-Client" to minimize fingerprinting
140    /// - Keeping the default to aid server-side debugging and analytics
141    pub user_agent: Option<String>,
142
143    /// Protocol version to use
144    pub protocol_version: String,
145
146    /// Size limits for requests and responses (v2.2.0+)
147    pub limits: LimitsConfig,
148
149    /// TLS/HTTPS configuration (v2.2.0+)
150    pub tls: TlsConfig,
151
152    /// Idle timeout between SSE chunks.
153    ///
154    /// Guards against a silent TCP half-open where the server stops writing
155    /// without closing the connection. If no chunk arrives within this window,
156    /// the SSE task breaks and the reconnect loop takes over. Set generously —
157    /// the SSE protocol tolerates long idle periods between events. Default: 5 minutes.
158    pub sse_read_timeout: Duration,
159}
160
161impl Default for StreamableHttpClientConfig {
162    fn default() -> Self {
163        Self {
164            base_url: "http://localhost:8080".to_string(),
165            endpoint_path: "/mcp".to_string(),
166            timeout: Duration::from_secs(30),
167            retry_policy: RetryPolicy::default(),
168            auth_token: None,
169            headers: HashMap::new(),
170            user_agent: Some(format!("TurboMCP-Client/{}", env!("CARGO_PKG_VERSION"))),
171            protocol_version: "2025-11-25".to_string(),
172            limits: LimitsConfig::default(),
173            tls: TlsConfig::default(),
174            sse_read_timeout: Duration::from_secs(300),
175        }
176    }
177}
178
179/// Streamable HTTP client transport
180pub struct StreamableHttpClientTransport {
181    config: StreamableHttpClientConfig,
182    http_client: HttpClient,
183    state: Arc<RwLock<TransportState>>,
184    capabilities: TransportCapabilities,
185    metrics: Arc<RwLock<TransportMetrics>>,
186    _event_emitter: TransportEventEmitter,
187
188    /// Legacy SSE message endpoint if a server sends an `endpoint` event.
189    ///
190    /// MCP 2025-11-25 Streamable HTTP uses a single MCP endpoint for POST and GET.
191    /// The `endpoint` SSE event belongs to the older HTTP+SSE transport, but keeping
192    /// this optional override lets the client interoperate with legacy servers.
193    message_endpoint: Arc<RwLock<Option<String>>>,
194
195    /// Session ID from server
196    session_id: Arc<RwLock<Option<String>>>,
197
198    /// Last event ID for resumability
199    last_event_id: Arc<RwLock<Option<String>>>,
200
201    /// Channel for incoming SSE messages
202    sse_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
203    sse_sender: mpsc::Sender<TransportMessage>,
204
205    /// Channel for immediate JSON responses from POST requests
206    response_receiver: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
207    response_sender: mpsc::Sender<TransportMessage>,
208
209    /// SSE connection task handle
210    sse_task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
211}
212
213impl std::fmt::Debug for StreamableHttpClientTransport {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        f.debug_struct("StreamableHttpClientTransport")
216            .field("base_url", &self.config.base_url)
217            .field("endpoint", &self.config.endpoint_path)
218            .finish()
219    }
220}
221
222impl StreamableHttpClientTransport {
223    /// Create a new streamable HTTP client transport.
224    ///
225    /// Returns an error if the underlying HTTP client cannot be built — most often a
226    /// bad TLS configuration (e.g., custom CA certificates that won't load against the
227    /// platform verifier). Pre-3.1 this was an `expect` and would panic the calling
228    /// process; v3.1 propagates it instead.
229    pub fn new(config: StreamableHttpClientConfig) -> TransportResult<Self> {
230        let (sse_tx, sse_rx) = mpsc::channel(1000);
231        let (response_tx, response_rx) = mpsc::channel(100);
232        let (event_emitter, _) = TransportEventEmitter::new();
233
234        // Emit insecurity warning if certificate validation is disabled
235        if config.tls.is_insecure() {
236            warn!(
237                "Certificate validation is disabled. This is insecure and should only be used \
238                 for testing or in secure mTLS mesh environments. \
239                 See https://turbomcp.org/docs/security/tls#certificate-validation"
240            );
241        }
242
243        // Build HTTP client with TLS configuration
244        // IMPORTANT: Must explicitly call use_rustls_tls() because cargo features are additive
245        // and other dependencies may bring in native-tls. Without this, TLS 1.3 minimum fails.
246        // See: https://github.com/seanmonstar/reqwest/issues/1314
247        let mut client_builder = HttpClient::builder()
248            .use_rustls_tls()
249            .timeout(config.timeout);
250
251        // Redirect policy: when carrying a bearer token, only follow same-origin redirects
252        // so the `Authorization: Bearer …` header (preserved by reqwest across redirects)
253        // cannot leak to a third-party host. Without an auth token we keep the default
254        // redirect behaviour (up to 10 follows) for compatibility with bog-standard HTTP.
255        if config.auth_token.is_some() {
256            client_builder =
257                client_builder.redirect(reqwest::redirect::Policy::custom(|attempt| {
258                    if attempt.previous().len() >= 10 {
259                        return attempt.error("too many redirects");
260                    }
261                    let prev_origin = attempt.previous().last().map(reqwest::Url::origin);
262                    if prev_origin.as_ref() == Some(&attempt.url().origin()) {
263                        attempt.follow()
264                    } else {
265                        // Stop the redirect chain; surface a 3xx to the caller so they can
266                        // re-authenticate against the new origin if appropriate.
267                        attempt.stop()
268                    }
269                }));
270        }
271
272        // Set User-Agent header if configured
273        if let Some(ref user_agent) = config.user_agent {
274            client_builder = client_builder.user_agent(user_agent);
275        }
276
277        // Configure TLS version (TLS 1.3 only in v3.0)
278        client_builder = match config.tls.min_version {
279            TlsVersion::Tls13 => client_builder.min_tls_version(reqwest::tls::Version::TLS_1_3),
280        };
281
282        // Configure certificate validation with security gate
283        if !config.tls.validate_certificates {
284            // SECURITY: Require explicit env var opt-in for insecure TLS
285            // This prevents accidental deployment of insecure configurations
286            const INSECURE_TLS_ENV_VAR: &str = "TURBOMCP_ALLOW_INSECURE_TLS";
287
288            if std::env::var(INSECURE_TLS_ENV_VAR).is_err() {
289                error!(
290                    "SECURITY: Certificate validation disabled but {} not set. \
291                     Overriding to validate_certificates=true for safety. \
292                     Set {}=1 to allow insecure TLS.",
293                    INSECURE_TLS_ENV_VAR, INSECURE_TLS_ENV_VAR
294                );
295                // Override: force secure config instead of panicking
296                // Don't apply danger_accept_invalid_certs
297            } else {
298                warn!(
299                    "SECURITY WARNING: TLS certificate validation is DISABLED. \
300                     This configuration is INSECURE and should ONLY be used: \
301                     (1) In development/testing environments, or \
302                     (2) In secure mTLS mesh where validation happens elsewhere. \
303                     NEVER use in production connecting to untrusted servers."
304                );
305
306                client_builder = client_builder.danger_accept_invalid_certs(true);
307            }
308        }
309
310        // Add custom CA certificates if provided
311        if let Some(ca_certs) = &config.tls.custom_ca_certs {
312            let mut loaded = 0usize;
313            let total = ca_certs.len();
314            for cert_bytes in ca_certs {
315                // Try to parse as PEM or DER
316                if let Ok(cert) = reqwest::Certificate::from_pem(cert_bytes) {
317                    client_builder = client_builder.add_root_certificate(cert);
318                    loaded += 1;
319                } else if let Ok(cert) = reqwest::Certificate::from_der(cert_bytes) {
320                    client_builder = client_builder.add_root_certificate(cert);
321                    loaded += 1;
322                } else {
323                    warn!(
324                        "Failed to parse custom CA certificate ({}/{}), skipping",
325                        loaded + 1,
326                        total
327                    );
328                }
329            }
330            if loaded == 0 && total > 0 {
331                error!("All {} custom CA certificates failed to parse", total);
332                // Don't panic - but log at error level. The connection will likely fail with TLS errors.
333            }
334            if loaded > 0 {
335                info!("Loaded {}/{} custom CA certificates", loaded, total);
336            }
337        }
338
339        let http_client = client_builder.build().map_err(|e| {
340            TransportError::ConfigurationError(format!(
341                "Failed to build HTTP client (likely bad TLS configuration): {e}"
342            ))
343        })?;
344
345        Ok(Self {
346            config,
347            http_client,
348            state: Arc::new(RwLock::new(TransportState::Disconnected)),
349            capabilities: TransportCapabilities {
350                max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
351                supports_compression: false,
352                supports_streaming: true,
353                supports_bidirectional: true,
354                supports_multiplexing: false,
355                compression_algorithms: Vec::new(),
356                custom: HashMap::new(),
357            },
358            metrics: Arc::new(RwLock::new(TransportMetrics::default())),
359            _event_emitter: event_emitter,
360            message_endpoint: Arc::new(RwLock::new(None)),
361            session_id: Arc::new(RwLock::new(None)),
362            last_event_id: Arc::new(RwLock::new(None)),
363            sse_receiver: Arc::new(Mutex::new(sse_rx)),
364            sse_sender: sse_tx,
365            response_receiver: Arc::new(Mutex::new(response_rx)),
366            response_sender: response_tx,
367            sse_task_handle: Arc::new(Mutex::new(None)),
368        })
369    }
370
371    /// Get full endpoint URL
372    fn get_endpoint_url(&self) -> String {
373        format!("{}{}", self.config.base_url, self.config.endpoint_path)
374    }
375
376    /// Get message endpoint URL (discovered or default)
377    async fn get_message_endpoint_url(&self) -> String {
378        let discovered = self.message_endpoint.read().await;
379        if let Some(endpoint) = discovered.as_ref() {
380            if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
381                endpoint.clone()
382            } else if endpoint.starts_with('/') {
383                format!("{}{}", self.config.base_url, endpoint)
384            } else {
385                format!("{}/{}", self.config.base_url, endpoint)
386            }
387        } else {
388            self.get_endpoint_url()
389        }
390    }
391
392    /// Build request headers
393    async fn build_headers(&self, accept: &str) -> header::HeaderMap {
394        let mut headers = header::HeaderMap::new();
395
396        // Use safe header value construction - skip invalid headers rather than panic
397        if let Ok(accept_value) = header::HeaderValue::from_str(accept) {
398            headers.insert(header::ACCEPT, accept_value);
399        }
400
401        if let Ok(protocol_value) = header::HeaderValue::from_str(&self.config.protocol_version) {
402            headers.insert("MCP-Protocol-Version", protocol_value);
403        }
404
405        if let Some(session_id) = self.session_id.read().await.as_ref()
406            && let Ok(session_value) = header::HeaderValue::from_str(session_id)
407        {
408            headers.insert("Mcp-Session-Id", session_value);
409        }
410
411        if let Some(last_event_id) = self.last_event_id.read().await.as_ref()
412            && let Ok(event_value) = header::HeaderValue::from_str(last_event_id)
413        {
414            headers.insert("Last-Event-ID", event_value);
415        }
416
417        if let Some(token) = &self.config.auth_token
418            && let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
419        {
420            headers.insert(header::AUTHORIZATION, auth_value);
421        }
422
423        for (key, value) in &self.config.headers {
424            if let (Ok(k), Ok(v)) = (
425                header::HeaderName::from_bytes(key.as_bytes()),
426                header::HeaderValue::from_str(value),
427            ) {
428                headers.insert(k, v);
429            }
430        }
431
432        headers
433    }
434
435    /// Start SSE connection task
436    async fn start_sse_connection(&self) -> TransportResult<()> {
437        if self.session_id.read().await.is_none() {
438            debug!("Deferring SSE connection until server provides a session ID");
439            return Ok(());
440        }
441
442        let mut task_handle = self.sse_task_handle.lock().await;
443        if let Some(handle) = task_handle.as_ref()
444            && !handle.is_finished()
445        {
446            debug!("SSE connection task already running");
447            return Ok(());
448        }
449
450        info!("Starting SSE connection to {}", self.get_endpoint_url());
451
452        let endpoint_url = self.get_endpoint_url();
453        let config = self.config.clone();
454        let http_client = self.http_client.clone();
455        let state = Arc::clone(&self.state);
456        let sse_sender = self.sse_sender.clone();
457        let session_id = Arc::clone(&self.session_id);
458        let last_event_id = Arc::clone(&self.last_event_id);
459        let message_endpoint = Arc::clone(&self.message_endpoint);
460
461        let task = tokio::spawn(async move {
462            Self::sse_connection_task(
463                endpoint_url,
464                config,
465                http_client,
466                state,
467                sse_sender,
468                session_id,
469                last_event_id,
470                message_endpoint,
471            )
472            .await;
473        });
474
475        *task_handle = Some(task);
476
477        Ok(())
478    }
479
480    /// SSE connection task with auto-reconnect
481    #[allow(clippy::too_many_arguments)]
482    async fn sse_connection_task(
483        endpoint_url: String,
484        config: StreamableHttpClientConfig,
485        http_client: HttpClient,
486        state: Arc<RwLock<TransportState>>,
487        sse_sender: mpsc::Sender<TransportMessage>,
488        session_id: Arc<RwLock<Option<String>>>,
489        last_event_id: Arc<RwLock<Option<String>>>,
490        message_endpoint: Arc<RwLock<Option<String>>>,
491    ) {
492        let mut attempt = 0u32;
493
494        loop {
495            // Check if we should retry
496            if let Some(delay) = config.retry_policy.delay(attempt) {
497                if attempt > 0 {
498                    warn!("Reconnecting in {:?} (attempt {})", delay, attempt + 1);
499                    tokio::time::sleep(delay).await;
500                }
501            } else {
502                error!("Max retry attempts reached, giving up");
503                *state.write().await = TransportState::Disconnected;
504                break;
505            }
506
507            // Build request with proper headers
508            let mut headers = header::HeaderMap::new();
509            headers.insert(
510                header::ACCEPT,
511                header::HeaderValue::from_static("text/event-stream"),
512            );
513
514            if let Ok(protocol_value) = header::HeaderValue::from_str(&config.protocol_version) {
515                headers.insert("MCP-Protocol-Version", protocol_value);
516            }
517
518            if let Some(sid) = session_id.read().await.as_ref()
519                && let Ok(session_value) = header::HeaderValue::from_str(sid)
520            {
521                headers.insert("Mcp-Session-Id", session_value);
522            }
523
524            if let Some(last_id) = last_event_id.read().await.as_ref()
525                && let Ok(event_value) = header::HeaderValue::from_str(last_id)
526            {
527                headers.insert("Last-Event-ID", event_value);
528            }
529
530            if let Some(token) = &config.auth_token
531                && let Ok(auth_value) = header::HeaderValue::from_str(&format!("Bearer {}", token))
532            {
533                headers.insert(header::AUTHORIZATION, auth_value);
534            }
535
536            // Connect to SSE endpoint
537            match http_client.get(&endpoint_url).headers(headers).send().await {
538                Ok(response) => {
539                    if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
540                        info!(
541                            "Server returned HTTP 405 for GET {}. Continuing without standalone SSE polling.",
542                            endpoint_url
543                        );
544                        break;
545                    }
546
547                    if !response.status().is_success() {
548                        error!("SSE connection failed: {}", response.status());
549                        attempt += 1;
550                        continue;
551                    }
552
553                    // Extract session ID from response headers
554                    if let Some(sid) = response
555                        .headers()
556                        .get("Mcp-Session-Id")
557                        .and_then(|v| v.to_str().ok())
558                    {
559                        *session_id.write().await = Some(sid.to_string());
560                        info!("Received session ID: {}", sid);
561                    }
562
563                    info!("SSE connection established");
564                    *state.write().await = TransportState::Connected;
565                    attempt = 0; // Reset attempt counter on success
566
567                    // Process SSE stream
568                    let mut stream = response.bytes_stream();
569                    let mut buffer = String::new();
570                    let read_timeout = config.sse_read_timeout;
571                    // Cap a single SSE event's accumulated buffer at the response-size limit so
572                    // a server that streams indefinitely without ever emitting `\n\n` cannot
573                    // OOM the client. `None` keeps the historical "no cap" behaviour.
574                    let buffer_cap = config
575                        .limits
576                        .enforce_on_streams
577                        .then_some(config.limits.max_response_size)
578                        .flatten();
579
580                    'sse_loop: loop {
581                        let chunk_result =
582                            match tokio::time::timeout(read_timeout, stream.next()).await {
583                                Ok(Some(r)) => r,
584                                Ok(None) => break,
585                                Err(_) => {
586                                    warn!(
587                                        "SSE read idle for {:?}; closing stream to reconnect",
588                                        read_timeout
589                                    );
590                                    break;
591                                }
592                            };
593                        match chunk_result {
594                            Ok(chunk) => {
595                                let chunk_str = String::from_utf8_lossy(&chunk);
596                                buffer.push_str(&chunk_str);
597
598                                // Process complete events
599                                while let Some(pos) = buffer.find("\n\n") {
600                                    let event_str = buffer[..pos].to_string();
601                                    buffer = buffer[pos + 2..].to_string();
602
603                                    if let Err(e) = Self::process_sse_event(
604                                        &event_str,
605                                        &sse_sender,
606                                        &last_event_id,
607                                        &message_endpoint,
608                                    )
609                                    .await
610                                    {
611                                        warn!("Failed to process SSE event: {}", e);
612                                    }
613                                }
614
615                                if let Some(cap) = buffer_cap
616                                    && buffer.len() > cap
617                                {
618                                    error!(
619                                        "SSE event buffer exceeded {} bytes without an event \
620                                         boundary; closing stream to avoid OOM",
621                                        cap
622                                    );
623                                    break 'sse_loop;
624                                }
625                            }
626                            Err(e) => {
627                                error!("Error reading SSE stream: {}", e);
628                                break;
629                            }
630                        }
631                    }
632
633                    warn!("SSE stream ended");
634                    *state.write().await = TransportState::Disconnected;
635                }
636                Err(e) => {
637                    error!("Failed to connect: {}", e);
638                    attempt += 1;
639                }
640            }
641        }
642    }
643
644    /// Process an SSE event from the standalone GET stream.
645    async fn process_sse_event(
646        event_str: &str,
647        sse_sender: &mpsc::Sender<TransportMessage>,
648        last_event_id: &Arc<RwLock<Option<String>>>,
649        message_endpoint: &Arc<RwLock<Option<String>>>,
650    ) -> TransportResult<()> {
651        let lines: Vec<&str> = event_str.lines().collect();
652        let mut event_type: Option<String> = None;
653        let mut event_data: Vec<String> = Vec::new();
654        let mut event_id: Option<String> = None;
655
656        for line in lines {
657            if line.is_empty() {
658                continue;
659            }
660
661            if let Some(colon_pos) = line.find(':') {
662                let field = &line[..colon_pos];
663                let value = line[colon_pos + 1..].trim_start();
664
665                match field {
666                    "event" => event_type = Some(value.to_string()),
667                    "data" => event_data.push(value.to_string()),
668                    "id" => event_id = Some(value.to_string()),
669                    _ => {}
670                }
671            }
672        }
673
674        // Save event ID
675        if let Some(id) = event_id {
676            *last_event_id.write().await = Some(id);
677        }
678
679        if event_data.is_empty() {
680            return Ok(());
681        }
682
683        let data_str = event_data.join("\n");
684
685        // Handle different event types
686        match event_type.as_deref() {
687            Some("endpoint") => {
688                // Legacy HTTP+SSE transport compatibility. Streamable HTTP
689                // (MCP 2025-11-25) uses a single endpoint, so connect/send must not
690                // depend on this event.
691                //
692                // The event data may be either:
693                // 1. A JSON object: {"uri":"http://..."}
694                // 2. A plain string: "http://..."
695                let endpoint_uri = if data_str.trim().starts_with('{') {
696                    // Parse JSON object and extract uri field
697                    let endpoint_json: serde_json::Value = serde_json::from_str(&data_str)
698                        .map_err(|e| {
699                            TransportError::SerializationFailed(format!(
700                                "Invalid endpoint JSON: {}",
701                                e
702                            ))
703                        })?;
704                    endpoint_json["uri"]
705                        .as_str()
706                        .ok_or_else(|| {
707                            TransportError::SerializationFailed(
708                                "Endpoint event missing 'uri' field".to_string(),
709                            )
710                        })?
711                        .to_string()
712                } else {
713                    // Plain string format
714                    data_str.clone()
715                };
716
717                info!("Discovered message endpoint: {}", endpoint_uri);
718                *message_endpoint.write().await = Some(endpoint_uri);
719                Ok(())
720            }
721            Some("message") | None => {
722                // Skip empty or whitespace-only events (keep-alive, malformed events)
723                // This is defensive against server sending empty data events
724                if data_str.trim().is_empty() {
725                    debug!("Skipping empty SSE event");
726                    return Ok(());
727                }
728
729                // Parse as JSON-RPC message
730                let json_value: serde_json::Value =
731                    serde_json::from_str(&data_str).map_err(|e| {
732                        TransportError::SerializationFailed(format!("Invalid JSON: {}", e))
733                    })?;
734
735                let message = TransportMessage::new(
736                    MessageId::from("sse-message".to_string()),
737                    Bytes::from(
738                        serde_json::to_vec(&json_value)
739                            .map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
740                    ),
741                );
742
743                sse_sender
744                    .send(message)
745                    .await
746                    .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
747
748                debug!("Received SSE message");
749                Ok(())
750            }
751            Some(other) => {
752                debug!("Ignoring unknown event type: {}", other);
753                Ok(())
754            }
755        }
756    }
757
758    /// Process SSE event from POST response
759    async fn process_post_sse_event(
760        event_str: &str,
761        response_sender: &mpsc::Sender<TransportMessage>,
762        last_event_id: &Arc<RwLock<Option<String>>>,
763    ) -> TransportResult<()> {
764        let lines: Vec<&str> = event_str.lines().collect();
765        let mut event_data: Vec<String> = Vec::new();
766        let mut event_id: Option<String> = None;
767
768        for line in lines {
769            if line.is_empty() {
770                continue;
771            }
772
773            if let Some(colon_pos) = line.find(':') {
774                let field = &line[..colon_pos];
775                let value = line[colon_pos + 1..].trim_start();
776
777                match field {
778                    "data" => event_data.push(value.to_string()),
779                    "id" => event_id = Some(value.to_string()),
780                    "event" => {
781                        // Event type field - we primarily care about "message" events
782                        // but we'll process any event with data
783                    }
784                    _ => {}
785                }
786            }
787        }
788
789        // Save event ID
790        if let Some(id) = event_id {
791            *last_event_id.write().await = Some(id);
792        }
793
794        if event_data.is_empty() {
795            return Ok(());
796        }
797
798        let data_str = event_data.join("\n");
799        if data_str.trim().is_empty() {
800            debug!("Skipping empty POST SSE event");
801            return Ok(());
802        }
803
804        // Parse as JSON-RPC message
805        let json_value: serde_json::Value = serde_json::from_str(&data_str).map_err(|e| {
806            TransportError::SerializationFailed(format!("Invalid JSON in POST SSE: {}", e))
807        })?;
808
809        let message = TransportMessage::new(
810            MessageId::from("post-sse-response".to_string()),
811            Bytes::from(
812                serde_json::to_vec(&json_value)
813                    .map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
814            ),
815        );
816
817        response_sender
818            .send(message.clone())
819            .await
820            .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
821
822        debug!(
823            "Queued message from POST SSE stream: {}",
824            String::from_utf8_lossy(&message.payload)
825        );
826        Ok(())
827    }
828
829    /// Await the next inbound message.
830    ///
831    /// Unlike [`Transport::receive`] — which is non-blocking by contract and
832    /// returns `None` immediately when no message is queued — this inherent
833    /// method awaits on both the response and SSE channels and returns when
834    /// one produces a message. This is the ergonomic choice for client code
835    /// that wants a blocking `recv` without building a select loop around
836    /// `receive().await`.
837    pub async fn recv_async(&self) -> TransportResult<TransportMessage> {
838        let mut response_receiver = self.response_receiver.lock().await;
839        let mut sse_receiver = self.sse_receiver.lock().await;
840        let message = tokio::select! {
841            biased;
842            // Prefer the response queue so synchronous POST replies land before
843            // server-push SSE messages when both are ready simultaneously.
844            msg = response_receiver.recv() => msg.ok_or_else(|| {
845                TransportError::ConnectionLost("Response channel disconnected".to_string())
846            })?,
847            msg = sse_receiver.recv() => msg.ok_or_else(|| {
848                TransportError::ConnectionLost("SSE channel disconnected".to_string())
849            })?,
850        };
851        let mut metrics = self.metrics.write().await;
852        metrics.messages_received += 1;
853        metrics.bytes_received += message.payload.len() as u64;
854        Ok(message)
855    }
856}
857
858impl Transport for StreamableHttpClientTransport {
859    fn send(
860        &self,
861        message: TransportMessage,
862    ) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
863        Box::pin(async move {
864            debug!("Sending message via HTTP POST");
865
866            // Validate request size against configured limits (v2.2.0+)
867            validate_request_size(message.payload.len(), &self.config.limits)?;
868
869            // Get message endpoint (discovered or default)
870            let url = self.get_message_endpoint_url().await;
871
872            // Build headers with proper Accept negotiation
873            let headers = self
874                .build_headers("application/json, text/event-stream")
875                .await;
876
877            // Send POST request
878            let response = self
879                .http_client
880                .post(&url)
881                .headers(headers)
882                .header(header::CONTENT_TYPE, "application/json")
883                .body(message.payload.to_vec())
884                .send()
885                .await
886                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
887
888            if !response.status().is_success() {
889                return Err(TransportError::ConnectionFailed(format!(
890                    "POST failed: {}",
891                    response.status()
892                )));
893            }
894
895            // Update session ID if provided
896            if let Some(session_id) = response
897                .headers()
898                .get("Mcp-Session-Id")
899                .and_then(|v| v.to_str().ok())
900            {
901                *self.session_id.write().await = Some(session_id.to_string());
902                self.start_sse_connection().await?;
903            }
904
905            // MCP 2025-11-25: HTTP 202 Accepted means notification/response was accepted (no body)
906            if response.status() == reqwest::StatusCode::ACCEPTED {
907                debug!("Received HTTP 202 Accepted (no response body expected)");
908                // Update metrics
909                {
910                    let mut metrics = self.metrics.write().await;
911                    metrics.messages_sent += 1;
912                    metrics.bytes_sent += message.payload.len() as u64;
913                }
914                return Ok(());
915            }
916
917            // Check response content type and handle accordingly
918            let content_type = response
919                .headers()
920                .get(header::CONTENT_TYPE)
921                .and_then(|v| v.to_str().ok())
922                .unwrap_or("");
923
924            if content_type.contains("application/json") {
925                // MCP 2025-11-25: Server returned immediate JSON response
926                debug!("Received JSON response from POST");
927
928                let response_bytes = response
929                    .bytes()
930                    .await
931                    .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
932
933                // Validate response size against configured limits (v2.2.0+)
934                validate_response_size(response_bytes.len(), &self.config.limits)?;
935
936                let response_message = TransportMessage::new(
937                    MessageId::from("http-response".to_string()),
938                    response_bytes,
939                );
940
941                // Queue the response for the next receive() call
942                self.response_sender
943                    .send(response_message)
944                    .await
945                    .map_err(|e| TransportError::ConnectionLost(e.to_string()))?;
946
947                debug!("JSON response queued successfully");
948            } else if content_type.contains("text/event-stream") {
949                // MCP 2025-11-25: Server returned SSE stream response from POST
950                // Process the stream synchronously to ensure responses are available
951                debug!("Received SSE stream response from POST, processing events");
952
953                let response_sender = self.response_sender.clone();
954                let last_event_id = Arc::clone(&self.last_event_id);
955
956                // Process SSE stream inline (not spawned) to ensure proper ordering
957                let mut stream = response.bytes_stream();
958                let mut buffer = String::new();
959                // Same buffer cap as the GET SSE loop — a buggy or malicious server that
960                // streams without ever closing an event must not OOM the client.
961                let buffer_cap = self
962                    .config
963                    .limits
964                    .enforce_on_streams
965                    .then_some(self.config.limits.max_response_size)
966                    .flatten();
967
968                'post_sse_loop: while let Some(chunk_result) = stream.next().await {
969                    match chunk_result {
970                        Ok(chunk) => {
971                            let chunk_str = String::from_utf8_lossy(&chunk);
972                            buffer.push_str(&chunk_str);
973
974                            // Process complete events
975                            while let Some(pos) = buffer.find("\n\n") {
976                                let event_str = buffer[..pos].to_string();
977                                buffer = buffer[pos + 2..].to_string();
978
979                                if let Err(e) = Self::process_post_sse_event(
980                                    &event_str,
981                                    &response_sender,
982                                    &last_event_id,
983                                )
984                                .await
985                                {
986                                    warn!("Failed to process POST SSE event: {}", e);
987                                }
988                            }
989
990                            if let Some(cap) = buffer_cap
991                                && buffer.len() > cap
992                            {
993                                error!(
994                                    "POST SSE event buffer exceeded {} bytes without an event \
995                                     boundary; closing stream to avoid OOM",
996                                    cap
997                                );
998                                break 'post_sse_loop;
999                            }
1000                        }
1001                        Err(e) => {
1002                            warn!("Error reading POST SSE stream: {}", e);
1003                            break;
1004                        }
1005                    }
1006                }
1007                debug!("POST SSE stream processing completed");
1008            }
1009
1010            // Update metrics
1011            {
1012                let mut metrics = self.metrics.write().await;
1013                metrics.messages_sent += 1;
1014                metrics.bytes_sent += message.payload.len() as u64;
1015            }
1016
1017            debug!("Message sent successfully");
1018            Ok(())
1019        })
1020    }
1021
1022    /// Non-blocking receive.
1023    ///
1024    /// Returns `Ok(None)` immediately when no message is queued. This is the
1025    /// `Transport` trait contract (polled from a select loop); it does **not**
1026    /// wait for the next message. Use [`Self::recv_async`] when you want to
1027    /// await the next message.
1028    fn receive(
1029        &self,
1030    ) -> Pin<Box<dyn Future<Output = TransportResult<Option<TransportMessage>>> + Send + '_>> {
1031        Box::pin(async move {
1032            // CRITICAL: Check response queue FIRST (for immediate JSON responses from POST)
1033            // This ensures request-response pattern works correctly per MCP 2025-11-25
1034            {
1035                let mut response_receiver = self.response_receiver.lock().await;
1036                match response_receiver.try_recv() {
1037                    Ok(message) => {
1038                        debug!("Received queued JSON response");
1039                        // Update metrics
1040                        {
1041                            let mut metrics = self.metrics.write().await;
1042                            metrics.messages_received += 1;
1043                            metrics.bytes_received += message.payload.len() as u64;
1044                        }
1045                        return Ok(Some(message));
1046                    }
1047                    Err(mpsc::error::TryRecvError::Empty) => {
1048                        // No queued responses, continue to check SSE channel
1049                    }
1050                    Err(mpsc::error::TryRecvError::Disconnected) => {
1051                        return Err(TransportError::ConnectionLost(
1052                            "Response channel disconnected".to_string(),
1053                        ));
1054                    }
1055                }
1056            }
1057
1058            // Check SSE channel for server-initiated messages
1059            let mut sse_receiver = self.sse_receiver.lock().await;
1060            match sse_receiver.try_recv() {
1061                Ok(message) => {
1062                    debug!("Received SSE message");
1063                    // Update metrics
1064                    {
1065                        let mut metrics = self.metrics.write().await;
1066                        metrics.messages_received += 1;
1067                        metrics.bytes_received += message.payload.len() as u64;
1068                    }
1069                    Ok(Some(message))
1070                }
1071                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
1072                Err(mpsc::error::TryRecvError::Disconnected) => Err(
1073                    TransportError::ConnectionLost("SSE channel disconnected".to_string()),
1074                ),
1075            }
1076        })
1077    }
1078
1079    fn capabilities(&self) -> &TransportCapabilities {
1080        &self.capabilities
1081    }
1082
1083    fn state(&self) -> Pin<Box<dyn Future<Output = TransportState> + Send + '_>> {
1084        Box::pin(async move { self.state.read().await.clone() })
1085    }
1086
1087    fn transport_type(&self) -> TransportType {
1088        TransportType::Http
1089    }
1090
1091    fn metrics(&self) -> Pin<Box<dyn Future<Output = TransportMetrics> + Send + '_>> {
1092        Box::pin(async move { self.metrics.read().await.clone() })
1093    }
1094
1095    fn connect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
1096        Box::pin(async move {
1097            info!("Connecting to {}", self.get_endpoint_url());
1098
1099            *self.state.write().await = TransportState::Connecting;
1100
1101            // Start SSE connection task
1102            self.start_sse_connection().await?;
1103
1104            *self.state.write().await = TransportState::Connected;
1105
1106            info!("Connected successfully");
1107            Ok(())
1108        })
1109    }
1110
1111    fn disconnect(&self) -> Pin<Box<dyn Future<Output = TransportResult<()>> + Send + '_>> {
1112        Box::pin(async move {
1113            info!("Disconnecting");
1114
1115            *self.state.write().await = TransportState::Disconnecting;
1116
1117            // Cancel SSE task
1118            if let Some(handle) = self.sse_task_handle.lock().await.take() {
1119                handle.abort();
1120            }
1121
1122            // Send DELETE to terminate session
1123            if let Some(session_id) = self.session_id.read().await.as_ref() {
1124                let url = self.get_endpoint_url();
1125                let mut headers = header::HeaderMap::new();
1126                if let Ok(session_value) = header::HeaderValue::from_str(session_id) {
1127                    headers.insert("Mcp-Session-Id", session_value);
1128                }
1129
1130                let _ = self.http_client.delete(&url).headers(headers).send().await;
1131            }
1132
1133            *self.state.write().await = TransportState::Disconnected;
1134
1135            info!("Disconnected");
1136            Ok(())
1137        })
1138    }
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143    use super::*;
1144
1145    #[test]
1146    fn test_retry_policy_fixed() {
1147        let policy = RetryPolicy::Fixed {
1148            interval: Duration::from_secs(5),
1149            max_attempts: Some(3),
1150        };
1151
1152        assert_eq!(policy.delay(0), Some(Duration::from_secs(5)));
1153        assert_eq!(policy.delay(1), Some(Duration::from_secs(5)));
1154        assert_eq!(policy.delay(2), Some(Duration::from_secs(5)));
1155        assert_eq!(policy.delay(3), None);
1156    }
1157
1158    #[test]
1159    fn test_retry_policy_exponential() {
1160        let policy = RetryPolicy::Exponential {
1161            base: Duration::from_secs(1),
1162            max_delay: Duration::from_secs(60),
1163            max_attempts: None,
1164        };
1165
1166        // With jitter, verify delays are within expected bounds
1167        // Expected base delays: 1s, 2s, 4s, 8s, etc. with ±25% jitter
1168        let delay0 = policy.delay(0).unwrap();
1169        assert!(delay0 >= Duration::from_millis(750) && delay0 <= Duration::from_millis(1250));
1170
1171        let delay1 = policy.delay(1).unwrap();
1172        assert!(delay1 >= Duration::from_millis(1500) && delay1 <= Duration::from_millis(2500));
1173
1174        let delay2 = policy.delay(2).unwrap();
1175        assert!(delay2 >= Duration::from_millis(3000) && delay2 <= Duration::from_millis(5000));
1176
1177        let delay3 = policy.delay(3).unwrap();
1178        assert!(delay3 >= Duration::from_millis(6000) && delay3 <= Duration::from_millis(10000));
1179
1180        let delay10 = policy.delay(10).unwrap();
1181        // Should be capped at max_delay (60s) with jitter
1182        assert!(delay10 >= Duration::from_millis(45000) && delay10 <= Duration::from_millis(75000));
1183    }
1184
1185    #[tokio::test]
1186    async fn test_client_creation() {
1187        let config = StreamableHttpClientConfig::default();
1188        let client = StreamableHttpClientTransport::new(config).expect("default config builds");
1189
1190        assert_eq!(client.transport_type(), TransportType::Http);
1191        assert!(client.capabilities().supports_streaming);
1192        assert!(client.capabilities().supports_bidirectional);
1193    }
1194
1195    #[tokio::test]
1196    async fn test_endpoint_event_json_parsing() {
1197        // Legacy HTTP+SSE compatibility: verify JSON endpoint events still parse.
1198        // Bug: Client was storing entire JSON string {"uri":"..."} instead of extracting URI.
1199
1200        use std::sync::Arc;
1201        use tokio::sync::RwLock;
1202
1203        let message_endpoint = Arc::new(RwLock::new(None::<String>));
1204
1205        // Simulate a legacy endpoint event with JSON format.
1206        let event_data = [r#"{"uri":"http://127.0.0.1:8080/mcp"}"#.to_string()];
1207        let data_str = event_data.join("\n");
1208
1209        // Parse JSON and extract URI (mimics the fix)
1210        let endpoint_uri = if data_str.trim().starts_with('{') {
1211            let endpoint_json: serde_json::Value =
1212                serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
1213            endpoint_json["uri"]
1214                .as_str()
1215                .expect("Missing uri field")
1216                .to_string()
1217        } else {
1218            data_str.clone()
1219        };
1220
1221        *message_endpoint.write().await = Some(endpoint_uri.clone());
1222
1223        // Verify URI was extracted correctly
1224        let stored = message_endpoint.read().await;
1225        assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
1226        assert!(stored.as_ref().unwrap().starts_with("http://"));
1227
1228        // Verify it's a valid URL
1229        assert!(stored.as_ref().unwrap().parse::<url::Url>().is_ok());
1230    }
1231
1232    #[tokio::test]
1233    async fn test_endpoint_event_plain_string_parsing() {
1234        // Legacy HTTP+SSE compatibility with plain string endpoint events.
1235
1236        use std::sync::Arc;
1237        use tokio::sync::RwLock;
1238
1239        let message_endpoint = Arc::new(RwLock::new(None::<String>));
1240
1241        // Simulate endpoint event with plain string format
1242        let event_data = ["http://127.0.0.1:8080/mcp".to_string()];
1243        let data_str = event_data.join("\n");
1244
1245        // Parse (should detect it's not JSON and use as-is)
1246        let endpoint_uri = if data_str.trim().starts_with('{') {
1247            let endpoint_json: serde_json::Value =
1248                serde_json::from_str(&data_str).expect("Failed to parse endpoint JSON");
1249            endpoint_json["uri"]
1250                .as_str()
1251                .expect("Missing uri field")
1252                .to_string()
1253        } else {
1254            data_str.clone()
1255        };
1256
1257        *message_endpoint.write().await = Some(endpoint_uri.clone());
1258
1259        // Verify plain string was stored correctly
1260        let stored = message_endpoint.read().await;
1261        assert_eq!(stored.as_ref().unwrap(), "http://127.0.0.1:8080/mcp");
1262        assert!(stored.as_ref().unwrap().starts_with("http://"));
1263    }
1264
1265    #[tokio::test]
1266    async fn test_post_sse_whitespace_data_event_is_ignored() {
1267        let (tx, mut rx) = mpsc::channel(1);
1268        let last_event_id = Arc::new(RwLock::new(None));
1269
1270        StreamableHttpClientTransport::process_post_sse_event(
1271            "id: primer-1\nevent: message\ndata:    \n",
1272            &tx,
1273            &last_event_id,
1274        )
1275        .await
1276        .expect("whitespace POST SSE event should be ignored");
1277
1278        assert_eq!(last_event_id.read().await.as_deref(), Some("primer-1"));
1279        assert!(rx.try_recv().is_err());
1280    }
1281
1282    #[tokio::test]
1283    async fn test_post_sse_json_event_is_queued() {
1284        let (tx, mut rx) = mpsc::channel(1);
1285        let last_event_id = Arc::new(RwLock::new(None));
1286
1287        StreamableHttpClientTransport::process_post_sse_event(
1288            "id: msg-1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\n",
1289            &tx,
1290            &last_event_id,
1291        )
1292        .await
1293        .expect("valid POST SSE event should be queued");
1294
1295        assert_eq!(last_event_id.read().await.as_deref(), Some("msg-1"));
1296        let message = rx.try_recv().expect("queued message");
1297        let value: serde_json::Value =
1298            serde_json::from_slice(&message.payload).expect("valid queued JSON");
1299        assert_eq!(value["jsonrpc"], "2.0");
1300    }
1301}