1use std::path::{Path, PathBuf};
17use std::time::{Duration, Instant};
18use vellaveto_types::command::resolve_executable;
19use vellaveto_types::{FallbackNegotiationHistory, TransportAttempt, TransportProtocol};
20
21use super::transport_health::TransportHealthTracker;
22use super::{FORWARDED_HEADERS, MAX_RESPONSE_BODY_BYTES};
23
24const MAX_STDERR_BYTES: usize = 4096;
26
27#[derive(Debug, Clone)]
29pub struct TransportTarget {
30 pub protocol: TransportProtocol,
32 pub url: String,
34 pub upstream_id: String,
36}
37
38#[derive(Debug)]
40pub struct SmartFallbackResult {
41 pub response: bytes::Bytes,
43 pub transport_used: TransportProtocol,
45 pub status: u16,
47 pub history: FallbackNegotiationHistory,
49}
50
51#[derive(Debug)]
53pub enum SmartFallbackError {
54 AllTransportsFailed { history: FallbackNegotiationHistory },
56 TotalTimeoutExceeded { history: FallbackNegotiationHistory },
58 NoTargets,
60}
61
62impl std::fmt::Display for SmartFallbackError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 Self::AllTransportsFailed { history } => {
66 write!(f, "all {} transport(s) failed", history.attempts.len())
67 }
68 Self::TotalTimeoutExceeded { history } => {
69 write!(
70 f,
71 "total timeout exceeded after {} attempt(s)",
72 history.attempts.len()
73 )
74 }
75 Self::NoTargets => write!(f, "no transport targets provided"),
76 }
77 }
78}
79
80impl std::error::Error for SmartFallbackError {}
81
82pub struct SmartFallbackChain<'a> {
87 client: &'a reqwest::Client,
88 health_tracker: &'a TransportHealthTracker,
89 per_attempt_timeout: Duration,
90 total_timeout: Duration,
91 stdio_enabled: bool,
92 stdio_command: Option<String>,
93 stdio_command_resolved: Option<PathBuf>,
94}
95
96impl<'a> SmartFallbackChain<'a> {
97 pub fn new(
99 client: &'a reqwest::Client,
100 health_tracker: &'a TransportHealthTracker,
101 per_attempt_timeout: Duration,
102 total_timeout: Duration,
103 ) -> Self {
104 Self {
105 client,
106 health_tracker,
107 per_attempt_timeout,
108 total_timeout,
109 stdio_enabled: false,
110 stdio_command: None,
111 stdio_command_resolved: None,
112 }
113 }
114
115 pub fn with_stdio(&self, command: String) -> Result<Self, String> {
117 let resolved = resolve_executable(&command, std::env::var_os("PATH").as_deref())?;
118 Ok(Self {
119 client: self.client,
120 health_tracker: self.health_tracker,
121 per_attempt_timeout: self.per_attempt_timeout,
122 total_timeout: self.total_timeout,
123 stdio_enabled: true,
124 stdio_command: Some(command),
125 stdio_command_resolved: Some(resolved),
126 })
127 }
128
129 pub async fn execute(
131 &self,
132 targets: &[TransportTarget],
133 body: bytes::Bytes,
134 headers: &reqwest::header::HeaderMap,
135 ) -> Result<SmartFallbackResult, SmartFallbackError> {
136 if targets.is_empty() {
137 return Err(SmartFallbackError::NoTargets);
138 }
139
140 let chain_start = Instant::now();
141 let mut attempts = Vec::new();
142
143 for target in targets {
144 if chain_start.elapsed() >= self.total_timeout {
146 let history = FallbackNegotiationHistory {
147 attempts,
148 successful_transport: None,
149 total_duration_ms: chain_start.elapsed().as_millis() as u64,
150 };
151 return Err(SmartFallbackError::TotalTimeoutExceeded { history });
152 }
153
154 if target.protocol == TransportProtocol::Stdio && !self.stdio_enabled {
156 continue;
157 }
158
159 if let Err(reason) = self
161 .health_tracker
162 .can_use(&target.upstream_id, target.protocol)
163 {
164 attempts.push(TransportAttempt {
165 protocol: target.protocol,
166 endpoint_url: target.url.clone(),
167 succeeded: false,
168 duration_ms: 0,
169 error: Some(reason),
170 });
171 continue;
172 }
173
174 let remaining = self
176 .total_timeout
177 .checked_sub(chain_start.elapsed())
178 .unwrap_or(Duration::ZERO);
179 let attempt_timeout = self.per_attempt_timeout.min(remaining);
180
181 if attempt_timeout.is_zero() {
182 let history = FallbackNegotiationHistory {
183 attempts,
184 successful_transport: None,
185 total_duration_ms: chain_start.elapsed().as_millis() as u64,
186 };
187 return Err(SmartFallbackError::TotalTimeoutExceeded { history });
188 }
189
190 let attempt_start = Instant::now();
191
192 let result = match target.protocol {
193 TransportProtocol::Http => {
194 self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
195 .await
196 }
197 TransportProtocol::WebSocket => {
198 self.dispatch_websocket(&target.url, body.clone(), attempt_timeout)
199 .await
200 }
201 TransportProtocol::Grpc => {
202 self.dispatch_http(&target.url, body.clone(), headers, attempt_timeout)
204 .await
205 }
206 TransportProtocol::Stdio => {
207 if let (Some(cmd), Some(resolved)) = (
208 self.stdio_command.as_ref(),
209 self.stdio_command_resolved.as_ref(),
210 ) {
211 self.dispatch_stdio(cmd, resolved, body.clone(), attempt_timeout)
212 .await
213 } else {
214 Err("stdio command not configured".to_string())
215 }
216 }
217 };
218
219 let duration_ms = attempt_start.elapsed().as_millis() as u64;
220
221 match result {
222 Ok((response_bytes, status)) => {
223 if status >= 500 {
228 self.health_tracker
229 .record_failure(&target.upstream_id, target.protocol);
230
231 metrics::counter!(
232 "vellaveto_transport_fallback_total",
233 "transport" => format!("{:?}", target.protocol),
234 "upstream_id" => target.upstream_id.clone(),
235 "result" => "server_error",
236 )
237 .increment(1);
238
239 attempts.push(TransportAttempt {
240 protocol: target.protocol,
241 endpoint_url: target.url.clone(),
242 succeeded: false,
243 duration_ms,
244 error: Some(format!("server error: HTTP {}", status)),
245 });
246
247 continue;
248 }
249
250 self.health_tracker
251 .record_success(&target.upstream_id, target.protocol);
252
253 metrics::counter!(
254 "vellaveto_transport_fallback_total",
255 "transport" => format!("{:?}", target.protocol),
256 "upstream_id" => target.upstream_id.clone(),
257 "result" => "success",
258 )
259 .increment(1);
260
261 attempts.push(TransportAttempt {
262 protocol: target.protocol,
263 endpoint_url: target.url.clone(),
264 succeeded: true,
265 duration_ms,
266 error: None,
267 });
268
269 let history = FallbackNegotiationHistory {
270 attempts,
271 successful_transport: Some(target.protocol),
272 total_duration_ms: chain_start.elapsed().as_millis() as u64,
273 };
274
275 return Ok(SmartFallbackResult {
276 response: response_bytes,
277 transport_used: target.protocol,
278 status,
279 history,
280 });
281 }
282 Err(error) => {
283 self.health_tracker
284 .record_failure(&target.upstream_id, target.protocol);
285
286 metrics::counter!(
287 "vellaveto_transport_fallback_total",
288 "transport" => format!("{:?}", target.protocol),
289 "upstream_id" => target.upstream_id.clone(),
290 "result" => "failure",
291 )
292 .increment(1);
293
294 attempts.push(TransportAttempt {
295 protocol: target.protocol,
296 endpoint_url: target.url.clone(),
297 succeeded: false,
298 duration_ms,
299 error: Some(error),
300 });
301 }
302 }
303 }
304
305 let history = FallbackNegotiationHistory {
306 attempts,
307 successful_transport: None,
308 total_duration_ms: chain_start.elapsed().as_millis() as u64,
309 };
310 Err(SmartFallbackError::AllTransportsFailed { history })
311 }
312
313 async fn dispatch_http(
319 &self,
320 url: &str,
321 body: bytes::Bytes,
322 headers: &reqwest::header::HeaderMap,
323 timeout: Duration,
324 ) -> Result<(bytes::Bytes, u16), String> {
325 let mut request = self.client.post(url).timeout(timeout);
326
327 for (key, value) in headers {
330 let key_lower = key.as_str().to_lowercase();
331 if FORWARDED_HEADERS
332 .iter()
333 .any(|&allowed| allowed == key_lower)
334 {
335 request = request.header(key.clone(), value.clone());
336 }
337 }
338
339 let mut resp = request
340 .body(body)
341 .send()
342 .await
343 .map_err(|e| format!("HTTP request error: {}", e))?;
344
345 let status = resp.status().as_u16();
346
347 if let Some(len) = resp.content_length() {
349 if len as usize > MAX_RESPONSE_BODY_BYTES {
350 return Err(format!(
351 "response body too large: {} bytes (max {})",
352 len, MAX_RESPONSE_BODY_BYTES
353 ));
354 }
355 }
356
357 let capacity = std::cmp::min(
360 resp.content_length().unwrap_or(8192) as usize,
361 MAX_RESPONSE_BODY_BYTES,
362 );
363 let mut response_body = Vec::with_capacity(capacity);
364 while let Some(chunk) = resp
365 .chunk()
366 .await
367 .map_err(|e| format!("HTTP response body error: {}", e))?
368 {
369 if response_body.len().saturating_add(chunk.len()) > MAX_RESPONSE_BODY_BYTES {
370 return Err(format!(
371 "response body too large: >{} bytes (max {})",
372 MAX_RESPONSE_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
373 ));
374 }
375 response_body.extend_from_slice(&chunk);
376 }
377
378 Ok((bytes::Bytes::from(response_body), status))
379 }
380
381 async fn dispatch_websocket(
386 &self,
387 url: &str,
388 body: bytes::Bytes,
389 timeout: Duration,
390 ) -> Result<(bytes::Bytes, u16), String> {
391 use tokio_tungstenite::tungstenite::Message;
392
393 let result = tokio::time::timeout(timeout, async {
394 let mut ws_config =
397 tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default();
398 ws_config.max_message_size = Some(MAX_RESPONSE_BODY_BYTES);
399 ws_config.max_frame_size = Some(MAX_RESPONSE_BODY_BYTES);
400 let (mut ws, _) =
401 tokio_tungstenite::connect_async_with_config(url, Some(ws_config), false)
402 .await
403 .map_err(|e| format!("WebSocket connect error: {}", e))?;
404
405 use futures_util::SinkExt;
406 let body_text = String::from_utf8(body.to_vec())
409 .map_err(|_| "WebSocket body contains invalid UTF-8".to_string())?;
410 ws.send(Message::Text(body_text.into()))
411 .await
412 .map_err(|e| format!("WebSocket send error: {}", e))?;
413
414 use futures_util::StreamExt;
415 let response = ws
416 .next()
417 .await
418 .ok_or_else(|| "WebSocket closed without response".to_string())?
419 .map_err(|e| format!("WebSocket receive error: {}", e))?;
420
421 ws.close(None)
422 .await
423 .map_err(|e| format!("WebSocket close error: {}", e))?;
424
425 let response_bytes = match response {
426 Message::Text(t) => {
427 let bytes = t.as_bytes();
428 if bytes.len() > MAX_RESPONSE_BODY_BYTES {
429 return Err(format!(
430 "WebSocket response too large: {} bytes (max {})",
431 bytes.len(),
432 MAX_RESPONSE_BODY_BYTES
433 ));
434 }
435 bytes::Bytes::from(Vec::from(bytes))
436 }
437 Message::Binary(b) => {
438 if b.len() > MAX_RESPONSE_BODY_BYTES {
439 return Err(format!(
440 "WebSocket response too large: {} bytes (max {})",
441 b.len(),
442 MAX_RESPONSE_BODY_BYTES
443 ));
444 }
445 bytes::Bytes::from(Vec::from(b.as_ref()))
446 }
447 other => {
448 return Err(format!("unexpected WebSocket message type: {:?}", other));
449 }
450 };
451
452 Ok::<_, String>((response_bytes, 200u16))
453 })
454 .await;
455
456 match result {
457 Ok(inner) => inner,
458 Err(_) => Err("WebSocket timeout".to_string()),
459 }
460 }
461
462 async fn dispatch_stdio(
474 &self,
475 _command: &str,
476 resolved_command: &Path,
477 body: bytes::Bytes,
478 timeout: Duration,
479 ) -> Result<(bytes::Bytes, u16), String> {
480 use tokio::io::{AsyncReadExt, AsyncWriteExt};
481 use tokio::process::Command;
482
483 let mut child = Command::new(resolved_command)
489 .env_clear()
490 .env("PATH", "/usr/local/bin:/usr/bin:/bin")
491 .env("HOME", "/tmp")
492 .env("LANG", "C.UTF-8")
493 .kill_on_drop(true)
494 .stdin(std::process::Stdio::piped())
495 .stdout(std::process::Stdio::piped())
496 .stderr(std::process::Stdio::piped())
497 .spawn()
498 .map_err(|e| format!("stdio spawn error: {}", e))?;
499
500 if let Some(mut stdin) = child.stdin.take() {
502 stdin
503 .write_all(&body)
504 .await
505 .map_err(|e| format!("stdio write error: {}", e))?;
506 }
507
508 let stdout_handle = child
510 .stdout
511 .take()
512 .ok_or_else(|| "stdio stdout not captured".to_string())?;
513 let mut stderr_handle = child
514 .stderr
515 .take()
516 .ok_or_else(|| "stdio stderr not captured".to_string())?;
517
518 tokio::select! {
520 result = async {
525 let stdout_future = async {
526 let mut stdout_buf = Vec::new();
527 let read_result = stdout_handle
528 .take((MAX_RESPONSE_BODY_BYTES as u64) + 1)
529 .read_to_end(&mut stdout_buf)
530 .await
531 .map_err(|e| format!("stdio stdout read error: {}", e));
532 read_result.map(|_| stdout_buf)
533 };
534
535 let wait_future = child.wait();
536
537 let (stdout_result, status_result) = tokio::join!(stdout_future, wait_future);
538
539 let stdout_buf = stdout_result?;
540 let status = status_result
541 .map_err(|e| format!("stdio wait error: {}", e))?;
542
543 if stdout_buf.len() > MAX_RESPONSE_BODY_BYTES {
544 return Err(format!(
545 "stdio stdout too large: >{} bytes (max {})",
546 MAX_RESPONSE_BODY_BYTES, MAX_RESPONSE_BODY_BYTES
547 ));
548 }
549
550 if !status.success() {
551 let mut stderr_buf = vec![0u8; MAX_STDERR_BYTES];
553 let n = stderr_handle
554 .read(&mut stderr_buf)
555 .await
556 .unwrap_or(0);
557 let stderr_snippet = String::from_utf8_lossy(&stderr_buf[..n]);
558 return Err(format!(
559 "stdio process exited with {:?}: {}",
560 status,
561 stderr_snippet.trim()
562 ));
563 }
564
565 Ok((bytes::Bytes::from(stdout_buf), 200u16))
566 } => {
567 result
568 }
569 _ = tokio::time::sleep(timeout) => {
570 let _ = child.kill().await;
572 Err("stdio timeout".to_string())
573 }
574 }
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581
582 fn make_tracker() -> TransportHealthTracker {
583 TransportHealthTracker::new(2, 1, 300)
584 }
585
586 fn make_targets(protos: &[(TransportProtocol, &str)]) -> Vec<TransportTarget> {
587 protos
588 .iter()
589 .map(|(p, url)| TransportTarget {
590 protocol: *p,
591 url: url.to_string(),
592 upstream_id: "test".to_string(),
593 })
594 .collect()
595 }
596
597 #[tokio::test]
598 async fn test_smart_fallback_no_targets() {
599 let client = reqwest::Client::new();
600 let tracker = make_tracker();
601 let chain = SmartFallbackChain::new(
602 &client,
603 &tracker,
604 Duration::from_secs(5),
605 Duration::from_secs(10),
606 );
607
608 let result = chain
609 .execute(&[], bytes::Bytes::new(), &reqwest::header::HeaderMap::new())
610 .await;
611
612 assert!(matches!(result, Err(SmartFallbackError::NoTargets)));
613 }
614
615 #[tokio::test]
616 async fn test_smart_fallback_all_circuits_open() {
617 let tracker = TransportHealthTracker::new(1, 1, 300);
618 let client = reqwest::Client::new();
619
620 tracker.record_failure("test", TransportProtocol::Http);
622 tracker.record_failure("test", TransportProtocol::WebSocket);
623
624 let targets = make_targets(&[
625 (TransportProtocol::Http, "http://localhost:1/mcp"),
626 (TransportProtocol::WebSocket, "ws://localhost:2/mcp"),
627 ]);
628
629 let chain = SmartFallbackChain::new(
630 &client,
631 &tracker,
632 Duration::from_secs(5),
633 Duration::from_secs(10),
634 );
635
636 let result = chain
637 .execute(
638 &targets,
639 bytes::Bytes::new(),
640 &reqwest::header::HeaderMap::new(),
641 )
642 .await;
643
644 match result {
645 Err(SmartFallbackError::AllTransportsFailed { history }) => {
646 assert_eq!(history.attempts.len(), 2);
647 assert!(history.successful_transport.is_none());
648 for attempt in &history.attempts {
649 assert!(!attempt.succeeded);
650 assert!(attempt.error.as_ref().unwrap().contains("circuit open"));
651 }
652 }
653 other => panic!("expected AllTransportsFailed, got {:?}", other),
654 }
655 }
656
657 #[tokio::test]
658 async fn test_smart_fallback_skips_stdio_when_disabled() {
659 let tracker = make_tracker();
660 let client = reqwest::Client::new();
661
662 let targets = make_targets(&[(TransportProtocol::Stdio, "stdio://local")]);
663
664 let chain = SmartFallbackChain::new(
665 &client,
666 &tracker,
667 Duration::from_secs(1),
668 Duration::from_secs(5),
669 );
670
671 let result = chain
672 .execute(
673 &targets,
674 bytes::Bytes::new(),
675 &reqwest::header::HeaderMap::new(),
676 )
677 .await;
678
679 assert!(matches!(
681 result,
682 Err(SmartFallbackError::AllTransportsFailed { .. })
683 ));
684 }
685
686 #[tokio::test]
687 async fn test_smart_fallback_first_fails_second_succeeds_via_circuit() {
688 let tracker = TransportHealthTracker::new(1, 1, 300);
690 let client = reqwest::Client::new();
691
692 tracker.record_failure("test", TransportProtocol::Grpc);
693
694 let targets = make_targets(&[
695 (TransportProtocol::Grpc, "http://localhost:1/grpc"),
696 (TransportProtocol::Http, "http://localhost:1/does-not-exist"),
697 ]);
698
699 let chain = SmartFallbackChain::new(
700 &client,
701 &tracker,
702 Duration::from_millis(500),
703 Duration::from_secs(5),
704 );
705
706 let result = chain
707 .execute(
708 &targets,
709 bytes::Bytes::new(),
710 &reqwest::header::HeaderMap::new(),
711 )
712 .await;
713
714 match result {
717 Err(SmartFallbackError::AllTransportsFailed { history }) => {
718 assert_eq!(history.attempts.len(), 2);
719 assert!(!history.attempts[0].succeeded);
721 assert!(history.attempts[0]
722 .error
723 .as_ref()
724 .unwrap()
725 .contains("circuit open"));
726 assert!(!history.attempts[1].succeeded);
728 }
729 other => panic!("expected AllTransportsFailed, got {:?}", other),
730 }
731 }
732
733 #[tokio::test]
734 async fn test_smart_fallback_total_timeout_zero() {
735 let tracker = make_tracker();
736 let client = reqwest::Client::new();
737
738 let targets = make_targets(&[(TransportProtocol::Http, "http://localhost:1/mcp")]);
739
740 let chain = SmartFallbackChain::new(
741 &client,
742 &tracker,
743 Duration::from_secs(5),
744 Duration::ZERO, );
746
747 let result = chain
748 .execute(
749 &targets,
750 bytes::Bytes::new(),
751 &reqwest::header::HeaderMap::new(),
752 )
753 .await;
754
755 assert!(matches!(
756 result,
757 Err(SmartFallbackError::TotalTimeoutExceeded { .. })
758 ));
759 }
760
761 #[test]
762 fn test_smart_fallback_error_display() {
763 let err = SmartFallbackError::NoTargets;
764 assert_eq!(format!("{}", err), "no transport targets provided");
765
766 let err = SmartFallbackError::AllTransportsFailed {
767 history: FallbackNegotiationHistory {
768 attempts: vec![TransportAttempt {
769 protocol: TransportProtocol::Http,
770 endpoint_url: "http://localhost".to_string(),
771 succeeded: false,
772 duration_ms: 100,
773 error: Some("timeout".to_string()),
774 }],
775 successful_transport: None,
776 total_duration_ms: 100,
777 },
778 };
779 assert!(format!("{}", err).contains("1 transport(s) failed"));
780 }
781
782 #[test]
783 fn test_transport_target_construction() {
784 let target = TransportTarget {
785 protocol: TransportProtocol::Grpc,
786 url: "http://localhost:50051".to_string(),
787 upstream_id: "backend-1".to_string(),
788 };
789 assert_eq!(target.protocol, TransportProtocol::Grpc);
790 assert_eq!(target.url, "http://localhost:50051");
791 }
792
793 #[tokio::test]
794 async fn test_smart_fallback_history_records_all_attempts() {
795 let tracker = TransportHealthTracker::new(1, 1, 300);
796 let client = reqwest::Client::new();
797
798 tracker.record_failure("test", TransportProtocol::Grpc);
800
801 let targets = make_targets(&[
802 (TransportProtocol::Grpc, "http://localhost:1/grpc"),
803 (TransportProtocol::WebSocket, "ws://localhost:2/ws"),
804 (TransportProtocol::Http, "http://localhost:3/http"),
805 ]);
806
807 let chain = SmartFallbackChain::new(
808 &client,
809 &tracker,
810 Duration::from_millis(200),
811 Duration::from_secs(5),
812 );
813
814 let result = chain
815 .execute(
816 &targets,
817 bytes::Bytes::new(),
818 &reqwest::header::HeaderMap::new(),
819 )
820 .await;
821
822 match result {
823 Err(SmartFallbackError::AllTransportsFailed { history }) => {
824 assert_eq!(history.attempts.len(), 3);
826 assert!(history.successful_transport.is_none());
827 assert!(history.total_duration_ms > 0 || history.attempts[0].duration_ms == 0);
828 }
829 other => panic!("expected AllTransportsFailed, got {:?}", other),
830 }
831 }
832
833 #[tokio::test]
834 async fn test_smart_fallback_with_stdio_enabled() {
835 let tracker = make_tracker();
836 let client = reqwest::Client::new();
837
838 let targets = vec![TransportTarget {
840 protocol: TransportProtocol::Stdio,
841 url: "stdio://local".to_string(),
842 upstream_id: "test".to_string(),
843 }];
844
845 let chain = SmartFallbackChain::new(
846 &client,
847 &tracker,
848 Duration::from_secs(5),
849 Duration::from_secs(10),
850 )
851 .with_stdio("/bin/cat".to_string())
852 .expect("stdio fallback command should resolve");
853
854 let body = bytes::Bytes::from(r#"{"test": true}"#);
855 let result = chain
856 .execute(&targets, body.clone(), &reqwest::header::HeaderMap::new())
857 .await;
858
859 match result {
860 Ok(res) => {
861 assert_eq!(res.transport_used, TransportProtocol::Stdio);
862 assert_eq!(res.status, 200);
863 assert_eq!(res.response, body);
864 assert_eq!(res.history.attempts.len(), 1);
865 assert!(res.history.attempts[0].succeeded);
866 }
867 Err(e) => panic!("expected success, got: {}", e),
868 }
869 }
870}