sqry_core/query/security/
recursion_guard.rs1use anyhow::{Result, bail};
32use std::sync::atomic::{AtomicUsize, Ordering};
33
34#[derive(Debug)]
62pub struct RecursionGuard {
63 max_depth: usize,
64 current_depth: usize,
65 max_depth_reached: usize,
66}
67
68impl RecursionGuard {
69 pub fn new(max_depth: usize) -> Result<Self> {
75 if max_depth == 0 {
76 bail!("RecursionGuard max_depth cannot be 0");
77 }
78
79 Ok(Self {
80 max_depth,
81 current_depth: 0,
82 max_depth_reached: 0,
83 })
84 }
85
86 pub fn enter(&mut self) -> Result<(), RecursionError> {
95 self.current_depth += 1;
96
97 if self.current_depth > self.max_depth_reached {
99 self.max_depth_reached = self.current_depth;
100 }
101
102 if self.current_depth > self.max_depth {
103 return Err(RecursionError::DepthLimitExceeded {
104 current: self.current_depth,
105 limit: self.max_depth,
106 });
107 }
108
109 Ok(())
110 }
111
112 pub fn exit(&mut self) {
117 if self.current_depth > 0 {
118 self.current_depth -= 1;
119 }
120 }
121
122 #[must_use]
124 pub fn current_depth(&self) -> usize {
125 self.current_depth
126 }
127
128 #[must_use]
132 pub fn max_depth_reached(&self) -> usize {
133 self.max_depth_reached
134 }
135
136 #[must_use]
138 pub fn max_depth(&self) -> usize {
139 self.max_depth
140 }
141}
142
143#[derive(Debug)]
169pub struct ExprFuelCounter {
170 fuel: AtomicUsize,
171 initial_fuel: usize,
172}
173
174impl ExprFuelCounter {
175 pub fn new(initial_fuel: usize) -> Result<Self> {
181 if initial_fuel == 0 {
182 bail!("ExprFuelCounter initial_fuel cannot be 0");
183 }
184
185 Ok(Self {
186 fuel: AtomicUsize::new(initial_fuel),
187 initial_fuel,
188 })
189 }
190
191 pub fn consume(&self, amount: usize) -> Result<(), RecursionError> {
202 let result = self
203 .fuel
204 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| {
205 if current >= amount {
206 Some(current - amount)
207 } else {
208 None
209 }
210 });
211
212 match result {
213 Ok(_previous) => Ok(()),
214 Err(current) => Err(RecursionError::FuelExhausted {
215 remaining: current,
216 requested: amount,
217 }),
218 }
219 }
220
221 #[must_use]
223 pub fn remaining(&self) -> usize {
224 self.fuel.load(Ordering::SeqCst)
225 }
226
227 #[must_use]
229 pub fn initial_fuel(&self) -> usize {
230 self.initial_fuel
231 }
232
233 #[must_use]
235 pub fn consumed(&self) -> usize {
236 self.initial_fuel.saturating_sub(self.remaining())
237 }
238
239 #[must_use]
241 pub fn has_fuel(&self, amount: usize) -> bool {
242 self.remaining() >= amount
243 }
244
245 pub fn reset(&self) {
249 self.fuel.store(self.initial_fuel, Ordering::SeqCst);
250 }
251}
252
253#[derive(Debug, thiserror::Error)]
255pub enum RecursionError {
256 #[error("Recursion depth limit exceeded: depth {current} > limit {limit}")]
258 DepthLimitExceeded {
259 current: usize,
261 limit: usize,
263 },
264
265 #[error(
267 "Expression evaluation fuel exhausted: requested {requested}, only {remaining} remaining"
268 )]
269 FuelExhausted {
270 remaining: usize,
272 requested: usize,
274 },
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
283 fn test_guard_new() {
284 let guard = RecursionGuard::new(100).unwrap();
285 assert_eq!(guard.current_depth(), 0);
286 assert_eq!(guard.max_depth(), 100);
287 assert_eq!(guard.max_depth_reached(), 0);
288 }
289
290 #[test]
291 fn test_guard_new_zero_fails() {
292 let result = RecursionGuard::new(0);
293 assert!(result.is_err());
294 assert!(result.unwrap_err().to_string().contains("cannot be 0"));
295 }
296
297 #[test]
298 fn test_guard_enter_exit() {
299 let mut guard = RecursionGuard::new(10).unwrap();
300
301 guard.enter().unwrap();
302 assert_eq!(guard.current_depth(), 1);
303 assert_eq!(guard.max_depth_reached(), 1);
304
305 guard.enter().unwrap();
306 assert_eq!(guard.current_depth(), 2);
307 assert_eq!(guard.max_depth_reached(), 2);
308
309 guard.exit();
310 assert_eq!(guard.current_depth(), 1);
311 assert_eq!(guard.max_depth_reached(), 2); guard.exit();
314 assert_eq!(guard.current_depth(), 0);
315 }
316
317 #[test]
318 fn test_guard_depth_limit_enforced() {
319 let mut guard = RecursionGuard::new(3).unwrap();
320
321 guard.enter().unwrap(); guard.enter().unwrap(); guard.enter().unwrap(); let err = guard.enter().unwrap_err(); assert!(matches!(
327 err,
328 RecursionError::DepthLimitExceeded {
329 current: 4,
330 limit: 3
331 }
332 ));
333 }
334
335 #[test]
336 fn test_guard_exit_at_zero_is_safe() {
337 let mut guard = RecursionGuard::new(10).unwrap();
338 guard.exit(); assert_eq!(guard.current_depth(), 0);
340 }
341
342 #[test]
343 fn test_guard_max_depth_tracking() {
344 let mut guard = RecursionGuard::new(100).unwrap();
345
346 for _ in 0..5 {
348 guard.enter().unwrap();
349 }
350 assert_eq!(guard.max_depth_reached(), 5);
351
352 for _ in 0..3 {
354 guard.exit();
355 }
356 assert_eq!(guard.current_depth(), 2);
357 assert_eq!(guard.max_depth_reached(), 5); guard.enter().unwrap();
361 assert_eq!(guard.max_depth_reached(), 5); }
363
364 #[test]
366 fn test_fuel_new() {
367 let fuel = ExprFuelCounter::new(1000).unwrap();
368 assert_eq!(fuel.remaining(), 1000);
369 assert_eq!(fuel.initial_fuel(), 1000);
370 assert_eq!(fuel.consumed(), 0);
371 }
372
373 #[test]
374 fn test_fuel_new_zero_fails() {
375 let result = ExprFuelCounter::new(0);
376 assert!(result.is_err());
377 assert!(result.unwrap_err().to_string().contains("cannot be 0"));
378 }
379
380 #[test]
381 fn test_fuel_consume() {
382 let fuel = ExprFuelCounter::new(100).unwrap();
383
384 fuel.consume(30).unwrap();
385 assert_eq!(fuel.remaining(), 70);
386 assert_eq!(fuel.consumed(), 30);
387
388 fuel.consume(40).unwrap();
389 assert_eq!(fuel.remaining(), 30);
390 assert_eq!(fuel.consumed(), 70);
391 }
392
393 #[test]
394 fn test_fuel_exhaustion() {
395 let fuel = ExprFuelCounter::new(50).unwrap();
396
397 fuel.consume(30).unwrap();
398 assert_eq!(fuel.remaining(), 20);
399
400 let err = fuel.consume(30).unwrap_err();
401 assert!(matches!(
402 err,
403 RecursionError::FuelExhausted {
404 remaining: 20,
405 requested: 30
406 }
407 ));
408
409 assert_eq!(fuel.remaining(), 20);
411 }
412
413 #[test]
414 fn test_fuel_exact_exhaustion() {
415 let fuel = ExprFuelCounter::new(100).unwrap();
416
417 fuel.consume(100).unwrap();
418 assert_eq!(fuel.remaining(), 0);
419
420 let err = fuel.consume(1).unwrap_err();
421 assert!(matches!(
422 err,
423 RecursionError::FuelExhausted {
424 remaining: 0,
425 requested: 1
426 }
427 ));
428 }
429
430 #[test]
431 fn test_fuel_has_fuel() {
432 let fuel = ExprFuelCounter::new(100).unwrap();
433
434 assert!(fuel.has_fuel(50));
435 assert!(fuel.has_fuel(100));
436 assert!(!fuel.has_fuel(101));
437
438 fuel.consume(60).unwrap();
439 assert!(fuel.has_fuel(40));
440 assert!(!fuel.has_fuel(41));
441 }
442
443 #[test]
444 fn test_fuel_reset() {
445 let fuel = ExprFuelCounter::new(100).unwrap();
446
447 fuel.consume(80).unwrap();
448 assert_eq!(fuel.remaining(), 20);
449
450 fuel.reset();
451 assert_eq!(fuel.remaining(), 100);
452 assert_eq!(fuel.consumed(), 0);
453 }
454
455 #[test]
456 fn test_fuel_no_underflow_on_exhaustion() {
457 let fuel = ExprFuelCounter::new(5).unwrap();
459
460 let err = fuel.consume(10).unwrap_err();
462 assert!(matches!(
463 err,
464 RecursionError::FuelExhausted {
465 remaining: 5,
466 requested: 10
467 }
468 ));
469
470 assert_eq!(fuel.remaining(), 5);
472 }
473
474 #[test]
475 fn test_fuel_multiple_small_consumes() {
476 let fuel = ExprFuelCounter::new(100).unwrap();
477
478 for _ in 0..10 {
479 fuel.consume(10).unwrap();
480 }
481
482 assert_eq!(fuel.remaining(), 0);
483 assert_eq!(fuel.consumed(), 100);
484 }
485
486 #[test]
488 fn test_recursive_function_with_guard() {
489 fn recursive_countdown(
490 n: usize,
491 guard: &mut RecursionGuard,
492 ) -> Result<usize, RecursionError> {
493 guard.enter()?;
494 let result = if n == 0 {
495 Ok(0)
496 } else {
497 recursive_countdown(n - 1, guard)
498 };
499 guard.exit();
500 result
501 }
502
503 let mut guard = RecursionGuard::new(100).unwrap();
504 let result = recursive_countdown(50, &mut guard);
505 assert!(result.is_ok());
506 assert_eq!(guard.current_depth(), 0); assert_eq!(guard.max_depth_reached(), 51); }
509
510 #[test]
511 fn test_recursive_function_exceeds_limit() {
512 fn recursive_countdown(
513 n: usize,
514 guard: &mut RecursionGuard,
515 ) -> Result<usize, RecursionError> {
516 guard.enter()?;
517 let result = if n == 0 {
518 Ok(0)
519 } else {
520 recursive_countdown(n - 1, guard)
521 };
522 guard.exit();
523 result
524 }
525
526 let mut guard = RecursionGuard::new(10).unwrap();
527 let result = recursive_countdown(20, &mut guard);
528 assert!(result.is_err());
529 assert!(matches!(
530 result.unwrap_err(),
531 RecursionError::DepthLimitExceeded { .. }
532 ));
533 }
534
535 #[test]
536 fn test_expression_evaluation_with_fuel() {
537 fn evaluate_tree(nodes: usize, fuel: &ExprFuelCounter) -> Result<(), RecursionError> {
538 for _ in 0..nodes {
539 fuel.consume(1)?;
540 }
541 Ok(())
542 }
543
544 let fuel = ExprFuelCounter::new(100).unwrap();
545 let result = evaluate_tree(50, &fuel);
546 assert!(result.is_ok());
547 assert_eq!(fuel.remaining(), 50);
548 }
549
550 #[test]
551 fn test_expression_evaluation_exhausts_fuel() {
552 fn evaluate_tree(nodes: usize, fuel: &ExprFuelCounter) -> Result<(), RecursionError> {
553 for _ in 0..nodes {
554 fuel.consume(1)?;
555 }
556 Ok(())
557 }
558
559 let fuel = ExprFuelCounter::new(50).unwrap();
560 let result = evaluate_tree(100, &fuel);
561 assert!(result.is_err());
562 assert!(matches!(
563 result.unwrap_err(),
564 RecursionError::FuelExhausted { .. }
565 ));
566 }
567}