Skip to main content

ucp_agent/
safety.rs

1//! Safety mechanisms: limits, circuit breakers, and guards.
2
3use crate::error::{AgentError, Result};
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::RwLock;
6use std::time::{Duration, Instant};
7use tracing::{debug, warn};
8
9/// Global limits across all sessions.
10#[derive(Debug, Clone)]
11pub struct GlobalLimits {
12    /// Maximum concurrent sessions.
13    pub max_sessions: usize,
14    /// Maximum total blocks in memory across all contexts.
15    pub max_total_context_blocks: usize,
16    /// Maximum operations per second (rate limiting).
17    pub max_ops_per_second: f64,
18    /// Global timeout for any single operation.
19    pub operation_timeout: Duration,
20}
21
22impl Default for GlobalLimits {
23    fn default() -> Self {
24        Self {
25            max_sessions: 100,
26            max_total_context_blocks: 100_000,
27            max_ops_per_second: 1000.0,
28            operation_timeout: Duration::from_secs(30),
29        }
30    }
31}
32
33/// Per-session limits.
34#[derive(Debug, Clone)]
35pub struct SessionLimits {
36    /// Maximum context window tokens.
37    pub max_context_tokens: usize,
38    /// Maximum context window blocks.
39    pub max_context_blocks: usize,
40    /// Maximum depth for single expansion.
41    pub max_expand_depth: usize,
42    /// Maximum blocks returned per operation.
43    pub max_results_per_operation: usize,
44    /// Maximum operations before forced pause.
45    pub max_operations_before_checkpoint: usize,
46    /// Session timeout (inactivity).
47    pub session_timeout: Duration,
48    /// Maximum navigation history size.
49    pub max_history_size: usize,
50    /// Budget for costly operations.
51    pub budget: OperationBudget,
52}
53
54impl Default for SessionLimits {
55    fn default() -> Self {
56        Self {
57            max_context_tokens: 8_000,
58            max_context_blocks: 200,
59            max_expand_depth: 10,
60            max_results_per_operation: 100,
61            max_operations_before_checkpoint: 1000,
62            session_timeout: Duration::from_secs(30 * 60), // 30 minutes
63            max_history_size: 100,
64            budget: OperationBudget::default(),
65        }
66    }
67}
68
69/// Budget for costly operations.
70#[derive(Debug, Clone)]
71pub struct OperationBudget {
72    /// Total allowed traversal operations.
73    pub traversal_operations: usize,
74    /// Total allowed search operations.
75    pub search_operations: usize,
76    /// Total blocks allowed to be read.
77    pub blocks_read: usize,
78}
79
80impl Default for OperationBudget {
81    fn default() -> Self {
82        Self {
83            traversal_operations: 10_000,
84            search_operations: 100,
85            blocks_read: 50_000,
86        }
87    }
88}
89
90/// Tracks budget usage.
91#[derive(Debug, Default)]
92pub struct BudgetTracker {
93    pub traversal_ops_used: AtomicUsize,
94    pub search_ops_used: AtomicUsize,
95    pub blocks_read_used: AtomicUsize,
96}
97
98impl BudgetTracker {
99    pub fn new() -> Self {
100        Self::default()
101    }
102
103    pub fn record_traversal(&self) {
104        self.traversal_ops_used.fetch_add(1, Ordering::Relaxed);
105    }
106
107    pub fn record_search(&self) {
108        self.search_ops_used.fetch_add(1, Ordering::Relaxed);
109    }
110
111    pub fn record_blocks_read(&self, count: usize) {
112        self.blocks_read_used.fetch_add(count, Ordering::Relaxed);
113    }
114
115    pub fn check_traversal_budget(&self, budget: &OperationBudget) -> Result<()> {
116        let used = self.traversal_ops_used.load(Ordering::Relaxed);
117        if used >= budget.traversal_operations {
118            return Err(AgentError::BudgetExhausted {
119                operation_type: "traversal".to_string(),
120            });
121        }
122        Ok(())
123    }
124
125    pub fn check_search_budget(&self, budget: &OperationBudget) -> Result<()> {
126        let used = self.search_ops_used.load(Ordering::Relaxed);
127        if used >= budget.search_operations {
128            return Err(AgentError::BudgetExhausted {
129                operation_type: "search".to_string(),
130            });
131        }
132        Ok(())
133    }
134
135    pub fn check_blocks_budget(&self, budget: &OperationBudget) -> Result<()> {
136        let used = self.blocks_read_used.load(Ordering::Relaxed);
137        if used >= budget.blocks_read {
138            return Err(AgentError::BudgetExhausted {
139                operation_type: "blocks_read".to_string(),
140            });
141        }
142        Ok(())
143    }
144
145    pub fn reset(&self) {
146        self.traversal_ops_used.store(0, Ordering::Relaxed);
147        self.search_ops_used.store(0, Ordering::Relaxed);
148        self.blocks_read_used.store(0, Ordering::Relaxed);
149    }
150}
151
152/// Circuit breaker state.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum CircuitState {
155    /// Normal operation.
156    Closed,
157    /// Failing, rejecting requests.
158    Open,
159    /// Testing recovery.
160    HalfOpen,
161}
162
163/// Circuit breaker for detecting runaway operations.
164pub struct CircuitBreaker {
165    state: RwLock<CircuitState>,
166    failure_count: AtomicUsize,
167    failure_threshold: usize,
168    recovery_timeout: Duration,
169    last_failure: RwLock<Option<Instant>>,
170    success_count_in_half_open: AtomicUsize,
171    success_threshold: usize,
172}
173
174impl CircuitBreaker {
175    pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
176        Self {
177            state: RwLock::new(CircuitState::Closed),
178            failure_count: AtomicUsize::new(0),
179            failure_threshold,
180            recovery_timeout,
181            last_failure: RwLock::new(None),
182            success_count_in_half_open: AtomicUsize::new(0),
183            success_threshold: 3, // Require 3 successes to close
184        }
185    }
186
187    pub fn state(&self) -> CircuitState {
188        *self.state.read().unwrap()
189    }
190
191    pub fn can_proceed(&self) -> Result<()> {
192        let state = *self.state.read().unwrap();
193
194        match state {
195            CircuitState::Closed => Ok(()),
196            CircuitState::Open => {
197                // Check if recovery timeout has passed
198                let last_failure = self.last_failure.read().unwrap();
199                if let Some(last) = *last_failure {
200                    if last.elapsed() >= self.recovery_timeout {
201                        // Transition to half-open
202                        drop(last_failure);
203                        *self.state.write().unwrap() = CircuitState::HalfOpen;
204                        self.success_count_in_half_open.store(0, Ordering::Relaxed);
205                        debug!("Circuit breaker transitioning to half-open");
206                        return Ok(());
207                    }
208                }
209                Err(AgentError::CircuitOpen {
210                    reason: "Too many failures, circuit is open".to_string(),
211                })
212            }
213            CircuitState::HalfOpen => {
214                // Allow one request through to test
215                Ok(())
216            }
217        }
218    }
219
220    pub fn record_success(&self) {
221        let state = *self.state.read().unwrap();
222
223        match state {
224            CircuitState::Closed => {
225                // Reset failure count on success
226                self.failure_count.store(0, Ordering::Relaxed);
227            }
228            CircuitState::HalfOpen => {
229                let successes = self
230                    .success_count_in_half_open
231                    .fetch_add(1, Ordering::Relaxed)
232                    + 1;
233                if successes >= self.success_threshold {
234                    // Transition back to closed
235                    *self.state.write().unwrap() = CircuitState::Closed;
236                    self.failure_count.store(0, Ordering::Relaxed);
237                    debug!("Circuit breaker closed after successful recovery");
238                }
239            }
240            CircuitState::Open => {
241                // Shouldn't happen, but ignore
242            }
243        }
244    }
245
246    pub fn record_failure(&self) {
247        let state = *self.state.read().unwrap();
248
249        match state {
250            CircuitState::Closed => {
251                let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
252                if failures >= self.failure_threshold {
253                    *self.state.write().unwrap() = CircuitState::Open;
254                    *self.last_failure.write().unwrap() = Some(Instant::now());
255                    warn!(
256                        "Circuit breaker opened after {} failures",
257                        self.failure_threshold
258                    );
259                }
260            }
261            CircuitState::HalfOpen => {
262                // Failure during recovery - go back to open
263                *self.state.write().unwrap() = CircuitState::Open;
264                *self.last_failure.write().unwrap() = Some(Instant::now());
265                self.success_count_in_half_open.store(0, Ordering::Relaxed);
266                warn!("Circuit breaker re-opened after failure during half-open");
267            }
268            CircuitState::Open => {
269                // Update last failure time
270                *self.last_failure.write().unwrap() = Some(Instant::now());
271            }
272        }
273    }
274
275    pub fn reset(&self) {
276        *self.state.write().unwrap() = CircuitState::Closed;
277        self.failure_count.store(0, Ordering::Relaxed);
278        *self.last_failure.write().unwrap() = None;
279        self.success_count_in_half_open.store(0, Ordering::Relaxed);
280    }
281}
282
283impl Default for CircuitBreaker {
284    fn default() -> Self {
285        Self::new(5, Duration::from_secs(30))
286    }
287}
288
289/// RAII guard for depth tracking.
290pub struct DepthGuardHandle<'a> {
291    guard: &'a DepthGuard,
292}
293
294impl<'a> Drop for DepthGuardHandle<'a> {
295    fn drop(&mut self) {
296        self.guard.current.fetch_sub(1, Ordering::Relaxed);
297    }
298}
299
300/// Depth guard prevents infinite recursion.
301pub struct DepthGuard {
302    current: AtomicUsize,
303    max: usize,
304}
305
306impl DepthGuard {
307    pub fn new(max: usize) -> Self {
308        Self {
309            current: AtomicUsize::new(0),
310            max,
311        }
312    }
313
314    /// Try to enter a deeper level. Returns a guard handle if successful.
315    pub fn try_enter(&self) -> Result<DepthGuardHandle<'_>> {
316        let current = self.current.fetch_add(1, Ordering::Relaxed);
317        if current >= self.max {
318            self.current.fetch_sub(1, Ordering::Relaxed);
319            return Err(AgentError::DepthLimitExceeded {
320                current: current + 1,
321                max: self.max,
322            });
323        }
324        Ok(DepthGuardHandle { guard: self })
325    }
326
327    pub fn current_depth(&self) -> usize {
328        self.current.load(Ordering::Relaxed)
329    }
330
331    pub fn max_depth(&self) -> usize {
332        self.max
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_budget_tracker() {
342        let tracker = BudgetTracker::new();
343        let budget = OperationBudget {
344            traversal_operations: 3,
345            search_operations: 2,
346            blocks_read: 10,
347        };
348
349        // Record some operations
350        tracker.record_traversal();
351        tracker.record_traversal();
352        assert!(tracker.check_traversal_budget(&budget).is_ok());
353
354        tracker.record_traversal();
355        assert!(tracker.check_traversal_budget(&budget).is_err());
356
357        // Reset and try again
358        tracker.reset();
359        assert!(tracker.check_traversal_budget(&budget).is_ok());
360    }
361
362    #[test]
363    fn test_circuit_breaker() {
364        let cb = CircuitBreaker::new(3, Duration::from_millis(100));
365
366        // Initially closed
367        assert_eq!(cb.state(), CircuitState::Closed);
368        assert!(cb.can_proceed().is_ok());
369
370        // Record failures until open
371        cb.record_failure();
372        cb.record_failure();
373        assert!(cb.can_proceed().is_ok());
374
375        cb.record_failure();
376        assert_eq!(cb.state(), CircuitState::Open);
377        assert!(cb.can_proceed().is_err());
378
379        // Wait for recovery timeout
380        std::thread::sleep(Duration::from_millis(150));
381        assert!(cb.can_proceed().is_ok()); // Should transition to half-open
382        assert_eq!(cb.state(), CircuitState::HalfOpen);
383
384        // Success in half-open should eventually close
385        cb.record_success();
386        cb.record_success();
387        cb.record_success();
388        assert_eq!(cb.state(), CircuitState::Closed);
389    }
390
391    #[test]
392    fn test_depth_guard() {
393        let guard = DepthGuard::new(3);
394
395        assert_eq!(guard.current_depth(), 0);
396
397        {
398            let _h1 = guard.try_enter().unwrap();
399            assert_eq!(guard.current_depth(), 1);
400
401            {
402                let _h2 = guard.try_enter().unwrap();
403                assert_eq!(guard.current_depth(), 2);
404
405                {
406                    let _h3 = guard.try_enter().unwrap();
407                    assert_eq!(guard.current_depth(), 3);
408
409                    // Should fail now
410                    assert!(guard.try_enter().is_err());
411                }
412                assert_eq!(guard.current_depth(), 2);
413            }
414            assert_eq!(guard.current_depth(), 1);
415        }
416        assert_eq!(guard.current_depth(), 0);
417    }
418}