uni_plugin/
circuit_breaker.rs1use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12
13use dashmap::DashMap;
14use parking_lot::RwLock;
15
16use crate::plugin::PluginId;
17use crate::qname::QName;
18
19#[derive(Clone, Copy, Debug)]
21pub struct BreakerConfig {
22 pub failure_threshold: u32,
24 pub cooldown: Duration,
26}
27
28impl Default for BreakerConfig {
29 fn default() -> Self {
30 Self {
31 failure_threshold: 10,
32 cooldown: Duration::from_secs(30),
33 }
34 }
35}
36
37#[derive(Debug)]
39struct BreakerState {
40 consecutive_failures: AtomicU64,
41 opened_at: RwLock<Option<Instant>>,
42}
43
44impl Default for BreakerState {
45 fn default() -> Self {
46 Self {
47 consecutive_failures: AtomicU64::new(0),
48 opened_at: RwLock::new(None),
49 }
50 }
51}
52
53#[derive(Debug)]
55pub struct CircuitBreaker {
56 cfg: BreakerConfig,
57 states: DashMap<(PluginId, QName), Arc<BreakerState>>,
58}
59
60impl CircuitBreaker {
61 #[must_use]
63 pub fn new(cfg: BreakerConfig) -> Self {
64 Self {
65 cfg,
66 states: DashMap::new(),
67 }
68 }
69
70 #[must_use]
73 pub fn allow(&self, plugin: &PluginId, qname: &QName) -> bool {
74 let key = (plugin.clone(), qname.clone());
75 let Some(state) = self.states.get(&key) else {
76 return true;
77 };
78 let opened_at = *state.opened_at.read();
79 match opened_at {
80 None => true,
81 Some(t) => {
82 if t.elapsed() >= self.cfg.cooldown {
83 *state.opened_at.write() = None;
86 state.consecutive_failures.store(0, Ordering::SeqCst);
87 true
88 } else {
89 false
90 }
91 }
92 }
93 }
94
95 pub fn record_success(&self, plugin: &PluginId, qname: &QName) {
97 let key = (plugin.clone(), qname.clone());
98 if let Some(state) = self.states.get(&key) {
99 state.consecutive_failures.store(0, Ordering::SeqCst);
100 *state.opened_at.write() = None;
101 }
102 }
103
104 pub fn record_failure(&self, plugin: &PluginId, qname: &QName) {
106 let key = (plugin.clone(), qname.clone());
107 let state = self
108 .states
109 .entry(key)
110 .or_insert_with(|| Arc::new(BreakerState::default()))
111 .clone();
112 let n = state.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
113 if n >= u64::from(self.cfg.failure_threshold) {
114 let mut opened = state.opened_at.write();
115 if opened.is_none() {
116 *opened = Some(Instant::now());
117 }
118 }
119 }
120
121 #[must_use]
123 pub fn failure_count(&self, plugin: &PluginId, qname: &QName) -> u64 {
124 let key = (plugin.clone(), qname.clone());
125 self.states
126 .get(&key)
127 .map(|s| s.consecutive_failures.load(Ordering::SeqCst))
128 .unwrap_or(0)
129 }
130}
131
132impl Default for CircuitBreaker {
133 fn default() -> Self {
134 Self::new(BreakerConfig::default())
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 fn fixture() -> (CircuitBreaker, PluginId, QName) {
143 (
144 CircuitBreaker::new(BreakerConfig {
145 failure_threshold: 3,
146 cooldown: Duration::from_millis(50),
147 }),
148 PluginId::new("test"),
149 QName::builtin("doomed"),
150 )
151 }
152
153 #[test]
154 fn fresh_breaker_allows_calls() {
155 let (b, p, q) = fixture();
156 assert!(b.allow(&p, &q));
157 }
158
159 #[test]
160 fn breaker_opens_after_threshold_failures() {
161 let (b, p, q) = fixture();
162 for _ in 0..3 {
163 b.record_failure(&p, &q);
164 }
165 assert!(!b.allow(&p, &q));
166 }
167
168 #[test]
169 fn success_resets_failure_count() {
170 let (b, p, q) = fixture();
171 b.record_failure(&p, &q);
172 b.record_failure(&p, &q);
173 b.record_success(&p, &q);
174 assert_eq!(b.failure_count(&p, &q), 0);
175 }
176
177 #[test]
178 fn breaker_half_opens_after_cooldown() {
179 let (b, p, q) = fixture();
180 for _ in 0..3 {
181 b.record_failure(&p, &q);
182 }
183 assert!(!b.allow(&p, &q));
184 std::thread::sleep(Duration::from_millis(60));
185 assert!(b.allow(&p, &q));
187 assert_eq!(b.failure_count(&p, &q), 0);
189 }
190}