simple_agents_router/
fallback.rs1use simple_agent_type::prelude::{
6 CompletionRequest, CompletionResponse, Provider, ProviderError, Result, SimpleAgentsError,
7};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy)]
12pub struct FallbackRouterConfig {
13 pub retryable_only: bool,
15}
16
17impl Default for FallbackRouterConfig {
18 fn default() -> Self {
19 Self {
20 retryable_only: true,
21 }
22 }
23}
24
25pub struct FallbackRouter {
27 providers: Vec<Arc<dyn Provider>>,
28 config: FallbackRouterConfig,
29}
30
31impl FallbackRouter {
32 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
37 Self::with_config(providers, FallbackRouterConfig::default())
38 }
39
40 pub fn with_config(
45 providers: Vec<Arc<dyn Provider>>,
46 config: FallbackRouterConfig,
47 ) -> Result<Self> {
48 if providers.is_empty() {
49 return Err(SimpleAgentsError::Routing(
50 "no providers configured".to_string(),
51 ));
52 }
53
54 Ok(Self { providers, config })
55 }
56
57 pub fn provider_count(&self) -> usize {
59 self.providers.len()
60 }
61
62 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
64 let mut last_error: Option<SimpleAgentsError> = None;
65
66 for provider in &self.providers {
67 let attempt = self.execute_provider(provider, request).await;
68 match attempt {
69 Ok(response) => return Ok(response),
70 Err(err) => {
71 if !self.should_fallback(&err) {
72 return Err(err);
73 }
74 last_error = Some(err);
75 }
76 }
77 }
78
79 Err(last_error
80 .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
81 }
82
83 async fn execute_provider(
84 &self,
85 provider: &Arc<dyn Provider>,
86 request: &CompletionRequest,
87 ) -> Result<CompletionResponse> {
88 let provider_request = provider.transform_request(request)?;
89 let provider_response = provider.execute(provider_request).await?;
90 provider.transform_response(provider_response)
91 }
92
93 fn should_fallback(&self, error: &SimpleAgentsError) -> bool {
94 if !self.config.retryable_only {
95 return true;
96 }
97
98 matches!(
99 error,
100 SimpleAgentsError::Provider(
101 ProviderError::RateLimit { .. }
102 | ProviderError::Timeout(_)
103 | ProviderError::ServerError(_)
104 ) | SimpleAgentsError::Network(_)
105 )
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use async_trait::async_trait;
113 use simple_agent_type::prelude::*;
114 use std::sync::atomic::{AtomicUsize, Ordering};
115
116 struct MockProvider {
117 name: &'static str,
118 attempts: AtomicUsize,
119 result: MockResult,
120 }
121
122 enum MockResult {
123 Ok,
124 RetryableError,
125 NonRetryableError,
126 }
127
128 impl MockProvider {
129 fn new(name: &'static str, result: MockResult) -> Self {
130 Self {
131 name,
132 attempts: AtomicUsize::new(0),
133 result,
134 }
135 }
136 }
137
138 #[async_trait]
139 impl Provider for MockProvider {
140 fn name(&self) -> &str {
141 self.name
142 }
143
144 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
145 Ok(ProviderRequest::new("http://example.com"))
146 }
147
148 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
149 self.attempts.fetch_add(1, Ordering::Relaxed);
150 match self.result {
151 MockResult::Ok => Ok(ProviderResponse::new(200, serde_json::Value::Null)),
152 MockResult::RetryableError => Err(SimpleAgentsError::Provider(
153 ProviderError::Timeout(std::time::Duration::from_secs(1)),
154 )),
155 MockResult::NonRetryableError => {
156 Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
157 }
158 }
159 }
160
161 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
162 Ok(CompletionResponse {
163 id: "resp_test".to_string(),
164 model: "test-model".to_string(),
165 choices: vec![CompletionChoice {
166 index: 0,
167 message: Message::assistant("ok"),
168 finish_reason: FinishReason::Stop,
169 logprobs: None,
170 }],
171 usage: Usage::new(1, 1),
172 created: None,
173 provider: Some(self.name().to_string()),
174 healing_metadata: None,
175 })
176 }
177 }
178
179 fn build_request() -> CompletionRequest {
180 CompletionRequest::builder()
181 .model("test-model")
182 .message(Message::user("hello"))
183 .build()
184 .unwrap()
185 }
186
187 #[test]
188 fn empty_router_returns_error() {
189 let result = FallbackRouter::new(Vec::new());
190 match result {
191 Ok(_) => panic!("expected error, got Ok"),
192 Err(SimpleAgentsError::Routing(message)) => {
193 assert_eq!(message, "no providers configured");
194 }
195 Err(_) => panic!("unexpected error type"),
196 }
197 }
198
199 #[tokio::test]
200 async fn falls_back_on_retryable_error() {
201 let router = FallbackRouter::new(vec![
202 Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
203 Arc::new(MockProvider::new("p2", MockResult::Ok)),
204 ])
205 .unwrap();
206
207 let response = router.complete(&build_request()).await.unwrap();
208 assert_eq!(response.provider.as_deref(), Some("p2"));
209 }
210
211 #[tokio::test]
212 async fn stops_on_non_retryable_error() {
213 let router = FallbackRouter::new(vec![
214 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
215 Arc::new(MockProvider::new("p2", MockResult::Ok)),
216 ])
217 .unwrap();
218
219 let err = router.complete(&build_request()).await.unwrap_err();
220 match err {
221 SimpleAgentsError::Provider(ProviderError::InvalidApiKey) => {}
222 _ => panic!("unexpected error"),
223 }
224 }
225
226 #[tokio::test]
227 async fn falls_back_on_all_errors_when_configured() {
228 let config = FallbackRouterConfig {
229 retryable_only: false,
230 };
231 let router = FallbackRouter::with_config(
232 vec![
233 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
234 Arc::new(MockProvider::new("p2", MockResult::Ok)),
235 ],
236 config,
237 )
238 .unwrap();
239
240 let response = router.complete(&build_request()).await.unwrap();
241 assert_eq!(response.provider.as_deref(), Some("p2"));
242 }
243}