1use crate::error::{AgentError, Result};
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::RwLock;
6use std::time::{Duration, Instant};
7use tracing::{debug, warn};
8
9#[derive(Debug, Clone)]
11pub struct GlobalLimits {
12 pub max_sessions: usize,
14 pub max_total_context_blocks: usize,
16 pub max_ops_per_second: f64,
18 pub operation_timeout: Duration,
20}
21
22impl Default for GlobalLimits {
23 fn default() -> Self {
24 Self {
25 max_sessions: 100,
26 max_total_context_blocks: 100_000,
27 max_ops_per_second: 1000.0,
28 operation_timeout: Duration::from_secs(30),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct SessionLimits {
36 pub max_context_tokens: usize,
38 pub max_context_blocks: usize,
40 pub max_expand_depth: usize,
42 pub max_results_per_operation: usize,
44 pub max_operations_before_checkpoint: usize,
46 pub session_timeout: Duration,
48 pub max_history_size: usize,
50 pub budget: OperationBudget,
52}
53
54impl Default for SessionLimits {
55 fn default() -> Self {
56 Self {
57 max_context_tokens: 8_000,
58 max_context_blocks: 200,
59 max_expand_depth: 10,
60 max_results_per_operation: 100,
61 max_operations_before_checkpoint: 1000,
62 session_timeout: Duration::from_secs(30 * 60), max_history_size: 100,
64 budget: OperationBudget::default(),
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct OperationBudget {
72 pub traversal_operations: usize,
74 pub search_operations: usize,
76 pub blocks_read: usize,
78}
79
80impl Default for OperationBudget {
81 fn default() -> Self {
82 Self {
83 traversal_operations: 10_000,
84 search_operations: 100,
85 blocks_read: 50_000,
86 }
87 }
88}
89
90#[derive(Debug, Default)]
92pub struct BudgetTracker {
93 pub traversal_ops_used: AtomicUsize,
94 pub search_ops_used: AtomicUsize,
95 pub blocks_read_used: AtomicUsize,
96}
97
98impl BudgetTracker {
99 pub fn new() -> Self {
100 Self::default()
101 }
102
103 pub fn record_traversal(&self) {
104 self.traversal_ops_used.fetch_add(1, Ordering::Relaxed);
105 }
106
107 pub fn record_search(&self) {
108 self.search_ops_used.fetch_add(1, Ordering::Relaxed);
109 }
110
111 pub fn record_blocks_read(&self, count: usize) {
112 self.blocks_read_used.fetch_add(count, Ordering::Relaxed);
113 }
114
115 pub fn check_traversal_budget(&self, budget: &OperationBudget) -> Result<()> {
116 let used = self.traversal_ops_used.load(Ordering::Relaxed);
117 if used >= budget.traversal_operations {
118 return Err(AgentError::BudgetExhausted {
119 operation_type: "traversal".to_string(),
120 });
121 }
122 Ok(())
123 }
124
125 pub fn check_search_budget(&self, budget: &OperationBudget) -> Result<()> {
126 let used = self.search_ops_used.load(Ordering::Relaxed);
127 if used >= budget.search_operations {
128 return Err(AgentError::BudgetExhausted {
129 operation_type: "search".to_string(),
130 });
131 }
132 Ok(())
133 }
134
135 pub fn check_blocks_budget(&self, budget: &OperationBudget) -> Result<()> {
136 let used = self.blocks_read_used.load(Ordering::Relaxed);
137 if used >= budget.blocks_read {
138 return Err(AgentError::BudgetExhausted {
139 operation_type: "blocks_read".to_string(),
140 });
141 }
142 Ok(())
143 }
144
145 pub fn reset(&self) {
146 self.traversal_ops_used.store(0, Ordering::Relaxed);
147 self.search_ops_used.store(0, Ordering::Relaxed);
148 self.blocks_read_used.store(0, Ordering::Relaxed);
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum CircuitState {
155 Closed,
157 Open,
159 HalfOpen,
161}
162
163pub struct CircuitBreaker {
165 state: RwLock<CircuitState>,
166 failure_count: AtomicUsize,
167 failure_threshold: usize,
168 recovery_timeout: Duration,
169 last_failure: RwLock<Option<Instant>>,
170 success_count_in_half_open: AtomicUsize,
171 success_threshold: usize,
172}
173
174impl CircuitBreaker {
175 pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
176 Self {
177 state: RwLock::new(CircuitState::Closed),
178 failure_count: AtomicUsize::new(0),
179 failure_threshold,
180 recovery_timeout,
181 last_failure: RwLock::new(None),
182 success_count_in_half_open: AtomicUsize::new(0),
183 success_threshold: 3, }
185 }
186
187 pub fn state(&self) -> CircuitState {
188 *self.state.read().unwrap()
189 }
190
191 pub fn can_proceed(&self) -> Result<()> {
192 let state = *self.state.read().unwrap();
193
194 match state {
195 CircuitState::Closed => Ok(()),
196 CircuitState::Open => {
197 let last_failure = self.last_failure.read().unwrap();
199 if let Some(last) = *last_failure {
200 if last.elapsed() >= self.recovery_timeout {
201 drop(last_failure);
203 *self.state.write().unwrap() = CircuitState::HalfOpen;
204 self.success_count_in_half_open.store(0, Ordering::Relaxed);
205 debug!("Circuit breaker transitioning to half-open");
206 return Ok(());
207 }
208 }
209 Err(AgentError::CircuitOpen {
210 reason: "Too many failures, circuit is open".to_string(),
211 })
212 }
213 CircuitState::HalfOpen => {
214 Ok(())
216 }
217 }
218 }
219
220 pub fn record_success(&self) {
221 let state = *self.state.read().unwrap();
222
223 match state {
224 CircuitState::Closed => {
225 self.failure_count.store(0, Ordering::Relaxed);
227 }
228 CircuitState::HalfOpen => {
229 let successes = self
230 .success_count_in_half_open
231 .fetch_add(1, Ordering::Relaxed)
232 + 1;
233 if successes >= self.success_threshold {
234 *self.state.write().unwrap() = CircuitState::Closed;
236 self.failure_count.store(0, Ordering::Relaxed);
237 debug!("Circuit breaker closed after successful recovery");
238 }
239 }
240 CircuitState::Open => {
241 }
243 }
244 }
245
246 pub fn record_failure(&self) {
247 let state = *self.state.read().unwrap();
248
249 match state {
250 CircuitState::Closed => {
251 let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
252 if failures >= self.failure_threshold {
253 *self.state.write().unwrap() = CircuitState::Open;
254 *self.last_failure.write().unwrap() = Some(Instant::now());
255 warn!(
256 "Circuit breaker opened after {} failures",
257 self.failure_threshold
258 );
259 }
260 }
261 CircuitState::HalfOpen => {
262 *self.state.write().unwrap() = CircuitState::Open;
264 *self.last_failure.write().unwrap() = Some(Instant::now());
265 self.success_count_in_half_open.store(0, Ordering::Relaxed);
266 warn!("Circuit breaker re-opened after failure during half-open");
267 }
268 CircuitState::Open => {
269 *self.last_failure.write().unwrap() = Some(Instant::now());
271 }
272 }
273 }
274
275 pub fn reset(&self) {
276 *self.state.write().unwrap() = CircuitState::Closed;
277 self.failure_count.store(0, Ordering::Relaxed);
278 *self.last_failure.write().unwrap() = None;
279 self.success_count_in_half_open.store(0, Ordering::Relaxed);
280 }
281}
282
283impl Default for CircuitBreaker {
284 fn default() -> Self {
285 Self::new(5, Duration::from_secs(30))
286 }
287}
288
289pub struct DepthGuardHandle<'a> {
291 guard: &'a DepthGuard,
292}
293
294impl<'a> Drop for DepthGuardHandle<'a> {
295 fn drop(&mut self) {
296 self.guard.current.fetch_sub(1, Ordering::Relaxed);
297 }
298}
299
300pub struct DepthGuard {
302 current: AtomicUsize,
303 max: usize,
304}
305
306impl DepthGuard {
307 pub fn new(max: usize) -> Self {
308 Self {
309 current: AtomicUsize::new(0),
310 max,
311 }
312 }
313
314 pub fn try_enter(&self) -> Result<DepthGuardHandle<'_>> {
316 let current = self.current.fetch_add(1, Ordering::Relaxed);
317 if current >= self.max {
318 self.current.fetch_sub(1, Ordering::Relaxed);
319 return Err(AgentError::DepthLimitExceeded {
320 current: current + 1,
321 max: self.max,
322 });
323 }
324 Ok(DepthGuardHandle { guard: self })
325 }
326
327 pub fn current_depth(&self) -> usize {
328 self.current.load(Ordering::Relaxed)
329 }
330
331 pub fn max_depth(&self) -> usize {
332 self.max
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_budget_tracker() {
342 let tracker = BudgetTracker::new();
343 let budget = OperationBudget {
344 traversal_operations: 3,
345 search_operations: 2,
346 blocks_read: 10,
347 };
348
349 tracker.record_traversal();
351 tracker.record_traversal();
352 assert!(tracker.check_traversal_budget(&budget).is_ok());
353
354 tracker.record_traversal();
355 assert!(tracker.check_traversal_budget(&budget).is_err());
356
357 tracker.reset();
359 assert!(tracker.check_traversal_budget(&budget).is_ok());
360 }
361
362 #[test]
363 fn test_circuit_breaker() {
364 let cb = CircuitBreaker::new(3, Duration::from_millis(100));
365
366 assert_eq!(cb.state(), CircuitState::Closed);
368 assert!(cb.can_proceed().is_ok());
369
370 cb.record_failure();
372 cb.record_failure();
373 assert!(cb.can_proceed().is_ok());
374
375 cb.record_failure();
376 assert_eq!(cb.state(), CircuitState::Open);
377 assert!(cb.can_proceed().is_err());
378
379 std::thread::sleep(Duration::from_millis(150));
381 assert!(cb.can_proceed().is_ok()); assert_eq!(cb.state(), CircuitState::HalfOpen);
383
384 cb.record_success();
386 cb.record_success();
387 cb.record_success();
388 assert_eq!(cb.state(), CircuitState::Closed);
389 }
390
391 #[test]
392 fn test_depth_guard() {
393 let guard = DepthGuard::new(3);
394
395 assert_eq!(guard.current_depth(), 0);
396
397 {
398 let _h1 = guard.try_enter().unwrap();
399 assert_eq!(guard.current_depth(), 1);
400
401 {
402 let _h2 = guard.try_enter().unwrap();
403 assert_eq!(guard.current_depth(), 2);
404
405 {
406 let _h3 = guard.try_enter().unwrap();
407 assert_eq!(guard.current_depth(), 3);
408
409 assert!(guard.try_enter().is_err());
411 }
412 assert_eq!(guard.current_depth(), 2);
413 }
414 assert_eq!(guard.current_depth(), 1);
415 }
416 assert_eq!(guard.current_depth(), 0);
417 }
418}