rustkernel_core/resilience/
circuit_breaker.rs1use super::{ResilienceError, ResilienceResult};
30use serde::{Deserialize, Serialize};
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
33use std::time::{Duration, Instant};
34use tokio::sync::RwLock;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum CircuitState {
40 #[default]
42 Closed,
43 Open,
45 HalfOpen,
47}
48
49impl std::fmt::Display for CircuitState {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Self::Closed => write!(f, "closed"),
53 Self::Open => write!(f, "open"),
54 Self::HalfOpen => write!(f, "half-open"),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CircuitBreakerConfig {
62 pub failure_threshold: u32,
64 pub success_threshold: u32,
66 pub reset_timeout: Duration,
68 pub window_size: Duration,
70 pub half_open_max_requests: u32,
72}
73
74impl Default for CircuitBreakerConfig {
75 fn default() -> Self {
76 Self {
77 failure_threshold: 5,
78 success_threshold: 2,
79 reset_timeout: Duration::from_secs(30),
80 window_size: Duration::from_secs(60),
81 half_open_max_requests: 3,
82 }
83 }
84}
85
86impl CircuitBreakerConfig {
87 pub fn production() -> Self {
89 Self {
90 failure_threshold: 5,
91 success_threshold: 3,
92 reset_timeout: Duration::from_secs(60),
93 window_size: Duration::from_secs(120),
94 half_open_max_requests: 5,
95 }
96 }
97
98 pub fn failure_threshold(mut self, threshold: u32) -> Self {
100 self.failure_threshold = threshold;
101 self
102 }
103
104 pub fn success_threshold(mut self, threshold: u32) -> Self {
106 self.success_threshold = threshold;
107 self
108 }
109
110 pub fn reset_timeout(mut self, timeout: Duration) -> Self {
112 self.reset_timeout = timeout;
113 self
114 }
115
116 pub fn window_size(mut self, size: Duration) -> Self {
118 self.window_size = size;
119 self
120 }
121
122 pub fn half_open_max_requests(mut self, max: u32) -> Self {
124 self.half_open_max_requests = max;
125 self
126 }
127}
128
129pub struct CircuitBreaker {
131 kernel_id: String,
133 config: CircuitBreakerConfig,
135 inner: Arc<CircuitBreakerInner>,
137}
138
139struct CircuitBreakerInner {
140 state: RwLock<CircuitState>,
141 failure_count: AtomicU32,
142 success_count: AtomicU32,
143 last_failure_time: RwLock<Option<Instant>>,
144 half_open_requests: AtomicU32,
145 total_requests: AtomicU64,
146 total_failures: AtomicU64,
147}
148
149impl CircuitBreaker {
150 pub fn new(kernel_id: impl Into<String>, config: CircuitBreakerConfig) -> Self {
152 Self {
153 kernel_id: kernel_id.into(),
154 config,
155 inner: Arc::new(CircuitBreakerInner {
156 state: RwLock::new(CircuitState::Closed),
157 failure_count: AtomicU32::new(0),
158 success_count: AtomicU32::new(0),
159 last_failure_time: RwLock::new(None),
160 half_open_requests: AtomicU32::new(0),
161 total_requests: AtomicU64::new(0),
162 total_failures: AtomicU64::new(0),
163 }),
164 }
165 }
166
167 pub async fn state(&self) -> CircuitState {
169 let state = *self.inner.state.read().await;
170
171 if state == CircuitState::Open {
173 if let Some(last_failure) = *self.inner.last_failure_time.read().await {
174 if last_failure.elapsed() >= self.config.reset_timeout {
175 return self.try_transition_to_half_open().await;
176 }
177 }
178 }
179
180 state
181 }
182
183 pub fn kernel_id(&self) -> &str {
185 &self.kernel_id
186 }
187
188 pub async fn is_allowed(&self) -> bool {
190 match self.state().await {
191 CircuitState::Closed => true,
192 CircuitState::Open => false,
193 CircuitState::HalfOpen => {
194 self.inner.half_open_requests.load(Ordering::Relaxed)
195 < self.config.half_open_max_requests
196 }
197 }
198 }
199
200 pub async fn execute<F, Fut, T, E>(&self, f: F) -> ResilienceResult<T>
202 where
203 F: FnOnce() -> Fut,
204 Fut: std::future::Future<Output = Result<T, E>>,
205 E: Into<crate::error::KernelError>,
206 {
207 self.inner.total_requests.fetch_add(1, Ordering::Relaxed);
208
209 let state = self.state().await;
211 match state {
212 CircuitState::Open => {
213 return Err(ResilienceError::CircuitOpen {
214 kernel_id: self.kernel_id.clone(),
215 });
216 }
217 CircuitState::HalfOpen => {
218 let current = self
220 .inner
221 .half_open_requests
222 .fetch_add(1, Ordering::Relaxed);
223 if current >= self.config.half_open_max_requests {
224 self.inner
225 .half_open_requests
226 .fetch_sub(1, Ordering::Relaxed);
227 return Err(ResilienceError::CircuitOpen {
228 kernel_id: self.kernel_id.clone(),
229 });
230 }
231 }
232 CircuitState::Closed => {}
233 }
234
235 let result = f().await;
237
238 match &result {
240 Ok(_) => self.record_success().await,
241 Err(_) => self.record_failure().await,
242 }
243
244 if state == CircuitState::HalfOpen {
246 self.inner
247 .half_open_requests
248 .fetch_sub(1, Ordering::Relaxed);
249 }
250
251 result.map_err(|e| ResilienceError::KernelError(e.into()))
252 }
253
254 pub async fn record_success(&self) {
256 let state = *self.inner.state.read().await;
257
258 match state {
259 CircuitState::Closed => {
260 self.inner.failure_count.store(0, Ordering::Relaxed);
262 }
263 CircuitState::HalfOpen => {
264 let successes = self.inner.success_count.fetch_add(1, Ordering::Relaxed) + 1;
265 if successes >= self.config.success_threshold {
266 self.transition_to_closed().await;
267 }
268 }
269 CircuitState::Open => {}
270 }
271 }
272
273 pub async fn record_failure(&self) {
275 self.inner.total_failures.fetch_add(1, Ordering::Relaxed);
276 *self.inner.last_failure_time.write().await = Some(Instant::now());
277
278 let state = *self.inner.state.read().await;
279
280 match state {
281 CircuitState::Closed => {
282 let failures = self.inner.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
283 if failures >= self.config.failure_threshold {
284 self.transition_to_open().await;
285 }
286 }
287 CircuitState::HalfOpen => {
288 self.transition_to_open().await;
290 }
291 CircuitState::Open => {}
292 }
293 }
294
295 pub async fn reset(&self) {
297 *self.inner.state.write().await = CircuitState::Closed;
298 self.inner.failure_count.store(0, Ordering::Relaxed);
299 self.inner.success_count.store(0, Ordering::Relaxed);
300 self.inner.half_open_requests.store(0, Ordering::Relaxed);
301 *self.inner.last_failure_time.write().await = None;
302 }
303
304 pub fn stats(&self) -> CircuitBreakerStats {
306 CircuitBreakerStats {
307 total_requests: self.inner.total_requests.load(Ordering::Relaxed),
308 total_failures: self.inner.total_failures.load(Ordering::Relaxed),
309 current_failures: self.inner.failure_count.load(Ordering::Relaxed),
310 }
311 }
312
313 async fn transition_to_open(&self) {
316 *self.inner.state.write().await = CircuitState::Open;
317 self.inner.success_count.store(0, Ordering::Relaxed);
318 tracing::warn!(
319 kernel_id = %self.kernel_id,
320 "Circuit breaker opened"
321 );
322 }
323
324 async fn transition_to_closed(&self) {
325 *self.inner.state.write().await = CircuitState::Closed;
326 self.inner.failure_count.store(0, Ordering::Relaxed);
327 self.inner.success_count.store(0, Ordering::Relaxed);
328 tracing::info!(
329 kernel_id = %self.kernel_id,
330 "Circuit breaker closed"
331 );
332 }
333
334 async fn try_transition_to_half_open(&self) -> CircuitState {
335 let mut state = self.inner.state.write().await;
336 if *state == CircuitState::Open {
337 *state = CircuitState::HalfOpen;
338 self.inner.success_count.store(0, Ordering::Relaxed);
339 self.inner.half_open_requests.store(0, Ordering::Relaxed);
340 tracing::info!(
341 kernel_id = %self.kernel_id,
342 "Circuit breaker half-open"
343 );
344 }
345 *state
346 }
347}
348
349impl Clone for CircuitBreaker {
350 fn clone(&self) -> Self {
351 Self {
352 kernel_id: self.kernel_id.clone(),
353 config: self.config.clone(),
354 inner: self.inner.clone(),
355 }
356 }
357}
358
359#[derive(Debug, Clone)]
361pub struct CircuitBreakerStats {
362 pub total_requests: u64,
364 pub total_failures: u64,
366 pub current_failures: u32,
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[tokio::test]
375 async fn test_circuit_breaker_starts_closed() {
376 let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
377 assert_eq!(cb.state().await, CircuitState::Closed);
378 assert!(cb.is_allowed().await);
379 }
380
381 #[tokio::test]
382 async fn test_circuit_opens_after_failures() {
383 let config = CircuitBreakerConfig::default().failure_threshold(3);
384 let cb = CircuitBreaker::new("test", config);
385
386 for _ in 0..3 {
388 cb.record_failure().await;
389 }
390
391 assert_eq!(cb.state().await, CircuitState::Open);
392 assert!(!cb.is_allowed().await);
393 }
394
395 #[tokio::test]
396 async fn test_circuit_resets_on_success() {
397 let config = CircuitBreakerConfig::default().failure_threshold(3);
398 let cb = CircuitBreaker::new("test", config);
399
400 cb.record_failure().await;
402 cb.record_failure().await;
403
404 cb.record_success().await;
406
407 assert_eq!(cb.inner.failure_count.load(Ordering::Relaxed), 0);
408 }
409
410 #[tokio::test]
411 async fn test_manual_reset() {
412 let config = CircuitBreakerConfig::default().failure_threshold(3);
413 let cb = CircuitBreaker::new("test", config);
414
415 for _ in 0..3 {
417 cb.record_failure().await;
418 }
419 assert_eq!(cb.state().await, CircuitState::Open);
420
421 cb.reset().await;
423 assert_eq!(cb.state().await, CircuitState::Closed);
424 }
425
426 #[test]
427 fn test_config_builder() {
428 let config = CircuitBreakerConfig::default()
429 .failure_threshold(10)
430 .reset_timeout(Duration::from_secs(60))
431 .success_threshold(5);
432
433 assert_eq!(config.failure_threshold, 10);
434 assert_eq!(config.reset_timeout, Duration::from_secs(60));
435 assert_eq!(config.success_threshold, 5);
436 }
437}