Skip to main content

vellaveto_http_proxy/proxy/
smart_fallback.rs

1// Copyright 2026 Paolo Vella
2// SPDX-License-Identifier: BUSL-1.1
3//
4// Use of this software is governed by the Business Source License
5// included in the LICENSE-BSL-1.1 file at the root of this repository.
6//
7// Change Date: Three years from the date of publication of this version.
8// Change License: MPL-2.0
9
10//! Smart cross-transport fallback chain orchestrator (Phase 29).
11//!
12//! Tries each transport in priority order, recording results in a
13//! `FallbackNegotiationHistory` for audit purposes. Uses `TransportHealthTracker`
14//! to skip transports whose circuits are open.
15
16use std::path::{Path, PathBuf};
17use std::time::{Duration, Instant};
18use vellaveto_types::command::resolve_executable;
19use vellaveto_types::{FallbackNegotiationHistory, TransportAttempt, TransportProtocol};
20
21use super::transport_health::TransportHealthTracker;
22use super::{FORWARDED_HEADERS, MAX_RESPONSE_BODY_BYTES};
23
24/// Maximum stderr capture from stdio subprocess (4 KB). FIND-R41-010.
25const MAX_STDERR_BYTES: usize = 4096;
26
27/// A single transport target to try during fallback.
28#[derive(Debug, Clone)]
29pub struct TransportTarget {
30    /// The transport protocol to use.
31    pub protocol: TransportProtocol,
32    /// The endpoint URL for this transport.
33    pub url: String,
34    /// Upstream identifier for circuit breaker tracking.
35    pub upstream_id: String,
36}
37
38/// Result of a successful smart fallback execution.
39#[derive(Debug)]
40pub struct SmartFallbackResult {
41    /// Response bytes from the upstream.
42    pub response: bytes::Bytes,
43    /// The transport that succeeded.
44    pub transport_used: TransportProtocol,
45    /// HTTP status code.
46    pub status: u16,
47    /// Full negotiation history for audit.
48    pub history: FallbackNegotiationHistory,
49}
50
51/// Errors from smart fallback execution.
52#[derive(Debug)]
53pub enum SmartFallbackError {
54    /// All transports failed after trying each one.
55    AllTransportsFailed { history: FallbackNegotiationHistory },
56    /// Total timeout budget exhausted.
57    TotalTimeoutExceeded { history: FallbackNegotiationHistory },
58    /// No targets provided.
59    NoTargets,
60}
61
62impl std::fmt::Display for SmartFallbackError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Self::AllTransportsFailed { history } => {
66                write!(f, "all {} transport(s) failed", history.attempts.len())
67            }
68            Self::TotalTimeoutExceeded { history } => {
69                write!(
70                    f,
71                    "total timeout exceeded after {} attempt(s)",
72                    history.attempts.len()
73                )
74            }
75            Self::NoTargets => write!(f, "no transport targets provided"),
76        }
77    }
78}
79
80impl std::error::Error for SmartFallbackError {}
81
82/// Smart fallback chain orchestrator.
83///
84/// Tries each `TransportTarget` in order, checking circuit breaker health
85/// before each attempt. Records all attempts in a `FallbackNegotiationHistory`.
86pub struct SmartFallbackChain<'a> {
87    client: &'a reqwest::Client,
88    health_tracker: &'a TransportHealthTracker,
89    per_attempt_timeout: Duration,
90    total_timeout: Duration,
91    stdio_enabled: bool,
92    stdio_command: Option<String>,
93    stdio_command_resolved: Option<PathBuf>,
94}
95
96impl<'a> SmartFallbackChain<'a> {
97    /// Create a new fallback chain.
98    pub fn new(
99        client: &'a reqwest::Client,
100        health_tracker: &'a TransportHealthTracker,
101        per_attempt_timeout: Duration,
102        total_timeout: Duration,
103    ) -> Self {
104        Self {
105            client,
106            health_tracker,
107            per_attempt_timeout,
108            total_timeout,
109            stdio_enabled: false,
110            stdio_command: None,
111            stdio_command_resolved: None,
112        }
113    }
114
115    /// Enable stdio fallback with the given command.
116    pub fn with_stdio(&self, command: String) -> Result<Self, String> {
117        let resolved = resolve_executable(&command, std::env::var_os("PATH").as_deref())?;
118        Ok(Self {
119            client: self.client,
120            health_tracker: self.health_tracker,
121            per_attempt_timeout: self.per_attempt_timeout,
122            total_timeout: self.total_timeout,
123            stdio_enabled: true,
124            stdio_command: Some(command),
125            stdio_command_resolved: Some(resolved),
126        })
127    }
128
129    /// Execute the fallback chain, trying each target in order.
130    pub async fn execute(
131        &self,
132        targets: &[TransportTarget],
133        body: bytes::Bytes,
134        headers: &reqwest::header::HeaderMap,
135    ) -> Result<SmartFallbackResult, SmartFallbackError> {
136        if targets.is_empty() {
137            return Err(SmartFallbackError::NoTargets);
138        }
139
140        let chain_start = Instant::now();
141        let mut attempts = Vec::new();
142
143        for target in targets {
144            // Check total timeout budget.
145            if chain_start.elapsed() >= self.total_timeout {
146                let history = FallbackNegotiationHistory {
147                    attempts,
148                    successful_transport: None,
149                    total_duration_ms: chain_start.elapsed().as_millis() as u64,
150                };
151                return Err(SmartFallbackError::TotalTimeoutExceeded { history });
152            }
153
154            // Skip stdio if not enabled.
155            if target.protocol == TransportProtocol::Stdio && !self.stdio_enabled {
156                continue;
157            }
158
159            // Check circuit breaker.
160            if let Err(reason) = self
161                .health_tracker
162                .can_use(&target.upstream_id, target.protocol)
163            {
164                attempts.push(TransportAttempt {
165                    protocol: target.protocol,
166                    endpoint_url: target.url.clone(),
167                    succeeded: false,
168                    duration_ms: 0,
169                    error: Some(reason),
170                });
171                continue;
172            }
173
174            // Calculate remaining timeout budget.
175            let remaining = self
176                .total_timeout
177                .checked_sub(chain_start.elapsed())
178                .unwrap_or(Duration::ZERO);
179            let attempt_timeout = self.per_attempt_timeout.min(remaining);
180
181            if attempt_timeout.is_zero() {
182                let history = FallbackNegotiationHistory {
183                    attempts,
184                    successful_transport: None,
185                    total_duration_ms: chain_start.elapsed().as_millis() as u64,
186                };
187                return Err(SmartFallbackError::TotalTimeoutExceeded { history });
188            }
189
190            let attempt_start = Instant::now();
191
192            let result = match target.protocol {
193                TransportProtocol::Http => {
194                    self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
195                        .await
196                }
197                TransportProtocol::WebSocket => {
198                    self.dispatch_websocket(&target.url, body.clone(), attempt_timeout)
199                        .await
200                }
201                TransportProtocol::Grpc => {
202                    // gRPC dispatch via HTTP bridge endpoint.
203                    self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
204                        .await
205                }
206                TransportProtocol::Stdio => {
207                    if let (Some(cmd), Some(resolved)) = (
208                        self.stdio_command.as_ref(),
209                        self.stdio_command_resolved.as_ref(),
210                    ) {
211                        self.dispatch_stdio(cmd, resolved, body.clone(), attempt_timeout)
212                            .await
213                    } else {
214                        Err("stdio command not configured".to_string())
215                    }
216                }
217            };
218
219            let duration_ms = attempt_start.elapsed().as_millis() as u64;
220
221            match result {
222                Ok((response_bytes, status)) => {
223                    // SECURITY (FIND-R43-012): Treat 5xx responses as failures so
224                    // the circuit breaker can open and trigger fallback to the next
225                    // transport. A backend returning 500s must not keep the circuit
226                    // closed.
227                    if status >= 500 {
228                        self.health_tracker
229                            .record_failure(&target.upstream_id, target.protocol);
230
231                        metrics::counter!(
232                            "vellaveto_transport_fallback_total",
233                            "transport" => format!("{:?}", target.protocol),
234                            "upstream_id" => target.upstream_id.clone(),
235                            "result" => "server_error",
236                        )
237                        .increment(1);
238
239                        attempts.push(TransportAttempt {
240                            protocol: target.protocol,
241                            endpoint_url: target.url.clone(),
242                            succeeded: false,
243                            duration_ms,
244                            error: Some(format!("server error: HTTP {status}")),
245                        });
246
247                        continue;
248                    }
249
250                    self.health_tracker
251                        .record_success(&target.upstream_id, target.protocol);
252
253                    metrics::counter!(
254                        "vellaveto_transport_fallback_total",
255                        "transport" => format!("{:?}", target.protocol),
256                        "upstream_id" => target.upstream_id.clone(),
257                        "result" => "success",
258                    )
259                    .increment(1);
260
261                    attempts.push(TransportAttempt {
262                        protocol: target.protocol,
263                        endpoint_url: target.url.clone(),
264                        succeeded: true,
265                        duration_ms,
266                        error: None,
267                    });
268
269                    let history = FallbackNegotiationHistory {
270                        attempts,
271                        successful_transport: Some(target.protocol),
272                        total_duration_ms: chain_start.elapsed().as_millis() as u64,
273                    };
274
275                    return Ok(SmartFallbackResult {
276                        response: response_bytes,
277                        transport_used: target.protocol,
278                        status,
279                        history,
280                    });
281                }
282                Err(error) => {
283                    self.health_tracker
284                        .record_failure(&target.upstream_id, target.protocol);
285
286                    metrics::counter!(
287                        "vellaveto_transport_fallback_total",
288                        "transport" => format!("{:?}", target.protocol),
289                        "upstream_id" => target.upstream_id.clone(),
290                        "result" => "failure",
291                    )
292                    .increment(1);
293
294                    attempts.push(TransportAttempt {
295                        protocol: target.protocol,
296                        endpoint_url: target.url.clone(),
297                        succeeded: false,
298                        duration_ms,
299                        error: Some(error),
300                    });
301                }
302            }
303        }
304
305        let history = FallbackNegotiationHistory {
306            attempts,
307            successful_transport: None,
308            total_duration_ms: chain_start.elapsed().as_millis() as u64,
309        };
310        Err(SmartFallbackError::AllTransportsFailed { history })
311    }
312
313    /// Dispatch via HTTP POST.
314    ///
315    /// SECURITY (FIND-R41-015): Uses chunk-based reading to prevent OOM from
316    /// chunked-encoded responses that omit Content-Length. Each chunk is checked
317    /// against MAX_RESPONSE_BODY_BYTES before accumulating.
318    async fn dispatch_http(
319        &self,
320        url: &str,
321        body: bytes::Bytes,
322        headers: &reqwest::header::HeaderMap,
323        timeout: Duration,
324    ) -> Result<(bytes::Bytes, u16), String> {
325        // SECURITY (R240-PROXY-1): Enforce HTTPS for non-local upstream URLs.
326        // Parity with fallback.rs — the R239 fix only covered the legacy path.
327        super::validate_upstream_url_scheme(url)?;
328
329        let mut request = self.client.post(url).timeout(timeout);
330
331        // SECURITY (FIND-R41-001): Only forward allowlisted headers to
332        // prevent leaking Authorization, Cookie, etc. to upstream backends.
333        for (key, value) in headers {
334            let key_lower = key.as_str().to_lowercase();
335            if FORWARDED_HEADERS
336                .iter()
337                .any(|&allowed| allowed == key_lower)
338            {
339                request = request.header(key.clone(), value.clone());
340            }
341        }
342
343        let mut resp = request
344            .body(body)
345            .send()
346            .await
347            .map_err(|e| format!("HTTP request error: {e}"))?;
348
349        let status = resp.status().as_u16();
350
351        // SECURITY (FIND-R41-004): Fast-reject if Content-Length exceeds limit.
352        // SECURITY (R240-P3-PROXY-2): Compare in u64 space to avoid truncation on 32-bit.
353        if let Some(len) = resp.content_length() {
354            if len > MAX_RESPONSE_BODY_BYTES as u64 {
355                return Err(format!(
356                    "response body too large: {len} bytes (max {MAX_RESPONSE_BODY_BYTES})"
357                ));
358            }
359        }
360
361        // SECURITY (FIND-R41-015): Read body in chunks with bounded accumulation.
362        // Prevents OOM from chunked-encoded responses that omit Content-Length.
363        let capacity = std::cmp::min(
364            resp.content_length().unwrap_or(8192) as usize,
365            MAX_RESPONSE_BODY_BYTES,
366        );
367        let mut response_body = Vec::with_capacity(capacity);
368        while let Some(chunk) = resp
369            .chunk()
370            .await
371            .map_err(|e| format!("HTTP response body error: {e}"))?
372        {
373            if response_body.len().saturating_add(chunk.len()) > MAX_RESPONSE_BODY_BYTES {
374                return Err(format!(
375                    "response body too large: >{MAX_RESPONSE_BODY_BYTES} bytes (max {MAX_RESPONSE_BODY_BYTES})"
376                ));
377            }
378            response_body.extend_from_slice(&chunk);
379        }
380
381        Ok((bytes::Bytes::from(response_body), status))
382    }
383
384    /// Dispatch via WebSocket (one-shot: connect, send, receive, close).
385    ///
386    /// SECURITY (FIND-R41-011): Configures max_message_size to prevent OOM
387    /// from unbounded upstream WebSocket frames.
388    async fn dispatch_websocket(
389        &self,
390        url: &str,
391        body: bytes::Bytes,
392        timeout: Duration,
393    ) -> Result<(bytes::Bytes, u16), String> {
394        // SECURITY (R240-PROXY-1): Enforce HTTPS/WSS for non-local upstream URLs.
395        super::validate_upstream_url_scheme(url)?;
396
397        use tokio_tungstenite::tungstenite::Message;
398
399        let result = tokio::time::timeout(timeout, async {
400            // SECURITY (FIND-R41-011): Configure max message/frame size to prevent
401            // a malicious upstream from sending unbounded WebSocket frames (OOM).
402            let mut ws_config =
403                tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default();
404            ws_config.max_message_size = Some(MAX_RESPONSE_BODY_BYTES);
405            ws_config.max_frame_size = Some(MAX_RESPONSE_BODY_BYTES);
406            let (mut ws, _) =
407                tokio_tungstenite::connect_async_with_config(url, Some(ws_config), false)
408                    .await
409                    .map_err(|e| format!("WebSocket connect error: {e}"))?;
410
411            use futures_util::SinkExt;
412            // SECURITY (FIND-R43-026): Reject invalid UTF-8 instead of silently
413            // replacing bytes with U+FFFD, which would mutate the request body.
414            let body_text = String::from_utf8(body.to_vec())
415                .map_err(|_| "WebSocket body contains invalid UTF-8".to_string())?;
416            ws.send(Message::Text(body_text.into()))
417                .await
418                .map_err(|e| format!("WebSocket send error: {e}"))?;
419
420            use futures_util::StreamExt;
421            let response = ws
422                .next()
423                .await
424                .ok_or_else(|| "WebSocket closed without response".to_string())?
425                .map_err(|e| format!("WebSocket receive error: {e}"))?;
426
427            ws.close(None)
428                .await
429                .map_err(|e| format!("WebSocket close error: {e}"))?;
430
431            let response_bytes = match response {
432                Message::Text(t) => {
433                    let bytes = t.as_bytes();
434                    if bytes.len() > MAX_RESPONSE_BODY_BYTES {
435                        return Err(format!(
436                            "WebSocket response too large: {} bytes (max {})",
437                            bytes.len(),
438                            MAX_RESPONSE_BODY_BYTES
439                        ));
440                    }
441                    bytes::Bytes::from(Vec::from(bytes))
442                }
443                Message::Binary(b) => {
444                    if b.len() > MAX_RESPONSE_BODY_BYTES {
445                        return Err(format!(
446                            "WebSocket response too large: {} bytes (max {})",
447                            b.len(),
448                            MAX_RESPONSE_BODY_BYTES
449                        ));
450                    }
451                    bytes::Bytes::from(Vec::from(b.as_ref()))
452                }
453                other => {
454                    return Err(format!("unexpected WebSocket message type: {other:?}"));
455                }
456            };
457
458            Ok::<_, String>((response_bytes, 200u16))
459        })
460        .await;
461
462        match result {
463            Ok(inner) => inner,
464            Err(_) => Err("WebSocket timeout".to_string()),
465        }
466    }
467
468    /// Dispatch via stdio subprocess.
469    ///
470    /// SECURITY (FIND-R41-002): Uses direct Command::new(command) instead of
471    /// `sh -c` to prevent shell injection. The command path is validated at
472    /// config time to be an absolute path with no shell metacharacters.
473    ///
474    /// SECURITY (FIND-R41-006): Explicitly kills child on timeout to prevent
475    /// zombie process accumulation.
476    ///
477    /// SECURITY (FIND-R41-010): Captures stderr for diagnostics instead of
478    /// discarding it.
479    async fn dispatch_stdio(
480        &self,
481        _command: &str,
482        resolved_command: &Path,
483        body: bytes::Bytes,
484        timeout: Duration,
485    ) -> Result<(bytes::Bytes, u16), String> {
486        use tokio::io::{AsyncReadExt, AsyncWriteExt};
487        use tokio::process::Command;
488
489        // SECURITY (FIND-R41-002): Execute command directly, not via sh -c.
490        // SECURITY (FIND-R43-002): Clear inherited environment to prevent leaking
491        // secrets (API keys, tokens, credentials) to the subprocess.
492        // SECURITY (FIND-R43-008): kill_on_drop ensures the child is killed if
493        // the async task is cancelled, preventing orphaned subprocesses.
494        let mut child = Command::new(resolved_command)
495            .env_clear()
496            .env("PATH", "/usr/local/bin:/usr/bin:/bin")
497            .env("HOME", "/tmp")
498            .env("LANG", "C.UTF-8")
499            .kill_on_drop(true)
500            .stdin(std::process::Stdio::piped())
501            .stdout(std::process::Stdio::piped())
502            .stderr(std::process::Stdio::piped())
503            .spawn()
504            .map_err(|e| format!("stdio spawn error: {e}"))?;
505
506        // Write to stdin and drop it to signal EOF.
507        if let Some(mut stdin) = child.stdin.take() {
508            stdin
509                .write_all(&body)
510                .await
511                .map_err(|e| format!("stdio write error: {e}"))?;
512        }
513
514        // Take stdout/stderr handles before waiting, so child is not consumed.
515        let stdout_handle = child
516            .stdout
517            .take()
518            .ok_or_else(|| "stdio stdout not captured".to_string())?;
519        let mut stderr_handle = child
520            .stderr
521            .take()
522            .ok_or_else(|| "stdio stderr not captured".to_string())?;
523
524        // SECURITY (FIND-R41-006): Use select! so we can kill the child on timeout.
525        tokio::select! {
526            // SECURITY (FIND-R43-001): Read stdout concurrently with waiting for
527            // process exit to prevent deadlock. If the subprocess writes >64KB to
528            // stdout, the pipe buffer fills and blocks; reading stdout only after
529            // wait() would deadlock. Using tokio::join! ensures both proceed.
530            result = async {
531                let stdout_future = async {
532                    let mut stdout_buf = Vec::new();
533                    let read_result = stdout_handle
534                        .take((MAX_RESPONSE_BODY_BYTES as u64) + 1)
535                        .read_to_end(&mut stdout_buf)
536                        .await
537                        .map_err(|e| format!("stdio stdout read error: {e}"));
538                    read_result.map(|_| stdout_buf)
539                };
540
541                let wait_future = child.wait();
542
543                let (stdout_result, status_result) = tokio::join!(stdout_future, wait_future);
544
545                let stdout_buf = stdout_result?;
546                let status = status_result
547                    .map_err(|e| format!("stdio wait error: {e}"))?;
548
549                if stdout_buf.len() > MAX_RESPONSE_BODY_BYTES {
550                    return Err(format!(
551                        "stdio stdout too large: >{MAX_RESPONSE_BODY_BYTES} bytes (max {MAX_RESPONSE_BODY_BYTES})"
552                    ));
553                }
554
555                if !status.success() {
556                    // SECURITY (FIND-R41-010): Include truncated stderr in error.
557                    let mut stderr_buf = vec![0u8; MAX_STDERR_BYTES];
558                    let n = stderr_handle
559                        .read(&mut stderr_buf)
560                        .await
561                        .unwrap_or(0);
562                    let stderr_snippet = String::from_utf8_lossy(&stderr_buf[..n]);
563                    return Err(format!(
564                        "stdio process exited with {:?}: {}",
565                        status,
566                        stderr_snippet.trim()
567                    ));
568                }
569
570                Ok((bytes::Bytes::from(stdout_buf), 200u16))
571            } => {
572                result
573            }
574            _ = tokio::time::sleep(timeout) => {
575                // SECURITY (FIND-R41-006): Kill child on timeout to prevent zombies.
576                let _ = child.kill().await;
577                Err("stdio timeout".to_string())
578            }
579        }
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    fn make_tracker() -> TransportHealthTracker {
588        TransportHealthTracker::new(2, 1, 300)
589    }
590
591    fn make_targets(protos: &[(TransportProtocol, &str)]) -> Vec<TransportTarget> {
592        protos
593            .iter()
594            .map(|(p, url)| TransportTarget {
595                protocol: *p,
596                url: url.to_string(),
597                upstream_id: "test".to_string(),
598            })
599            .collect()
600    }
601
602    #[tokio::test]
603    async fn test_smart_fallback_no_targets() {
604        let client = reqwest::Client::new();
605        let tracker = make_tracker();
606        let chain = SmartFallbackChain::new(
607            &client,
608            &tracker,
609            Duration::from_secs(5),
610            Duration::from_secs(10),
611        );
612
613        let result = chain
614            .execute(&[], bytes::Bytes::new(), &reqwest::header::HeaderMap::new())
615            .await;
616
617        assert!(matches!(result, Err(SmartFallbackError::NoTargets)));
618    }
619
620    #[tokio::test]
621    async fn test_smart_fallback_all_circuits_open() {
622        let tracker = TransportHealthTracker::new(1, 1, 300);
623        let client = reqwest::Client::new();
624
625        // Open both circuits.
626        tracker.record_failure("test", TransportProtocol::Http);
627        tracker.record_failure("test", TransportProtocol::WebSocket);
628
629        let targets = make_targets(&[
630            (TransportProtocol::Http, "http://localhost:1/mcp"),
631            (TransportProtocol::WebSocket, "ws://localhost:2/mcp"),
632        ]);
633
634        let chain = SmartFallbackChain::new(
635            &client,
636            &tracker,
637            Duration::from_secs(5),
638            Duration::from_secs(10),
639        );
640
641        let result = chain
642            .execute(
643                &targets,
644                bytes::Bytes::new(),
645                &reqwest::header::HeaderMap::new(),
646            )
647            .await;
648
649        match result {
650            Err(SmartFallbackError::AllTransportsFailed { history }) => {
651                assert_eq!(history.attempts.len(), 2);
652                assert!(history.successful_transport.is_none());
653                for attempt in &history.attempts {
654                    assert!(!attempt.succeeded);
655                    assert!(attempt.error.as_ref().unwrap().contains("circuit open"));
656                }
657            }
658            other => panic!("expected AllTransportsFailed, got {other:?}"),
659        }
660    }
661
662    #[tokio::test]
663    async fn test_smart_fallback_skips_stdio_when_disabled() {
664        let tracker = make_tracker();
665        let client = reqwest::Client::new();
666
667        let targets = make_targets(&[(TransportProtocol::Stdio, "stdio://local")]);
668
669        let chain = SmartFallbackChain::new(
670            &client,
671            &tracker,
672            Duration::from_secs(1),
673            Duration::from_secs(5),
674        );
675
676        let result = chain
677            .execute(
678                &targets,
679                bytes::Bytes::new(),
680                &reqwest::header::HeaderMap::new(),
681            )
682            .await;
683
684        // Stdio skipped, so all "available" transports failed.
685        assert!(matches!(
686            result,
687            Err(SmartFallbackError::AllTransportsFailed { .. })
688        ));
689    }
690
691    #[tokio::test]
692    async fn test_smart_fallback_first_fails_second_succeeds_via_circuit() {
693        // First transport's circuit is open, second should be tried.
694        let tracker = TransportHealthTracker::new(1, 1, 300);
695        let client = reqwest::Client::new();
696
697        tracker.record_failure("test", TransportProtocol::Grpc);
698
699        let targets = make_targets(&[
700            (TransportProtocol::Grpc, "http://localhost:1/grpc"),
701            (TransportProtocol::Http, "http://localhost:1/does-not-exist"),
702        ]);
703
704        let chain = SmartFallbackChain::new(
705            &client,
706            &tracker,
707            Duration::from_millis(500),
708            Duration::from_secs(5),
709        );
710
711        let result = chain
712            .execute(
713                &targets,
714                bytes::Bytes::new(),
715                &reqwest::header::HeaderMap::new(),
716            )
717            .await;
718
719        // Both should fail (gRPC circuit open, HTTP connection refused)
720        // but the history should have 2 attempts.
721        match result {
722            Err(SmartFallbackError::AllTransportsFailed { history }) => {
723                assert_eq!(history.attempts.len(), 2);
724                // First attempt was circuit-open (0ms).
725                assert!(!history.attempts[0].succeeded);
726                assert!(history.attempts[0]
727                    .error
728                    .as_ref()
729                    .unwrap()
730                    .contains("circuit open"));
731                // Second attempt tried HTTP but connection refused.
732                assert!(!history.attempts[1].succeeded);
733            }
734            other => panic!("expected AllTransportsFailed, got {other:?}"),
735        }
736    }
737
738    #[tokio::test]
739    async fn test_smart_fallback_total_timeout_zero() {
740        let tracker = make_tracker();
741        let client = reqwest::Client::new();
742
743        let targets = make_targets(&[(TransportProtocol::Http, "http://localhost:1/mcp")]);
744
745        let chain = SmartFallbackChain::new(
746            &client,
747            &tracker,
748            Duration::from_secs(5),
749            Duration::ZERO, // Zero total timeout
750        );
751
752        let result = chain
753            .execute(
754                &targets,
755                bytes::Bytes::new(),
756                &reqwest::header::HeaderMap::new(),
757            )
758            .await;
759
760        assert!(matches!(
761            result,
762            Err(SmartFallbackError::TotalTimeoutExceeded { .. })
763        ));
764    }
765
766    #[test]
767    fn test_smart_fallback_error_display() {
768        let err = SmartFallbackError::NoTargets;
769        assert_eq!(format!("{err}"), "no transport targets provided");
770
771        let err = SmartFallbackError::AllTransportsFailed {
772            history: FallbackNegotiationHistory {
773                attempts: vec![TransportAttempt {
774                    protocol: TransportProtocol::Http,
775                    endpoint_url: "http://localhost".to_string(),
776                    succeeded: false,
777                    duration_ms: 100,
778                    error: Some("timeout".to_string()),
779                }],
780                successful_transport: None,
781                total_duration_ms: 100,
782            },
783        };
784        assert!(format!("{err}").contains("1 transport(s) failed"));
785    }
786
787    #[test]
788    fn test_transport_target_construction() {
789        let target = TransportTarget {
790            protocol: TransportProtocol::Grpc,
791            url: "http://localhost:50051".to_string(),
792            upstream_id: "backend-1".to_string(),
793        };
794        assert_eq!(target.protocol, TransportProtocol::Grpc);
795        assert_eq!(target.url, "http://localhost:50051");
796    }
797
798    #[tokio::test]
799    async fn test_smart_fallback_history_records_all_attempts() {
800        let tracker = TransportHealthTracker::new(1, 1, 300);
801        let client = reqwest::Client::new();
802
803        // Open first circuit.
804        tracker.record_failure("test", TransportProtocol::Grpc);
805
806        let targets = make_targets(&[
807            (TransportProtocol::Grpc, "http://localhost:1/grpc"),
808            (TransportProtocol::WebSocket, "ws://localhost:2/ws"),
809            (TransportProtocol::Http, "http://localhost:3/http"),
810        ]);
811
812        let chain = SmartFallbackChain::new(
813            &client,
814            &tracker,
815            Duration::from_millis(200),
816            Duration::from_secs(5),
817        );
818
819        let result = chain
820            .execute(
821                &targets,
822                bytes::Bytes::new(),
823                &reqwest::header::HeaderMap::new(),
824            )
825            .await;
826
827        match result {
828            Err(SmartFallbackError::AllTransportsFailed { history }) => {
829                // Should have attempted all 3: gRPC (circuit open), WS (connection refused), HTTP (connection refused).
830                assert_eq!(history.attempts.len(), 3);
831                assert!(history.successful_transport.is_none());
832                assert!(history.total_duration_ms > 0 || history.attempts[0].duration_ms == 0);
833            }
834            other => panic!("expected AllTransportsFailed, got {other:?}"),
835        }
836    }
837
838    #[tokio::test]
839    async fn test_smart_fallback_with_stdio_enabled() {
840        let tracker = make_tracker();
841        let client = reqwest::Client::new();
842
843        // Use 'echo' as a simple stdio command that reads stdin and outputs.
844        let targets = vec![TransportTarget {
845            protocol: TransportProtocol::Stdio,
846            url: "stdio://local".to_string(),
847            upstream_id: "test".to_string(),
848        }];
849
850        let chain = SmartFallbackChain::new(
851            &client,
852            &tracker,
853            Duration::from_secs(5),
854            Duration::from_secs(10),
855        )
856        .with_stdio("/bin/cat".to_string())
857        .expect("stdio fallback command should resolve");
858
859        let body = bytes::Bytes::from(r#"{"test": true}"#);
860        let result = chain
861            .execute(&targets, body.clone(), &reqwest::header::HeaderMap::new())
862            .await;
863
864        match result {
865            Ok(res) => {
866                assert_eq!(res.transport_used, TransportProtocol::Stdio);
867                assert_eq!(res.status, 200);
868                assert_eq!(res.response, body);
869                assert_eq!(res.history.attempts.len(), 1);
870                assert!(res.history.attempts[0].succeeded);
871            }
872            Err(e) => panic!("expected success, got: {e}"),
873        }
874    }
875}