Skip to main content

sqry_core/query/security/
guard.rs

1//! Query execution guard with timeout, memory, and resource limits
2//!
3//! This module provides the runtime guard that enforces security limits
4//! during query execution. The guard tracks:
5//! - Elapsed time vs timeout limit
6//! - Result count vs result cap
7//! - Memory usage vs memory limit
8//!
9//! All limits are NON-NEGOTIABLE per the security requirements.
10
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::{Duration, Instant};
13
14use super::config::QuerySecurityConfig;
15
16/// Query execution guard with timeout, memory, and resource limits
17///
18/// **MEMORY ENFORCEMENT** (per Codex review):
19/// Uses a tracked allocation approach where memory usage is estimated
20/// based on result sizes and checked at regular intervals.
21///
22/// # Example
23///
24/// ```
25/// use sqry_core::query::security::{QuerySecurityConfig, QueryGuard};
26///
27/// let config = QuerySecurityConfig::default();
28/// let guard = QueryGuard::new(config);
29///
30/// // During query execution:
31/// guard.should_continue().expect("should not fail initially");
32/// guard.record_result(128); // Record a result with estimated size
33/// ```
34pub struct QueryGuard {
35    config: QuerySecurityConfig,
36    start_time: Instant,
37    result_count: AtomicUsize,
38    memory_usage: AtomicUsize,
39    check_interval: usize,
40    checks_performed: AtomicUsize,
41}
42
43impl QueryGuard {
44    /// Create a new query guard
45    #[must_use]
46    pub fn new(config: QuerySecurityConfig) -> Self {
47        Self {
48            config,
49            start_time: Instant::now(),
50            result_count: AtomicUsize::new(0),
51            memory_usage: AtomicUsize::new(0),
52            check_interval: 100, // Check every 100 results
53            checks_performed: AtomicUsize::new(0),
54        }
55    }
56
57    /// Create a new query guard with custom check interval
58    ///
59    /// The check interval controls how often memory limits are checked.
60    /// Lower values mean more frequent checks but higher overhead.
61    #[must_use]
62    pub fn with_check_interval(config: QuerySecurityConfig, interval: usize) -> Self {
63        Self {
64            check_interval: interval.max(1), // At least check every result
65            ..Self::new(config)
66        }
67    }
68
69    /// Check if query should continue
70    ///
71    /// # Errors
72    ///
73    /// Returns `QuerySecurityError` if any limit is exceeded:
74    /// - `Timeout`: Query has run longer than the configured timeout
75    /// - `ResultCapExceeded`: More results collected than the result cap
76    /// - `MemoryLimitExceeded`: Estimated memory usage exceeds limit
77    ///
78    /// **NOTE** (per Codex iter6): Uses getter methods since fields are private.
79    pub fn should_continue(&self) -> Result<(), QuerySecurityError> {
80        // Check timeout - use getter since field is private
81        let elapsed = self.start_time.elapsed();
82        let timeout_limit = self.config.timeout();
83        if elapsed > timeout_limit {
84            return Err(QuerySecurityError::Timeout {
85                elapsed,
86                limit: timeout_limit,
87            });
88        }
89
90        // Check result cap - use getter since field is private
91        let count = self.result_count.load(Ordering::Relaxed);
92        let result_limit = self.config.result_cap();
93        if count >= result_limit {
94            return Err(QuerySecurityError::ResultCapExceeded {
95                count,
96                limit: result_limit,
97            });
98        }
99
100        // Check memory (periodically to reduce overhead) - use getter
101        let checks = self.checks_performed.fetch_add(1, Ordering::Relaxed);
102        if checks.is_multiple_of(self.check_interval) {
103            let usage = self.memory_usage.load(Ordering::Relaxed);
104            let memory_limit = self.config.memory_limit();
105            if usage >= memory_limit {
106                return Err(QuerySecurityError::MemoryLimitExceeded {
107                    usage,
108                    limit: memory_limit,
109                });
110            }
111        }
112
113        Ok(())
114    }
115
116    /// Record a result and its estimated memory footprint
117    ///
118    /// This should be called for each result added to the result set.
119    /// The estimated size should include:
120    /// - The Node/TraitImpl struct size
121    /// - String allocations for names, paths, etc.
122    /// - Any metadata stored with the result
123    pub fn record_result(&self, estimated_size: usize) {
124        self.result_count.fetch_add(1, Ordering::Relaxed);
125        self.memory_usage
126            .fetch_add(estimated_size, Ordering::Relaxed);
127    }
128
129    /// Get elapsed time since query started
130    #[must_use]
131    pub fn elapsed(&self) -> Duration {
132        self.start_time.elapsed()
133    }
134
135    /// Get current result count
136    #[must_use]
137    pub fn result_count(&self) -> usize {
138        self.result_count.load(Ordering::Relaxed)
139    }
140
141    /// Get current memory usage estimate
142    #[must_use]
143    pub fn memory_usage(&self) -> usize {
144        self.memory_usage.load(Ordering::Relaxed)
145    }
146
147    /// Get the security configuration
148    #[must_use]
149    pub fn config(&self) -> &QuerySecurityConfig {
150        &self.config
151    }
152}
153
154/// Security errors from query execution
155#[derive(Debug, thiserror::Error)]
156pub enum QuerySecurityError {
157    /// Query execution exceeded the timeout limit
158    #[error("Query timeout: {elapsed:?} exceeded {limit:?}")]
159    Timeout {
160        /// How long the query ran before being stopped
161        elapsed: Duration,
162        /// The configured timeout limit
163        limit: Duration,
164    },
165
166    /// Query returned more results than the result cap
167    #[error("Result cap exceeded: {count} >= {limit}")]
168    ResultCapExceeded {
169        /// Number of results collected
170        count: usize,
171        /// The configured result cap
172        limit: usize,
173    },
174
175    /// Query memory usage exceeded the memory limit
176    #[error("Memory limit exceeded: {usage} bytes >= {limit} bytes")]
177    MemoryLimitExceeded {
178        /// Estimated memory usage in bytes
179        usage: usize,
180        /// The configured memory limit in bytes
181        limit: usize,
182    },
183
184    /// Pre-execution cost estimate exceeds the cost limit
185    #[error("Query cost exceeds limit: {estimated} > {limit}")]
186    CostLimitExceeded {
187        /// Estimated cost of the query
188        estimated: usize,
189        /// The configured cost limit
190        limit: usize,
191    },
192}
193
194impl QuerySecurityError {
195    /// Convert security error to completion status for partial results (per Codex iter10)
196    ///
197    /// When a limit is exceeded during execution, this converts the error
198    /// into a status indicator that can be returned with partial results.
199    #[must_use]
200    pub fn into_completion_status(self) -> QueryCompletionStatus {
201        match self {
202            Self::Timeout { elapsed, limit } => QueryCompletionStatus::TimedOut { elapsed, limit },
203            Self::ResultCapExceeded { count, limit } => {
204                QueryCompletionStatus::ResultCapReached { count, limit }
205            }
206            Self::MemoryLimitExceeded { usage, limit } => {
207                QueryCompletionStatus::MemoryLimitReached {
208                    usage_bytes: usage,
209                    limit_bytes: limit,
210                }
211            }
212            Self::CostLimitExceeded { .. } =>
213            // Cost limit is checked before execution, not during
214            // If we somehow hit this, treat as complete (no partial results scenario)
215            {
216                QueryCompletionStatus::Complete
217            }
218        }
219    }
220}
221
222/// Completion status for query results (per Codex iter10)
223///
224/// Indicates whether the result set is complete or was truncated due to limits.
225/// This allows callers to know if they received all matching results or only
226/// a partial set.
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub enum QueryCompletionStatus {
229    /// All matching results returned
230    Complete,
231
232    /// Results truncated due to result cap (see count for how many were returned)
233    ResultCapReached {
234        /// Number of results returned
235        count: usize,
236        /// The configured result cap
237        limit: usize,
238    },
239
240    /// Results truncated due to memory limit
241    MemoryLimitReached {
242        /// Actual memory usage in bytes
243        usage_bytes: usize,
244        /// The configured memory limit in bytes
245        limit_bytes: usize,
246    },
247
248    /// Results truncated due to timeout
249    TimedOut {
250        /// How long the query ran
251        elapsed: Duration,
252        /// The configured timeout limit
253        limit: Duration,
254    },
255}
256
257impl QueryCompletionStatus {
258    /// Returns true if the result set is complete
259    #[must_use]
260    pub fn is_complete(&self) -> bool {
261        matches!(self, Self::Complete)
262    }
263
264    /// Returns a user-friendly message for CLI output
265    #[must_use]
266    pub fn message(&self) -> String {
267        match self {
268            Self::Complete => "Query completed successfully".to_string(),
269            Self::ResultCapReached { count, limit } => {
270                format!(
271                    "Results truncated: showing {count} of {limit}+ matches (result cap reached)"
272                )
273            }
274            Self::MemoryLimitReached {
275                usage_bytes,
276                limit_bytes,
277            } => {
278                format!(
279                    "Results truncated: memory limit reached ({} of {} MB)",
280                    usage_bytes / (1024 * 1024),
281                    limit_bytes / (1024 * 1024)
282                )
283            }
284            Self::TimedOut { elapsed, limit } => {
285                format!(
286                    "Results truncated: query timed out after {:.1}s (limit: {}s)",
287                    elapsed.as_secs_f64(),
288                    limit.as_secs()
289                )
290            }
291        }
292    }
293
294    /// Returns the JSON field name for this status type
295    ///
296    /// Used for JSON output format consistency.
297    #[must_use]
298    pub fn status_field(&self) -> &'static str {
299        match self {
300            Self::Complete => "complete",
301            Self::ResultCapReached { .. } => "result_cap_reached",
302            Self::MemoryLimitReached { .. } => "memory_limit_reached",
303            Self::TimedOut { .. } => "timed_out",
304        }
305    }
306
307    /// Returns the CLI exit code for this status
308    ///
309    /// - Complete: 0 (success)
310    /// - Truncated results: 2 (partial success, distinct from errors)
311    #[must_use]
312    pub fn exit_code(&self) -> i32 {
313        match self {
314            Self::Complete => 0,
315            _ => 2, // Partial results - distinct from error (1)
316        }
317    }
318}
319
320/// Result set with completion status (per Codex iter10)
321///
322/// Wraps query results with a status indicator so callers know whether
323/// the results are complete or truncated.
324#[derive(Debug)]
325pub struct QueryResultSet<T> {
326    /// The results (may be partial if status != Complete)
327    pub results: Vec<T>,
328
329    /// Completion status indicating if results are complete or truncated
330    pub status: QueryCompletionStatus,
331
332    /// Actual memory usage tracked during execution
333    pub memory_usage_bytes: usize,
334
335    /// Actual elapsed time
336    pub elapsed: Duration,
337}
338
339impl<T> QueryResultSet<T> {
340    /// Create a complete result set
341    #[must_use]
342    pub fn complete(results: Vec<T>, memory_usage_bytes: usize, elapsed: Duration) -> Self {
343        Self {
344            results,
345            status: QueryCompletionStatus::Complete,
346            memory_usage_bytes,
347            elapsed,
348        }
349    }
350
351    /// Create a truncated result set
352    #[must_use]
353    pub fn truncated(
354        results: Vec<T>,
355        status: QueryCompletionStatus,
356        memory_usage_bytes: usize,
357        elapsed: Duration,
358    ) -> Self {
359        Self {
360            results,
361            status,
362            memory_usage_bytes,
363            elapsed,
364        }
365    }
366
367    /// Returns true if all results were returned
368    #[must_use]
369    pub fn is_complete(&self) -> bool {
370        self.status.is_complete()
371    }
372
373    /// Get the number of results
374    #[must_use]
375    pub fn len(&self) -> usize {
376        self.results.len()
377    }
378
379    /// Check if the result set is empty
380    #[must_use]
381    pub fn is_empty(&self) -> bool {
382        self.results.is_empty()
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_guard_initial_state() {
392        let guard = QueryGuard::new(QuerySecurityConfig::default());
393        assert_eq!(guard.result_count(), 0);
394        assert_eq!(guard.memory_usage(), 0);
395        assert!(guard.should_continue().is_ok());
396    }
397
398    #[test]
399    fn test_guard_record_result() {
400        let guard = QueryGuard::new(QuerySecurityConfig::default());
401        guard.record_result(1024);
402        assert_eq!(guard.result_count(), 1);
403        assert_eq!(guard.memory_usage(), 1024);
404    }
405
406    #[test]
407    fn test_guard_result_cap() {
408        let config = QuerySecurityConfig::default().with_result_cap(5);
409        let guard = QueryGuard::new(config);
410
411        for _ in 0..5 {
412            guard.record_result(100);
413        }
414
415        let err = guard.should_continue().unwrap_err();
416        assert!(matches!(
417            err,
418            QuerySecurityError::ResultCapExceeded { count: 5, limit: 5 }
419        ));
420    }
421
422    #[test]
423    fn test_guard_memory_limit() {
424        // Set check interval to 1 so we check every time
425        let config = QuerySecurityConfig::default().with_memory_limit(1000);
426        let guard = QueryGuard::with_check_interval(config, 1);
427
428        // Add enough to exceed limit
429        guard.record_result(500);
430        assert!(guard.should_continue().is_ok());
431
432        guard.record_result(600);
433        let err = guard.should_continue().unwrap_err();
434        assert!(matches!(
435            err,
436            QuerySecurityError::MemoryLimitExceeded { .. }
437        ));
438    }
439
440    #[test]
441    fn test_completion_status_messages() {
442        assert_eq!(
443            QueryCompletionStatus::Complete.message(),
444            "Query completed successfully"
445        );
446
447        let cap_status = QueryCompletionStatus::ResultCapReached {
448            count: 100,
449            limit: 100,
450        };
451        assert!(cap_status.message().contains("100"));
452
453        let mem_status = QueryCompletionStatus::MemoryLimitReached {
454            usage_bytes: 10 * 1024 * 1024,
455            limit_bytes: 10 * 1024 * 1024,
456        };
457        assert!(mem_status.message().contains("MB"));
458
459        let timeout_status = QueryCompletionStatus::TimedOut {
460            elapsed: Duration::from_secs(5),
461            limit: Duration::from_secs(5),
462        };
463        assert!(timeout_status.message().contains("timed out"));
464    }
465
466    #[test]
467    fn test_completion_status_is_complete() {
468        assert!(QueryCompletionStatus::Complete.is_complete());
469        assert!(
470            !QueryCompletionStatus::ResultCapReached {
471                count: 10,
472                limit: 10
473            }
474            .is_complete()
475        );
476    }
477
478    #[test]
479    fn test_error_to_status_conversion() {
480        let timeout_err = QuerySecurityError::Timeout {
481            elapsed: Duration::from_secs(10),
482            limit: Duration::from_secs(5),
483        };
484        assert!(matches!(
485            timeout_err.into_completion_status(),
486            QueryCompletionStatus::TimedOut { .. }
487        ));
488
489        let cap_err = QuerySecurityError::ResultCapExceeded {
490            count: 100,
491            limit: 50,
492        };
493        assert!(matches!(
494            cap_err.into_completion_status(),
495            QueryCompletionStatus::ResultCapReached { .. }
496        ));
497    }
498
499    #[test]
500    fn test_result_set_complete() {
501        let results = vec![1, 2, 3];
502        let set = QueryResultSet::complete(results, 100, Duration::from_millis(10));
503        assert!(set.is_complete());
504        assert_eq!(set.len(), 3);
505        assert!(!set.is_empty());
506    }
507
508    #[test]
509    fn test_result_set_truncated() {
510        let results = vec![1, 2];
511        let status = QueryCompletionStatus::ResultCapReached { count: 2, limit: 2 };
512        let set = QueryResultSet::truncated(results, status, 50, Duration::from_millis(5));
513        assert!(!set.is_complete());
514        assert_eq!(set.len(), 2);
515    }
516
517    #[test]
518    fn test_exit_codes() {
519        assert_eq!(QueryCompletionStatus::Complete.exit_code(), 0);
520        assert_eq!(
521            QueryCompletionStatus::ResultCapReached {
522                count: 10,
523                limit: 10
524            }
525            .exit_code(),
526            2
527        );
528        assert_eq!(
529            QueryCompletionStatus::TimedOut {
530                elapsed: Duration::from_secs(5),
531                limit: Duration::from_secs(5)
532            }
533            .exit_code(),
534            2
535        );
536    }
537}