Skip to main content

synaptic_middleware/
circuit_breaker.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use serde_json::Value;
7use synaptic_core::SynapticError;
8use tokio::sync::RwLock;
9
10use crate::{
11    AgentMiddleware, ModelCaller, ModelRequest, ModelResponse, ToolCallRequest, ToolCaller,
12};
13
14/// Circuit breaker state machine states.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CircuitState {
17    /// Normal operation — requests flow through.
18    Closed,
19    /// Failures exceeded threshold — requests are rejected immediately.
20    Open,
21    /// After recovery timeout — allows a single probe request.
22    HalfOpen,
23}
24
25/// Per-target circuit state.
26#[derive(Debug)]
27struct CircuitTracker {
28    state: CircuitState,
29    failure_count: usize,
30    last_failure: Option<Instant>,
31}
32
33impl CircuitTracker {
34    fn new() -> Self {
35        Self {
36            state: CircuitState::Closed,
37            failure_count: 0,
38            last_failure: None,
39        }
40    }
41}
42
43/// Configuration for the circuit breaker.
44#[derive(Debug, Clone)]
45pub struct CircuitBreakerConfig {
46    /// Number of consecutive failures before opening the circuit.
47    pub failure_threshold: usize,
48    /// Time to wait before transitioning from Open to HalfOpen.
49    pub recovery_timeout: Duration,
50}
51
52impl Default for CircuitBreakerConfig {
53    fn default() -> Self {
54        Self {
55            failure_threshold: 5,
56            recovery_timeout: Duration::from_secs(60),
57        }
58    }
59}
60
61/// Middleware that implements the circuit breaker pattern for tool calls.
62///
63/// Tracks failures per tool name and opens the circuit when failures
64/// exceed the configured threshold. After the recovery timeout, a single
65/// probe request is allowed through (half-open). If it succeeds, the
66/// circuit closes; if it fails, the circuit reopens.
67pub struct CircuitBreakerMiddleware {
68    config: CircuitBreakerConfig,
69    circuits: Arc<RwLock<HashMap<String, CircuitTracker>>>,
70}
71
72impl CircuitBreakerMiddleware {
73    pub fn new(config: CircuitBreakerConfig) -> Self {
74        Self {
75            config,
76            circuits: Arc::new(RwLock::new(HashMap::new())),
77        }
78    }
79
80    /// Get the current state for a given tool name.
81    pub async fn state_for(&self, tool_name: &str) -> CircuitState {
82        let circuits = self.circuits.read().await;
83        circuits
84            .get(tool_name)
85            .map(|t| {
86                if t.state == CircuitState::Open {
87                    // Check if recovery timeout has elapsed
88                    if let Some(last_failure) = t.last_failure {
89                        if last_failure.elapsed() >= self.config.recovery_timeout {
90                            return CircuitState::HalfOpen;
91                        }
92                    }
93                }
94                t.state
95            })
96            .unwrap_or(CircuitState::Closed)
97    }
98
99    async fn record_success(&self, tool_name: &str) {
100        let mut circuits = self.circuits.write().await;
101        let tracker = circuits
102            .entry(tool_name.to_string())
103            .or_insert_with(CircuitTracker::new);
104        tracker.state = CircuitState::Closed;
105        tracker.failure_count = 0;
106    }
107
108    async fn record_failure(&self, tool_name: &str) {
109        let mut circuits = self.circuits.write().await;
110        let tracker = circuits
111            .entry(tool_name.to_string())
112            .or_insert_with(CircuitTracker::new);
113        tracker.failure_count += 1;
114        tracker.last_failure = Some(Instant::now());
115        if tracker.failure_count >= self.config.failure_threshold {
116            tracker.state = CircuitState::Open;
117        }
118    }
119}
120
121#[async_trait]
122impl AgentMiddleware for CircuitBreakerMiddleware {
123    async fn wrap_tool_call(
124        &self,
125        request: ToolCallRequest,
126        next: &dyn ToolCaller,
127    ) -> Result<Value, SynapticError> {
128        let tool_name = &request.call.name;
129        let state = self.state_for(tool_name).await;
130
131        match state {
132            CircuitState::Open => Err(SynapticError::Tool(format!(
133                "circuit breaker open for tool '{}' — too many consecutive failures",
134                tool_name
135            ))),
136            CircuitState::HalfOpen | CircuitState::Closed => {
137                match next.call(request.clone()).await {
138                    Ok(result) => {
139                        self.record_success(tool_name).await;
140                        Ok(result)
141                    }
142                    Err(e) => {
143                        self.record_failure(tool_name).await;
144                        Err(e)
145                    }
146                }
147            }
148        }
149    }
150
151    async fn wrap_model_call(
152        &self,
153        request: ModelRequest,
154        next: &dyn ModelCaller,
155    ) -> Result<ModelResponse, SynapticError> {
156        let state = self.state_for("__model__").await;
157
158        match state {
159            CircuitState::Open => Err(SynapticError::Model(
160                "circuit breaker open for model — too many consecutive failures".to_string(),
161            )),
162            CircuitState::HalfOpen | CircuitState::Closed => match next.call(request).await {
163                Ok(result) => {
164                    self.record_success("__model__").await;
165                    Ok(result)
166                }
167                Err(e) => {
168                    self.record_failure("__model__").await;
169                    Err(e)
170                }
171            },
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[tokio::test]
181    async fn circuit_starts_closed() {
182        let cb = CircuitBreakerMiddleware::new(CircuitBreakerConfig::default());
183        assert_eq!(cb.state_for("test_tool").await, CircuitState::Closed);
184    }
185
186    #[tokio::test]
187    async fn circuit_opens_after_threshold() {
188        let config = CircuitBreakerConfig {
189            failure_threshold: 3,
190            recovery_timeout: Duration::from_secs(60),
191        };
192        let cb = CircuitBreakerMiddleware::new(config);
193
194        cb.record_failure("tool_a").await;
195        cb.record_failure("tool_a").await;
196        assert_eq!(cb.state_for("tool_a").await, CircuitState::Closed);
197
198        cb.record_failure("tool_a").await;
199        assert_eq!(cb.state_for("tool_a").await, CircuitState::Open);
200    }
201
202    #[tokio::test]
203    async fn circuit_transitions_to_half_open() {
204        let config = CircuitBreakerConfig {
205            failure_threshold: 1,
206            recovery_timeout: Duration::from_millis(10),
207        };
208        let cb = CircuitBreakerMiddleware::new(config);
209
210        cb.record_failure("tool_a").await;
211        assert_eq!(cb.state_for("tool_a").await, CircuitState::Open);
212
213        tokio::time::sleep(Duration::from_millis(20)).await;
214        assert_eq!(cb.state_for("tool_a").await, CircuitState::HalfOpen);
215    }
216
217    #[tokio::test]
218    async fn success_resets_circuit() {
219        let config = CircuitBreakerConfig {
220            failure_threshold: 2,
221            recovery_timeout: Duration::from_secs(60),
222        };
223        let cb = CircuitBreakerMiddleware::new(config);
224
225        cb.record_failure("tool_a").await;
226        cb.record_success("tool_a").await;
227        assert_eq!(cb.state_for("tool_a").await, CircuitState::Closed);
228    }
229
230    #[tokio::test]
231    async fn per_tool_isolation() {
232        let config = CircuitBreakerConfig {
233            failure_threshold: 1,
234            recovery_timeout: Duration::from_secs(60),
235        };
236        let cb = CircuitBreakerMiddleware::new(config);
237
238        cb.record_failure("tool_a").await;
239        assert_eq!(cb.state_for("tool_a").await, CircuitState::Open);
240        assert_eq!(cb.state_for("tool_b").await, CircuitState::Closed);
241    }
242}