1use std::path::{Path, PathBuf};
17use std::time::{Duration, Instant};
18use vellaveto_types::command::resolve_executable;
19use vellaveto_types::{FallbackNegotiationHistory, TransportAttempt, TransportProtocol};
20
21use super::transport_health::TransportHealthTracker;
22use super::{FORWARDED_HEADERS, MAX_RESPONSE_BODY_BYTES};
23
24const MAX_STDERR_BYTES: usize = 4096;
26
27#[derive(Debug, Clone)]
29pub struct TransportTarget {
30 pub protocol: TransportProtocol,
32 pub url: String,
34 pub upstream_id: String,
36}
37
38#[derive(Debug)]
40pub struct SmartFallbackResult {
41 pub response: bytes::Bytes,
43 pub transport_used: TransportProtocol,
45 pub status: u16,
47 pub history: FallbackNegotiationHistory,
49}
50
51#[derive(Debug)]
53pub enum SmartFallbackError {
54 AllTransportsFailed { history: FallbackNegotiationHistory },
56 TotalTimeoutExceeded { history: FallbackNegotiationHistory },
58 NoTargets,
60}
61
62impl std::fmt::Display for SmartFallbackError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 Self::AllTransportsFailed { history } => {
66 write!(f, "all {} transport(s) failed", history.attempts.len())
67 }
68 Self::TotalTimeoutExceeded { history } => {
69 write!(
70 f,
71 "total timeout exceeded after {} attempt(s)",
72 history.attempts.len()
73 )
74 }
75 Self::NoTargets => write!(f, "no transport targets provided"),
76 }
77 }
78}
79
80impl std::error::Error for SmartFallbackError {}
81
82pub struct SmartFallbackChain<'a> {
87 client: &'a reqwest::Client,
88 health_tracker: &'a TransportHealthTracker,
89 per_attempt_timeout: Duration,
90 total_timeout: Duration,
91 stdio_enabled: bool,
92 stdio_command: Option<String>,
93 stdio_command_resolved: Option<PathBuf>,
94}
95
96impl<'a> SmartFallbackChain<'a> {
97 pub fn new(
99 client: &'a reqwest::Client,
100 health_tracker: &'a TransportHealthTracker,
101 per_attempt_timeout: Duration,
102 total_timeout: Duration,
103 ) -> Self {
104 Self {
105 client,
106 health_tracker,
107 per_attempt_timeout,
108 total_timeout,
109 stdio_enabled: false,
110 stdio_command: None,
111 stdio_command_resolved: None,
112 }
113 }
114
115 pub fn with_stdio(&self, command: String) -> Result<Self, String> {
117 let resolved = resolve_executable(&command, std::env::var_os("PATH").as_deref())?;
118 Ok(Self {
119 client: self.client,
120 health_tracker: self.health_tracker,
121 per_attempt_timeout: self.per_attempt_timeout,
122 total_timeout: self.total_timeout,
123 stdio_enabled: true,
124 stdio_command: Some(command),
125 stdio_command_resolved: Some(resolved),
126 })
127 }
128
129 pub async fn execute(
131 &self,
132 targets: &[TransportTarget],
133 body: bytes::Bytes,
134 headers: &reqwest::header::HeaderMap,
135 ) -> Result<SmartFallbackResult, SmartFallbackError> {
136 if targets.is_empty() {
137 return Err(SmartFallbackError::NoTargets);
138 }
139
140 let chain_start = Instant::now();
141 let mut attempts = Vec::new();
142
143 for target in targets {
144 if chain_start.elapsed() >= self.total_timeout {
146 let history = FallbackNegotiationHistory {
147 attempts,
148 successful_transport: None,
149 total_duration_ms: chain_start.elapsed().as_millis() as u64,
150 };
151 return Err(SmartFallbackError::TotalTimeoutExceeded { history });
152 }
153
154 if target.protocol == TransportProtocol::Stdio && !self.stdio_enabled {
156 continue;
157 }
158
159 if let Err(reason) = self
161 .health_tracker
162 .can_use(&target.upstream_id, target.protocol)
163 {
164 attempts.push(TransportAttempt {
165 protocol: target.protocol,
166 endpoint_url: target.url.clone(),
167 succeeded: false,
168 duration_ms: 0,
169 error: Some(reason),
170 });
171 continue;
172 }
173
174 let remaining = self
176 .total_timeout
177 .checked_sub(chain_start.elapsed())
178 .unwrap_or(Duration::ZERO);
179 let attempt_timeout = self.per_attempt_timeout.min(remaining);
180
181 if attempt_timeout.is_zero() {
182 let history = FallbackNegotiationHistory {
183 attempts,
184 successful_transport: None,
185 total_duration_ms: chain_start.elapsed().as_millis() as u64,
186 };
187 return Err(SmartFallbackError::TotalTimeoutExceeded { history });
188 }
189
190 let attempt_start = Instant::now();
191
192 let result = match target.protocol {
193 TransportProtocol::Http => {
194 self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
195 .await
196 }
197 TransportProtocol::WebSocket => {
198 self.dispatch_websocket(&target.url, body.clone(), attempt_timeout)
199 .await
200 }
201 TransportProtocol::Grpc => {
202 self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
204 .await
205 }
206 TransportProtocol::Stdio => {
207 if let (Some(cmd), Some(resolved)) = (
208 self.stdio_command.as_ref(),
209 self.stdio_command_resolved.as_ref(),
210 ) {
211 self.dispatch_stdio(cmd, resolved, body.clone(), attempt_timeout)
212 .await
213 } else {
214 Err("stdio command not configured".to_string())
215 }
216 }
217 };
218
219 let duration_ms = attempt_start.elapsed().as_millis() as u64;
220
221 match result {
222 Ok((response_bytes, status)) => {
223 if status >= 500 {
228 self.health_tracker
229 .record_failure(&target.upstream_id, target.protocol);
230
231 metrics::counter!(
232 "vellaveto_transport_fallback_total",
233 "transport" => format!("{:?}", target.protocol),
234 "upstream_id" => target.upstream_id.clone(),
235 "result" => "server_error",
236 )
237 .increment(1);
238
239 attempts.push(TransportAttempt {
240 protocol: target.protocol,
241 endpoint_url: target.url.clone(),
242 succeeded: false,
243 duration_ms,
244 error: Some(format!("server error: HTTP {status}")),
245 });
246
247 continue;
248 }
249
250 self.health_tracker
251 .record_success(&target.upstream_id, target.protocol);
252
253 metrics::counter!(
254 "vellaveto_transport_fallback_total",
255 "transport" => format!("{:?}", target.protocol),
256 "upstream_id" => target.upstream_id.clone(),
257 "result" => "success",
258 )
259 .increment(1);
260
261 attempts.push(TransportAttempt {
262 protocol: target.protocol,
263 endpoint_url: target.url.clone(),
264 succeeded: true,
265 duration_ms,
266 error: None,
267 });
268
269 let history = FallbackNegotiationHistory {
270 attempts,
271 successful_transport: Some(target.protocol),
272 total_duration_ms: chain_start.elapsed().as_millis() as u64,
273 };
274
275 return Ok(SmartFallbackResult {
276 response: response_bytes,
277 transport_used: target.protocol,
278 status,
279 history,
280 });
281 }
282 Err(error) => {
283 self.health_tracker
284 .record_failure(&target.upstream_id, target.protocol);
285
286 metrics::counter!(
287 "vellaveto_transport_fallback_total",
288 "transport" => format!("{:?}", target.protocol),
289 "upstream_id" => target.upstream_id.clone(),
290 "result" => "failure",
291 )
292 .increment(1);
293
294 attempts.push(TransportAttempt {
295 protocol: target.protocol,
296 endpoint_url: target.url.clone(),
297 succeeded: false,
298 duration_ms,
299 error: Some(error),
300 });
301 }
302 }
303 }
304
305 let history = FallbackNegotiationHistory {
306 attempts,
307 successful_transport: None,
308 total_duration_ms: chain_start.elapsed().as_millis() as u64,
309 };
310 Err(SmartFallbackError::AllTransportsFailed { history })
311 }
312
313 async fn dispatch_http(
319 &self,
320 url: &str,
321 body: bytes::Bytes,
322 headers: &reqwest::header::HeaderMap,
323 timeout: Duration,
324 ) -> Result<(bytes::Bytes, u16), String> {
325 super::validate_upstream_url_scheme(url)?;
328
329 let mut request = self.client.post(url).timeout(timeout);
330
331 for (key, value) in headers {
334 let key_lower = key.as_str().to_lowercase();
335 if FORWARDED_HEADERS
336 .iter()
337 .any(|&allowed| allowed == key_lower)
338 {
339 request = request.header(key.clone(), value.clone());
340 }
341 }
342
343 let mut resp = request
344 .body(body)
345 .send()
346 .await
347 .map_err(|e| format!("HTTP request error: {e}"))?;
348
349 let status = resp.status().as_u16();
350
351 if let Some(len) = resp.content_length() {
354 if len > MAX_RESPONSE_BODY_BYTES as u64 {
355 return Err(format!(
356 "response body too large: {len} bytes (max {MAX_RESPONSE_BODY_BYTES})"
357 ));
358 }
359 }
360
361 let capacity = std::cmp::min(
364 resp.content_length().unwrap_or(8192) as usize,
365 MAX_RESPONSE_BODY_BYTES,
366 );
367 let mut response_body = Vec::with_capacity(capacity);
368 while let Some(chunk) = resp
369 .chunk()
370 .await
371 .map_err(|e| format!("HTTP response body error: {e}"))?
372 {
373 if response_body.len().saturating_add(chunk.len()) > MAX_RESPONSE_BODY_BYTES {
374 return Err(format!(
375 "response body too large: >{MAX_RESPONSE_BODY_BYTES} bytes (max {MAX_RESPONSE_BODY_BYTES})"
376 ));
377 }
378 response_body.extend_from_slice(&chunk);
379 }
380
381 Ok((bytes::Bytes::from(response_body), status))
382 }
383
384 async fn dispatch_websocket(
389 &self,
390 url: &str,
391 body: bytes::Bytes,
392 timeout: Duration,
393 ) -> Result<(bytes::Bytes, u16), String> {
394 super::validate_upstream_url_scheme(url)?;
396
397 use tokio_tungstenite::tungstenite::Message;
398
399 let result = tokio::time::timeout(timeout, async {
400 let mut ws_config =
403 tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default();
404 ws_config.max_message_size = Some(MAX_RESPONSE_BODY_BYTES);
405 ws_config.max_frame_size = Some(MAX_RESPONSE_BODY_BYTES);
406 let (mut ws, _) =
407 tokio_tungstenite::connect_async_with_config(url, Some(ws_config), false)
408 .await
409 .map_err(|e| format!("WebSocket connect error: {e}"))?;
410
411 use futures_util::SinkExt;
412 let body_text = String::from_utf8(body.to_vec())
415 .map_err(|_| "WebSocket body contains invalid UTF-8".to_string())?;
416 ws.send(Message::Text(body_text.into()))
417 .await
418 .map_err(|e| format!("WebSocket send error: {e}"))?;
419
420 use futures_util::StreamExt;
421 let response = ws
422 .next()
423 .await
424 .ok_or_else(|| "WebSocket closed without response".to_string())?
425 .map_err(|e| format!("WebSocket receive error: {e}"))?;
426
427 ws.close(None)
428 .await
429 .map_err(|e| format!("WebSocket close error: {e}"))?;
430
431 let response_bytes = match response {
432 Message::Text(t) => {
433 let bytes = t.as_bytes();
434 if bytes.len() > MAX_RESPONSE_BODY_BYTES {
435 return Err(format!(
436 "WebSocket response too large: {} bytes (max {})",
437 bytes.len(),
438 MAX_RESPONSE_BODY_BYTES
439 ));
440 }
441 bytes::Bytes::from(Vec::from(bytes))
442 }
443 Message::Binary(b) => {
444 if b.len() > MAX_RESPONSE_BODY_BYTES {
445 return Err(format!(
446 "WebSocket response too large: {} bytes (max {})",
447 b.len(),
448 MAX_RESPONSE_BODY_BYTES
449 ));
450 }
451 bytes::Bytes::from(Vec::from(b.as_ref()))
452 }
453 other => {
454 return Err(format!("unexpected WebSocket message type: {other:?}"));
455 }
456 };
457
458 Ok::<_, String>((response_bytes, 200u16))
459 })
460 .await;
461
462 match result {
463 Ok(inner) => inner,
464 Err(_) => Err("WebSocket timeout".to_string()),
465 }
466 }
467
468 async fn dispatch_stdio(
480 &self,
481 _command: &str,
482 resolved_command: &Path,
483 body: bytes::Bytes,
484 timeout: Duration,
485 ) -> Result<(bytes::Bytes, u16), String> {
486 use tokio::io::{AsyncReadExt, AsyncWriteExt};
487 use tokio::process::Command;
488
489 let mut child = Command::new(resolved_command)
495 .env_clear()
496 .env("PATH", "/usr/local/bin:/usr/bin:/bin")
497 .env("HOME", "/tmp")
498 .env("LANG", "C.UTF-8")
499 .kill_on_drop(true)
500 .stdin(std::process::Stdio::piped())
501 .stdout(std::process::Stdio::piped())
502 .stderr(std::process::Stdio::piped())
503 .spawn()
504 .map_err(|e| format!("stdio spawn error: {e}"))?;
505
506 if let Some(mut stdin) = child.stdin.take() {
508 stdin
509 .write_all(&body)
510 .await
511 .map_err(|e| format!("stdio write error: {e}"))?;
512 }
513
514 let stdout_handle = child
516 .stdout
517 .take()
518 .ok_or_else(|| "stdio stdout not captured".to_string())?;
519 let mut stderr_handle = child
520 .stderr
521 .take()
522 .ok_or_else(|| "stdio stderr not captured".to_string())?;
523
524 tokio::select! {
526 result = async {
531 let stdout_future = async {
532 let mut stdout_buf = Vec::new();
533 let read_result = stdout_handle
534 .take((MAX_RESPONSE_BODY_BYTES as u64) + 1)
535 .read_to_end(&mut stdout_buf)
536 .await
537 .map_err(|e| format!("stdio stdout read error: {e}"));
538 read_result.map(|_| stdout_buf)
539 };
540
541 let wait_future = child.wait();
542
543 let (stdout_result, status_result) = tokio::join!(stdout_future, wait_future);
544
545 let stdout_buf = stdout_result?;
546 let status = status_result
547 .map_err(|e| format!("stdio wait error: {e}"))?;
548
549 if stdout_buf.len() > MAX_RESPONSE_BODY_BYTES {
550 return Err(format!(
551 "stdio stdout too large: >{MAX_RESPONSE_BODY_BYTES} bytes (max {MAX_RESPONSE_BODY_BYTES})"
552 ));
553 }
554
555 if !status.success() {
556 let mut stderr_buf = vec![0u8; MAX_STDERR_BYTES];
558 let n = stderr_handle
559 .read(&mut stderr_buf)
560 .await
561 .unwrap_or(0);
562 let stderr_snippet = String::from_utf8_lossy(&stderr_buf[..n]);
563 return Err(format!(
564 "stdio process exited with {:?}: {}",
565 status,
566 stderr_snippet.trim()
567 ));
568 }
569
570 Ok((bytes::Bytes::from(stdout_buf), 200u16))
571 } => {
572 result
573 }
574 _ = tokio::time::sleep(timeout) => {
575 let _ = child.kill().await;
577 Err("stdio timeout".to_string())
578 }
579 }
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 fn make_tracker() -> TransportHealthTracker {
588 TransportHealthTracker::new(2, 1, 300)
589 }
590
591 fn make_targets(protos: &[(TransportProtocol, &str)]) -> Vec<TransportTarget> {
592 protos
593 .iter()
594 .map(|(p, url)| TransportTarget {
595 protocol: *p,
596 url: url.to_string(),
597 upstream_id: "test".to_string(),
598 })
599 .collect()
600 }
601
602 #[tokio::test]
603 async fn test_smart_fallback_no_targets() {
604 let client = reqwest::Client::new();
605 let tracker = make_tracker();
606 let chain = SmartFallbackChain::new(
607 &client,
608 &tracker,
609 Duration::from_secs(5),
610 Duration::from_secs(10),
611 );
612
613 let result = chain
614 .execute(&[], bytes::Bytes::new(), &reqwest::header::HeaderMap::new())
615 .await;
616
617 assert!(matches!(result, Err(SmartFallbackError::NoTargets)));
618 }
619
620 #[tokio::test]
621 async fn test_smart_fallback_all_circuits_open() {
622 let tracker = TransportHealthTracker::new(1, 1, 300);
623 let client = reqwest::Client::new();
624
625 tracker.record_failure("test", TransportProtocol::Http);
627 tracker.record_failure("test", TransportProtocol::WebSocket);
628
629 let targets = make_targets(&[
630 (TransportProtocol::Http, "http://localhost:1/mcp"),
631 (TransportProtocol::WebSocket, "ws://localhost:2/mcp"),
632 ]);
633
634 let chain = SmartFallbackChain::new(
635 &client,
636 &tracker,
637 Duration::from_secs(5),
638 Duration::from_secs(10),
639 );
640
641 let result = chain
642 .execute(
643 &targets,
644 bytes::Bytes::new(),
645 &reqwest::header::HeaderMap::new(),
646 )
647 .await;
648
649 match result {
650 Err(SmartFallbackError::AllTransportsFailed { history }) => {
651 assert_eq!(history.attempts.len(), 2);
652 assert!(history.successful_transport.is_none());
653 for attempt in &history.attempts {
654 assert!(!attempt.succeeded);
655 assert!(attempt.error.as_ref().unwrap().contains("circuit open"));
656 }
657 }
658 other => panic!("expected AllTransportsFailed, got {other:?}"),
659 }
660 }
661
662 #[tokio::test]
663 async fn test_smart_fallback_skips_stdio_when_disabled() {
664 let tracker = make_tracker();
665 let client = reqwest::Client::new();
666
667 let targets = make_targets(&[(TransportProtocol::Stdio, "stdio://local")]);
668
669 let chain = SmartFallbackChain::new(
670 &client,
671 &tracker,
672 Duration::from_secs(1),
673 Duration::from_secs(5),
674 );
675
676 let result = chain
677 .execute(
678 &targets,
679 bytes::Bytes::new(),
680 &reqwest::header::HeaderMap::new(),
681 )
682 .await;
683
684 assert!(matches!(
686 result,
687 Err(SmartFallbackError::AllTransportsFailed { .. })
688 ));
689 }
690
691 #[tokio::test]
692 async fn test_smart_fallback_first_fails_second_succeeds_via_circuit() {
693 let tracker = TransportHealthTracker::new(1, 1, 300);
695 let client = reqwest::Client::new();
696
697 tracker.record_failure("test", TransportProtocol::Grpc);
698
699 let targets = make_targets(&[
700 (TransportProtocol::Grpc, "http://localhost:1/grpc"),
701 (TransportProtocol::Http, "http://localhost:1/does-not-exist"),
702 ]);
703
704 let chain = SmartFallbackChain::new(
705 &client,
706 &tracker,
707 Duration::from_millis(500),
708 Duration::from_secs(5),
709 );
710
711 let result = chain
712 .execute(
713 &targets,
714 bytes::Bytes::new(),
715 &reqwest::header::HeaderMap::new(),
716 )
717 .await;
718
719 match result {
722 Err(SmartFallbackError::AllTransportsFailed { history }) => {
723 assert_eq!(history.attempts.len(), 2);
724 assert!(!history.attempts[0].succeeded);
726 assert!(history.attempts[0]
727 .error
728 .as_ref()
729 .unwrap()
730 .contains("circuit open"));
731 assert!(!history.attempts[1].succeeded);
733 }
734 other => panic!("expected AllTransportsFailed, got {other:?}"),
735 }
736 }
737
738 #[tokio::test]
739 async fn test_smart_fallback_total_timeout_zero() {
740 let tracker = make_tracker();
741 let client = reqwest::Client::new();
742
743 let targets = make_targets(&[(TransportProtocol::Http, "http://localhost:1/mcp")]);
744
745 let chain = SmartFallbackChain::new(
746 &client,
747 &tracker,
748 Duration::from_secs(5),
749 Duration::ZERO, );
751
752 let result = chain
753 .execute(
754 &targets,
755 bytes::Bytes::new(),
756 &reqwest::header::HeaderMap::new(),
757 )
758 .await;
759
760 assert!(matches!(
761 result,
762 Err(SmartFallbackError::TotalTimeoutExceeded { .. })
763 ));
764 }
765
766 #[test]
767 fn test_smart_fallback_error_display() {
768 let err = SmartFallbackError::NoTargets;
769 assert_eq!(format!("{err}"), "no transport targets provided");
770
771 let err = SmartFallbackError::AllTransportsFailed {
772 history: FallbackNegotiationHistory {
773 attempts: vec![TransportAttempt {
774 protocol: TransportProtocol::Http,
775 endpoint_url: "http://localhost".to_string(),
776 succeeded: false,
777 duration_ms: 100,
778 error: Some("timeout".to_string()),
779 }],
780 successful_transport: None,
781 total_duration_ms: 100,
782 },
783 };
784 assert!(format!("{err}").contains("1 transport(s) failed"));
785 }
786
787 #[test]
788 fn test_transport_target_construction() {
789 let target = TransportTarget {
790 protocol: TransportProtocol::Grpc,
791 url: "http://localhost:50051".to_string(),
792 upstream_id: "backend-1".to_string(),
793 };
794 assert_eq!(target.protocol, TransportProtocol::Grpc);
795 assert_eq!(target.url, "http://localhost:50051");
796 }
797
798 #[tokio::test]
799 async fn test_smart_fallback_history_records_all_attempts() {
800 let tracker = TransportHealthTracker::new(1, 1, 300);
801 let client = reqwest::Client::new();
802
803 tracker.record_failure("test", TransportProtocol::Grpc);
805
806 let targets = make_targets(&[
807 (TransportProtocol::Grpc, "http://localhost:1/grpc"),
808 (TransportProtocol::WebSocket, "ws://localhost:2/ws"),
809 (TransportProtocol::Http, "http://localhost:3/http"),
810 ]);
811
812 let chain = SmartFallbackChain::new(
813 &client,
814 &tracker,
815 Duration::from_millis(200),
816 Duration::from_secs(5),
817 );
818
819 let result = chain
820 .execute(
821 &targets,
822 bytes::Bytes::new(),
823 &reqwest::header::HeaderMap::new(),
824 )
825 .await;
826
827 match result {
828 Err(SmartFallbackError::AllTransportsFailed { history }) => {
829 assert_eq!(history.attempts.len(), 3);
831 assert!(history.successful_transport.is_none());
832 assert!(history.total_duration_ms > 0 || history.attempts[0].duration_ms == 0);
833 }
834 other => panic!("expected AllTransportsFailed, got {other:?}"),
835 }
836 }
837
838 #[tokio::test]
839 async fn test_smart_fallback_with_stdio_enabled() {
840 let tracker = make_tracker();
841 let client = reqwest::Client::new();
842
843 let targets = vec![TransportTarget {
845 protocol: TransportProtocol::Stdio,
846 url: "stdio://local".to_string(),
847 upstream_id: "test".to_string(),
848 }];
849
850 let chain = SmartFallbackChain::new(
851 &client,
852 &tracker,
853 Duration::from_secs(5),
854 Duration::from_secs(10),
855 )
856 .with_stdio("/bin/cat".to_string())
857 .expect("stdio fallback command should resolve");
858
859 let body = bytes::Bytes::from(r#"{"test": true}"#);
860 let result = chain
861 .execute(&targets, body.clone(), &reqwest::header::HeaderMap::new())
862 .await;
863
864 match result {
865 Ok(res) => {
866 assert_eq!(res.transport_used, TransportProtocol::Stdio);
867 assert_eq!(res.status, 200);
868 assert_eq!(res.response, body);
869 assert_eq!(res.history.attempts.len(), 1);
870 assert!(res.history.attempts[0].succeeded);
871 }
872 Err(e) => panic!("expected success, got: {e}"),
873 }
874 }
875}