Skip to main content

synaptic_middleware/
model_fallback.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, SynapticError};
5
6use crate::{AgentMiddleware, BaseChatModelCaller, ModelCaller, ModelRequest, ModelResponse};
7
8/// Falls back to alternative models when the primary model fails.
9///
10/// On error from the primary model call, the middleware tries each
11/// fallback model in order until one succeeds.
12pub struct ModelFallbackMiddleware {
13    fallbacks: Vec<Arc<dyn ChatModel>>,
14}
15
16impl ModelFallbackMiddleware {
17    pub fn new(fallbacks: Vec<Arc<dyn ChatModel>>) -> Self {
18        Self { fallbacks }
19    }
20}
21
22#[async_trait]
23impl AgentMiddleware for ModelFallbackMiddleware {
24    async fn wrap_model_call(
25        &self,
26        request: ModelRequest,
27        next: &dyn ModelCaller,
28    ) -> Result<ModelResponse, SynapticError> {
29        match next.call(request.clone()).await {
30            Ok(resp) => Ok(resp),
31            Err(primary_err) => {
32                for fallback in &self.fallbacks {
33                    let caller = BaseChatModelCaller::new(fallback.clone());
34                    match caller.call(request.clone()).await {
35                        Ok(resp) => return Ok(resp),
36                        Err(_) => continue,
37                    }
38                }
39                // All fallbacks failed; return original error
40                Err(primary_err)
41            }
42        }
43    }
44}