synaptic_middleware/
model_fallback.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, SynapticError};
5
6use crate::{AgentMiddleware, BaseChatModelCaller, ModelCaller, ModelRequest, ModelResponse};
7
8pub struct ModelFallbackMiddleware {
13 fallbacks: Vec<Arc<dyn ChatModel>>,
14}
15
16impl ModelFallbackMiddleware {
17 pub fn new(fallbacks: Vec<Arc<dyn ChatModel>>) -> Self {
18 Self { fallbacks }
19 }
20}
21
22#[async_trait]
23impl AgentMiddleware for ModelFallbackMiddleware {
24 async fn wrap_model_call(
25 &self,
26 request: ModelRequest,
27 next: &dyn ModelCaller,
28 ) -> Result<ModelResponse, SynapticError> {
29 match next.call(request.clone()).await {
30 Ok(resp) => Ok(resp),
31 Err(primary_err) => {
32 for fallback in &self.fallbacks {
33 let caller = BaseChatModelCaller::new(fallback.clone());
34 match caller.call(request.clone()).await {
35 Ok(resp) => return Ok(resp),
36 Err(_) => continue,
37 }
38 }
39 Err(primary_err)
41 }
42 }
43 }
44}