simple_agents_router/
circuit_breaker.rs1use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, Copy)]
12pub struct CircuitBreakerConfig {
13 pub failure_threshold: u32,
15 pub open_duration: Duration,
17 pub success_threshold: u32,
19}
20
21impl Default for CircuitBreakerConfig {
22 fn default() -> Self {
23 Self {
24 failure_threshold: 3,
25 open_duration: Duration::from_secs(10),
26 success_threshold: 1,
27 }
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32enum InternalState {
33 Closed,
34 Open { opened_at: Instant },
35 HalfOpen,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum CircuitBreakerState {
41 Closed,
43 Open,
45 HalfOpen,
47}
48
49#[derive(Debug)]
50struct CircuitBreakerInner {
51 state: InternalState,
52 consecutive_failures: u32,
53 consecutive_successes: u32,
54}
55
56impl CircuitBreakerInner {
57 fn new() -> Self {
58 Self {
59 state: InternalState::Closed,
60 consecutive_failures: 0,
61 consecutive_successes: 0,
62 }
63 }
64}
65
66#[derive(Debug)]
68pub struct CircuitBreaker {
69 config: CircuitBreakerConfig,
70 inner: Mutex<CircuitBreakerInner>,
71}
72
73impl CircuitBreaker {
74 pub fn new(config: CircuitBreakerConfig) -> Self {
76 Self {
77 config,
78 inner: Mutex::new(CircuitBreakerInner::new()),
79 }
80 }
81
82 pub fn allow_request(&self) -> bool {
84 let mut inner = self
85 .inner
86 .lock()
87 .unwrap_or_else(|poisoned| poisoned.into_inner());
88 match inner.state {
89 InternalState::Closed => true,
90 InternalState::Open { opened_at } => {
91 if opened_at.elapsed() >= self.config.open_duration {
92 inner.state = InternalState::HalfOpen;
93 inner.consecutive_successes = 0;
94 inner.consecutive_failures = 0;
95 true
96 } else {
97 false
98 }
99 }
100 InternalState::HalfOpen => true,
101 }
102 }
103
104 pub fn record_success(&self) {
106 let mut inner = self
107 .inner
108 .lock()
109 .unwrap_or_else(|poisoned| poisoned.into_inner());
110 match inner.state {
111 InternalState::Closed => {
112 inner.consecutive_failures = 0;
113 }
114 InternalState::HalfOpen => {
115 inner.consecutive_successes = inner.consecutive_successes.saturating_add(1);
116 if inner.consecutive_successes >= self.config.success_threshold {
117 inner.state = InternalState::Closed;
118 inner.consecutive_failures = 0;
119 inner.consecutive_successes = 0;
120 }
121 }
122 InternalState::Open { .. } => {}
123 }
124 }
125
126 pub fn record_failure(&self) {
128 let mut inner = self
129 .inner
130 .lock()
131 .unwrap_or_else(|poisoned| poisoned.into_inner());
132 match inner.state {
133 InternalState::Closed => {
134 inner.consecutive_failures = inner.consecutive_failures.saturating_add(1);
135 if inner.consecutive_failures >= self.config.failure_threshold {
136 inner.state = InternalState::Open {
137 opened_at: Instant::now(),
138 };
139 }
140 }
141 InternalState::HalfOpen => {
142 inner.state = InternalState::Open {
143 opened_at: Instant::now(),
144 };
145 inner.consecutive_failures = 1;
146 inner.consecutive_successes = 0;
147 }
148 InternalState::Open { .. } => {}
149 }
150 }
151
152 pub fn record_result(&self, result: &std::result::Result<(), SimpleAgentsError>) {
154 match result {
155 Ok(_) => self.record_success(),
156 Err(error) => {
157 if matches!(
158 error,
159 SimpleAgentsError::Provider(
160 ProviderError::RateLimit { .. }
161 | ProviderError::Timeout(_)
162 | ProviderError::ServerError(_)
163 ) | SimpleAgentsError::Network(_)
164 ) {
165 self.record_failure();
166 }
167 }
168 }
169 }
170
171 pub fn state(&self) -> CircuitBreakerState {
173 let inner = self
174 .inner
175 .lock()
176 .unwrap_or_else(|poisoned| poisoned.into_inner());
177 match inner.state {
178 InternalState::Closed => CircuitBreakerState::Closed,
179 InternalState::Open { .. } => CircuitBreakerState::Open,
180 InternalState::HalfOpen => CircuitBreakerState::HalfOpen,
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn opens_after_failures() {
191 let breaker = CircuitBreaker::new(CircuitBreakerConfig {
192 failure_threshold: 2,
193 open_duration: Duration::from_secs(10),
194 success_threshold: 1,
195 });
196
197 assert!(breaker.allow_request());
198 breaker.record_failure();
199 assert!(breaker.allow_request());
200 breaker.record_failure();
201 assert_eq!(breaker.state(), CircuitBreakerState::Open);
202 assert!(!breaker.allow_request());
203 }
204
205 #[test]
206 fn closes_after_success_in_half_open() {
207 let breaker = CircuitBreaker::new(CircuitBreakerConfig {
208 failure_threshold: 1,
209 open_duration: Duration::from_millis(0),
210 success_threshold: 1,
211 });
212
213 breaker.record_failure();
214 assert_eq!(breaker.state(), CircuitBreakerState::Open);
215 assert!(breaker.allow_request());
216 assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
217 breaker.record_success();
218 assert_eq!(breaker.state(), CircuitBreakerState::Closed);
219 }
220
221 #[test]
222 fn reopens_on_failure_in_half_open() {
223 let breaker = CircuitBreaker::new(CircuitBreakerConfig {
224 failure_threshold: 1,
225 open_duration: Duration::from_millis(0),
226 success_threshold: 2,
227 });
228
229 breaker.record_failure();
230 assert_eq!(breaker.state(), CircuitBreakerState::Open);
231 assert!(breaker.allow_request());
232 assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
233 breaker.record_failure();
234 assert_eq!(breaker.state(), CircuitBreakerState::Open);
235 }
236}