Skip to main content

vellaveto_http_proxy/proxy/
smart_fallback.rs

1// Copyright 2026 Paolo Vella
2// SPDX-License-Identifier: BUSL-1.1
3//
4// Use of this software is governed by the Business Source License
5// included in the LICENSE-BSL-1.1 file at the root of this repository.
6//
7// Change Date: Three years from the date of publication of this version.
8// Change License: MPL-2.0
9
10//! Smart cross-transport fallback chain orchestrator (Phase 29).
11//!
12//! Tries each transport in priority order, recording results in a
13//! `FallbackNegotiationHistory` for audit purposes. Uses `TransportHealthTracker`
14//! to skip transports whose circuits are open.
15
16use std::path::{Path, PathBuf};
17use std::time::{Duration, Instant};
18use vellaveto_types::command::resolve_executable;
19use vellaveto_types::{FallbackNegotiationHistory, TransportAttempt, TransportProtocol};
20
21use super::transport_health::TransportHealthTracker;
22use super::{FORWARDED_HEADERS, MAX_RESPONSE_BODY_BYTES};
23
24/// Maximum stderr capture from stdio subprocess (4 KB). FIND-R41-010.
25const MAX_STDERR_BYTES: usize = 4096;
26
27/// A single transport target to try during fallback.
28#[derive(Debug, Clone)]
29pub struct TransportTarget {
30    /// The transport protocol to use.
31    pub protocol: TransportProtocol,
32    /// The endpoint URL for this transport.
33    pub url: String,
34    /// Upstream identifier for circuit breaker tracking.
35    pub upstream_id: String,
36}
37
38/// Result of a successful smart fallback execution.
39#[derive(Debug)]
40pub struct SmartFallbackResult {
41    /// Response bytes from the upstream.
42    pub response: bytes::Bytes,
43    /// The transport that succeeded.
44    pub transport_used: TransportProtocol,
45    /// HTTP status code.
46    pub status: u16,
47    /// Full negotiation history for audit.
48    pub history: FallbackNegotiationHistory,
49}
50
51/// Errors from smart fallback execution.
52#[derive(Debug)]
53pub enum SmartFallbackError {
54    /// All transports failed after trying each one.
55    AllTransportsFailed { history: FallbackNegotiationHistory },
56    /// Total timeout budget exhausted.
57    TotalTimeoutExceeded { history: FallbackNegotiationHistory },
58    /// No targets provided.
59    NoTargets,
60}
61
62impl std::fmt::Display for SmartFallbackError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Self::AllTransportsFailed { history } => {
66                write!(f, "all {} transport(s) failed", history.attempts.len())
67            }
68            Self::TotalTimeoutExceeded { history } => {
69                write!(
70                    f,
71                    "total timeout exceeded after {} attempt(s)",
72                    history.attempts.len()
73                )
74            }
75            Self::NoTargets => write!(f, "no transport targets provided"),
76        }
77    }
78}
79
80impl std::error::Error for SmartFallbackError {}
81
82/// Smart fallback chain orchestrator.
83///
84/// Tries each `TransportTarget` in order, checking circuit breaker health
85/// before each attempt. Records all attempts in a `FallbackNegotiationHistory`.
86pub struct SmartFallbackChain<'a> {
87    client: &'a reqwest::Client,
88    health_tracker: &'a TransportHealthTracker,
89    per_attempt_timeout: Duration,
90    total_timeout: Duration,
91    stdio_enabled: bool,
92    stdio_command: Option<String>,
93    stdio_command_resolved: Option<PathBuf>,
94}
95
96impl<'a> SmartFallbackChain<'a> {
97    /// Create a new fallback chain.
98    pub fn new(
99        client: &'a reqwest::Client,
100        health_tracker: &'a TransportHealthTracker,
101        per_attempt_timeout: Duration,
102        total_timeout: Duration,
103    ) -> Self {
104        Self {
105            client,
106            health_tracker,
107            per_attempt_timeout,
108            total_timeout,
109            stdio_enabled: false,
110            stdio_command: None,
111            stdio_command_resolved: None,
112        }
113    }
114
115    /// Enable stdio fallback with the given command.
116    pub fn with_stdio(&self, command: String) -> Result<Self, String> {
117        let resolved = resolve_executable(&command, std::env::var_os("PATH").as_deref())?;
118        Ok(Self {
119            client: self.client,
120            health_tracker: self.health_tracker,
121            per_attempt_timeout: self.per_attempt_timeout,
122            total_timeout: self.total_timeout,
123            stdio_enabled: true,
124            stdio_command: Some(command),
125            stdio_command_resolved: Some(resolved),
126        })
127    }
128
129    /// Execute the fallback chain, trying each target in order.
130    pub async fn execute(
131        &self,
132        targets: &[TransportTarget],
133        body: bytes::Bytes,
134        headers: &reqwest::header::HeaderMap,
135    ) -> Result<SmartFallbackResult, SmartFallbackError> {
136        if targets.is_empty() {
137            return Err(SmartFallbackError::NoTargets);
138        }
139
140        let chain_start = Instant::now();
141        let mut attempts = Vec::new();
142
143        for target in targets {
144            // Check total timeout budget.
145            if chain_start.elapsed() >= self.total_timeout {
146                let history = FallbackNegotiationHistory {
147                    attempts,
148                    successful_transport: None,
149                    total_duration_ms: chain_start.elapsed().as_millis() as u64,
150                };
151                return Err(SmartFallbackError::TotalTimeoutExceeded { history });
152            }
153
154            // Skip stdio if not enabled.
155            if target.protocol == TransportProtocol::Stdio && !self.stdio_enabled {
156                continue;
157            }
158
159            // Check circuit breaker.
160            if let Err(reason) = self
161                .health_tracker
162                .can_use(&target.upstream_id, target.protocol)
163            {
164                attempts.push(TransportAttempt {
165                    protocol: target.protocol,
166                    endpoint_url: target.url.clone(),
167                    succeeded: false,
168                    duration_ms: 0,
169                    error: Some(reason),
170                });
171                continue;
172            }
173
174            // Calculate remaining timeout budget.
175            let remaining = self
176                .total_timeout
177                .checked_sub(chain_start.elapsed())
178                .unwrap_or(Duration::ZERO);
179            let attempt_timeout = self.per_attempt_timeout.min(remaining);
180
181            if attempt_timeout.is_zero() {
182                let history = FallbackNegotiationHistory {
183                    attempts,
184                    successful_transport: None,
185                    total_duration_ms: chain_start.elapsed().as_millis() as u64,
186                };
187                return Err(SmartFallbackError::TotalTimeoutExceeded { history });
188            }
189
190            let attempt_start = Instant::now();
191
192            let result = match target.protocol {
193                TransportProtocol::Http => {
194                    self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
195                        .await
196                }
197                TransportProtocol::WebSocket => {
198                    self.dispatch_websocket(&target.url, body.clone(), attempt_timeout)
199                        .await
200                }
201                TransportProtocol::Grpc => {
202                    // gRPC dispatch via HTTP bridge endpoint.
203                    self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
204                        .await
205                }
206                TransportProtocol::Stdio => {
207                    if let (Some(cmd), Some(resolved)) = (
208                        self.stdio_command.as_ref(),
209                        self.stdio_command_resolved.as_ref(),
210                    ) {
211                        self.dispatch_stdio(cmd, resolved, body.clone(), attempt_timeout)
212                            .await
213                    } else {
214                        Err("stdio command not configured".to_string())
215                    }
216                }
217            };
218
219            let duration_ms = attempt_start.elapsed().as_millis() as u64;
220
221            match result {
222                Ok((response_bytes, status)) => {
223                    // SECURITY (FIND-R43-012): Treat 5xx responses as failures so
224                    // the circuit breaker can open and trigger fallback to the next
225                    // transport. A backend returning 500s must not keep the circuit
226                    // closed.
227                    if status >= 500 {
228                        self.health_tracker
229                            .record_failure(&target.upstream_id, target.protocol);
230
231                        metrics::counter!(
232                            "vellaveto_transport_fallback_total",
233                            "transport" => format!("{:?}", target.protocol),
234                            "upstream_id" => target.upstream_id.clone(),
235                            "result" => "server_error",
236                        )
237                        .increment(1);
238
239                        attempts.push(TransportAttempt {
240                            protocol: target.protocol,
241                            endpoint_url: target.url.clone(),
242                            succeeded: false,
243                            duration_ms,
244                            error: Some(format!("server error: HTTP {}", status)),
245                        });
246
247                        continue;
248                    }
249
250                    self.health_tracker
251                        .record_success(&target.upstream_id, target.protocol);
252
253                    metrics::counter!(
254                        "vellaveto_transport_fallback_total",
255                        "transport" => format!("{:?}", target.protocol),
256                        "upstream_id" => target.upstream_id.clone(),
257                        "result" => "success",
258                    )
259                    .increment(1);
260
261                    attempts.push(TransportAttempt {
262                        protocol: target.protocol,
263                        endpoint_url: target.url.clone(),
264                        succeeded: true,
265                        duration_ms,
266                        error: None,
267                    });
268
269                    let history = FallbackNegotiationHistory {
270                        attempts,
271                        successful_transport: Some(target.protocol),
272                        total_duration_ms: chain_start.elapsed().as_millis() as u64,
273                    };
274
275                    return Ok(SmartFallbackResult {
276                        response: response_bytes,
277                        transport_used: target.protocol,
278                        status,
279                        history,
280                    });
281                }
282                Err(error) => {
283                    self.health_tracker
284                        .record_failure(&target.upstream_id, target.protocol);
285
286                    metrics::counter!(
287                        "vellaveto_transport_fallback_total",
288                        "transport" => format!("{:?}", target.protocol),
289                        "upstream_id" => target.upstream_id.clone(),
290                        "result" => "failure",
291                    )
292                    .increment(1);
293
294                    attempts.push(TransportAttempt {
295                        protocol: target.protocol,
296                        endpoint_url: target.url.clone(),
297                        succeeded: false,
298                        duration_ms,
299                        error: Some(error),
300                    });
301                }
302            }
303        }
304
305        let history = FallbackNegotiationHistory {
306            attempts,
307            successful_transport: None,
308            total_duration_ms: chain_start.elapsed().as_millis() as u64,
309        };
310        Err(SmartFallbackError::AllTransportsFailed { history })
311    }
312
313    /// Dispatch via HTTP POST.
314    ///
315    /// SECURITY (FIND-R41-015): Uses chunk-based reading to prevent OOM from
316    /// chunked-encoded responses that omit Content-Length. Each chunk is checked
317    /// against MAX_RESPONSE_BODY_BYTES before accumulating.
318    async fn dispatch_http(
319        &self,
320        url: &str,
321        body: bytes::Bytes,
322        headers: &reqwest::header::HeaderMap,
323        timeout: Duration,
324    ) -> Result<(bytes::Bytes, u16), String> {
325        let mut request = self.client.post(url).timeout(timeout);
326
327        // SECURITY (FIND-R41-001): Only forward allowlisted headers to
328        // prevent leaking Authorization, Cookie, etc. to upstream backends.
329        for (key, value) in headers {
330            let key_lower = key.as_str().to_lowercase();
331            if FORWARDED_HEADERS
332                .iter()
333                .any(|&allowed| allowed == key_lower)
334            {
335                request = request.header(key.clone(), value.clone());
336            }
337        }
338
339        let mut resp = request
340            .body(body)
341            .send()
342            .await
343            .map_err(|e| format!("HTTP request error: {}", e))?;
344
345        let status = resp.status().as_u16();
346
347        // SECURITY (FIND-R41-004): Fast-reject if Content-Length exceeds limit.
348        if let Some(len) = resp.content_length() {
349            if len as usize > MAX_RESPONSE_BODY_BYTES {
350                return Err(format!(
351                    "response body too large: {} bytes (max {})",
352                    len, MAX_RESPONSE_BODY_BYTES
353                ));
354            }
355        }
356
357        // SECURITY (FIND-R41-015): Read body in chunks with bounded accumulation.
358        // Prevents OOM from chunked-encoded responses that omit Content-Length.
359        let capacity = std::cmp::min(
360            resp.content_length().unwrap_or(8192) as usize,
361            MAX_RESPONSE_BODY_BYTES,
362        );
363        let mut response_body = Vec::with_capacity(capacity);
364        while let Some(chunk) = resp
365            .chunk()
366            .await
367            .map_err(|e| format!("HTTP response body error: {}", e))?
368        {
369            if response_body.len().saturating_add(chunk.len()) > MAX_RESPONSE_BODY_BYTES {
370                return Err(format!(
371                    "response body too large: >{} bytes (max {})",
372                    MAX_RESPONSE_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
373                ));
374            }
375            response_body.extend_from_slice(&chunk);
376        }
377
378        Ok((bytes::Bytes::from(response_body), status))
379    }
380
381    /// Dispatch via WebSocket (one-shot: connect, send, receive, close).
382    ///
383    /// SECURITY (FIND-R41-011): Configures max_message_size to prevent OOM
384    /// from unbounded upstream WebSocket frames.
385    async fn dispatch_websocket(
386        &self,
387        url: &str,
388        body: bytes::Bytes,
389        timeout: Duration,
390    ) -> Result<(bytes::Bytes, u16), String> {
391        use tokio_tungstenite::tungstenite::Message;
392
393        let result = tokio::time::timeout(timeout, async {
394            // SECURITY (FIND-R41-011): Configure max message/frame size to prevent
395            // a malicious upstream from sending unbounded WebSocket frames (OOM).
396            let mut ws_config =
397                tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default();
398            ws_config.max_message_size = Some(MAX_RESPONSE_BODY_BYTES);
399            ws_config.max_frame_size = Some(MAX_RESPONSE_BODY_BYTES);
400            let (mut ws, _) =
401                tokio_tungstenite::connect_async_with_config(url, Some(ws_config), false)
402                    .await
403                    .map_err(|e| format!("WebSocket connect error: {}", e))?;
404
405            use futures_util::SinkExt;
406            // SECURITY (FIND-R43-026): Reject invalid UTF-8 instead of silently
407            // replacing bytes with U+FFFD, which would mutate the request body.
408            let body_text = String::from_utf8(body.to_vec())
409                .map_err(|_| "WebSocket body contains invalid UTF-8".to_string())?;
410            ws.send(Message::Text(body_text.into()))
411                .await
412                .map_err(|e| format!("WebSocket send error: {}", e))?;
413
414            use futures_util::StreamExt;
415            let response = ws
416                .next()
417                .await
418                .ok_or_else(|| "WebSocket closed without response".to_string())?
419                .map_err(|e| format!("WebSocket receive error: {}", e))?;
420
421            ws.close(None)
422                .await
423                .map_err(|e| format!("WebSocket close error: {}", e))?;
424
425            let response_bytes = match response {
426                Message::Text(t) => {
427                    let bytes = t.as_bytes();
428                    if bytes.len() > MAX_RESPONSE_BODY_BYTES {
429                        return Err(format!(
430                            "WebSocket response too large: {} bytes (max {})",
431                            bytes.len(),
432                            MAX_RESPONSE_BODY_BYTES
433                        ));
434                    }
435                    bytes::Bytes::from(Vec::from(bytes))
436                }
437                Message::Binary(b) => {
438                    if b.len() > MAX_RESPONSE_BODY_BYTES {
439                        return Err(format!(
440                            "WebSocket response too large: {} bytes (max {})",
441                            b.len(),
442                            MAX_RESPONSE_BODY_BYTES
443                        ));
444                    }
445                    bytes::Bytes::from(Vec::from(b.as_ref()))
446                }
447                other => {
448                    return Err(format!("unexpected WebSocket message type: {:?}", other));
449                }
450            };
451
452            Ok::<_, String>((response_bytes, 200u16))
453        })
454        .await;
455
456        match result {
457            Ok(inner) => inner,
458            Err(_) => Err("WebSocket timeout".to_string()),
459        }
460    }
461
462    /// Dispatch via stdio subprocess.
463    ///
464    /// SECURITY (FIND-R41-002): Uses direct Command::new(command) instead of
465    /// `sh -c` to prevent shell injection. The command path is validated at
466    /// config time to be an absolute path with no shell metacharacters.
467    ///
468    /// SECURITY (FIND-R41-006): Explicitly kills child on timeout to prevent
469    /// zombie process accumulation.
470    ///
471    /// SECURITY (FIND-R41-010): Captures stderr for diagnostics instead of
472    /// discarding it.
473    async fn dispatch_stdio(
474        &self,
475        _command: &str,
476        resolved_command: &Path,
477        body: bytes::Bytes,
478        timeout: Duration,
479    ) -> Result<(bytes::Bytes, u16), String> {
480        use tokio::io::{AsyncReadExt, AsyncWriteExt};
481        use tokio::process::Command;
482
483        // SECURITY (FIND-R41-002): Execute command directly, not via sh -c.
484        // SECURITY (FIND-R43-002): Clear inherited environment to prevent leaking
485        // secrets (API keys, tokens, credentials) to the subprocess.
486        // SECURITY (FIND-R43-008): kill_on_drop ensures the child is killed if
487        // the async task is cancelled, preventing orphaned subprocesses.
488        let mut child = Command::new(resolved_command)
489            .env_clear()
490            .env("PATH", "/usr/local/bin:/usr/bin:/bin")
491            .env("HOME", "/tmp")
492            .env("LANG", "C.UTF-8")
493            .kill_on_drop(true)
494            .stdin(std::process::Stdio::piped())
495            .stdout(std::process::Stdio::piped())
496            .stderr(std::process::Stdio::piped())
497            .spawn()
498            .map_err(|e| format!("stdio spawn error: {}", e))?;
499
500        // Write to stdin and drop it to signal EOF.
501        if let Some(mut stdin) = child.stdin.take() {
502            stdin
503                .write_all(&body)
504                .await
505                .map_err(|e| format!("stdio write error: {}", e))?;
506        }
507
508        // Take stdout/stderr handles before waiting, so child is not consumed.
509        let stdout_handle = child
510            .stdout
511            .take()
512            .ok_or_else(|| "stdio stdout not captured".to_string())?;
513        let mut stderr_handle = child
514            .stderr
515            .take()
516            .ok_or_else(|| "stdio stderr not captured".to_string())?;
517
518        // SECURITY (FIND-R41-006): Use select! so we can kill the child on timeout.
519        tokio::select! {
520            // SECURITY (FIND-R43-001): Read stdout concurrently with waiting for
521            // process exit to prevent deadlock. If the subprocess writes >64KB to
522            // stdout, the pipe buffer fills and blocks; reading stdout only after
523            // wait() would deadlock. Using tokio::join! ensures both proceed.
524            result = async {
525                let stdout_future = async {
526                    let mut stdout_buf = Vec::new();
527                    let read_result = stdout_handle
528                        .take((MAX_RESPONSE_BODY_BYTES as u64) + 1)
529                        .read_to_end(&mut stdout_buf)
530                        .await
531                        .map_err(|e| format!("stdio stdout read error: {}", e));
532                    read_result.map(|_| stdout_buf)
533                };
534
535                let wait_future = child.wait();
536
537                let (stdout_result, status_result) = tokio::join!(stdout_future, wait_future);
538
539                let stdout_buf = stdout_result?;
540                let status = status_result
541                    .map_err(|e| format!("stdio wait error: {}", e))?;
542
543                if stdout_buf.len() > MAX_RESPONSE_BODY_BYTES {
544                    return Err(format!(
545                        "stdio stdout too large: >{} bytes (max {})",
546                        MAX_RESPONSE_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
547                    ));
548                }
549
550                if !status.success() {
551                    // SECURITY (FIND-R41-010): Include truncated stderr in error.
552                    let mut stderr_buf = vec![0u8; MAX_STDERR_BYTES];
553                    let n = stderr_handle
554                        .read(&mut stderr_buf)
555                        .await
556                        .unwrap_or(0);
557                    let stderr_snippet = String::from_utf8_lossy(&stderr_buf[..n]);
558                    return Err(format!(
559                        "stdio process exited with {:?}: {}",
560                        status,
561                        stderr_snippet.trim()
562                    ));
563                }
564
565                Ok((bytes::Bytes::from(stdout_buf), 200u16))
566            } => {
567                result
568            }
569            _ = tokio::time::sleep(timeout) => {
570                // SECURITY (FIND-R41-006): Kill child on timeout to prevent zombies.
571                let _ = child.kill().await;
572                Err("stdio timeout".to_string())
573            }
574        }
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581
582    fn make_tracker() -> TransportHealthTracker {
583        TransportHealthTracker::new(2, 1, 300)
584    }
585
586    fn make_targets(protos: &[(TransportProtocol, &str)]) -> Vec<TransportTarget> {
587        protos
588            .iter()
589            .map(|(p, url)| TransportTarget {
590                protocol: *p,
591                url: url.to_string(),
592                upstream_id: "test".to_string(),
593            })
594            .collect()
595    }
596
597    #[tokio::test]
598    async fn test_smart_fallback_no_targets() {
599        let client = reqwest::Client::new();
600        let tracker = make_tracker();
601        let chain = SmartFallbackChain::new(
602            &client,
603            &tracker,
604            Duration::from_secs(5),
605            Duration::from_secs(10),
606        );
607
608        let result = chain
609            .execute(&[], bytes::Bytes::new(), &reqwest::header::HeaderMap::new())
610            .await;
611
612        assert!(matches!(result, Err(SmartFallbackError::NoTargets)));
613    }
614
615    #[tokio::test]
616    async fn test_smart_fallback_all_circuits_open() {
617        let tracker = TransportHealthTracker::new(1, 1, 300);
618        let client = reqwest::Client::new();
619
620        // Open both circuits.
621        tracker.record_failure("test", TransportProtocol::Http);
622        tracker.record_failure("test", TransportProtocol::WebSocket);
623
624        let targets = make_targets(&[
625            (TransportProtocol::Http, "http://localhost:1/mcp"),
626            (TransportProtocol::WebSocket, "ws://localhost:2/mcp"),
627        ]);
628
629        let chain = SmartFallbackChain::new(
630            &client,
631            &tracker,
632            Duration::from_secs(5),
633            Duration::from_secs(10),
634        );
635
636        let result = chain
637            .execute(
638                &targets,
639                bytes::Bytes::new(),
640                &reqwest::header::HeaderMap::new(),
641            )
642            .await;
643
644        match result {
645            Err(SmartFallbackError::AllTransportsFailed { history }) => {
646                assert_eq!(history.attempts.len(), 2);
647                assert!(history.successful_transport.is_none());
648                for attempt in &history.attempts {
649                    assert!(!attempt.succeeded);
650                    assert!(attempt.error.as_ref().unwrap().contains("circuit open"));
651                }
652            }
653            other => panic!("expected AllTransportsFailed, got {:?}", other),
654        }
655    }
656
657    #[tokio::test]
658    async fn test_smart_fallback_skips_stdio_when_disabled() {
659        let tracker = make_tracker();
660        let client = reqwest::Client::new();
661
662        let targets = make_targets(&[(TransportProtocol::Stdio, "stdio://local")]);
663
664        let chain = SmartFallbackChain::new(
665            &client,
666            &tracker,
667            Duration::from_secs(1),
668            Duration::from_secs(5),
669        );
670
671        let result = chain
672            .execute(
673                &targets,
674                bytes::Bytes::new(),
675                &reqwest::header::HeaderMap::new(),
676            )
677            .await;
678
679        // Stdio skipped, so all "available" transports failed.
680        assert!(matches!(
681            result,
682            Err(SmartFallbackError::AllTransportsFailed { .. })
683        ));
684    }
685
686    #[tokio::test]
687    async fn test_smart_fallback_first_fails_second_succeeds_via_circuit() {
688        // First transport's circuit is open, second should be tried.
689        let tracker = TransportHealthTracker::new(1, 1, 300);
690        let client = reqwest::Client::new();
691
692        tracker.record_failure("test", TransportProtocol::Grpc);
693
694        let targets = make_targets(&[
695            (TransportProtocol::Grpc, "http://localhost:1/grpc"),
696            (TransportProtocol::Http, "http://localhost:1/does-not-exist"),
697        ]);
698
699        let chain = SmartFallbackChain::new(
700            &client,
701            &tracker,
702            Duration::from_millis(500),
703            Duration::from_secs(5),
704        );
705
706        let result = chain
707            .execute(
708                &targets,
709                bytes::Bytes::new(),
710                &reqwest::header::HeaderMap::new(),
711            )
712            .await;
713
714        // Both should fail (gRPC circuit open, HTTP connection refused)
715        // but the history should have 2 attempts.
716        match result {
717            Err(SmartFallbackError::AllTransportsFailed { history }) => {
718                assert_eq!(history.attempts.len(), 2);
719                // First attempt was circuit-open (0ms).
720                assert!(!history.attempts[0].succeeded);
721                assert!(history.attempts[0]
722                    .error
723                    .as_ref()
724                    .unwrap()
725                    .contains("circuit open"));
726                // Second attempt tried HTTP but connection refused.
727                assert!(!history.attempts[1].succeeded);
728            }
729            other => panic!("expected AllTransportsFailed, got {:?}", other),
730        }
731    }
732
733    #[tokio::test]
734    async fn test_smart_fallback_total_timeout_zero() {
735        let tracker = make_tracker();
736        let client = reqwest::Client::new();
737
738        let targets = make_targets(&[(TransportProtocol::Http, "http://localhost:1/mcp")]);
739
740        let chain = SmartFallbackChain::new(
741            &client,
742            &tracker,
743            Duration::from_secs(5),
744            Duration::ZERO, // Zero total timeout
745        );
746
747        let result = chain
748            .execute(
749                &targets,
750                bytes::Bytes::new(),
751                &reqwest::header::HeaderMap::new(),
752            )
753            .await;
754
755        assert!(matches!(
756            result,
757            Err(SmartFallbackError::TotalTimeoutExceeded { .. })
758        ));
759    }
760
761    #[test]
762    fn test_smart_fallback_error_display() {
763        let err = SmartFallbackError::NoTargets;
764        assert_eq!(format!("{}", err), "no transport targets provided");
765
766        let err = SmartFallbackError::AllTransportsFailed {
767            history: FallbackNegotiationHistory {
768                attempts: vec![TransportAttempt {
769                    protocol: TransportProtocol::Http,
770                    endpoint_url: "http://localhost".to_string(),
771                    succeeded: false,
772                    duration_ms: 100,
773                    error: Some("timeout".to_string()),
774                }],
775                successful_transport: None,
776                total_duration_ms: 100,
777            },
778        };
779        assert!(format!("{}", err).contains("1 transport(s) failed"));
780    }
781
782    #[test]
783    fn test_transport_target_construction() {
784        let target = TransportTarget {
785            protocol: TransportProtocol::Grpc,
786            url: "http://localhost:50051".to_string(),
787            upstream_id: "backend-1".to_string(),
788        };
789        assert_eq!(target.protocol, TransportProtocol::Grpc);
790        assert_eq!(target.url, "http://localhost:50051");
791    }
792
793    #[tokio::test]
794    async fn test_smart_fallback_history_records_all_attempts() {
795        let tracker = TransportHealthTracker::new(1, 1, 300);
796        let client = reqwest::Client::new();
797
798        // Open first circuit.
799        tracker.record_failure("test", TransportProtocol::Grpc);
800
801        let targets = make_targets(&[
802            (TransportProtocol::Grpc, "http://localhost:1/grpc"),
803            (TransportProtocol::WebSocket, "ws://localhost:2/ws"),
804            (TransportProtocol::Http, "http://localhost:3/http"),
805        ]);
806
807        let chain = SmartFallbackChain::new(
808            &client,
809            &tracker,
810            Duration::from_millis(200),
811            Duration::from_secs(5),
812        );
813
814        let result = chain
815            .execute(
816                &targets,
817                bytes::Bytes::new(),
818                &reqwest::header::HeaderMap::new(),
819            )
820            .await;
821
822        match result {
823            Err(SmartFallbackError::AllTransportsFailed { history }) => {
824                // Should have attempted all 3: gRPC (circuit open), WS (connection refused), HTTP (connection refused).
825                assert_eq!(history.attempts.len(), 3);
826                assert!(history.successful_transport.is_none());
827                assert!(history.total_duration_ms > 0 || history.attempts[0].duration_ms == 0);
828            }
829            other => panic!("expected AllTransportsFailed, got {:?}", other),
830        }
831    }
832
833    #[tokio::test]
834    async fn test_smart_fallback_with_stdio_enabled() {
835        let tracker = make_tracker();
836        let client = reqwest::Client::new();
837
838        // Use 'echo' as a simple stdio command that reads stdin and outputs.
839        let targets = vec![TransportTarget {
840            protocol: TransportProtocol::Stdio,
841            url: "stdio://local".to_string(),
842            upstream_id: "test".to_string(),
843        }];
844
845        let chain = SmartFallbackChain::new(
846            &client,
847            &tracker,
848            Duration::from_secs(5),
849            Duration::from_secs(10),
850        )
851        .with_stdio("/bin/cat".to_string())
852        .expect("stdio fallback command should resolve");
853
854        let body = bytes::Bytes::from(r#"{"test": true}"#);
855        let result = chain
856            .execute(&targets, body.clone(), &reqwest::header::HeaderMap::new())
857            .await;
858
859        match result {
860            Ok(res) => {
861                assert_eq!(res.transport_used, TransportProtocol::Stdio);
862                assert_eq!(res.status, 200);
863                assert_eq!(res.response, body);
864                assert_eq!(res.history.attempts.len(), 1);
865                assert!(res.history.attempts[0].succeeded);
866            }
867            Err(e) => panic!("expected success, got: {}", e),
868        }
869    }
870}