1use serde::{Deserialize, Serialize};
24use tracing::{debug, info, warn};
25
26use super::error::LlmError;
27use crate::config::{
28 FallbackBehavior, FallbackConfig as ConfigFallbackConfig, OnAllFailedBehavior,
29};
30
31#[derive(Debug, Clone)]
33pub struct FallbackResult<T> {
34 pub result: T,
36 pub degraded: bool,
38 pub model: String,
40 pub endpoint: String,
42 pub fallback_history: Vec<FallbackStep>,
44}
45
46impl<T> FallbackResult<T> {
47 pub fn success(result: T, model: String, endpoint: String) -> Self {
49 Self {
50 result,
51 degraded: false,
52 model,
53 endpoint,
54 fallback_history: Vec::new(),
55 }
56 }
57
58 pub fn from_fallback(
60 result: T,
61 model: String,
62 endpoint: String,
63 history: Vec<FallbackStep>,
64 ) -> Self {
65 Self {
66 result,
67 degraded: true,
68 model,
69 endpoint,
70 fallback_history: history,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FallbackStep {
78 pub from_model: String,
80 pub to_model: Option<String>,
82 pub from_endpoint: String,
84 pub to_endpoint: Option<String>,
86 pub reason: String,
88}
89
90#[derive(Debug, Clone)]
92pub struct FallbackChain {
93 config: FallbackConfig,
94}
95
96#[derive(Debug, Clone)]
98pub struct FallbackConfig {
99 pub enabled: bool,
101 pub models: Vec<String>,
103 pub endpoints: Vec<String>,
105 pub on_rate_limit: FallbackBehavior,
107 pub on_timeout: FallbackBehavior,
109 pub on_all_failed: OnAllFailedBehavior,
111}
112
113impl Default for FallbackConfig {
114 fn default() -> Self {
115 Self {
116 enabled: true,
117 models: vec!["gpt-4o-mini".to_string(), "glm-4-flash".to_string()],
118 endpoints: vec![],
119 on_rate_limit: FallbackBehavior::RetryThenFallback,
120 on_timeout: FallbackBehavior::RetryThenFallback,
121 on_all_failed: OnAllFailedBehavior::ReturnError,
122 }
123 }
124}
125
126impl From<ConfigFallbackConfig> for FallbackConfig {
127 fn from(config: ConfigFallbackConfig) -> Self {
128 Self {
129 enabled: config.enabled,
130 models: config.models,
131 endpoints: config.endpoints,
132 on_rate_limit: config.on_rate_limit,
133 on_timeout: config.on_timeout,
134 on_all_failed: config.on_all_failed,
135 }
136 }
137}
138
139impl FallbackConfig {
140 pub fn new() -> Self {
142 Self::default()
143 }
144
145 pub fn disabled() -> Self {
147 Self {
148 enabled: false,
149 ..Self::default()
150 }
151 }
152}
153
154impl FallbackChain {
155 pub fn new(config: FallbackConfig) -> Self {
157 Self { config }
158 }
159
160 pub fn disabled() -> Self {
162 Self::new(FallbackConfig::disabled())
163 }
164
165 pub fn config(&self) -> &FallbackConfig {
167 &self.config
168 }
169
170 pub fn is_enabled(&self) -> bool {
172 self.config.enabled
173 }
174
175 pub fn behavior_for_error(&self, error: &LlmError) -> FallbackBehavior {
177 match error {
178 LlmError::RateLimit(_) => self.config.on_rate_limit,
179 LlmError::Timeout(_) => self.config.on_timeout,
180 _ => FallbackBehavior::Fail,
181 }
182 }
183
184 pub fn should_fallback(&self, error: &LlmError) -> bool {
186 if !self.config.enabled {
187 return false;
188 }
189
190 match self.behavior_for_error(error) {
191 FallbackBehavior::Fallback | FallbackBehavior::RetryThenFallback => true,
192 FallbackBehavior::Retry | FallbackBehavior::Fail => false,
193 }
194 }
195
196 pub fn should_retry(&self, error: &LlmError) -> bool {
198 if !self.config.enabled {
199 return false;
200 }
201
202 match self.behavior_for_error(error) {
203 FallbackBehavior::Retry | FallbackBehavior::RetryThenFallback => true,
204 FallbackBehavior::Fallback | FallbackBehavior::Fail => false,
205 }
206 }
207
208 pub fn next_model(&self, current: &str) -> Option<String> {
210 let models = &self.config.models;
211 let current_idx = models.iter().position(|m| m == current);
212
213 match current_idx {
214 Some(idx) if idx + 1 < models.len() => {
216 let next = models[idx + 1].clone();
217 info!(from = current, to = %next, "Falling back to next model");
218 Some(next)
219 }
220 Some(_) => {
222 warn!(
223 model = current,
224 "Already at last fallback model, no more available"
225 );
226 None
227 }
228 None => {
230 if !models.is_empty() && models[0] != current {
231 let next = models[0].clone();
232 info!(from = current, to = %next, "Falling back to first fallback model");
233 Some(next)
234 } else {
235 warn!(model = current, "No more fallback models available");
236 None
237 }
238 }
239 }
240 }
241
242 pub fn next_endpoint(&self, current: &str) -> Option<String> {
244 let endpoints = &self.config.endpoints;
245 let current_idx = endpoints.iter().position(|e| e == current);
246
247 match current_idx {
248 Some(idx) if idx + 1 < endpoints.len() => {
250 let next = endpoints[idx + 1].clone();
251 info!(from = current, to = %next, "Falling back to next endpoint");
252 Some(next)
253 }
254 Some(_) => {
256 warn!(
257 endpoint = current,
258 "Already at last fallback endpoint, no more available"
259 );
260 None
261 }
262 None => {
264 if !endpoints.is_empty() && endpoints[0] != current {
265 let next = endpoints[0].clone();
266 info!(from = current, to = %next, "Falling back to first fallback endpoint");
267 Some(next)
268 } else {
269 debug!(endpoint = current, "No more fallback endpoints available");
270 None
271 }
272 }
273 }
274 }
275
276 pub fn record_fallback(
278 &self,
279 history: &mut Vec<FallbackStep>,
280 from_model: String,
281 to_model: Option<String>,
282 from_endpoint: String,
283 to_endpoint: Option<String>,
284 reason: String,
285 ) {
286 let step = FallbackStep {
287 from_model,
288 to_model,
289 from_endpoint,
290 to_endpoint,
291 reason,
292 };
293 debug!(?step, "Recording fallback step");
294 history.push(step);
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_fallback_config_default() {
304 let config = FallbackConfig::default();
305 assert!(config.enabled);
306 assert!(!config.models.is_empty());
307 }
308
309 #[test]
310 fn test_fallback_chain_disabled() {
311 let chain = FallbackChain::disabled();
312 assert!(!chain.is_enabled());
313 }
314
315 #[test]
316 fn test_next_model() {
317 let config = FallbackConfig {
318 models: vec![
319 "gpt-4o".to_string(),
320 "gpt-4o-mini".to_string(),
321 "glm-4-flash".to_string(),
322 ],
323 ..FallbackConfig::default()
324 };
325 let chain = FallbackChain::new(config);
326
327 assert_eq!(chain.next_model("gpt-4o"), Some("gpt-4o-mini".to_string()));
329 assert_eq!(
330 chain.next_model("gpt-4o-mini"),
331 Some("glm-4-flash".to_string())
332 );
333 assert_eq!(chain.next_model("glm-4-flash"), None);
334 }
335
336 #[test]
337 fn test_next_model_not_in_list() {
338 let config = FallbackConfig {
339 models: vec!["gpt-4o-mini".to_string()],
340 ..FallbackConfig::default()
341 };
342 let chain = FallbackChain::new(config);
343
344 assert_eq!(
346 chain.next_model("unknown-model"),
347 Some("gpt-4o-mini".to_string())
348 );
349 }
350
351 #[test]
352 fn test_behavior_for_rate_limit() {
353 let config = FallbackConfig {
354 on_rate_limit: FallbackBehavior::Fallback,
355 ..FallbackConfig::default()
356 };
357 let chain = FallbackChain::new(config);
358
359 let error = LlmError::RateLimit("Rate limited".to_string());
360 assert_eq!(chain.behavior_for_error(&error), FallbackBehavior::Fallback);
361 }
362
363 #[test]
364 fn test_should_fallback() {
365 let config = FallbackConfig {
366 enabled: true,
367 on_rate_limit: FallbackBehavior::RetryThenFallback,
368 ..FallbackConfig::default()
369 };
370 let chain = FallbackChain::new(config);
371
372 let error = LlmError::RateLimit("Rate limited".to_string());
373 assert!(chain.should_fallback(&error));
374
375 let chain_disabled = FallbackChain::disabled();
376 assert!(!chain_disabled.should_fallback(&error));
377 }
378}