simple_agents_router/
circuit_breaker.rs1use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, Copy)]
12pub struct CircuitBreakerConfig {
13 pub failure_threshold: u32,
15 pub open_duration: Duration,
17 pub success_threshold: u32,
19}
20
21impl Default for CircuitBreakerConfig {
22 fn default() -> Self {
23 Self {
24 failure_threshold: 3,
25 open_duration: Duration::from_secs(10),
26 success_threshold: 1,
27 }
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32enum InternalState {
33 Closed,
34 Open { opened_at: Instant },
35 HalfOpen,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum CircuitBreakerState {
41 Closed,
43 Open,
45 HalfOpen,
47}
48
49#[derive(Debug)]
50struct CircuitBreakerInner {
51 state: InternalState,
52 consecutive_failures: u32,
53 consecutive_successes: u32,
54}
55
56impl CircuitBreakerInner {
57 fn new() -> Self {
58 Self {
59 state: InternalState::Closed,
60 consecutive_failures: 0,
61 consecutive_successes: 0,
62 }
63 }
64}
65
66#[derive(Debug)]
68pub struct CircuitBreaker {
69 config: CircuitBreakerConfig,
70 inner: Mutex<CircuitBreakerInner>,
71}
72
73impl CircuitBreaker {
74 pub fn new(config: CircuitBreakerConfig) -> Self {
76 Self {
77 config,
78 inner: Mutex::new(CircuitBreakerInner::new()),
79 }
80 }
81
82 pub fn allow_request(&self) -> bool {
84 let mut inner = self.inner.lock().expect("circuit breaker lock poisoned");
85 match inner.state {
86 InternalState::Closed => true,
87 InternalState::Open { opened_at } => {
88 if opened_at.elapsed() >= self.config.open_duration {
89 inner.state = InternalState::HalfOpen;
90 inner.consecutive_successes = 0;
91 inner.consecutive_failures = 0;
92 true
93 } else {
94 false
95 }
96 }
97 InternalState::HalfOpen => true,
98 }
99 }
100
101 pub fn record_success(&self) {
103 let mut inner = self.inner.lock().expect("circuit breaker lock poisoned");
104 match inner.state {
105 InternalState::Closed => {
106 inner.consecutive_failures = 0;
107 }
108 InternalState::HalfOpen => {
109 inner.consecutive_successes = inner.consecutive_successes.saturating_add(1);
110 if inner.consecutive_successes >= self.config.success_threshold {
111 inner.state = InternalState::Closed;
112 inner.consecutive_failures = 0;
113 inner.consecutive_successes = 0;
114 }
115 }
116 InternalState::Open { .. } => {}
117 }
118 }
119
120 pub fn record_failure(&self) {
122 let mut inner = self.inner.lock().expect("circuit breaker lock poisoned");
123 match inner.state {
124 InternalState::Closed => {
125 inner.consecutive_failures = inner.consecutive_failures.saturating_add(1);
126 if inner.consecutive_failures >= self.config.failure_threshold {
127 inner.state = InternalState::Open {
128 opened_at: Instant::now(),
129 };
130 }
131 }
132 InternalState::HalfOpen => {
133 inner.state = InternalState::Open {
134 opened_at: Instant::now(),
135 };
136 inner.consecutive_failures = 1;
137 inner.consecutive_successes = 0;
138 }
139 InternalState::Open { .. } => {}
140 }
141 }
142
143 pub fn record_result(&self, result: &std::result::Result<(), SimpleAgentsError>) {
145 match result {
146 Ok(_) => self.record_success(),
147 Err(error) => {
148 if matches!(
149 error,
150 SimpleAgentsError::Provider(
151 ProviderError::RateLimit { .. }
152 | ProviderError::Timeout(_)
153 | ProviderError::ServerError(_)
154 ) | SimpleAgentsError::Network(_)
155 ) {
156 self.record_failure();
157 }
158 }
159 }
160 }
161
162 pub fn state(&self) -> CircuitBreakerState {
164 let inner = self.inner.lock().expect("circuit breaker lock poisoned");
165 match inner.state {
166 InternalState::Closed => CircuitBreakerState::Closed,
167 InternalState::Open { .. } => CircuitBreakerState::Open,
168 InternalState::HalfOpen => CircuitBreakerState::HalfOpen,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn opens_after_failures() {
179 let breaker = CircuitBreaker::new(CircuitBreakerConfig {
180 failure_threshold: 2,
181 open_duration: Duration::from_secs(10),
182 success_threshold: 1,
183 });
184
185 assert!(breaker.allow_request());
186 breaker.record_failure();
187 assert!(breaker.allow_request());
188 breaker.record_failure();
189 assert_eq!(breaker.state(), CircuitBreakerState::Open);
190 assert!(!breaker.allow_request());
191 }
192
193 #[test]
194 fn closes_after_success_in_half_open() {
195 let breaker = CircuitBreaker::new(CircuitBreakerConfig {
196 failure_threshold: 1,
197 open_duration: Duration::from_millis(0),
198 success_threshold: 1,
199 });
200
201 breaker.record_failure();
202 assert_eq!(breaker.state(), CircuitBreakerState::Open);
203 assert!(breaker.allow_request());
204 assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
205 breaker.record_success();
206 assert_eq!(breaker.state(), CircuitBreakerState::Closed);
207 }
208
209 #[test]
210 fn reopens_on_failure_in_half_open() {
211 let breaker = CircuitBreaker::new(CircuitBreakerConfig {
212 failure_threshold: 1,
213 open_duration: Duration::from_millis(0),
214 success_threshold: 2,
215 });
216
217 breaker.record_failure();
218 assert_eq!(breaker.state(), CircuitBreakerState::Open);
219 assert!(breaker.allow_request());
220 assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
221 breaker.record_failure();
222 assert_eq!(breaker.state(), CircuitBreakerState::Open);
223 }
224}