Skip to main content

vtcode_core/tools/resilience/
circuit_breaker.rs

1use crate::types::CompactStr;
2use hashbrown::HashMap;
3use parking_lot::RwLock;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use vtcode_commons::ErrorCategory;
7
8use crate::metrics::MetricsCollector;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum CircuitState {
12    #[default]
13    Closed, // Normal operation, allow all calls
14    Open,     // Failing, reject all calls immediately
15    HalfOpen, // Testing, allow limited calls to check recovery
16}
17
18impl CircuitState {
19    /// Returns valid transitions from this state
20    #[inline]
21    const fn valid_transitions(&self) -> &'static [CircuitState] {
22        match self {
23            CircuitState::Closed => &[CircuitState::Open],
24            CircuitState::Open => &[CircuitState::HalfOpen],
25            CircuitState::HalfOpen => &[CircuitState::Closed, CircuitState::Open],
26        }
27    }
28
29    /// Check if transition to target state is valid
30    #[inline]
31    fn can_transition_to(&self, target: CircuitState) -> bool {
32        self.valid_transitions().contains(&target)
33    }
34}
35
36#[derive(Clone)]
37pub struct CircuitBreakerConfig {
38    pub failure_threshold: u32,
39    pub reset_timeout: Duration, // Initial/Base timeout
40    pub min_backoff: Duration,   // Minimum wait time
41    pub max_backoff: Duration,   // Maximum wait time
42    pub backoff_factor: f64,     // Multiplier (e.g., 2.0 for exponential)
43}
44
45impl Default for CircuitBreakerConfig {
46    fn default() -> Self {
47        Self {
48            failure_threshold: 7,
49            reset_timeout: Duration::from_secs(60),
50            min_backoff: Duration::from_secs(5), // Start with 5s
51            max_backoff: Duration::from_secs(120), // Cap at 2m
52            backoff_factor: 2.0,
53        }
54    }
55}
56
57#[derive(Debug, Clone, Default)]
58struct ToolCircuitState {
59    status: CircuitState,
60    failure_count: u32,
61    last_failure_time: Option<Instant>,
62    current_backoff: Duration, // Current backoff duration for this tool
63    circuit_opened_at: Option<Instant>, // When circuit first opened (for diagnostics)
64    open_count: u32,           // How many times circuit has opened
65    denied_requests: u32,
66    last_denied_at: Option<Instant>,
67    last_error_category: Option<ErrorCategory>,
68}
69
70impl ToolCircuitState {
71    /// Transition to a new state with debug assertion for valid transitions
72    #[inline]
73    fn transition_to(&mut self, new_state: CircuitState) {
74        debug_assert!(
75            self.status.can_transition_to(new_state),
76            "Invalid circuit state transition: {:?} -> {:?}",
77            self.status,
78            new_state
79        );
80        self.status = new_state;
81    }
82
83    /// Reset state on successful recovery (from HalfOpen or Open)
84    #[inline]
85    fn reset_on_success(&mut self) {
86        self.status = CircuitState::Closed;
87        self.failure_count = 0;
88        self.last_failure_time = None;
89        self.current_backoff = Duration::ZERO;
90        self.circuit_opened_at = None;
91        self.last_error_category = None;
92    }
93}
94
95#[derive(Debug, Clone)]
96pub struct ToolCircuitDiagnostics {
97    pub tool_name: String,
98    pub status: CircuitState,
99    pub failure_count: u32,
100    pub current_backoff: Duration,
101    pub remaining_backoff: Option<Duration>,
102    pub opened_at: Option<Instant>,
103    pub open_count: u32,
104    pub is_open: bool,
105    pub denied_requests: u32,
106    pub last_denied_at: Option<Instant>,
107    pub last_error_category: Option<ErrorCategory>,
108}
109
110#[derive(Debug, Clone, Default)]
111pub struct CircuitBreakerSnapshot {
112    pub diagnostics: Vec<ToolCircuitDiagnostics>,
113    pub open_circuits: Vec<String>,
114    pub open_count: usize,
115}
116
117/// Per-tool circuit breaker that tracks failure state independently for each tool.
118/// This prevents one misbehaving tool from disabling all tools in the system.
119///
120/// Uses `parking_lot::RwLock` for better concurrent access:
121/// - Read operations (allow_request, state checks) can proceed in parallel
122/// - Write operations (record_success/failure) acquire exclusive access
123/// - `parking_lot` is more efficient for short critical sections than std::Mutex
124#[derive(Clone)]
125pub struct CircuitBreaker {
126    /// Per-tool state tracking with RwLock for better read concurrency
127    tool_states: Arc<RwLock<HashMap<CompactStr, ToolCircuitState>>>,
128    config: CircuitBreakerConfig,
129    metrics: Option<Arc<MetricsCollector>>,
130}
131
132impl CircuitBreaker {
133    pub fn new(config: CircuitBreakerConfig) -> Self {
134        Self::build(config, None)
135    }
136
137    pub fn with_metrics(config: CircuitBreakerConfig, metrics: Arc<MetricsCollector>) -> Self {
138        Self::build(config, Some(metrics))
139    }
140
141    fn build(config: CircuitBreakerConfig, metrics: Option<Arc<MetricsCollector>>) -> Self {
142        Self {
143            tool_states: Arc::new(RwLock::new(HashMap::new())),
144            config,
145            metrics,
146        }
147    }
148
149    #[inline]
150    fn record_half_open_metric(&self) {
151        if let Some(metrics) = &self.metrics {
152            metrics.record_half_open();
153        }
154    }
155
156    #[inline]
157    fn record_breaker_denial_metric(&self) {
158        if let Some(metrics) = &self.metrics {
159            metrics.record_breaker_denial();
160        }
161    }
162
163    #[inline]
164    fn record_circuit_open_metric(&self) {
165        if let Some(metrics) = &self.metrics {
166            metrics.record_circuit_open();
167        }
168    }
169
170    /// Check if a request for a specific tool is allowed to proceed.
171    /// Returns true if allowed, false if the circuit is open for this tool.
172    ///
173    /// Uses optimistic read-first approach:
174    /// 1. Try read lock first (allows concurrent reads)
175    /// 2. Only upgrade to write lock if state transition is needed
176    pub fn allow_request_for_tool(&self, tool_name: &str) -> bool {
177        {
178            let states = self.tool_states.read();
179            if let Some(state) = states.get(tool_name) {
180                match state.status {
181                    CircuitState::Closed | CircuitState::HalfOpen => return true,
182                    CircuitState::Open => {
183                        if let Some(last_failure) = state.last_failure_time {
184                            let backoff = if state.current_backoff == Duration::ZERO {
185                                self.config.reset_timeout
186                            } else {
187                                state.current_backoff
188                            };
189                            if last_failure.elapsed() >= backoff {
190                                // Fall through to the write lock so we can transition to HalfOpen.
191                            }
192                        }
193                    }
194                }
195            } else {
196                return true;
197            }
198        }
199
200        let mut states = self.tool_states.write();
201        let state = states.entry(CompactStr::from(tool_name)).or_default();
202
203        match state.status {
204            CircuitState::Closed | CircuitState::HalfOpen => true,
205            CircuitState::Open => {
206                if let Some(last_failure) = state.last_failure_time {
207                    let backoff = if state.current_backoff == Duration::ZERO {
208                        self.config.reset_timeout
209                    } else {
210                        state.current_backoff
211                    };
212
213                    if last_failure.elapsed() >= backoff {
214                        state.transition_to(CircuitState::HalfOpen);
215                        self.record_half_open_metric();
216                        return true;
217                    }
218                }
219                state.denied_requests = state.denied_requests.saturating_add(1);
220                state.last_denied_at = Some(Instant::now());
221                self.record_breaker_denial_metric();
222                false
223            }
224        }
225    }
226
227    /// Get remaining backoff time for a tool (if Open)
228    pub fn remaining_backoff(&self, tool_name: &str) -> Option<Duration> {
229        let states = self.tool_states.read();
230        let state = states.get(tool_name)?;
231
232        if state.status == CircuitState::Open
233            && let Some(last) = state.last_failure_time
234        {
235            let backoff = state.current_backoff;
236            let elapsed = last.elapsed();
237            return backoff.checked_sub(elapsed);
238        }
239        None
240    }
241
242    /// Record success for a specific tool
243    ///
244    /// State transitions on success:
245    /// - HalfOpen -> Closed (probe succeeded)
246    /// - Closed -> Closed (reset failure count)
247    /// - Open -> Closed (forced recovery, e.g., manual reset)
248    pub fn record_success_for_tool(&self, tool_name: &str) {
249        let mut states = self.tool_states.write();
250        let state = states.entry(CompactStr::from(tool_name)).or_default();
251
252        match state.status {
253            CircuitState::HalfOpen => {
254                // Probe succeeded - use batched reset
255                state.reset_on_success();
256            }
257            CircuitState::Closed => {
258                // Reset failure count on success if we want purely consecutive failures
259                state.failure_count = 0;
260            }
261            CircuitState::Open => {
262                // Should not happen theoretically unless race condition or forced reset
263                // Using direct assignment here since this is an exceptional recovery path
264                state.reset_on_success();
265            }
266        }
267    }
268
269    /// Record failure for a specific tool.
270    /// Non-retryable validation, policy, and permission failures are ignored.
271    ///
272    /// State transitions on failure:
273    /// - Closed -> Open (when threshold reached)
274    /// - HalfOpen -> Open (probe failed, increase backoff)
275    /// - Open -> Open (no change, just update timestamp)
276    pub fn record_failure_category_for_tool(&self, tool_name: &str, category: ErrorCategory) {
277        if !category.should_trip_circuit_breaker() {
278            tracing::debug!(
279                tool = %tool_name,
280                category = %category,
281                "Skipping circuit breaker failure accounting for non-circuit-breaking error"
282            );
283            return;
284        }
285
286        let mut states = self.tool_states.write();
287        let state = states.entry(CompactStr::from(tool_name)).or_default();
288        state.last_failure_time = Some(Instant::now());
289        state.last_error_category = Some(category);
290
291        match state.status {
292            CircuitState::Closed => {
293                state.failure_count += 1;
294                if state.failure_count >= self.config.failure_threshold {
295                    state.transition_to(CircuitState::Open);
296                    state.current_backoff = self.config.min_backoff;
297                    state.circuit_opened_at = Some(Instant::now());
298                    state.open_count += 1;
299                    self.record_circuit_open_metric();
300
301                    tracing::warn!(
302                        tool = %tool_name,
303                        failures = state.failure_count,
304                        backoff_sec = state.current_backoff.as_secs(),
305                        open_count = state.open_count,
306                        "Circuit breaker OPEN for tool"
307                    );
308                }
309            }
310            CircuitState::HalfOpen => {
311                // Probe failed, revert to Open and increase backoff
312                state.transition_to(CircuitState::Open);
313                state.circuit_opened_at = Some(Instant::now());
314                state.open_count += 1;
315                // Exponential backoff
316                let next_backoff = state.current_backoff.as_secs_f64() * self.config.backoff_factor;
317                state.current_backoff = Duration::try_from_secs_f64(next_backoff)
318                    .unwrap_or(self.config.max_backoff)
319                    .min(self.config.max_backoff)
320                    .max(self.config.min_backoff);
321                self.record_circuit_open_metric();
322
323                tracing::warn!(
324                    tool = %tool_name,
325                    backoff_sec = state.current_backoff.as_secs(),
326                    open_count = state.open_count,
327                    "Circuit breaker re-OPENED (probe failed)"
328                );
329            }
330            CircuitState::Open => {
331                // Already open, just update time - backoff stays same until probe attempt
332            }
333        }
334    }
335
336    /// Convenience wrapper that maps a boolean argument error flag to the appropriate ErrorCategory.
337    pub fn record_failure_for_tool(&self, tool_name: &str, is_argument_error: bool) {
338        let category = if is_argument_error {
339            ErrorCategory::InvalidParameters
340        } else {
341            ErrorCategory::ExecutionError
342        };
343        self.record_failure_category_for_tool(tool_name, category);
344    }
345
346    /// Get the circuit state for a specific tool
347    pub fn state_for_tool(&self, tool_name: &str) -> CircuitState {
348        let states = self.tool_states.read();
349        states
350            .get(tool_name)
351            .map(|s| s.status)
352            .unwrap_or(CircuitState::Closed)
353    }
354
355    /// Reset the circuit breaker state for a specific tool
356    pub fn reset_tool(&self, tool_name: &str) {
357        let mut states = self.tool_states.write();
358        states.remove(tool_name);
359    }
360
361    /// Reset all tool circuit breaker states
362    pub fn reset_all(&self) {
363        let mut states = self.tool_states.write();
364        states.clear();
365    }
366
367    /// Get list of tools with currently OPEN circuits
368    pub fn get_open_circuits(&self) -> Vec<String> {
369        self.snapshot().open_circuits
370    }
371
372    /// Get diagnostic information for a specific tool
373    pub fn get_diagnostics(&self, tool_name: &str) -> ToolCircuitDiagnostics {
374        self.snapshot()
375            .diagnostics
376            .into_iter()
377            .find(|diag| diag.tool_name == tool_name)
378            .unwrap_or_else(|| ToolCircuitDiagnostics {
379                tool_name: tool_name.to_string(),
380                status: CircuitState::Closed,
381                failure_count: 0,
382                current_backoff: Duration::ZERO,
383                remaining_backoff: None,
384                opened_at: None,
385                open_count: 0,
386                is_open: false,
387                denied_requests: 0,
388                last_denied_at: None,
389                last_error_category: None,
390            })
391    }
392
393    /// Get diagnostics for all tools
394    pub fn get_all_diagnostics(&self) -> Vec<ToolCircuitDiagnostics> {
395        self.snapshot().diagnostics
396    }
397
398    /// Get a full snapshot of circuit breaker state under a single read lock.
399    pub fn snapshot(&self) -> CircuitBreakerSnapshot {
400        let states = self.tool_states.read();
401        let diagnostics: Vec<ToolCircuitDiagnostics> = states
402            .iter()
403            .map(|(name, state)| {
404                let is_open = matches!(state.status, CircuitState::Open);
405                ToolCircuitDiagnostics {
406                    tool_name: name.to_string(),
407                    status: state.status,
408                    failure_count: state.failure_count,
409                    current_backoff: state.current_backoff,
410                    remaining_backoff: if is_open {
411                        state
412                            .last_failure_time
413                            .and_then(|last| state.current_backoff.checked_sub(last.elapsed()))
414                    } else {
415                        None
416                    },
417                    opened_at: state.circuit_opened_at,
418                    open_count: state.open_count,
419                    is_open,
420                    denied_requests: state.denied_requests,
421                    last_denied_at: state.last_denied_at,
422                    last_error_category: state.last_error_category,
423                }
424            })
425            .collect();
426
427        let open_circuits: Vec<String> = diagnostics
428            .iter()
429            .filter(|diag| diag.is_open)
430            .map(|diag| diag.tool_name.clone())
431            .collect();
432
433        CircuitBreakerSnapshot {
434            diagnostics,
435            open_count: open_circuits.len(),
436            open_circuits,
437        }
438    }
439
440    /// Check if recovery pause should be triggered based on open circuit count
441    pub fn should_pause_for_recovery(&self, max_open_circuits: usize) -> bool {
442        self.snapshot().open_count >= max_open_circuits
443    }
444
445    /// Get count of currently open circuits
446    pub fn open_circuit_count(&self) -> usize {
447        self.snapshot().open_count
448    }
449}
450
451impl Default for CircuitBreaker {
452    fn default() -> Self {
453        Self::new(CircuitBreakerConfig::default())
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::metrics::MetricsCollector;
461
462    #[test]
463    fn invalid_parameters_do_not_open_circuit() {
464        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
465            failure_threshold: 2,
466            ..Default::default()
467        });
468
469        breaker.record_failure_category_for_tool("read_file", ErrorCategory::InvalidParameters);
470        breaker.record_failure_category_for_tool("read_file", ErrorCategory::InvalidParameters);
471
472        assert_eq!(breaker.state_for_tool("read_file"), CircuitState::Closed);
473        assert_eq!(breaker.get_diagnostics("read_file").failure_count, 0);
474    }
475
476    #[test]
477    fn denied_requests_are_recorded_for_open_circuit() {
478        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
479            failure_threshold: 1,
480            min_backoff: Duration::from_secs(30),
481            ..Default::default()
482        });
483
484        breaker.record_failure_category_for_tool("shell", ErrorCategory::ExecutionError);
485        assert_eq!(breaker.state_for_tool("shell"), CircuitState::Open);
486        assert!(!breaker.allow_request_for_tool("shell"));
487
488        let diagnostics = breaker.get_diagnostics("shell");
489        assert_eq!(diagnostics.denied_requests, 1);
490        assert!(diagnostics.last_denied_at.is_some());
491        assert_eq!(
492            diagnostics.last_error_category,
493            Some(ErrorCategory::ExecutionError)
494        );
495    }
496
497    #[test]
498    fn metrics_record_open_half_open_and_denials() {
499        let metrics = Arc::new(MetricsCollector::new());
500        let breaker = CircuitBreaker::with_metrics(
501            CircuitBreakerConfig {
502                failure_threshold: 1,
503                min_backoff: Duration::from_millis(10),
504                max_backoff: Duration::from_secs(1),
505                ..Default::default()
506            },
507            metrics.clone(),
508        );
509
510        breaker.record_failure_category_for_tool("shell", ErrorCategory::ExecutionError);
511        assert!(!breaker.allow_request_for_tool("shell"));
512
513        std::thread::sleep(Duration::from_millis(20));
514        assert!(breaker.allow_request_for_tool("shell"));
515
516        let execution = metrics.get_execution_metrics();
517        assert_eq!(execution.circuit_open_events, 1);
518        assert_eq!(execution.breaker_denials, 1);
519        assert_eq!(execution.half_open_events, 1);
520    }
521
522    #[test]
523    fn overflowing_half_open_backoff_clamps_to_max_backoff() {
524        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
525            failure_threshold: 1,
526            min_backoff: Duration::from_millis(1),
527            max_backoff: Duration::from_millis(10),
528            backoff_factor: f64::MAX,
529            ..Default::default()
530        });
531
532        breaker.record_failure_category_for_tool("shell", ErrorCategory::ExecutionError);
533        std::thread::sleep(Duration::from_millis(2));
534        assert!(breaker.allow_request_for_tool("shell"));
535
536        breaker.record_failure_category_for_tool("shell", ErrorCategory::ExecutionError);
537
538        assert_eq!(
539            breaker.get_diagnostics("shell").current_backoff,
540            Duration::from_millis(10)
541        );
542    }
543}