Skip to main content

sqry_core/query/security/
recursion_guard.rs

1//! Recursion depth guards for preventing stack overflow
2//!
3//! This module provides guards that protect against stack overflow from
4//! deeply nested structures (AST trees, expressions, etc.).
5//!
6//! # Guards
7//!
8//! - [`RecursionGuard`]: Depth-based guard for AST traversal
9//! - [`ExprFuelCounter`]: Fuel-based guard for expression evaluation
10//!
11//! # Usage
12//!
13//! ```no_run
14//! use sqry_core::config::RecursionLimits;
15//! use sqry_core::query::security::RecursionGuard;
16//!
17//! # struct Node;
18//! # impl Node { fn children(&self) -> Vec<&Node> { vec![] } }
19//! # type RecursionError = Box<dyn std::error::Error>;
20//! fn walk_tree(node: &Node, guard: &mut RecursionGuard) -> Result<(), RecursionError> {
21//!     guard.enter()?;
22//!     // Process node...
23//!     for child in node.children() {
24//!         walk_tree(child, guard)?;
25//!     }
26//!     guard.exit();
27//!     Ok(())
28//! }
29//! ```
30
31use anyhow::{Result, bail};
32use std::sync::atomic::{AtomicUsize, Ordering};
33
34/// Recursion depth guard for AST traversal and file operations
35///
36/// Tracks the current recursion depth and enforces a maximum limit to
37/// prevent stack overflow from pathological inputs like deeply nested
38/// function definitions.
39///
40/// # Thread Safety
41///
42/// `RecursionGuard` is NOT thread-safe and should not be shared between threads.
43/// Each thread should have its own guard instance.
44///
45/// # Example
46///
47/// ```
48/// use sqry_core::query::security::RecursionGuard;
49///
50/// fn process_node(node: &str, guard: &mut RecursionGuard) -> Result<(), Box<dyn std::error::Error>> {
51///     guard.enter()?;
52///     // Process the node...
53///     guard.exit();
54///     Ok(())
55/// }
56///
57/// let mut guard = RecursionGuard::new(100)?;
58/// process_node("example", &mut guard)?;
59/// # Ok::<(), Box<dyn std::error::Error>>(())
60/// ```
61#[derive(Debug)]
62pub struct RecursionGuard {
63    max_depth: usize,
64    current_depth: usize,
65    max_depth_reached: usize,
66}
67
68impl RecursionGuard {
69    /// Create a new recursion guard with the specified maximum depth
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if `max_depth` is 0.
74    pub fn new(max_depth: usize) -> Result<Self> {
75        if max_depth == 0 {
76            bail!("RecursionGuard max_depth cannot be 0");
77        }
78
79        Ok(Self {
80            max_depth,
81            current_depth: 0,
82            max_depth_reached: 0,
83        })
84    }
85
86    /// Enter a new recursion level
87    ///
88    /// This must be called at the beginning of each recursive function.
89    /// Must be paired with a corresponding [`exit`](Self::exit) call.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`RecursionError::DepthLimitExceeded`] if entering would exceed the max depth.
94    pub fn enter(&mut self) -> Result<(), RecursionError> {
95        self.current_depth += 1;
96
97        // Track maximum depth reached for telemetry
98        if self.current_depth > self.max_depth_reached {
99            self.max_depth_reached = self.current_depth;
100        }
101
102        if self.current_depth > self.max_depth {
103            return Err(RecursionError::DepthLimitExceeded {
104                current: self.current_depth,
105                limit: self.max_depth,
106            });
107        }
108
109        Ok(())
110    }
111
112    /// Exit the current recursion level
113    ///
114    /// This must be called when leaving a recursive function, typically
115    /// in a `defer`-like pattern or before returning.
116    pub fn exit(&mut self) {
117        if self.current_depth > 0 {
118            self.current_depth -= 1;
119        }
120    }
121
122    /// Get the current recursion depth
123    #[must_use]
124    pub fn current_depth(&self) -> usize {
125        self.current_depth
126    }
127
128    /// Get the maximum depth reached during execution
129    ///
130    /// Useful for telemetry and understanding actual depth requirements.
131    #[must_use]
132    pub fn max_depth_reached(&self) -> usize {
133        self.max_depth_reached
134    }
135
136    /// Get the configured maximum depth limit
137    #[must_use]
138    pub fn max_depth(&self) -> usize {
139        self.max_depth
140    }
141}
142
143/// Expression fuel counter for limiting expression evaluation complexity
144///
145/// Uses a fuel-based approach where each operation consumes fuel.
146/// This prevents both deep recursion (many nested calls) and wide
147/// recursion (many sibling calls).
148///
149/// # Thread Safety
150///
151/// `ExprFuelCounter` uses atomic operations and is safe to share between threads.
152///
153/// # Example
154///
155/// ```
156/// use sqry_core::query::security::ExprFuelCounter;
157///
158/// fn evaluate_expr(expr: &str, fuel: &ExprFuelCounter) -> Result<(), Box<dyn std::error::Error>> {
159///     fuel.consume(1)?;
160///     // Evaluate expression...
161///     Ok(())
162/// }
163///
164/// let fuel = ExprFuelCounter::new(1000)?;
165/// evaluate_expr("a AND b", &fuel)?;
166/// # Ok::<(), Box<dyn std::error::Error>>(())
167/// ```
168#[derive(Debug)]
169pub struct ExprFuelCounter {
170    fuel: AtomicUsize,
171    initial_fuel: usize,
172}
173
174impl ExprFuelCounter {
175    /// Create a new fuel counter with the specified initial fuel
176    ///
177    /// # Errors
178    ///
179    /// Returns an error if `initial_fuel` is 0.
180    pub fn new(initial_fuel: usize) -> Result<Self> {
181        if initial_fuel == 0 {
182            bail!("ExprFuelCounter initial_fuel cannot be 0");
183        }
184
185        Ok(Self {
186            fuel: AtomicUsize::new(initial_fuel),
187            initial_fuel,
188        })
189    }
190
191    /// Consume the specified amount of fuel
192    ///
193    /// # Errors
194    ///
195    /// Returns [`RecursionError::FuelExhausted`] if there is not enough fuel remaining.
196    ///
197    /// # Implementation Note
198    ///
199    /// Uses `fetch_update` for atomic check-then-subtract to prevent underflow
200    /// (per `FINAL_CORRECTIONS.md` `ExprFuelCounter` bug fix).
201    pub fn consume(&self, amount: usize) -> Result<(), RecursionError> {
202        let result = self
203            .fuel
204            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| {
205                if current >= amount {
206                    Some(current - amount)
207                } else {
208                    None
209                }
210            });
211
212        match result {
213            Ok(_previous) => Ok(()),
214            Err(current) => Err(RecursionError::FuelExhausted {
215                remaining: current,
216                requested: amount,
217            }),
218        }
219    }
220
221    /// Get the current fuel remaining
222    #[must_use]
223    pub fn remaining(&self) -> usize {
224        self.fuel.load(Ordering::SeqCst)
225    }
226
227    /// Get the initial fuel amount
228    #[must_use]
229    pub fn initial_fuel(&self) -> usize {
230        self.initial_fuel
231    }
232
233    /// Get the amount of fuel consumed so far
234    #[must_use]
235    pub fn consumed(&self) -> usize {
236        self.initial_fuel.saturating_sub(self.remaining())
237    }
238
239    /// Check if there is enough fuel remaining
240    #[must_use]
241    pub fn has_fuel(&self, amount: usize) -> bool {
242        self.remaining() >= amount
243    }
244
245    /// Reset the fuel counter to its initial value
246    ///
247    /// Useful for reusing a fuel counter across multiple operations.
248    pub fn reset(&self) {
249        self.fuel.store(self.initial_fuel, Ordering::SeqCst);
250    }
251}
252
253/// Errors from recursion guards
254#[derive(Debug, thiserror::Error)]
255pub enum RecursionError {
256    /// Recursion depth limit exceeded
257    #[error("Recursion depth limit exceeded: depth {current} > limit {limit}")]
258    DepthLimitExceeded {
259        /// Current recursion depth
260        current: usize,
261        /// Maximum allowed depth
262        limit: usize,
263    },
264
265    /// Expression fuel exhausted
266    #[error(
267        "Expression evaluation fuel exhausted: requested {requested}, only {remaining} remaining"
268    )]
269    FuelExhausted {
270        /// Amount of fuel remaining
271        remaining: usize,
272        /// Amount of fuel requested
273        requested: usize,
274    },
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    // RecursionGuard tests
282    #[test]
283    fn test_guard_new() {
284        let guard = RecursionGuard::new(100).unwrap();
285        assert_eq!(guard.current_depth(), 0);
286        assert_eq!(guard.max_depth(), 100);
287        assert_eq!(guard.max_depth_reached(), 0);
288    }
289
290    #[test]
291    fn test_guard_new_zero_fails() {
292        let result = RecursionGuard::new(0);
293        assert!(result.is_err());
294        assert!(result.unwrap_err().to_string().contains("cannot be 0"));
295    }
296
297    #[test]
298    fn test_guard_enter_exit() {
299        let mut guard = RecursionGuard::new(10).unwrap();
300
301        guard.enter().unwrap();
302        assert_eq!(guard.current_depth(), 1);
303        assert_eq!(guard.max_depth_reached(), 1);
304
305        guard.enter().unwrap();
306        assert_eq!(guard.current_depth(), 2);
307        assert_eq!(guard.max_depth_reached(), 2);
308
309        guard.exit();
310        assert_eq!(guard.current_depth(), 1);
311        assert_eq!(guard.max_depth_reached(), 2); // Max reached stays at 2
312
313        guard.exit();
314        assert_eq!(guard.current_depth(), 0);
315    }
316
317    #[test]
318    fn test_guard_depth_limit_enforced() {
319        let mut guard = RecursionGuard::new(3).unwrap();
320
321        guard.enter().unwrap(); // depth 1
322        guard.enter().unwrap(); // depth 2
323        guard.enter().unwrap(); // depth 3
324
325        let err = guard.enter().unwrap_err(); // depth 4 - should fail
326        assert!(matches!(
327            err,
328            RecursionError::DepthLimitExceeded {
329                current: 4,
330                limit: 3
331            }
332        ));
333    }
334
335    #[test]
336    fn test_guard_exit_at_zero_is_safe() {
337        let mut guard = RecursionGuard::new(10).unwrap();
338        guard.exit(); // Should not panic or underflow
339        assert_eq!(guard.current_depth(), 0);
340    }
341
342    #[test]
343    fn test_guard_max_depth_tracking() {
344        let mut guard = RecursionGuard::new(100).unwrap();
345
346        // Go to depth 5
347        for _ in 0..5 {
348            guard.enter().unwrap();
349        }
350        assert_eq!(guard.max_depth_reached(), 5);
351
352        // Come back to depth 2
353        for _ in 0..3 {
354            guard.exit();
355        }
356        assert_eq!(guard.current_depth(), 2);
357        assert_eq!(guard.max_depth_reached(), 5); // Max stays at 5
358
359        // Go to depth 3
360        guard.enter().unwrap();
361        assert_eq!(guard.max_depth_reached(), 5); // Still 5, not 3
362    }
363
364    // ExprFuelCounter tests
365    #[test]
366    fn test_fuel_new() {
367        let fuel = ExprFuelCounter::new(1000).unwrap();
368        assert_eq!(fuel.remaining(), 1000);
369        assert_eq!(fuel.initial_fuel(), 1000);
370        assert_eq!(fuel.consumed(), 0);
371    }
372
373    #[test]
374    fn test_fuel_new_zero_fails() {
375        let result = ExprFuelCounter::new(0);
376        assert!(result.is_err());
377        assert!(result.unwrap_err().to_string().contains("cannot be 0"));
378    }
379
380    #[test]
381    fn test_fuel_consume() {
382        let fuel = ExprFuelCounter::new(100).unwrap();
383
384        fuel.consume(30).unwrap();
385        assert_eq!(fuel.remaining(), 70);
386        assert_eq!(fuel.consumed(), 30);
387
388        fuel.consume(40).unwrap();
389        assert_eq!(fuel.remaining(), 30);
390        assert_eq!(fuel.consumed(), 70);
391    }
392
393    #[test]
394    fn test_fuel_exhaustion() {
395        let fuel = ExprFuelCounter::new(50).unwrap();
396
397        fuel.consume(30).unwrap();
398        assert_eq!(fuel.remaining(), 20);
399
400        let err = fuel.consume(30).unwrap_err();
401        assert!(matches!(
402            err,
403            RecursionError::FuelExhausted {
404                remaining: 20,
405                requested: 30
406            }
407        ));
408
409        // Fuel should remain unchanged after failed consume
410        assert_eq!(fuel.remaining(), 20);
411    }
412
413    #[test]
414    fn test_fuel_exact_exhaustion() {
415        let fuel = ExprFuelCounter::new(100).unwrap();
416
417        fuel.consume(100).unwrap();
418        assert_eq!(fuel.remaining(), 0);
419
420        let err = fuel.consume(1).unwrap_err();
421        assert!(matches!(
422            err,
423            RecursionError::FuelExhausted {
424                remaining: 0,
425                requested: 1
426            }
427        ));
428    }
429
430    #[test]
431    fn test_fuel_has_fuel() {
432        let fuel = ExprFuelCounter::new(100).unwrap();
433
434        assert!(fuel.has_fuel(50));
435        assert!(fuel.has_fuel(100));
436        assert!(!fuel.has_fuel(101));
437
438        fuel.consume(60).unwrap();
439        assert!(fuel.has_fuel(40));
440        assert!(!fuel.has_fuel(41));
441    }
442
443    #[test]
444    fn test_fuel_reset() {
445        let fuel = ExprFuelCounter::new(100).unwrap();
446
447        fuel.consume(80).unwrap();
448        assert_eq!(fuel.remaining(), 20);
449
450        fuel.reset();
451        assert_eq!(fuel.remaining(), 100);
452        assert_eq!(fuel.consumed(), 0);
453    }
454
455    #[test]
456    fn test_fuel_no_underflow_on_exhaustion() {
457        // This test verifies the fix from FINAL_CORRECTIONS.md
458        let fuel = ExprFuelCounter::new(5).unwrap();
459
460        // Try to consume more than available
461        let err = fuel.consume(10).unwrap_err();
462        assert!(matches!(
463            err,
464            RecursionError::FuelExhausted {
465                remaining: 5,
466                requested: 10
467            }
468        ));
469
470        // Fuel should still be 5, not underflowed
471        assert_eq!(fuel.remaining(), 5);
472    }
473
474    #[test]
475    fn test_fuel_multiple_small_consumes() {
476        let fuel = ExprFuelCounter::new(100).unwrap();
477
478        for _ in 0..10 {
479            fuel.consume(10).unwrap();
480        }
481
482        assert_eq!(fuel.remaining(), 0);
483        assert_eq!(fuel.consumed(), 100);
484    }
485
486    // Integration tests
487    #[test]
488    fn test_recursive_function_with_guard() {
489        fn recursive_countdown(
490            n: usize,
491            guard: &mut RecursionGuard,
492        ) -> Result<usize, RecursionError> {
493            guard.enter()?;
494            let result = if n == 0 {
495                Ok(0)
496            } else {
497                recursive_countdown(n - 1, guard)
498            };
499            guard.exit();
500            result
501        }
502
503        let mut guard = RecursionGuard::new(100).unwrap();
504        let result = recursive_countdown(50, &mut guard);
505        assert!(result.is_ok());
506        assert_eq!(guard.current_depth(), 0); // Should be back to 0
507        assert_eq!(guard.max_depth_reached(), 51); // 50 + initial call
508    }
509
510    #[test]
511    fn test_recursive_function_exceeds_limit() {
512        fn recursive_countdown(
513            n: usize,
514            guard: &mut RecursionGuard,
515        ) -> Result<usize, RecursionError> {
516            guard.enter()?;
517            let result = if n == 0 {
518                Ok(0)
519            } else {
520                recursive_countdown(n - 1, guard)
521            };
522            guard.exit();
523            result
524        }
525
526        let mut guard = RecursionGuard::new(10).unwrap();
527        let result = recursive_countdown(20, &mut guard);
528        assert!(result.is_err());
529        assert!(matches!(
530            result.unwrap_err(),
531            RecursionError::DepthLimitExceeded { .. }
532        ));
533    }
534
535    #[test]
536    fn test_expression_evaluation_with_fuel() {
537        fn evaluate_tree(nodes: usize, fuel: &ExprFuelCounter) -> Result<(), RecursionError> {
538            for _ in 0..nodes {
539                fuel.consume(1)?;
540            }
541            Ok(())
542        }
543
544        let fuel = ExprFuelCounter::new(100).unwrap();
545        let result = evaluate_tree(50, &fuel);
546        assert!(result.is_ok());
547        assert_eq!(fuel.remaining(), 50);
548    }
549
550    #[test]
551    fn test_expression_evaluation_exhausts_fuel() {
552        fn evaluate_tree(nodes: usize, fuel: &ExprFuelCounter) -> Result<(), RecursionError> {
553            for _ in 0..nodes {
554                fuel.consume(1)?;
555            }
556            Ok(())
557        }
558
559        let fuel = ExprFuelCounter::new(50).unwrap();
560        let result = evaluate_tree(100, &fuel);
561        assert!(result.is_err());
562        assert!(matches!(
563            result.unwrap_err(),
564            RecursionError::FuelExhausted { .. }
565        ));
566    }
567}