Skip to main content

synaptic_middleware/
tool_retry.rs

1use std::time::Duration;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use synaptic_core::SynapticError;
6
7use crate::{AgentMiddleware, ToolCallRequest, ToolCaller};
8
9/// Retries failed tool calls with configurable attempts and backoff.
10pub struct ToolRetryMiddleware {
11    max_retries: usize,
12    base_delay: Duration,
13}
14
15impl ToolRetryMiddleware {
16    pub fn new(max_retries: usize) -> Self {
17        Self {
18            max_retries,
19            base_delay: Duration::from_millis(100),
20        }
21    }
22
23    pub fn with_base_delay(mut self, delay: Duration) -> Self {
24        self.base_delay = delay;
25        self
26    }
27}
28
29#[async_trait]
30impl AgentMiddleware for ToolRetryMiddleware {
31    async fn wrap_tool_call(
32        &self,
33        request: ToolCallRequest,
34        next: &dyn ToolCaller,
35    ) -> Result<Value, SynapticError> {
36        let mut last_err = None;
37        for attempt in 0..=self.max_retries {
38            match next.call(request.clone()).await {
39                Ok(val) => return Ok(val),
40                Err(e) => {
41                    last_err = Some(e);
42                    if attempt < self.max_retries {
43                        let delay = self.base_delay * 2u32.saturating_pow(attempt as u32);
44                        tokio::time::sleep(delay).await;
45                    }
46                }
47            }
48        }
49        Err(last_err.unwrap())
50    }
51}