Skip to main content

terraphim_router/
fallback.rs

1//! Fallback routing for when primary provider fails
2//!
3//! This module provides fallback logic to route to alternative providers
4//! when the primary choice fails (e.g., agent spawn failure, LLM API error).
5
6use crate::{Router, RoutingContext, RoutingDecision, RoutingError};
7use terraphim_types::capability::{Provider, ProviderType};
8use tracing::{info, info_span, warn, Instrument};
9
10/// Fallback strategy when primary provider fails
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum FallbackStrategy {
13    /// Try next best provider from routing
14    #[default]
15    NextBestProvider,
16    /// Fall back to LLM if agent fails
17    LlmFallback,
18    /// Retry same provider
19    Retry { max_attempts: u32 },
20    /// Fail immediately
21    FailFast,
22}
23
24/// Router with fallback capabilities
25#[derive(Debug)]
26pub struct FallbackRouter {
27    router: Router,
28    fallback_strategy: FallbackStrategy,
29    max_fallbacks: u32,
30}
31
32impl FallbackRouter {
33    /// Create a new fallback router
34    pub fn new(router: Router) -> Self {
35        Self {
36            router,
37            fallback_strategy: FallbackStrategy::default(),
38            max_fallbacks: 3,
39        }
40    }
41
42    /// Set fallback strategy
43    pub fn with_strategy(mut self, strategy: FallbackStrategy) -> Self {
44        self.fallback_strategy = strategy;
45        self
46    }
47
48    /// Set max fallback attempts
49    pub fn with_max_fallbacks(mut self, max: u32) -> Self {
50        self.max_fallbacks = max;
51        self
52    }
53
54    /// Route with fallback on failure
55    ///
56    /// The `execute` closure receives a cloned `Provider` to avoid
57    /// lifetime issues with async closures.
58    pub async fn route_with_fallback<F, Fut>(
59        &self,
60        prompt: &str,
61        context: &RoutingContext,
62        mut execute: F,
63    ) -> Result<RoutingDecision, RoutingError>
64    where
65        F: FnMut(Provider) -> Fut,
66        Fut: std::future::Future<Output = Result<(), String>>,
67    {
68        let mut attempts = 0;
69        let mut current_prompt = prompt.to_string();
70
71        let fallback_span = info_span!(
72            "router.route_with_fallback",
73            prompt_len = prompt.len(),
74            fallback_strategy = ?self.fallback_strategy,
75            max_fallbacks = self.max_fallbacks,
76            total_attempts = tracing::field::Empty,
77            final_provider = tracing::field::Empty,
78            outcome = tracing::field::Empty,
79        );
80
81        async {
82            loop {
83                let decision = self.router.route(&current_prompt, context)?;
84                let provider = decision.provider.clone();
85
86                let attempt_span = info_span!(
87                    "router.fallback_attempt",
88                    attempt_number = attempts + 1,
89                    provider_id = provider.id.as_str(),
90                    provider_type = ?provider.provider_type,
91                    outcome = tracing::field::Empty,
92                );
93
94                let execute_result = async {
95                    info!(
96                        attempt = attempts + 1,
97                        provider_id = provider.id.as_str(),
98                        provider_name = provider.name.as_str(),
99                        "Attempting provider execution"
100                    );
101
102                    match execute(provider.clone()).await {
103                        Ok(()) => {
104                            info!(
105                                provider_id = provider.id.as_str(),
106                                "Provider execution succeeded"
107                            );
108                            tracing::Span::current().record("outcome", "success");
109                            Ok(decision.clone())
110                        }
111                        Err(error) => {
112                            warn!(
113                                provider_id = provider.id.as_str(),
114                                error = error.as_str(),
115                                "Provider execution failed"
116                            );
117                            tracing::Span::current().record("outcome", "failed");
118                            Err(error)
119                        }
120                    }
121                }
122                .instrument(attempt_span)
123                .await;
124
125                match execute_result {
126                    Ok(decision) => {
127                        tracing::Span::current().record("total_attempts", attempts + 1);
128                        tracing::Span::current()
129                            .record("final_provider", decision.provider.id.as_str());
130                        tracing::Span::current().record("outcome", "success");
131                        return Ok(decision);
132                    }
133                    Err(_error) => {
134                        attempts += 1;
135                        if attempts >= self.max_fallbacks {
136                            tracing::Span::current().record("total_attempts", attempts);
137                            tracing::Span::current().record("outcome", "exhausted");
138                            return Err(RoutingError::NoProviderFound(vec![]));
139                        }
140
141                        match self.fallback_strategy {
142                            FallbackStrategy::FailFast => {
143                                tracing::Span::current().record("outcome", "fail_fast");
144                                return Err(RoutingError::NoProviderFound(vec![]));
145                            }
146                            FallbackStrategy::Retry { max_attempts } => {
147                                if attempts >= max_attempts {
148                                    continue;
149                                }
150                                info!(delay_ms = 1000, "Retrying same provider after delay");
151                                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
152                            }
153                            FallbackStrategy::NextBestProvider => {
154                                info!("Excluding failed provider, trying next best");
155                                current_prompt = format!("{} [exclude:{}]", prompt, provider.id);
156                            }
157                            FallbackStrategy::LlmFallback => {
158                                if matches!(provider.provider_type, ProviderType::Agent { .. }) {
159                                    info!("Agent failed, falling back to LLM preference");
160                                    current_prompt = format!("{} [prefer:llm]", prompt);
161                                }
162                            }
163                        }
164                    }
165                }
166            }
167        }
168        .instrument(fallback_span)
169        .await
170    }
171
172    /// Get inner router
173    pub fn router(&self) -> &Router {
174        &self.router
175    }
176
177    /// Get mutable inner router
178    pub fn router_mut(&mut self) -> &mut Router {
179        &mut self.router
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use std::path::PathBuf;
187    use terraphim_types::capability::Capability;
188
189    fn create_test_router() -> Router {
190        let mut router = Router::new();
191
192        // Add LLM provider
193        router.add_provider(Provider::new(
194            "gpt-4",
195            "GPT-4",
196            ProviderType::Llm {
197                model_id: "gpt-4".to_string(),
198                api_endpoint: "https://api.openai.com".to_string(),
199            },
200            vec![Capability::CodeGeneration],
201        ));
202
203        // Add agent provider
204        router.add_provider(Provider::new(
205            "@codex",
206            "Codex",
207            ProviderType::Agent {
208                agent_id: "@codex".to_string(),
209                cli_command: "opencode".to_string(),
210                working_dir: PathBuf::from("/tmp"),
211            },
212            vec![Capability::CodeGeneration],
213        ));
214
215        router
216    }
217
218    #[tokio::test]
219    async fn test_fallback_to_next_provider() {
220        let router = create_test_router();
221        let fallback_router = FallbackRouter::new(router)
222            .with_strategy(FallbackStrategy::NextBestProvider)
223            .with_max_fallbacks(2);
224
225        let attempts = std::sync::atomic::AtomicU32::new(0);
226        let result = fallback_router
227            .route_with_fallback(
228                "Implement a function",
229                &RoutingContext::default(),
230                |_provider| {
231                    let n = attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
232                    async move {
233                        // First attempt fails, second succeeds
234                        if n == 1 {
235                            Err("First provider failed".to_string())
236                        } else {
237                            Ok(())
238                        }
239                    }
240                },
241            )
242            .await;
243
244        assert!(result.is_ok());
245        assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 2);
246    }
247
248    #[tokio::test]
249    async fn test_fail_fast() {
250        let router = create_test_router();
251        let fallback_router = FallbackRouter::new(router).with_strategy(FallbackStrategy::FailFast);
252
253        let result = fallback_router
254            .route_with_fallback(
255                "Implement a function",
256                &RoutingContext::default(),
257                |_provider| async { Err("Always fails".to_string()) },
258            )
259            .await;
260
261        assert!(result.is_err());
262    }
263}