Skip to main content

simple_agents_router/
fallback.rs

1//! Fallback routing implementation.
2//!
3//! Attempts providers in order, falling back on retryable errors.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderError, Result, SimpleAgentsError,
7};
8use std::sync::Arc;
9
10/// Configuration for fallback routing.
11#[derive(Debug, Clone, Copy)]
12pub struct FallbackRouterConfig {
13    /// If true, fallback only on retryable provider errors.
14    pub retryable_only: bool,
15}
16
17impl Default for FallbackRouterConfig {
18    fn default() -> Self {
19        Self {
20            retryable_only: true,
21        }
22    }
23}
24
25/// Router that tries providers in order and falls back on eligible errors.
26pub struct FallbackRouter {
27    providers: Vec<Arc<dyn Provider>>,
28    config: FallbackRouterConfig,
29}
30
31impl FallbackRouter {
32    /// Create a new fallback router.
33    ///
34    /// # Errors
35    /// Returns a routing error if no providers are supplied.
36    pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
37        Self::with_config(providers, FallbackRouterConfig::default())
38    }
39
40    /// Create a new fallback router with custom configuration.
41    ///
42    /// # Errors
43    /// Returns a routing error if no providers are supplied.
44    pub fn with_config(
45        providers: Vec<Arc<dyn Provider>>,
46        config: FallbackRouterConfig,
47    ) -> Result<Self> {
48        if providers.is_empty() {
49            return Err(SimpleAgentsError::Routing(
50                "no providers configured".to_string(),
51            ));
52        }
53
54        Ok(Self { providers, config })
55    }
56
57    /// Return the number of configured providers.
58    pub fn provider_count(&self) -> usize {
59        self.providers.len()
60    }
61
62    /// Execute a completion request with fallback logic.
63    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
64        let mut last_error: Option<SimpleAgentsError> = None;
65
66        for provider in &self.providers {
67            let attempt = self.execute_provider(provider, request).await;
68            match attempt {
69                Ok(response) => return Ok(response),
70                Err(err) => {
71                    if !self.should_fallback(&err) {
72                        return Err(err);
73                    }
74                    last_error = Some(err);
75                }
76            }
77        }
78
79        Err(last_error
80            .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
81    }
82
83    /// Execute a streaming request with fallback logic.
84    pub async fn stream(
85        &self,
86        request: &CompletionRequest,
87    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
88        for provider in &self.providers {
89            let provider_request = provider.transform_request(request)?;
90            match provider.execute_stream(provider_request).await {
91                Ok(stream) => return Ok(stream),
92                Err(err) => {
93                    if !self.should_fallback(&err) {
94                        return Err(err);
95                    }
96                    // Continue to next provider
97                }
98            }
99        }
100
101        Err(SimpleAgentsError::Routing("no providers configured".to_string()))
102    }
103
104    async fn execute_provider(
105        &self,
106        provider: &Arc<dyn Provider>,
107        request: &CompletionRequest,
108    ) -> Result<CompletionResponse> {
109        let provider_request = provider.transform_request(request)?;
110        let provider_response = provider.execute(provider_request).await?;
111        provider.transform_response(provider_response)
112    }
113
114    fn should_fallback(&self, error: &SimpleAgentsError) -> bool {
115        if !self.config.retryable_only {
116            return true;
117        }
118
119        matches!(
120            error,
121            SimpleAgentsError::Provider(
122                ProviderError::RateLimit { .. }
123                    | ProviderError::Timeout(_)
124                    | ProviderError::ServerError(_)
125            ) | SimpleAgentsError::Network(_)
126        )
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use async_trait::async_trait;
134    use simple_agent_type::prelude::*;
135    use std::sync::atomic::{AtomicUsize, Ordering};
136
137    struct MockProvider {
138        name: &'static str,
139        attempts: AtomicUsize,
140        result: MockResult,
141    }
142
143    enum MockResult {
144        Ok,
145        RetryableError,
146        NonRetryableError,
147    }
148
149    impl MockProvider {
150        fn new(name: &'static str, result: MockResult) -> Self {
151            Self {
152                name,
153                attempts: AtomicUsize::new(0),
154                result,
155            }
156        }
157    }
158
159    #[async_trait]
160    impl Provider for MockProvider {
161        fn name(&self) -> &str {
162            self.name
163        }
164
165        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
166            Ok(ProviderRequest::new("http://example.com"))
167        }
168
169        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
170            self.attempts.fetch_add(1, Ordering::Relaxed);
171            match self.result {
172                MockResult::Ok => Ok(ProviderResponse::new(200, serde_json::Value::Null)),
173                MockResult::RetryableError => Err(SimpleAgentsError::Provider(
174                    ProviderError::Timeout(std::time::Duration::from_secs(1)),
175                )),
176                MockResult::NonRetryableError => {
177                    Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
178                }
179            }
180        }
181
182        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
183            Ok(CompletionResponse {
184                id: "resp_test".to_string(),
185                model: "test-model".to_string(),
186                choices: vec![CompletionChoice {
187                    index: 0,
188                    message: Message::assistant("ok"),
189                    finish_reason: FinishReason::Stop,
190                    logprobs: None,
191                }],
192                usage: Usage::new(1, 1),
193                created: None,
194                provider: Some(self.name().to_string()),
195                healing_metadata: None,
196            })
197        }
198    }
199
200    fn build_request() -> CompletionRequest {
201        CompletionRequest::builder()
202            .model("test-model")
203            .message(Message::user("hello"))
204            .build()
205            .unwrap()
206    }
207
208    #[test]
209    fn empty_router_returns_error() {
210        let result = FallbackRouter::new(Vec::new());
211        match result {
212            Ok(_) => panic!("expected error, got Ok"),
213            Err(SimpleAgentsError::Routing(message)) => {
214                assert_eq!(message, "no providers configured");
215            }
216            Err(_) => panic!("unexpected error type"),
217        }
218    }
219
220    #[tokio::test]
221    async fn falls_back_on_retryable_error() {
222        let router = FallbackRouter::new(vec![
223            Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
224            Arc::new(MockProvider::new("p2", MockResult::Ok)),
225        ])
226        .unwrap();
227
228        let response = router.complete(&build_request()).await.unwrap();
229        assert_eq!(response.provider.as_deref(), Some("p2"));
230    }
231
232    #[tokio::test]
233    async fn stops_on_non_retryable_error() {
234        let router = FallbackRouter::new(vec![
235            Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
236            Arc::new(MockProvider::new("p2", MockResult::Ok)),
237        ])
238        .unwrap();
239
240        let err = router.complete(&build_request()).await.unwrap_err();
241        match err {
242            SimpleAgentsError::Provider(ProviderError::InvalidApiKey) => {}
243            _ => panic!("unexpected error"),
244        }
245    }
246
247    #[tokio::test]
248    async fn falls_back_on_all_errors_when_configured() {
249        let config = FallbackRouterConfig {
250            retryable_only: false,
251        };
252        let router = FallbackRouter::with_config(
253            vec![
254                Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
255                Arc::new(MockProvider::new("p2", MockResult::Ok)),
256            ],
257            config,
258        )
259        .unwrap();
260
261        let response = router.complete(&build_request()).await.unwrap();
262        assert_eq!(response.provider.as_deref(), Some("p2"));
263    }
264}