1use crate::{Router, RoutingContext, RoutingDecision, RoutingError};
7use terraphim_types::capability::{Provider, ProviderType};
8use tracing::{info, info_span, warn, Instrument};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum FallbackStrategy {
13 #[default]
15 NextBestProvider,
16 LlmFallback,
18 Retry { max_attempts: u32 },
20 FailFast,
22}
23
24#[derive(Debug)]
26pub struct FallbackRouter {
27 router: Router,
28 fallback_strategy: FallbackStrategy,
29 max_fallbacks: u32,
30}
31
32impl FallbackRouter {
33 pub fn new(router: Router) -> Self {
35 Self {
36 router,
37 fallback_strategy: FallbackStrategy::default(),
38 max_fallbacks: 3,
39 }
40 }
41
42 pub fn with_strategy(mut self, strategy: FallbackStrategy) -> Self {
44 self.fallback_strategy = strategy;
45 self
46 }
47
48 pub fn with_max_fallbacks(mut self, max: u32) -> Self {
50 self.max_fallbacks = max;
51 self
52 }
53
54 pub async fn route_with_fallback<F, Fut>(
59 &self,
60 prompt: &str,
61 context: &RoutingContext,
62 mut execute: F,
63 ) -> Result<RoutingDecision, RoutingError>
64 where
65 F: FnMut(Provider) -> Fut,
66 Fut: std::future::Future<Output = Result<(), String>>,
67 {
68 let mut attempts = 0;
69 let mut current_prompt = prompt.to_string();
70
71 let fallback_span = info_span!(
72 "router.route_with_fallback",
73 prompt_len = prompt.len(),
74 fallback_strategy = ?self.fallback_strategy,
75 max_fallbacks = self.max_fallbacks,
76 total_attempts = tracing::field::Empty,
77 final_provider = tracing::field::Empty,
78 outcome = tracing::field::Empty,
79 );
80
81 async {
82 loop {
83 let decision = self.router.route(¤t_prompt, context)?;
84 let provider = decision.provider.clone();
85
86 let attempt_span = info_span!(
87 "router.fallback_attempt",
88 attempt_number = attempts + 1,
89 provider_id = provider.id.as_str(),
90 provider_type = ?provider.provider_type,
91 outcome = tracing::field::Empty,
92 );
93
94 let execute_result = async {
95 info!(
96 attempt = attempts + 1,
97 provider_id = provider.id.as_str(),
98 provider_name = provider.name.as_str(),
99 "Attempting provider execution"
100 );
101
102 match execute(provider.clone()).await {
103 Ok(()) => {
104 info!(
105 provider_id = provider.id.as_str(),
106 "Provider execution succeeded"
107 );
108 tracing::Span::current().record("outcome", "success");
109 Ok(decision.clone())
110 }
111 Err(error) => {
112 warn!(
113 provider_id = provider.id.as_str(),
114 error = error.as_str(),
115 "Provider execution failed"
116 );
117 tracing::Span::current().record("outcome", "failed");
118 Err(error)
119 }
120 }
121 }
122 .instrument(attempt_span)
123 .await;
124
125 match execute_result {
126 Ok(decision) => {
127 tracing::Span::current().record("total_attempts", attempts + 1);
128 tracing::Span::current()
129 .record("final_provider", decision.provider.id.as_str());
130 tracing::Span::current().record("outcome", "success");
131 return Ok(decision);
132 }
133 Err(_error) => {
134 attempts += 1;
135 if attempts >= self.max_fallbacks {
136 tracing::Span::current().record("total_attempts", attempts);
137 tracing::Span::current().record("outcome", "exhausted");
138 return Err(RoutingError::NoProviderFound(vec![]));
139 }
140
141 match self.fallback_strategy {
142 FallbackStrategy::FailFast => {
143 tracing::Span::current().record("outcome", "fail_fast");
144 return Err(RoutingError::NoProviderFound(vec![]));
145 }
146 FallbackStrategy::Retry { max_attempts } => {
147 if attempts >= max_attempts {
148 continue;
149 }
150 info!(delay_ms = 1000, "Retrying same provider after delay");
151 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
152 }
153 FallbackStrategy::NextBestProvider => {
154 info!("Excluding failed provider, trying next best");
155 current_prompt = format!("{} [exclude:{}]", prompt, provider.id);
156 }
157 FallbackStrategy::LlmFallback => {
158 if matches!(provider.provider_type, ProviderType::Agent { .. }) {
159 info!("Agent failed, falling back to LLM preference");
160 current_prompt = format!("{} [prefer:llm]", prompt);
161 }
162 }
163 }
164 }
165 }
166 }
167 }
168 .instrument(fallback_span)
169 .await
170 }
171
172 pub fn router(&self) -> &Router {
174 &self.router
175 }
176
177 pub fn router_mut(&mut self) -> &mut Router {
179 &mut self.router
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use std::path::PathBuf;
187 use terraphim_types::capability::Capability;
188
189 fn create_test_router() -> Router {
190 let mut router = Router::new();
191
192 router.add_provider(Provider::new(
194 "gpt-4",
195 "GPT-4",
196 ProviderType::Llm {
197 model_id: "gpt-4".to_string(),
198 api_endpoint: "https://api.openai.com".to_string(),
199 },
200 vec![Capability::CodeGeneration],
201 ));
202
203 router.add_provider(Provider::new(
205 "@codex",
206 "Codex",
207 ProviderType::Agent {
208 agent_id: "@codex".to_string(),
209 cli_command: "opencode".to_string(),
210 working_dir: PathBuf::from("/tmp"),
211 },
212 vec![Capability::CodeGeneration],
213 ));
214
215 router
216 }
217
218 #[tokio::test]
219 async fn test_fallback_to_next_provider() {
220 let router = create_test_router();
221 let fallback_router = FallbackRouter::new(router)
222 .with_strategy(FallbackStrategy::NextBestProvider)
223 .with_max_fallbacks(2);
224
225 let attempts = std::sync::atomic::AtomicU32::new(0);
226 let result = fallback_router
227 .route_with_fallback(
228 "Implement a function",
229 &RoutingContext::default(),
230 |_provider| {
231 let n = attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
232 async move {
233 if n == 1 {
235 Err("First provider failed".to_string())
236 } else {
237 Ok(())
238 }
239 }
240 },
241 )
242 .await;
243
244 assert!(result.is_ok());
245 assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 2);
246 }
247
248 #[tokio::test]
249 async fn test_fail_fast() {
250 let router = create_test_router();
251 let fallback_router = FallbackRouter::new(router).with_strategy(FallbackStrategy::FailFast);
252
253 let result = fallback_router
254 .route_with_fallback(
255 "Implement a function",
256 &RoutingContext::default(),
257 |_provider| async { Err("Always fails".to_string()) },
258 )
259 .await;
260
261 assert!(result.is_err());
262 }
263}