synaptic_middleware/
tool_retry.rs1use std::time::Duration;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use synaptic_core::SynapticError;
6
7use crate::{AgentMiddleware, ToolCallRequest, ToolCaller};
8
9pub struct ToolRetryMiddleware {
11 max_retries: usize,
12 base_delay: Duration,
13}
14
15impl ToolRetryMiddleware {
16 pub fn new(max_retries: usize) -> Self {
17 Self {
18 max_retries,
19 base_delay: Duration::from_millis(100),
20 }
21 }
22
23 pub fn with_base_delay(mut self, delay: Duration) -> Self {
24 self.base_delay = delay;
25 self
26 }
27}
28
29#[async_trait]
30impl AgentMiddleware for ToolRetryMiddleware {
31 async fn wrap_tool_call(
32 &self,
33 request: ToolCallRequest,
34 next: &dyn ToolCaller,
35 ) -> Result<Value, SynapticError> {
36 let mut last_err = None;
37 for attempt in 0..=self.max_retries {
38 match next.call(request.clone()).await {
39 Ok(val) => return Ok(val),
40 Err(e) => {
41 last_err = Some(e);
42 if attempt < self.max_retries {
43 let delay = self.base_delay * 2u32.saturating_pow(attempt as u32);
44 tokio::time::sleep(delay).await;
45 }
46 }
47 }
48 }
49 Err(last_err.unwrap())
50 }
51}