Skip to main content

vectorless/llm/
fallback.rs

1// Copyright (c) 2026 vectorless developers
2// SPDX-License-Identifier: Apache-2.0
3
4//! Fallback and error recovery for LLM calls.
5//!
6//! This module provides graceful degradation when LLM API calls fail:
7//! - Automatic model switching (e.g., gpt-4o → gpt-4o-mini)
8//! - Automatic endpoint switching
9//! - Configurable retry and fallback behaviors
10//!
11//! # Example
12//!
13//! ```rust
14//! use vectorless::llm::fallback::{FallbackChain, FallbackConfig};
15//!
16//! let config = FallbackConfig::default();
17//! let chain = FallbackChain::new(config);
18//!
19//! // Check if fallback is enabled
20//! assert!(chain.is_enabled());
21//! ```
22
23use serde::{Deserialize, Serialize};
24use tracing::{debug, info, warn};
25
26use super::error::LlmError;
27use crate::config::{
28    FallbackBehavior, FallbackConfig as ConfigFallbackConfig, OnAllFailedBehavior,
29};
30
31/// Result from a fallback-aware LLM call.
32#[derive(Debug, Clone)]
33pub struct FallbackResult<T> {
34    /// The actual result.
35    pub result: T,
36    /// Whether the result came from a fallback model/endpoint.
37    pub degraded: bool,
38    /// The model that was ultimately used.
39    pub model: String,
40    /// The endpoint that was ultimately used.
41    pub endpoint: String,
42    /// History of fallback attempts (for debugging).
43    pub fallback_history: Vec<FallbackStep>,
44}
45
46impl<T> FallbackResult<T> {
47    /// Create a successful result without fallback.
48    pub fn success(result: T, model: String, endpoint: String) -> Self {
49        Self {
50            result,
51            degraded: false,
52            model,
53            endpoint,
54            fallback_history: Vec::new(),
55        }
56    }
57
58    /// Create a result from a fallback.
59    pub fn from_fallback(
60        result: T,
61        model: String,
62        endpoint: String,
63        history: Vec<FallbackStep>,
64    ) -> Self {
65        Self {
66            result,
67            degraded: true,
68            model,
69            endpoint,
70            fallback_history: history,
71        }
72    }
73}
74
75/// A single step in the fallback chain.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FallbackStep {
78    /// The model we tried.
79    pub from_model: String,
80    /// The model we fell back to (if any).
81    pub to_model: Option<String>,
82    /// The endpoint we tried.
83    pub from_endpoint: String,
84    /// The endpoint we fell back to (if any).
85    pub to_endpoint: Option<String>,
86    /// The reason for fallback.
87    pub reason: String,
88}
89
90/// Fallback chain manager.
91#[derive(Debug, Clone)]
92pub struct FallbackChain {
93    config: FallbackConfig,
94}
95
96/// Runtime fallback configuration (converted from config::FallbackConfig).
97#[derive(Debug, Clone)]
98pub struct FallbackConfig {
99    /// Whether fallback is enabled.
100    pub enabled: bool,
101    /// Fallback models in priority order.
102    pub models: Vec<String>,
103    /// Fallback endpoints in priority order.
104    pub endpoints: Vec<String>,
105    /// Behavior on rate limit error.
106    pub on_rate_limit: FallbackBehavior,
107    /// Behavior on timeout error.
108    pub on_timeout: FallbackBehavior,
109    /// Behavior when all attempts fail.
110    pub on_all_failed: OnAllFailedBehavior,
111}
112
113impl Default for FallbackConfig {
114    fn default() -> Self {
115        Self {
116            enabled: true,
117            models: vec!["gpt-4o-mini".to_string(), "glm-4-flash".to_string()],
118            endpoints: vec![],
119            on_rate_limit: FallbackBehavior::RetryThenFallback,
120            on_timeout: FallbackBehavior::RetryThenFallback,
121            on_all_failed: OnAllFailedBehavior::ReturnError,
122        }
123    }
124}
125
126impl From<ConfigFallbackConfig> for FallbackConfig {
127    fn from(config: ConfigFallbackConfig) -> Self {
128        Self {
129            enabled: config.enabled,
130            models: config.models,
131            endpoints: config.endpoints,
132            on_rate_limit: config.on_rate_limit,
133            on_timeout: config.on_timeout,
134            on_all_failed: config.on_all_failed,
135        }
136    }
137}
138
139impl FallbackConfig {
140    /// Create a new fallback config.
141    pub fn new() -> Self {
142        Self::default()
143    }
144
145    /// Disable fallback.
146    pub fn disabled() -> Self {
147        Self {
148            enabled: false,
149            ..Self::default()
150        }
151    }
152}
153
154impl FallbackChain {
155    /// Create a new fallback chain with the given configuration.
156    pub fn new(config: FallbackConfig) -> Self {
157        Self { config }
158    }
159
160    /// Create a disabled fallback chain (no fallback).
161    pub fn disabled() -> Self {
162        Self::new(FallbackConfig::disabled())
163    }
164
165    /// Get the configuration.
166    pub fn config(&self) -> &FallbackConfig {
167        &self.config
168    }
169
170    /// Check if fallback is enabled.
171    pub fn is_enabled(&self) -> bool {
172        self.config.enabled
173    }
174
175    /// Determine the appropriate behavior for an error.
176    pub fn behavior_for_error(&self, error: &LlmError) -> FallbackBehavior {
177        match error {
178            LlmError::RateLimit(_) => self.config.on_rate_limit,
179            LlmError::Timeout(_) => self.config.on_timeout,
180            _ => FallbackBehavior::Fail,
181        }
182    }
183
184    /// Check if an error should trigger fallback.
185    pub fn should_fallback(&self, error: &LlmError) -> bool {
186        if !self.config.enabled {
187            return false;
188        }
189
190        match self.behavior_for_error(error) {
191            FallbackBehavior::Fallback | FallbackBehavior::RetryThenFallback => true,
192            FallbackBehavior::Retry | FallbackBehavior::Fail => false,
193        }
194    }
195
196    /// Check if an error should trigger retry.
197    pub fn should_retry(&self, error: &LlmError) -> bool {
198        if !self.config.enabled {
199            return false;
200        }
201
202        match self.behavior_for_error(error) {
203            FallbackBehavior::Retry | FallbackBehavior::RetryThenFallback => true,
204            FallbackBehavior::Fallback | FallbackBehavior::Fail => false,
205        }
206    }
207
208    /// Get the next fallback model.
209    pub fn next_model(&self, current: &str) -> Option<String> {
210        let models = &self.config.models;
211        let current_idx = models.iter().position(|m| m == current);
212
213        match current_idx {
214            // Current model is in the list, try next one
215            Some(idx) if idx + 1 < models.len() => {
216                let next = models[idx + 1].clone();
217                info!(from = current, to = %next, "Falling back to next model");
218                Some(next)
219            }
220            // Current model is the last in the list, no more fallbacks
221            Some(_) => {
222                warn!(
223                    model = current,
224                    "Already at last fallback model, no more available"
225                );
226                None
227            }
228            // Current model not in fallback list, try first fallback
229            None => {
230                if !models.is_empty() && models[0] != current {
231                    let next = models[0].clone();
232                    info!(from = current, to = %next, "Falling back to first fallback model");
233                    Some(next)
234                } else {
235                    warn!(model = current, "No more fallback models available");
236                    None
237                }
238            }
239        }
240    }
241
242    /// Get the next fallback endpoint.
243    pub fn next_endpoint(&self, current: &str) -> Option<String> {
244        let endpoints = &self.config.endpoints;
245        let current_idx = endpoints.iter().position(|e| e == current);
246
247        match current_idx {
248            // Current endpoint is in the list, try next one
249            Some(idx) if idx + 1 < endpoints.len() => {
250                let next = endpoints[idx + 1].clone();
251                info!(from = current, to = %next, "Falling back to next endpoint");
252                Some(next)
253            }
254            // Current endpoint is the last in the list, no more fallbacks
255            Some(_) => {
256                warn!(
257                    endpoint = current,
258                    "Already at last fallback endpoint, no more available"
259                );
260                None
261            }
262            // Current endpoint not in fallback list, try first fallback
263            None => {
264                if !endpoints.is_empty() && endpoints[0] != current {
265                    let next = endpoints[0].clone();
266                    info!(from = current, to = %next, "Falling back to first fallback endpoint");
267                    Some(next)
268                } else {
269                    debug!(endpoint = current, "No more fallback endpoints available");
270                    None
271                }
272            }
273        }
274    }
275
276    /// Record a fallback step.
277    pub fn record_fallback(
278        &self,
279        history: &mut Vec<FallbackStep>,
280        from_model: String,
281        to_model: Option<String>,
282        from_endpoint: String,
283        to_endpoint: Option<String>,
284        reason: String,
285    ) {
286        let step = FallbackStep {
287            from_model,
288            to_model,
289            from_endpoint,
290            to_endpoint,
291            reason,
292        };
293        debug!(?step, "Recording fallback step");
294        history.push(step);
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_fallback_config_default() {
304        let config = FallbackConfig::default();
305        assert!(config.enabled);
306        assert!(!config.models.is_empty());
307    }
308
309    #[test]
310    fn test_fallback_chain_disabled() {
311        let chain = FallbackChain::disabled();
312        assert!(!chain.is_enabled());
313    }
314
315    #[test]
316    fn test_next_model() {
317        let config = FallbackConfig {
318            models: vec![
319                "gpt-4o".to_string(),
320                "gpt-4o-mini".to_string(),
321                "glm-4-flash".to_string(),
322            ],
323            ..FallbackConfig::default()
324        };
325        let chain = FallbackChain::new(config);
326
327        // Should get next model in chain
328        assert_eq!(chain.next_model("gpt-4o"), Some("gpt-4o-mini".to_string()));
329        assert_eq!(
330            chain.next_model("gpt-4o-mini"),
331            Some("glm-4-flash".to_string())
332        );
333        assert_eq!(chain.next_model("glm-4-flash"), None);
334    }
335
336    #[test]
337    fn test_next_model_not_in_list() {
338        let config = FallbackConfig {
339            models: vec!["gpt-4o-mini".to_string()],
340            ..FallbackConfig::default()
341        };
342        let chain = FallbackChain::new(config);
343
344        // Should fall back to first model in list
345        assert_eq!(
346            chain.next_model("unknown-model"),
347            Some("gpt-4o-mini".to_string())
348        );
349    }
350
351    #[test]
352    fn test_behavior_for_rate_limit() {
353        let config = FallbackConfig {
354            on_rate_limit: FallbackBehavior::Fallback,
355            ..FallbackConfig::default()
356        };
357        let chain = FallbackChain::new(config);
358
359        let error = LlmError::RateLimit("Rate limited".to_string());
360        assert_eq!(chain.behavior_for_error(&error), FallbackBehavior::Fallback);
361    }
362
363    #[test]
364    fn test_should_fallback() {
365        let config = FallbackConfig {
366            enabled: true,
367            on_rate_limit: FallbackBehavior::RetryThenFallback,
368            ..FallbackConfig::default()
369        };
370        let chain = FallbackChain::new(config);
371
372        let error = LlmError::RateLimit("Rate limited".to_string());
373        assert!(chain.should_fallback(&error));
374
375        let chain_disabled = FallbackChain::disabled();
376        assert!(!chain_disabled.should_fallback(&error));
377    }
378}