Skip to main content

rustic_ai/
failover.rs

1use std::error::Error;
2use std::future::Future;
3
4use tracing::{debug, warn};
5
6use crate::model_config::{ModelConfigResolver, ResolvedModelConfig};
7
8use crate::error::AgentError;
9use crate::model::ModelError;
10
11pub fn classify_error_kind(error: &(dyn Error + 'static)) -> Option<&'static str> {
12    if let Some(agent_error) = error.downcast_ref::<AgentError>() {
13        return classify_agent_error(agent_error);
14    }
15    if let Some(model_error) = error.downcast_ref::<ModelError>() {
16        return classify_model_error(model_error);
17    }
18    None
19}
20
21fn classify_agent_error(error: &AgentError) -> Option<&'static str> {
22    match error {
23        AgentError::Model(model_error) => classify_model_error(model_error),
24        _ => None,
25    }
26}
27
28fn classify_model_error(error: &ModelError) -> Option<&'static str> {
29    match error {
30        ModelError::Timeout => Some("timeout"),
31        ModelError::Transport(_) => Some("connect_error"),
32        ModelError::HttpStatus { status } => match *status {
33            401 => Some("http_401"),
34            403 => Some("http_403"),
35            429 => Some("http_429"),
36            status if status >= 500 => Some("http_5xx"),
37            _ => None,
38        },
39        ModelError::Provider(_) | ModelError::Serialization(_) => Some("model_error"),
40        ModelError::Unsupported(_) => None,
41    }
42}
43
44#[derive(Clone, Debug, PartialEq)]
45pub struct FailoverResult<T> {
46    pub value: T,
47    pub model_used: String,
48    pub failed_over: bool,
49    pub primary_attempts: u32,
50}
51
52pub async fn run_with_failover<T, E, F, Fut>(
53    resolver: &dyn ModelConfigResolver,
54    agent_name: &str,
55    requested_model: Option<&str>,
56    environment: Option<&str>,
57    invoke: F,
58) -> Result<FailoverResult<T>, E>
59where
60    E: Error + Send + Sync + 'static,
61    F: FnMut(&str) -> Fut,
62    Fut: Future<Output = Result<T, E>>,
63{
64    run_with_failover_with_classifier(
65        resolver,
66        agent_name,
67        requested_model,
68        environment,
69        invoke,
70        |error| classify_error_kind(error),
71    )
72    .await
73}
74
75pub async fn run_with_failover_with_classifier<T, E, F, Fut, C>(
76    resolver: &dyn ModelConfigResolver,
77    agent_name: &str,
78    requested_model: Option<&str>,
79    environment: Option<&str>,
80    invoke: F,
81    classifier: C,
82) -> Result<FailoverResult<T>, E>
83where
84    E: Error + Send + Sync + 'static,
85    F: FnMut(&str) -> Fut,
86    Fut: Future<Output = Result<T, E>>,
87    C: Fn(&E) -> Option<&'static str>,
88{
89    let config = resolver.resolve_model_config(agent_name, requested_model, environment);
90    run_with_config_and_classifier(config, invoke, classifier).await
91}
92
93pub async fn run_with_utility_failover<T, E, F, Fut>(
94    resolver: &dyn ModelConfigResolver,
95    utility_name: &str,
96    environment: Option<&str>,
97    invoke: F,
98) -> Result<FailoverResult<T>, E>
99where
100    E: Error + Send + Sync + 'static,
101    F: FnMut(&str) -> Fut,
102    Fut: Future<Output = Result<T, E>>,
103{
104    run_with_utility_failover_with_classifier(
105        resolver,
106        utility_name,
107        environment,
108        invoke,
109        |error| classify_error_kind(error),
110    )
111    .await
112}
113
114pub async fn run_with_utility_failover_with_classifier<T, E, F, Fut, C>(
115    resolver: &dyn ModelConfigResolver,
116    utility_name: &str,
117    environment: Option<&str>,
118    invoke: F,
119    classifier: C,
120) -> Result<FailoverResult<T>, E>
121where
122    E: Error + Send + Sync + 'static,
123    F: FnMut(&str) -> Fut,
124    Fut: Future<Output = Result<T, E>>,
125    C: Fn(&E) -> Option<&'static str>,
126{
127    let config = resolver.resolve_utility_config(utility_name, environment);
128    run_with_config_and_classifier(config, invoke, classifier).await
129}
130
131pub async fn run_with_config<T, E, F, Fut>(
132    config: ResolvedModelConfig,
133    invoke: F,
134) -> Result<FailoverResult<T>, E>
135where
136    E: Error + Send + Sync + 'static,
137    F: FnMut(&str) -> Fut,
138    Fut: Future<Output = Result<T, E>>,
139{
140    run_with_config_and_classifier(config, invoke, |error| classify_error_kind(error)).await
141}
142
143pub async fn run_with_config_and_classifier<T, E, F, Fut, C>(
144    config: ResolvedModelConfig,
145    mut invoke: F,
146    classifier: C,
147) -> Result<FailoverResult<T>, E>
148where
149    E: Error + Send + Sync + 'static,
150    F: FnMut(&str) -> Fut,
151    Fut: Future<Output = Result<T, E>>,
152    C: Fn(&E) -> Option<&'static str>,
153{
154    let mut last_kind = None;
155    let mut last_error = None;
156
157    for attempt in 0..=config.retry_limit {
158        match invoke(&config.primary).await {
159            Ok(value) => {
160                return Ok(FailoverResult {
161                    value,
162                    model_used: config.primary.clone(),
163                    failed_over: false,
164                    primary_attempts: attempt + 1,
165                });
166            }
167            Err(error) => {
168                let kind = classifier(&error);
169                last_kind = kind;
170                if !kind.is_some_and(|kind| config.failover_on.contains(kind)) {
171                    debug!(
172                        model = config.primary.as_str(),
173                        attempt = attempt + 1,
174                        error_kind = kind.unwrap_or(""),
175                        "primary request failed without failover"
176                    );
177                    return Err(error);
178                }
179                last_error = Some(error);
180                if attempt < config.retry_limit {
181                    debug!(
182                        model = config.primary.as_str(),
183                        attempt = attempt + 1,
184                        error_kind = kind.unwrap_or(""),
185                        "primary request failed, retrying"
186                    );
187                    continue;
188                }
189                break;
190            }
191        }
192    }
193
194    let should_failover =
195        config.backup.is_some() && last_kind.is_some_and(|kind| config.failover_on.contains(kind));
196    if !should_failover && let Some(error) = last_error {
197        warn!(
198            model = config.primary.as_str(),
199            error_kind = last_kind.unwrap_or(""),
200            "primary request failed and no failover configured"
201        );
202        return Err(error);
203    }
204
205    let backup = config.backup.clone().unwrap_or_default();
206    warn!(
207        primary = config.primary.as_str(),
208        backup = backup.as_str(),
209        error_kind = last_kind.unwrap_or(""),
210        "failing over to backup model"
211    );
212    let result = invoke(&backup).await?;
213    Ok(FailoverResult {
214        value: result,
215        model_used: backup,
216        failed_over: true,
217        primary_attempts: config.retry_limit + 1,
218    })
219}