synaptic_middleware/
circuit_breaker.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use serde_json::Value;
7use synaptic_core::SynapticError;
8use tokio::sync::RwLock;
9
10use crate::{
11 AgentMiddleware, ModelCaller, ModelRequest, ModelResponse, ToolCallRequest, ToolCaller,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CircuitState {
17 Closed,
19 Open,
21 HalfOpen,
23}
24
25#[derive(Debug)]
27struct CircuitTracker {
28 state: CircuitState,
29 failure_count: usize,
30 last_failure: Option<Instant>,
31}
32
33impl CircuitTracker {
34 fn new() -> Self {
35 Self {
36 state: CircuitState::Closed,
37 failure_count: 0,
38 last_failure: None,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct CircuitBreakerConfig {
46 pub failure_threshold: usize,
48 pub recovery_timeout: Duration,
50}
51
52impl Default for CircuitBreakerConfig {
53 fn default() -> Self {
54 Self {
55 failure_threshold: 5,
56 recovery_timeout: Duration::from_secs(60),
57 }
58 }
59}
60
61pub struct CircuitBreakerMiddleware {
68 config: CircuitBreakerConfig,
69 circuits: Arc<RwLock<HashMap<String, CircuitTracker>>>,
70}
71
72impl CircuitBreakerMiddleware {
73 pub fn new(config: CircuitBreakerConfig) -> Self {
74 Self {
75 config,
76 circuits: Arc::new(RwLock::new(HashMap::new())),
77 }
78 }
79
80 pub async fn state_for(&self, tool_name: &str) -> CircuitState {
82 let circuits = self.circuits.read().await;
83 circuits
84 .get(tool_name)
85 .map(|t| {
86 if t.state == CircuitState::Open {
87 if let Some(last_failure) = t.last_failure {
89 if last_failure.elapsed() >= self.config.recovery_timeout {
90 return CircuitState::HalfOpen;
91 }
92 }
93 }
94 t.state
95 })
96 .unwrap_or(CircuitState::Closed)
97 }
98
99 async fn record_success(&self, tool_name: &str) {
100 let mut circuits = self.circuits.write().await;
101 let tracker = circuits
102 .entry(tool_name.to_string())
103 .or_insert_with(CircuitTracker::new);
104 tracker.state = CircuitState::Closed;
105 tracker.failure_count = 0;
106 }
107
108 async fn record_failure(&self, tool_name: &str) {
109 let mut circuits = self.circuits.write().await;
110 let tracker = circuits
111 .entry(tool_name.to_string())
112 .or_insert_with(CircuitTracker::new);
113 tracker.failure_count += 1;
114 tracker.last_failure = Some(Instant::now());
115 if tracker.failure_count >= self.config.failure_threshold {
116 tracker.state = CircuitState::Open;
117 }
118 }
119}
120
121#[async_trait]
122impl AgentMiddleware for CircuitBreakerMiddleware {
123 async fn wrap_tool_call(
124 &self,
125 request: ToolCallRequest,
126 next: &dyn ToolCaller,
127 ) -> Result<Value, SynapticError> {
128 let tool_name = &request.call.name;
129 let state = self.state_for(tool_name).await;
130
131 match state {
132 CircuitState::Open => Err(SynapticError::Tool(format!(
133 "circuit breaker open for tool '{}' — too many consecutive failures",
134 tool_name
135 ))),
136 CircuitState::HalfOpen | CircuitState::Closed => {
137 match next.call(request.clone()).await {
138 Ok(result) => {
139 self.record_success(tool_name).await;
140 Ok(result)
141 }
142 Err(e) => {
143 self.record_failure(tool_name).await;
144 Err(e)
145 }
146 }
147 }
148 }
149 }
150
151 async fn wrap_model_call(
152 &self,
153 request: ModelRequest,
154 next: &dyn ModelCaller,
155 ) -> Result<ModelResponse, SynapticError> {
156 let state = self.state_for("__model__").await;
157
158 match state {
159 CircuitState::Open => Err(SynapticError::Model(
160 "circuit breaker open for model — too many consecutive failures".to_string(),
161 )),
162 CircuitState::HalfOpen | CircuitState::Closed => match next.call(request).await {
163 Ok(result) => {
164 self.record_success("__model__").await;
165 Ok(result)
166 }
167 Err(e) => {
168 self.record_failure("__model__").await;
169 Err(e)
170 }
171 },
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[tokio::test]
181 async fn circuit_starts_closed() {
182 let cb = CircuitBreakerMiddleware::new(CircuitBreakerConfig::default());
183 assert_eq!(cb.state_for("test_tool").await, CircuitState::Closed);
184 }
185
186 #[tokio::test]
187 async fn circuit_opens_after_threshold() {
188 let config = CircuitBreakerConfig {
189 failure_threshold: 3,
190 recovery_timeout: Duration::from_secs(60),
191 };
192 let cb = CircuitBreakerMiddleware::new(config);
193
194 cb.record_failure("tool_a").await;
195 cb.record_failure("tool_a").await;
196 assert_eq!(cb.state_for("tool_a").await, CircuitState::Closed);
197
198 cb.record_failure("tool_a").await;
199 assert_eq!(cb.state_for("tool_a").await, CircuitState::Open);
200 }
201
202 #[tokio::test]
203 async fn circuit_transitions_to_half_open() {
204 let config = CircuitBreakerConfig {
205 failure_threshold: 1,
206 recovery_timeout: Duration::from_millis(10),
207 };
208 let cb = CircuitBreakerMiddleware::new(config);
209
210 cb.record_failure("tool_a").await;
211 assert_eq!(cb.state_for("tool_a").await, CircuitState::Open);
212
213 tokio::time::sleep(Duration::from_millis(20)).await;
214 assert_eq!(cb.state_for("tool_a").await, CircuitState::HalfOpen);
215 }
216
217 #[tokio::test]
218 async fn success_resets_circuit() {
219 let config = CircuitBreakerConfig {
220 failure_threshold: 2,
221 recovery_timeout: Duration::from_secs(60),
222 };
223 let cb = CircuitBreakerMiddleware::new(config);
224
225 cb.record_failure("tool_a").await;
226 cb.record_success("tool_a").await;
227 assert_eq!(cb.state_for("tool_a").await, CircuitState::Closed);
228 }
229
230 #[tokio::test]
231 async fn per_tool_isolation() {
232 let config = CircuitBreakerConfig {
233 failure_threshold: 1,
234 recovery_timeout: Duration::from_secs(60),
235 };
236 let cb = CircuitBreakerMiddleware::new(config);
237
238 cb.record_failure("tool_a").await;
239 assert_eq!(cb.state_for("tool_a").await, CircuitState::Open);
240 assert_eq!(cb.state_for("tool_b").await, CircuitState::Closed);
241 }
242}