1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderError, Result,
7 SimpleAgentsError,
8};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy)]
13pub struct FallbackRouterConfig {
14 pub retryable_only: bool,
16}
17
18impl Default for FallbackRouterConfig {
19 fn default() -> Self {
20 Self {
21 retryable_only: true,
22 }
23 }
24}
25
26pub struct FallbackRouter {
28 providers: Vec<Arc<dyn Provider>>,
29 config: FallbackRouterConfig,
30}
31
32impl FallbackRouter {
33 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
38 Self::with_config(providers, FallbackRouterConfig::default())
39 }
40
41 pub fn with_config(
46 providers: Vec<Arc<dyn Provider>>,
47 config: FallbackRouterConfig,
48 ) -> Result<Self> {
49 if providers.is_empty() {
50 return Err(SimpleAgentsError::Routing(
51 "no providers configured".to_string(),
52 ));
53 }
54
55 Ok(Self { providers, config })
56 }
57
58 pub fn provider_count(&self) -> usize {
60 self.providers.len()
61 }
62
63 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
65 let mut last_error: Option<SimpleAgentsError> = None;
66
67 for provider in &self.providers {
68 let attempt = self.execute_provider(provider, request).await;
69 match attempt {
70 Ok(response) => return Ok(response),
71 Err(err) => {
72 if !self.should_fallback(&err) {
73 return Err(err);
74 }
75 last_error = Some(err);
76 }
77 }
78 }
79
80 Err(last_error
81 .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
82 }
83
84 pub async fn stream(
86 &self,
87 request: &CompletionRequest,
88 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
89 let mut last_error: Option<SimpleAgentsError> = None;
90
91 for provider in &self.providers {
92 let provider_request = provider.transform_request(request)?;
93 match provider.execute_stream(provider_request).await {
94 Ok(stream) => return Ok(stream),
95 Err(err) => {
96 if !self.should_fallback(&err) {
97 return Err(err);
98 }
99 last_error = Some(err);
100 }
101 }
102 }
103
104 Err(last_error
105 .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
106 }
107
108 async fn execute_provider(
109 &self,
110 provider: &Arc<dyn Provider>,
111 request: &CompletionRequest,
112 ) -> Result<CompletionResponse> {
113 let provider_request = provider.transform_request(request)?;
114 let provider_response = provider.execute(provider_request).await?;
115 provider.transform_response(provider_response)
116 }
117
118 fn should_fallback(&self, error: &SimpleAgentsError) -> bool {
119 if !self.config.retryable_only {
120 return true;
121 }
122
123 matches!(
124 error,
125 SimpleAgentsError::Provider(
126 ProviderError::RateLimit { .. }
127 | ProviderError::Timeout(_)
128 | ProviderError::ServerError(_)
129 ) | SimpleAgentsError::Network(_)
130 )
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use async_trait::async_trait;
138 use simple_agent_type::prelude::*;
139 use std::sync::atomic::{AtomicUsize, Ordering};
140
141 struct MockProvider {
142 name: &'static str,
143 attempts: AtomicUsize,
144 result: MockResult,
145 }
146
147 enum MockResult {
148 Ok,
149 RetryableError,
150 NonRetryableError,
151 }
152
153 impl MockProvider {
154 fn new(name: &'static str, result: MockResult) -> Self {
155 Self {
156 name,
157 attempts: AtomicUsize::new(0),
158 result,
159 }
160 }
161 }
162
163 #[async_trait]
164 impl Provider for MockProvider {
165 fn name(&self) -> &str {
166 self.name
167 }
168
169 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
170 Ok(ProviderRequest::new("http://example.com"))
171 }
172
173 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
174 self.attempts.fetch_add(1, Ordering::Relaxed);
175 match self.result {
176 MockResult::Ok => Ok(ProviderResponse::new(200, serde_json::Value::Null)),
177 MockResult::RetryableError => Err(SimpleAgentsError::Provider(
178 ProviderError::Timeout(std::time::Duration::from_secs(1)),
179 )),
180 MockResult::NonRetryableError => {
181 Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
182 }
183 }
184 }
185
186 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
187 Ok(CompletionResponse {
188 id: "resp_test".to_string(),
189 model: "test-model".to_string(),
190 choices: vec![CompletionChoice {
191 index: 0,
192 message: Message::assistant("ok"),
193 finish_reason: FinishReason::Stop,
194 logprobs: None,
195 }],
196 usage: Usage::new(1, 1),
197 created: None,
198 provider: Some(self.name().to_string()),
199 healing_metadata: None,
200 })
201 }
202 }
203
204 fn build_request() -> CompletionRequest {
205 CompletionRequest::builder()
206 .model("test-model")
207 .message(Message::user("hello"))
208 .build()
209 .unwrap()
210 }
211
212 #[test]
213 fn empty_router_returns_error() {
214 let result = FallbackRouter::new(Vec::new());
215 match result {
216 Ok(_) => panic!("expected error, got Ok"),
217 Err(SimpleAgentsError::Routing(message)) => {
218 assert_eq!(message, "no providers configured");
219 }
220 Err(_) => panic!("unexpected error type"),
221 }
222 }
223
224 #[tokio::test]
225 async fn falls_back_on_retryable_error() {
226 let router = FallbackRouter::new(vec![
227 Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
228 Arc::new(MockProvider::new("p2", MockResult::Ok)),
229 ])
230 .unwrap();
231
232 let response = router.complete(&build_request()).await.unwrap();
233 assert_eq!(response.provider.as_deref(), Some("p2"));
234 }
235
236 #[tokio::test]
237 async fn stops_on_non_retryable_error() {
238 let router = FallbackRouter::new(vec![
239 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
240 Arc::new(MockProvider::new("p2", MockResult::Ok)),
241 ])
242 .unwrap();
243
244 let err = router.complete(&build_request()).await.unwrap_err();
245 match err {
246 SimpleAgentsError::Provider(ProviderError::InvalidApiKey) => {}
247 _ => panic!("unexpected error"),
248 }
249 }
250
251 #[tokio::test]
252 async fn falls_back_on_all_errors_when_configured() {
253 let config = FallbackRouterConfig {
254 retryable_only: false,
255 };
256 let router = FallbackRouter::with_config(
257 vec![
258 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
259 Arc::new(MockProvider::new("p2", MockResult::Ok)),
260 ],
261 config,
262 )
263 .unwrap();
264
265 let response = router.complete(&build_request()).await.unwrap();
266 assert_eq!(response.provider.as_deref(), Some("p2"));
267 }
268
269 #[tokio::test]
270 async fn stream_returns_last_provider_error() {
271 let router = FallbackRouter::new(vec![
272 Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
273 Arc::new(MockProvider::new("p2", MockResult::RetryableError)),
274 ])
275 .unwrap();
276
277 let err = match router.stream(&build_request()).await {
278 Ok(_) => panic!("expected stream setup to fail"),
279 Err(err) => err,
280 };
281 match err {
282 SimpleAgentsError::Provider(ProviderError::Timeout(_)) => {}
283 _ => panic!("unexpected error"),
284 }
285 }
286}