sqry_core/query/security/
guard.rs1use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::{Duration, Instant};
13
14use super::config::QuerySecurityConfig;
15
16pub struct QueryGuard {
35 config: QuerySecurityConfig,
36 start_time: Instant,
37 result_count: AtomicUsize,
38 memory_usage: AtomicUsize,
39 check_interval: usize,
40 checks_performed: AtomicUsize,
41}
42
43impl QueryGuard {
44 #[must_use]
46 pub fn new(config: QuerySecurityConfig) -> Self {
47 Self {
48 config,
49 start_time: Instant::now(),
50 result_count: AtomicUsize::new(0),
51 memory_usage: AtomicUsize::new(0),
52 check_interval: 100, checks_performed: AtomicUsize::new(0),
54 }
55 }
56
57 #[must_use]
62 pub fn with_check_interval(config: QuerySecurityConfig, interval: usize) -> Self {
63 Self {
64 check_interval: interval.max(1), ..Self::new(config)
66 }
67 }
68
69 pub fn should_continue(&self) -> Result<(), QuerySecurityError> {
80 let elapsed = self.start_time.elapsed();
82 let timeout_limit = self.config.timeout();
83 if elapsed > timeout_limit {
84 return Err(QuerySecurityError::Timeout {
85 elapsed,
86 limit: timeout_limit,
87 });
88 }
89
90 let count = self.result_count.load(Ordering::Relaxed);
92 let result_limit = self.config.result_cap();
93 if count >= result_limit {
94 return Err(QuerySecurityError::ResultCapExceeded {
95 count,
96 limit: result_limit,
97 });
98 }
99
100 let checks = self.checks_performed.fetch_add(1, Ordering::Relaxed);
102 if checks.is_multiple_of(self.check_interval) {
103 let usage = self.memory_usage.load(Ordering::Relaxed);
104 let memory_limit = self.config.memory_limit();
105 if usage >= memory_limit {
106 return Err(QuerySecurityError::MemoryLimitExceeded {
107 usage,
108 limit: memory_limit,
109 });
110 }
111 }
112
113 Ok(())
114 }
115
116 pub fn record_result(&self, estimated_size: usize) {
124 self.result_count.fetch_add(1, Ordering::Relaxed);
125 self.memory_usage
126 .fetch_add(estimated_size, Ordering::Relaxed);
127 }
128
129 #[must_use]
131 pub fn elapsed(&self) -> Duration {
132 self.start_time.elapsed()
133 }
134
135 #[must_use]
137 pub fn result_count(&self) -> usize {
138 self.result_count.load(Ordering::Relaxed)
139 }
140
141 #[must_use]
143 pub fn memory_usage(&self) -> usize {
144 self.memory_usage.load(Ordering::Relaxed)
145 }
146
147 #[must_use]
149 pub fn config(&self) -> &QuerySecurityConfig {
150 &self.config
151 }
152}
153
154#[derive(Debug, thiserror::Error)]
156pub enum QuerySecurityError {
157 #[error("Query timeout: {elapsed:?} exceeded {limit:?}")]
159 Timeout {
160 elapsed: Duration,
162 limit: Duration,
164 },
165
166 #[error("Result cap exceeded: {count} >= {limit}")]
168 ResultCapExceeded {
169 count: usize,
171 limit: usize,
173 },
174
175 #[error("Memory limit exceeded: {usage} bytes >= {limit} bytes")]
177 MemoryLimitExceeded {
178 usage: usize,
180 limit: usize,
182 },
183
184 #[error("Query cost exceeds limit: {estimated} > {limit}")]
186 CostLimitExceeded {
187 estimated: usize,
189 limit: usize,
191 },
192}
193
194impl QuerySecurityError {
195 #[must_use]
200 pub fn into_completion_status(self) -> QueryCompletionStatus {
201 match self {
202 Self::Timeout { elapsed, limit } => QueryCompletionStatus::TimedOut { elapsed, limit },
203 Self::ResultCapExceeded { count, limit } => {
204 QueryCompletionStatus::ResultCapReached { count, limit }
205 }
206 Self::MemoryLimitExceeded { usage, limit } => {
207 QueryCompletionStatus::MemoryLimitReached {
208 usage_bytes: usage,
209 limit_bytes: limit,
210 }
211 }
212 Self::CostLimitExceeded { .. } =>
213 {
216 QueryCompletionStatus::Complete
217 }
218 }
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub enum QueryCompletionStatus {
229 Complete,
231
232 ResultCapReached {
234 count: usize,
236 limit: usize,
238 },
239
240 MemoryLimitReached {
242 usage_bytes: usize,
244 limit_bytes: usize,
246 },
247
248 TimedOut {
250 elapsed: Duration,
252 limit: Duration,
254 },
255}
256
257impl QueryCompletionStatus {
258 #[must_use]
260 pub fn is_complete(&self) -> bool {
261 matches!(self, Self::Complete)
262 }
263
264 #[must_use]
266 pub fn message(&self) -> String {
267 match self {
268 Self::Complete => "Query completed successfully".to_string(),
269 Self::ResultCapReached { count, limit } => {
270 format!(
271 "Results truncated: showing {count} of {limit}+ matches (result cap reached)"
272 )
273 }
274 Self::MemoryLimitReached {
275 usage_bytes,
276 limit_bytes,
277 } => {
278 format!(
279 "Results truncated: memory limit reached ({} of {} MB)",
280 usage_bytes / (1024 * 1024),
281 limit_bytes / (1024 * 1024)
282 )
283 }
284 Self::TimedOut { elapsed, limit } => {
285 format!(
286 "Results truncated: query timed out after {:.1}s (limit: {}s)",
287 elapsed.as_secs_f64(),
288 limit.as_secs()
289 )
290 }
291 }
292 }
293
294 #[must_use]
298 pub fn status_field(&self) -> &'static str {
299 match self {
300 Self::Complete => "complete",
301 Self::ResultCapReached { .. } => "result_cap_reached",
302 Self::MemoryLimitReached { .. } => "memory_limit_reached",
303 Self::TimedOut { .. } => "timed_out",
304 }
305 }
306
307 #[must_use]
312 pub fn exit_code(&self) -> i32 {
313 match self {
314 Self::Complete => 0,
315 _ => 2, }
317 }
318}
319
320#[derive(Debug)]
325pub struct QueryResultSet<T> {
326 pub results: Vec<T>,
328
329 pub status: QueryCompletionStatus,
331
332 pub memory_usage_bytes: usize,
334
335 pub elapsed: Duration,
337}
338
339impl<T> QueryResultSet<T> {
340 #[must_use]
342 pub fn complete(results: Vec<T>, memory_usage_bytes: usize, elapsed: Duration) -> Self {
343 Self {
344 results,
345 status: QueryCompletionStatus::Complete,
346 memory_usage_bytes,
347 elapsed,
348 }
349 }
350
351 #[must_use]
353 pub fn truncated(
354 results: Vec<T>,
355 status: QueryCompletionStatus,
356 memory_usage_bytes: usize,
357 elapsed: Duration,
358 ) -> Self {
359 Self {
360 results,
361 status,
362 memory_usage_bytes,
363 elapsed,
364 }
365 }
366
367 #[must_use]
369 pub fn is_complete(&self) -> bool {
370 self.status.is_complete()
371 }
372
373 #[must_use]
375 pub fn len(&self) -> usize {
376 self.results.len()
377 }
378
379 #[must_use]
381 pub fn is_empty(&self) -> bool {
382 self.results.is_empty()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_guard_initial_state() {
392 let guard = QueryGuard::new(QuerySecurityConfig::default());
393 assert_eq!(guard.result_count(), 0);
394 assert_eq!(guard.memory_usage(), 0);
395 assert!(guard.should_continue().is_ok());
396 }
397
398 #[test]
399 fn test_guard_record_result() {
400 let guard = QueryGuard::new(QuerySecurityConfig::default());
401 guard.record_result(1024);
402 assert_eq!(guard.result_count(), 1);
403 assert_eq!(guard.memory_usage(), 1024);
404 }
405
406 #[test]
407 fn test_guard_result_cap() {
408 let config = QuerySecurityConfig::default().with_result_cap(5);
409 let guard = QueryGuard::new(config);
410
411 for _ in 0..5 {
412 guard.record_result(100);
413 }
414
415 let err = guard.should_continue().unwrap_err();
416 assert!(matches!(
417 err,
418 QuerySecurityError::ResultCapExceeded { count: 5, limit: 5 }
419 ));
420 }
421
422 #[test]
423 fn test_guard_memory_limit() {
424 let config = QuerySecurityConfig::default().with_memory_limit(1000);
426 let guard = QueryGuard::with_check_interval(config, 1);
427
428 guard.record_result(500);
430 assert!(guard.should_continue().is_ok());
431
432 guard.record_result(600);
433 let err = guard.should_continue().unwrap_err();
434 assert!(matches!(
435 err,
436 QuerySecurityError::MemoryLimitExceeded { .. }
437 ));
438 }
439
440 #[test]
441 fn test_completion_status_messages() {
442 assert_eq!(
443 QueryCompletionStatus::Complete.message(),
444 "Query completed successfully"
445 );
446
447 let cap_status = QueryCompletionStatus::ResultCapReached {
448 count: 100,
449 limit: 100,
450 };
451 assert!(cap_status.message().contains("100"));
452
453 let mem_status = QueryCompletionStatus::MemoryLimitReached {
454 usage_bytes: 10 * 1024 * 1024,
455 limit_bytes: 10 * 1024 * 1024,
456 };
457 assert!(mem_status.message().contains("MB"));
458
459 let timeout_status = QueryCompletionStatus::TimedOut {
460 elapsed: Duration::from_secs(5),
461 limit: Duration::from_secs(5),
462 };
463 assert!(timeout_status.message().contains("timed out"));
464 }
465
466 #[test]
467 fn test_completion_status_is_complete() {
468 assert!(QueryCompletionStatus::Complete.is_complete());
469 assert!(
470 !QueryCompletionStatus::ResultCapReached {
471 count: 10,
472 limit: 10
473 }
474 .is_complete()
475 );
476 }
477
478 #[test]
479 fn test_error_to_status_conversion() {
480 let timeout_err = QuerySecurityError::Timeout {
481 elapsed: Duration::from_secs(10),
482 limit: Duration::from_secs(5),
483 };
484 assert!(matches!(
485 timeout_err.into_completion_status(),
486 QueryCompletionStatus::TimedOut { .. }
487 ));
488
489 let cap_err = QuerySecurityError::ResultCapExceeded {
490 count: 100,
491 limit: 50,
492 };
493 assert!(matches!(
494 cap_err.into_completion_status(),
495 QueryCompletionStatus::ResultCapReached { .. }
496 ));
497 }
498
499 #[test]
500 fn test_result_set_complete() {
501 let results = vec![1, 2, 3];
502 let set = QueryResultSet::complete(results, 100, Duration::from_millis(10));
503 assert!(set.is_complete());
504 assert_eq!(set.len(), 3);
505 assert!(!set.is_empty());
506 }
507
508 #[test]
509 fn test_result_set_truncated() {
510 let results = vec![1, 2];
511 let status = QueryCompletionStatus::ResultCapReached { count: 2, limit: 2 };
512 let set = QueryResultSet::truncated(results, status, 50, Duration::from_millis(5));
513 assert!(!set.is_complete());
514 assert_eq!(set.len(), 2);
515 }
516
517 #[test]
518 fn test_exit_codes() {
519 assert_eq!(QueryCompletionStatus::Complete.exit_code(), 0);
520 assert_eq!(
521 QueryCompletionStatus::ResultCapReached {
522 count: 10,
523 limit: 10
524 }
525 .exit_code(),
526 2
527 );
528 assert_eq!(
529 QueryCompletionStatus::TimedOut {
530 elapsed: Duration::from_secs(5),
531 limit: Duration::from_secs(5)
532 }
533 .exit_code(),
534 2
535 );
536 }
537}