1use super::error::{AstQueryError, Result};
28use super::types::{ContextKind, ContextualMatch};
29use lru::LruCache;
30use regex::Regex;
31use std::num::NonZeroUsize;
32use std::sync::Mutex;
33
34const MAX_REGEX_LENGTH: usize = 1000;
36
37const MAX_REPETITION_COUNT: usize = 1000;
39
40fn validate_regex_pattern(pattern: &str) -> Result<()> {
42 RegexSafetyScanner::new(pattern)?.scan()
43}
44
45struct RegexSafetyScanner<'a> {
46 pattern: &'a str,
47 chars: Vec<char>,
48 pos: usize,
49 state: RegexScanState,
50}
51
52impl<'a> RegexSafetyScanner<'a> {
53 fn new(pattern: &'a str) -> Result<Self> {
54 if pattern.len() > MAX_REGEX_LENGTH {
55 return Err(invalid_regex(
56 pattern,
57 &format!(
58 "Regex pattern too long (max {} chars, got {})",
59 MAX_REGEX_LENGTH,
60 pattern.len()
61 ),
62 ));
63 }
64
65 Ok(Self {
66 pattern,
67 chars: pattern.chars().collect(),
68 pos: 0,
69 state: RegexScanState::default(),
70 })
71 }
72
73 fn scan(mut self) -> Result<()> {
74 while self.pos < self.chars.len() {
75 let ch = self.chars[self.pos];
76
77 if self.state.consume_escape(ch) {
78 self.pos += 1;
79 continue;
80 }
81
82 if self.state.update_bracket_state(ch) || self.state.in_brackets {
83 self.pos += 1;
84 continue;
85 }
86
87 if self.is_quantified_group() {
89 self.check_group_safety()?;
90 }
91
92 if self.handle_repetition_range()? {
94 continue;
95 }
96
97 if self.is_nested_repetition() {
99 return Err(invalid_regex(
100 self.pattern,
101 "Regex pattern contains directly nested repetition operators",
102 ));
103 }
104
105 self.pos += 1;
106 }
107
108 Ok(())
111 }
112
113 fn is_quantified_group(&self) -> bool {
114 if self.chars[self.pos] != ')' || self.pos + 1 >= self.chars.len() {
115 return false;
116 }
117 matches!(self.chars[self.pos + 1], '+' | '*' | '?' | '{')
118 }
119
120 fn check_group_safety(&self) -> Result<()> {
121 if self.has_alternation_in_group() {
122 return Err(invalid_regex(
123 self.pattern,
124 "Regex pattern contains alternation inside quantified group like (a|b)+ which can cause exponential backtracking",
125 ));
126 }
127
128 if self.has_quantifier_in_group() {
129 return Err(invalid_regex(
130 self.pattern,
131 "Regex pattern contains nested quantifiers like (x+)+ or (x*)* which can cause catastrophic backtracking",
132 ));
133 }
134
135 Ok(())
136 }
137
138 fn handle_repetition_range(&mut self) -> Result<bool> {
139 if self.chars[self.pos] != '{' {
140 return Ok(false);
141 }
142
143 if let Some(end_pos) = self.find_matching_brace() {
144 let range_str: String = self.chars[self.pos + 1..end_pos].iter().collect();
145 validate_repetition_range(&range_str, self.pattern)?;
146 self.pos = end_pos + 1;
147 Ok(true)
148 } else {
149 Ok(false)
150 }
151 }
152
153 fn is_nested_repetition(&self) -> bool {
154 if self.pos + 1 >= self.chars.len() {
155 return false;
156 }
157 let ch = self.chars[self.pos];
158 let next_ch = self.chars[self.pos + 1];
159 matches!(ch, '*' | '+' | '?') && matches!(next_ch, '*' | '+' | '?')
160 }
161
162 fn has_alternation_in_group(&self) -> bool {
163 self.group_contains(false, |ch| ch == '|')
164 }
165
166 fn has_quantifier_in_group(&self) -> bool {
167 self.group_contains(true, |ch| matches!(ch, '+' | '*' | '?'))
168 }
169
170 fn group_contains<F>(&self, escape_sensitive: bool, predicate: F) -> bool
171 where
172 F: Fn(char) -> bool,
173 {
174 let Some(start) = self.find_group_start(escape_sensitive) else {
175 return false;
176 };
177
178 let scan_start = start + 1;
179 if scan_start >= self.pos {
180 return false;
181 }
182
183 self.chars[scan_start..self.pos]
184 .iter()
185 .copied()
186 .any(predicate)
187 }
188
189 fn find_group_start(&self, escape_sensitive: bool) -> Option<usize> {
190 let mut depth = 0;
191 let mut i = self.pos;
192 let mut escape_next = false;
193
194 while i > 0 {
195 i -= 1;
196 let ch = self.chars[i];
197
198 if escape_sensitive {
199 if escape_next {
200 escape_next = false;
201 continue;
202 }
203
204 if ch == '\\' {
205 escape_next = true;
206 continue;
207 }
208 }
209
210 match ch {
211 ')' => depth += 1,
212 '(' => {
213 if depth == 0 {
214 return Some(i);
215 }
216 depth -= 1;
217 }
218 _ => {}
219 }
220 }
221
222 None
223 }
224
225 fn find_matching_brace(&self) -> Option<usize> {
226 for i in self.pos + 1..self.chars.len() {
227 if self.chars[i] == '}' {
228 return Some(i);
229 }
230 if !self.chars[i].is_numeric() && self.chars[i] != ',' {
232 return None;
233 }
234 }
235 None
236 }
237}
238
239#[derive(Default)]
240struct RegexScanState {
241 escape_next: bool,
242 in_brackets: bool,
243}
244
245impl RegexScanState {
246 fn consume_escape(&mut self, ch: char) -> bool {
247 if self.escape_next {
248 self.escape_next = false;
249 return true;
250 }
251 if ch == '\\' {
252 self.escape_next = true;
253 return true;
254 }
255 false
256 }
257
258 fn update_bracket_state(&mut self, ch: char) -> bool {
259 match ch {
260 '[' => {
261 self.in_brackets = true;
262 true
263 }
264 ']' => {
265 self.in_brackets = false;
266 true
267 }
268 _ => false,
269 }
270 }
271}
272
273fn invalid_regex(pattern: &str, message: &str) -> AstQueryError {
274 AstQueryError::InvalidRegex {
275 pattern: pattern.to_string(),
276 source: regex::Error::Syntax(message.to_string()),
277 }
278}
279
280fn validate_repetition_range(range_str: &str, pattern: &str) -> Result<()> {
282 let parts: Vec<&str> = range_str.split(',').collect();
283
284 for part in &parts {
285 if part.is_empty() {
286 continue;
287 }
288
289 match part.trim().parse::<usize>() {
290 Ok(count) => {
291 if count > MAX_REPETITION_COUNT {
292 return Err(AstQueryError::InvalidRegex {
293 pattern: pattern.to_string(),
294 source: regex::Error::Syntax(format!(
295 "Repetition count {count} exceeds maximum of {MAX_REPETITION_COUNT}"
296 )),
297 });
298 }
299 }
300 Err(_) => {
301 return Err(AstQueryError::InvalidRegex {
302 pattern: pattern.to_string(),
303 source: regex::Error::Syntax(format!("Invalid repetition range: {range_str}")),
304 });
305 }
306 }
307 }
308
309 Ok(())
310}
311
312#[derive(Debug, Clone, PartialEq, Eq)]
314pub enum DepthOp {
315 Eq(usize),
317 Gt(usize),
319 Lt(usize),
321 Gte(usize),
323 Lte(usize),
325}
326
327impl DepthOp {
328 #[must_use]
330 pub fn matches(&self, depth: usize) -> bool {
331 match self {
332 DepthOp::Eq(n) => depth == *n,
333 DepthOp::Gt(n) => depth > *n,
334 DepthOp::Lt(n) => depth < *n,
335 DepthOp::Gte(n) => depth >= *n,
336 DepthOp::Lte(n) => depth <= *n,
337 }
338 }
339}
340
341#[derive(Debug, Clone)]
343pub enum AstPredicate {
344 Kind(ContextKind),
346 NameRegex(Regex),
348 Parent(ContextKind),
350 In(String),
352 Depth(DepthOp),
354 Path(String),
356 Lang(String),
358}
359
360impl AstPredicate {
361 #[must_use]
363 pub fn matches(&self, ctx_match: &ContextualMatch) -> bool {
364 match self {
365 AstPredicate::Kind(kind) => ctx_match.context.kind == *kind,
366 AstPredicate::NameRegex(regex) => regex.is_match(&ctx_match.name),
367 AstPredicate::Parent(kind) => {
368 if let Some(ref parent) = ctx_match.context.parent {
369 parent.kind == *kind
370 } else {
371 false
372 }
373 }
374 AstPredicate::In(name) => {
375 ctx_match.context.ancestors.iter().any(|a| a.name == *name)
377 || ctx_match
378 .context
379 .parent
380 .as_ref()
381 .is_some_and(|p| p.name == *name)
382 }
383 AstPredicate::Depth(op) => op.matches(ctx_match.context.depth()),
384 AstPredicate::Path(path) => ctx_match.context.path().contains(path),
385 AstPredicate::Lang(lang) => ctx_match.language == *lang,
386 }
387 }
388}
389
390#[derive(Debug, Clone)]
392pub enum AstExpr {
393 Predicate(AstPredicate),
395 And(Box<AstExpr>, Box<AstExpr>),
397 Or(Box<AstExpr>, Box<AstExpr>),
399 Not(Box<AstExpr>),
401}
402
403impl AstExpr {
404 #[must_use]
406 pub fn matches(&self, ctx_match: &ContextualMatch) -> bool {
407 match self {
408 AstExpr::Predicate(pred) => pred.matches(ctx_match),
409 AstExpr::And(left, right) => left.matches(ctx_match) && right.matches(ctx_match),
410 AstExpr::Or(left, right) => left.matches(ctx_match) || right.matches(ctx_match),
411 AstExpr::Not(expr) => !expr.matches(ctx_match),
412 }
413 }
414}
415
416const QUERY_CACHE_SIZE: usize = 100;
418
419static QUERY_CACHE: std::sync::LazyLock<Mutex<LruCache<String, AstExpr>>> =
421 std::sync::LazyLock::new(|| {
422 Mutex::new(LruCache::new(NonZeroUsize::new(QUERY_CACHE_SIZE).unwrap()))
423 });
424
425pub fn parse_query(input: &str) -> Result<AstExpr> {
432 {
434 match QUERY_CACHE.lock() {
435 Ok(mut cache) => {
436 if let Some(cached) = cache.get(input) {
437 return Ok(cached.clone());
438 }
439 }
440 Err(poisoned) => {
441 let mut cache = poisoned.into_inner();
442 if let Some(cached) = cache.get(input) {
443 return Ok(cached.clone());
444 }
445 }
446 }
447 }
448
449 let mut parser = QueryParser::new(input);
451 let expr = parser.parse_expr()?;
452
453 {
455 match QUERY_CACHE.lock() {
456 Ok(mut cache) => {
457 cache.put(input.to_string(), expr.clone());
458 }
459 Err(poisoned) => {
460 let mut cache = poisoned.into_inner();
461 cache.put(input.to_string(), expr.clone());
462 }
463 }
464 }
465
466 Ok(expr)
467}
468
469#[cfg(test)]
471#[allow(deprecated)]
472pub fn clear_query_cache() {
473 match QUERY_CACHE.lock() {
474 Ok(mut cache) => cache.clear(),
475 Err(poisoned) => poisoned.into_inner().clear(),
476 }
477}
478
479#[derive(Debug, Clone, PartialEq, Eq)]
481enum Token {
482 Identifier(String),
483 Colon,
484 TildeEquals, LParen,
486 RParen,
487 And,
488 Or,
489 Not,
490 Eof,
491}
492
493struct QueryParser {
495 chars: Vec<char>,
496 pos: usize,
497 current_token: Token,
498}
499
500impl QueryParser {
501 fn new(input: &str) -> Self {
502 let mut parser = Self {
503 chars: input.chars().collect(),
504 pos: 0,
505 current_token: Token::Eof,
506 };
507 parser.current_token = parser.next_token();
508 parser
509 }
510
511 fn parse_expr(&mut self) -> Result<AstExpr> {
512 self.parse_or()
513 }
514
515 fn parse_or(&mut self) -> Result<AstExpr> {
516 let mut left = self.parse_and()?;
517
518 while matches!(self.current_token, Token::Or) {
519 self.advance();
520 let right = self.parse_and()?;
521 left = AstExpr::Or(Box::new(left), Box::new(right));
522 }
523
524 Ok(left)
525 }
526
527 fn parse_and(&mut self) -> Result<AstExpr> {
528 let mut left = self.parse_not()?;
529
530 while matches!(self.current_token, Token::And) {
531 self.advance();
532 let right = self.parse_not()?;
533 left = AstExpr::And(Box::new(left), Box::new(right));
534 }
535
536 Ok(left)
537 }
538
539 fn parse_not(&mut self) -> Result<AstExpr> {
540 if matches!(self.current_token, Token::Not) {
541 self.advance();
542 let expr = self.parse_not()?;
543 Ok(AstExpr::Not(Box::new(expr)))
544 } else {
545 self.parse_primary()
546 }
547 }
548
549 fn parse_primary(&mut self) -> Result<AstExpr> {
550 if matches!(self.current_token, Token::LParen) {
551 self.advance();
552 let expr = self.parse_expr()?;
553 if !matches!(self.current_token, Token::RParen) {
554 return Err(AstQueryError::UnexpectedToken {
555 expected: ")".to_string(),
556 actual: format!("{:?}", self.current_token),
557 });
558 }
559 self.advance();
560 return Ok(expr);
561 }
562
563 self.parse_predicate()
564 }
565
566 fn parse_predicate(&mut self) -> Result<AstExpr> {
567 let name = self.expect_identifier("predicate name")?;
568 self.advance();
569
570 let op = self.expect_operator()?;
571 self.advance();
572
573 let value = self.parse_predicate_value(&name, &op)?;
574
575 Self::create_predicate(&name, op, value)
576 }
577
578 fn expect_identifier(&self, expected: &str) -> Result<String> {
579 if let Token::Identifier(name) = &self.current_token {
580 Ok(name.clone())
581 } else {
582 Err(AstQueryError::UnexpectedToken {
583 expected: expected.to_string(),
584 actual: format!("{:?}", self.current_token),
585 })
586 }
587 }
588
589 fn expect_operator(&self) -> Result<Token> {
590 let op = self.current_token.clone();
591 if !matches!(op, Token::Colon | Token::TildeEquals) {
592 return Err(AstQueryError::UnexpectedToken {
593 expected: ":' or '~='".to_string(),
594 actual: format!("{op:?}"),
595 });
596 }
597 Ok(op)
598 }
599
600 fn parse_predicate_value(&mut self, name: &str, op: &Token) -> Result<String> {
601 if name == "name" && matches!(op, Token::TildeEquals) {
602 self.read_regex_value()
603 } else if let Token::Identifier(val) = &self.current_token {
604 let v = val.clone();
605 self.advance();
606 Ok(v)
607 } else {
608 Err(AstQueryError::UnexpectedToken {
609 expected: "value".to_string(),
610 actual: format!("{:?}", self.current_token),
611 })
612 }
613 }
614
615 fn create_predicate(name: &str, op: Token, value: String) -> Result<AstExpr> {
616 let predicate = match (name, op) {
617 ("kind", Token::Colon) => AstPredicate::Kind(Self::parse_context_kind(&value)?),
618 ("name", Token::TildeEquals) => {
619 validate_regex_pattern(&value)?;
620 let regex = Regex::new(&value).map_err(|e| AstQueryError::InvalidRegex {
621 pattern: value.clone(),
622 source: e,
623 })?;
624 AstPredicate::NameRegex(regex)
625 }
626 ("parent", Token::Colon) => AstPredicate::Parent(Self::parse_context_kind(&value)?),
627 ("in", Token::Colon) => AstPredicate::In(value),
628 ("depth", Token::Colon) => AstPredicate::Depth(Self::parse_depth_op(&value)?),
629 ("path", Token::Colon) => AstPredicate::Path(value),
630 ("lang", Token::Colon) => AstPredicate::Lang(value),
631 (pred, _) => {
632 return Err(AstQueryError::UnknownPredicate {
633 predicate: pred.to_string(),
634 });
635 }
636 };
637
638 Ok(AstExpr::Predicate(predicate))
639 }
640
641 fn parse_context_kind(s: &str) -> Result<ContextKind> {
642 match s {
643 "function" => Ok(ContextKind::Function),
644 "method" => Ok(ContextKind::Method),
645 "class" => Ok(ContextKind::Class),
646 "struct" => Ok(ContextKind::Struct),
647 "interface" => Ok(ContextKind::Interface),
648 "enum" => Ok(ContextKind::Enum),
649 "trait" => Ok(ContextKind::Trait),
650 "module" => Ok(ContextKind::Module),
651 "constant" => Ok(ContextKind::Constant),
652 "variable" => Ok(ContextKind::Variable),
653 "type" | "typealias" => Ok(ContextKind::TypeAlias),
654 "impl" => Ok(ContextKind::Impl),
655 _ => Err(AstQueryError::ParseError(format!(
656 "Unknown context kind: {s}"
657 ))),
658 }
659 }
660
661 fn parse_depth_op(s: &str) -> Result<DepthOp> {
662 if let Some(rest) = s.strip_prefix(">=") {
663 Self::parse_depth_num(rest, DepthOp::Gte, s)
664 } else if let Some(rest) = s.strip_prefix("<=") {
665 Self::parse_depth_num(rest, DepthOp::Lte, s)
666 } else if let Some(rest) = s.strip_prefix('>') {
667 Self::parse_depth_num(rest, DepthOp::Gt, s)
668 } else if let Some(rest) = s.strip_prefix('<') {
669 Self::parse_depth_num(rest, DepthOp::Lt, s)
670 } else {
671 Self::parse_depth_num(s, DepthOp::Eq, s)
672 }
673 }
674
675 fn parse_depth_num<F>(num_str: &str, ctor: F, original: &str) -> Result<DepthOp>
676 where
677 F: FnOnce(usize) -> DepthOp,
678 {
679 num_str
680 .parse()
681 .map(ctor)
682 .map_err(|_| AstQueryError::InvalidDepth {
683 value: original.to_string(),
684 })
685 }
686
687 fn read_regex_value(&mut self) -> Result<String> {
688 let mut value = String::new();
689 let mut paren_depth = 0;
690
691 loop {
692 match &self.current_token {
693 Token::Identifier(s) => {
694 value.push_str(s);
695 self.advance();
696 }
697 Token::LParen => {
698 value.push('(');
699 paren_depth += 1;
700 self.advance();
701 }
702 Token::RParen => {
703 if paren_depth == 0 {
704 break;
705 }
706 value.push(')');
707 paren_depth -= 1;
708 self.advance();
709 }
710 Token::And | Token::Or | Token::Eof => {
711 break;
712 }
713 _ => {
714 return Err(AstQueryError::UnexpectedToken {
715 expected: "regex value".to_string(),
716 actual: format!("{:?}", self.current_token),
717 });
718 }
719 }
720 }
721
722 if value.is_empty() {
723 return Err(AstQueryError::UnexpectedToken {
724 expected: "regex value".to_string(),
725 actual: "empty".to_string(),
726 });
727 }
728
729 Ok(value)
730 }
731
732 fn advance(&mut self) {
733 self.current_token = self.next_token();
734 }
735
736 fn next_token(&mut self) -> Token {
737 self.skip_whitespace();
738
739 if self.pos >= self.chars.len() {
740 return Token::Eof;
741 }
742
743 let ch = self.chars[self.pos];
744
745 match ch {
746 '(' => self.consume_char(Token::LParen),
747 ')' => self.consume_char(Token::RParen),
748 ':' => self.consume_char(Token::Colon),
749 '~' => self.scan_tilde(),
750 _ if Self::is_identifier_char(ch) => self.read_identifier(),
751 _ => self.consume_char(Token::Identifier(ch.to_string())),
752 }
753 }
754
755 fn consume_char(&mut self, token: Token) -> Token {
756 self.pos += 1;
757 token
758 }
759
760 fn scan_tilde(&mut self) -> Token {
761 if self.peek_char(1) == Some('=') {
762 self.pos += 2;
763 Token::TildeEquals
764 } else {
765 self.pos += 1;
766 Token::Identifier("~".to_string())
767 }
768 }
769
770 fn is_identifier_char(ch: char) -> bool {
771 ch.is_alphanumeric()
772 || matches!(
773 ch,
774 '_' | '>'
775 | '<'
776 | '='
777 | '^'
778 | '$'
779 | '.'
780 | '/'
781 | '-'
782 | '*'
783 | '+'
784 | '?'
785 | '['
786 | ']'
787 | '{'
788 | '}'
789 | '|'
790 | '\\'
791 | ','
792 )
793 }
794
795 fn read_identifier(&mut self) -> Token {
796 let start = self.pos;
797 while self.pos < self.chars.len() {
798 let ch = self.chars[self.pos];
799 if Self::is_identifier_char(ch) {
800 self.pos += 1;
801 } else {
802 break;
803 }
804 }
805
806 let text: String = self.chars[start..self.pos].iter().collect();
807 match text.to_uppercase().as_str() {
808 "AND" => Token::And,
809 "OR" => Token::Or,
810 "NOT" => Token::Not,
811 _ => Token::Identifier(text),
812 }
813 }
814
815 fn skip_whitespace(&mut self) {
816 while let Some(ch) = self.current_char() {
817 if ch.is_whitespace() {
818 self.pos += 1;
819 } else {
820 break;
821 }
822 }
823 }
824
825 fn current_char(&self) -> Option<char> {
826 self.chars.get(self.pos).copied()
827 }
828
829 fn peek_char(&self, offset: usize) -> Option<char> {
830 self.chars.get(self.pos + offset).copied()
831 }
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837 use std::path::PathBuf;
838
839 fn make_test_match(
840 name: &str,
841 kind: ContextKind,
842 parent: Option<(&str, ContextKind)>,
843 ancestors: Vec<(&str, ContextKind)>,
844 lang: &str,
845 ) -> ContextualMatch {
846 let immediate =
847 super::super::types::ContextItem::new(name.to_string(), kind, 1, 10, 0, 100);
848
849 let parent_item = parent.map(|(pname, pkind)| {
850 super::super::types::ContextItem::new(pname.to_string(), pkind, 1, 20, 0, 200)
851 });
852
853 let ancestor_items: Vec<_> = ancestors
854 .into_iter()
855 .map(|(aname, akind)| {
856 super::super::types::ContextItem::new(aname.to_string(), akind, 1, 30, 0, 300)
857 })
858 .collect();
859
860 let context = super::super::types::Context::new(immediate, parent_item, ancestor_items);
861
862 let location = super::super::types::ContextualMatchLocation::new(
863 PathBuf::from("test.rs"),
864 1,
865 0,
866 10,
867 1,
868 );
869 ContextualMatch::new(name.to_string(), location, context, lang.to_string())
870 }
871
872 #[test]
873 fn test_parse_kind_predicate() {
874 let query = parse_query("kind:function").unwrap();
875 let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
876 assert!(query.matches(&test_match));
877 }
878
879 #[test]
880 fn test_parse_name_regex() {
881 let query = parse_query("name~=test").unwrap();
882 let test_match = make_test_match("test_func", ContextKind::Function, None, vec![], "rust");
883 let other_match =
884 make_test_match("other_func", ContextKind::Function, None, vec![], "rust");
885 assert!(query.matches(&test_match), "Should match test_func");
886 assert!(!query.matches(&other_match), "Should not match other_func");
887 }
888
889 #[test]
890 fn test_parse_parent_predicate() {
891 let query = parse_query("parent:class").unwrap();
892 let test_match = make_test_match(
893 "method",
894 ContextKind::Method,
895 Some(("MyClass", ContextKind::Class)),
896 vec![],
897 "rust",
898 );
899 let no_parent = make_test_match("func", ContextKind::Function, None, vec![], "rust");
900 assert!(query.matches(&test_match));
901 assert!(!query.matches(&no_parent));
902 }
903
904 #[test]
905 fn test_parse_in_predicate() {
906 let query = parse_query("in:MyClass").unwrap();
907 let test_match = make_test_match(
908 "method",
909 ContextKind::Method,
910 Some(("InnerClass", ContextKind::Class)),
911 vec![("MyClass", ContextKind::Class)],
912 "rust",
913 );
914 let not_in = make_test_match("func", ContextKind::Function, None, vec![], "rust");
915 assert!(query.matches(&test_match));
916 assert!(!query.matches(¬_in));
917 }
918
919 #[test]
920 fn test_parse_depth_operators() {
921 let eq_query = parse_query("depth:2").unwrap();
922 let greater_than_query = parse_query("depth:>1").unwrap();
923 let less_than_query = parse_query("depth:<3").unwrap();
924 let greater_equal_query = parse_query("depth:>=2").unwrap();
925 let less_equal_query = parse_query("depth:<=2").unwrap();
926
927 let depth_2 = make_test_match(
928 "method",
929 ContextKind::Method,
930 Some(("MyClass", ContextKind::Class)),
931 vec![],
932 "rust",
933 );
934
935 assert!(eq_query.matches(&depth_2));
936 assert!(greater_than_query.matches(&depth_2));
937 assert!(less_than_query.matches(&depth_2));
938 assert!(greater_equal_query.matches(&depth_2));
939 assert!(less_equal_query.matches(&depth_2));
940 }
941
942 #[test]
943 fn test_parse_path_predicate() {
944 let query = parse_query("path:MyClass").unwrap();
945 let test_match = make_test_match(
946 "method",
947 ContextKind::Method,
948 Some(("MyClass", ContextKind::Class)),
949 vec![],
950 "rust",
951 );
952 let no_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
953 assert!(query.matches(&test_match));
954 assert!(!query.matches(&no_match));
955 }
956
957 #[test]
958 fn test_parse_lang_predicate() {
959 let query = parse_query("lang:rust").unwrap();
960 let rust_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
961 let js_match = make_test_match("func", ContextKind::Function, None, vec![], "javascript");
962 assert!(query.matches(&rust_match));
963 assert!(!query.matches(&js_match));
964 }
965
966 #[test]
967 fn test_parse_and_expression() {
968 let query = parse_query("kind:method AND parent:class").unwrap();
969 let match_both = make_test_match(
970 "method",
971 ContextKind::Method,
972 Some(("MyClass", ContextKind::Class)),
973 vec![],
974 "rust",
975 );
976 let match_kind_only = make_test_match("method", ContextKind::Method, None, vec![], "rust");
977 assert!(query.matches(&match_both));
978 assert!(!query.matches(&match_kind_only));
979 }
980
981 #[test]
982 fn test_parse_or_expression() {
983 let query = parse_query("kind:function OR kind:method").unwrap();
984 let func_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
985 let method_match = make_test_match("method", ContextKind::Method, None, vec![], "rust");
986 let class_match = make_test_match("MyClass", ContextKind::Class, None, vec![], "rust");
987 assert!(query.matches(&func_match));
988 assert!(query.matches(&method_match));
989 assert!(!query.matches(&class_match));
990 }
991
992 #[test]
993 fn test_parse_not_expression() {
994 let query = parse_query("NOT kind:class").unwrap();
995 let func_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
996 let class_match = make_test_match("MyClass", ContextKind::Class, None, vec![], "rust");
997 assert!(query.matches(&func_match));
998 assert!(!query.matches(&class_match));
999 }
1000
1001 #[test]
1002 fn test_parse_parentheses() {
1003 let query = parse_query("(kind:method AND parent:class) OR kind:function").unwrap();
1004 let method_in_class = make_test_match(
1005 "method",
1006 ContextKind::Method,
1007 Some(("MyClass", ContextKind::Class)),
1008 vec![],
1009 "rust",
1010 );
1011 let func = make_test_match("func", ContextKind::Function, None, vec![], "rust");
1012 let method_no_parent = make_test_match("method", ContextKind::Method, None, vec![], "rust");
1013 assert!(query.matches(&method_in_class));
1014 assert!(query.matches(&func));
1015 assert!(!query.matches(&method_no_parent));
1016 }
1017
1018 #[test]
1019 fn test_parse_complex_query() {
1020 let query = parse_query("kind:method AND depth:>0 AND NOT in:TestClass").unwrap();
1021 let matching = make_test_match(
1022 "method",
1023 ContextKind::Method,
1024 Some(("MyClass", ContextKind::Class)),
1025 vec![],
1026 "rust",
1027 );
1028 let in_test_class = make_test_match(
1029 "method",
1030 ContextKind::Method,
1031 Some(("TestClass", ContextKind::Class)),
1032 vec![],
1033 "rust",
1034 );
1035 assert!(query.matches(&matching));
1036 assert!(!query.matches(&in_test_class));
1037 }
1038
1039 #[test]
1040 fn test_parse_error_unknown_predicate() {
1041 let result = parse_query("unknown:value");
1042 assert!(matches!(
1043 result,
1044 Err(AstQueryError::UnknownPredicate { .. })
1045 ));
1046 }
1047
1048 #[test]
1049 fn test_parse_error_invalid_regex() {
1050 let result = parse_query("name~=[invalid");
1051 assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1052 }
1053
1054 #[test]
1055 fn test_parse_error_invalid_depth() {
1056 let result = parse_query("depth:abc");
1057 assert!(matches!(result, Err(AstQueryError::InvalidDepth { .. })));
1058 }
1059
1060 #[test]
1061 fn test_parse_error_missing_rparen() {
1062 let result = parse_query("(kind:function");
1063 assert!(matches!(result, Err(AstQueryError::UnexpectedToken { .. })));
1064 }
1065
1066 #[test]
1067 fn test_depth_op_matches() {
1068 assert!(DepthOp::Eq(3).matches(3));
1069 assert!(!DepthOp::Eq(3).matches(2));
1070
1071 assert!(DepthOp::Gt(2).matches(3));
1072 assert!(!DepthOp::Gt(2).matches(2));
1073
1074 assert!(DepthOp::Lt(3).matches(2));
1075 assert!(!DepthOp::Lt(3).matches(3));
1076
1077 assert!(DepthOp::Gte(2).matches(2));
1078 assert!(DepthOp::Gte(2).matches(3));
1079 assert!(!DepthOp::Gte(2).matches(1));
1080
1081 assert!(DepthOp::Lte(3).matches(3));
1082 assert!(DepthOp::Lte(3).matches(2));
1083 assert!(!DepthOp::Lte(3).matches(4));
1084 }
1085
1086 #[test]
1087 fn test_case_insensitive_keywords() {
1088 let query1 = parse_query("kind:function and name~=test").unwrap();
1089 let query2 = parse_query("kind:function AND name~=test").unwrap();
1090 let query3 = parse_query("kind:function AnD name~=test").unwrap();
1091 let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1092 assert!(query1.matches(&test_match));
1093 assert!(query2.matches(&test_match));
1094 assert!(query3.matches(&test_match));
1095 }
1096
1097 #[test]
1098 fn test_parse_unicode_identifier() {
1099 let unicode_identifier = "\u{6a21}\u{5757}\u{540d}";
1100 let query = parse_query(&format!("in:{unicode_identifier}")).unwrap();
1101 let test_match = make_test_match(
1102 "method",
1103 ContextKind::Method,
1104 Some((unicode_identifier, ContextKind::Class)),
1105 vec![],
1106 "rust",
1107 );
1108 assert!(query.matches(&test_match));
1109 }
1110
1111 #[test]
1112 fn test_regex_validation_too_long() {
1113 let long_pattern = "a".repeat(1001);
1114 let query_str = format!("name~={long_pattern}");
1115 let result = parse_query(&query_str);
1116 assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1117 if let Err(AstQueryError::InvalidRegex { pattern, .. }) = result {
1118 assert_eq!(pattern.len(), 1001);
1119 }
1120 }
1121
1122 #[test]
1123 fn test_regex_validation_catastrophic_backtracking_nested_plus() {
1124 let result = validate_regex_pattern("(a+)+");
1125 assert!(result.is_err(), "Nested quantifiers should be rejected");
1126 }
1127
1128 #[test]
1129 fn test_regex_validation_catastrophic_backtracking_nested_star() {
1130 let result = validate_regex_pattern("(a*)*");
1131 assert!(result.is_err(), "Nested quantifiers should be rejected");
1132 }
1133
1134 #[test]
1135 fn test_regex_validation_nested_repetition_star_plus() {
1136 let result = parse_query("name~=a*+");
1137 assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1138 }
1139
1140 #[test]
1141 fn test_regex_validation_nested_repetition_plus_star() {
1142 let result = parse_query("name~=a+*");
1143 assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1144 }
1145
1146 #[test]
1147 fn test_regex_validation_safe_patterns_allowed() {
1148 let safe_patterns = vec![
1149 "name~=test",
1150 "name~=^test_.*",
1151 "name~=foo|bar",
1152 "name~=[a-z]+",
1153 "name~=\\w+",
1154 "name~=test{1,5}",
1155 ];
1156 for pattern in safe_patterns {
1157 let result = parse_query(pattern);
1158 assert!(result.is_ok(), "Safe pattern should be allowed: {pattern}");
1159 }
1160 }
1161
1162 #[test]
1163 fn test_regex_validation_reasonable_length_allowed() {
1164 let pattern = "a".repeat(999);
1165 let query_str = format!("name~={pattern}");
1166 let result = parse_query(&query_str);
1167 assert!(result.is_ok(), "999-char pattern should be allowed");
1168 }
1169
1170 #[test]
1171 fn test_validate_regex_pattern_directly() {
1172 assert!(validate_regex_pattern("test").is_ok());
1173 assert!(
1174 validate_regex_pattern("[a-z]+").is_ok(),
1175 "Character class with quantifier should be allowed"
1176 );
1177 assert!(
1178 validate_regex_pattern("(a+)+").is_err(),
1179 "Nested quantifiers should be rejected"
1180 );
1181 assert!(
1182 validate_regex_pattern("(a*)*").is_err(),
1183 "Nested quantifiers should be rejected"
1184 );
1185 assert!(
1186 validate_regex_pattern("a*+").is_err(),
1187 "Direct nested quantifiers should be rejected"
1188 );
1189 }
1190
1191 #[test]
1192 fn test_regex_validation_alternation_explosion_simple() {
1193 let result = validate_regex_pattern("(a|ab)*");
1194 assert!(
1195 result.is_err(),
1196 "Alternation in quantified group should be rejected"
1197 );
1198 }
1199
1200 #[test]
1201 fn test_regex_validation_alternation_explosion_plus() {
1202 let result = validate_regex_pattern("(foo|bar)+");
1203 assert!(result.is_err(), "Alternation with + should be rejected");
1204 }
1205
1206 #[test]
1207 fn test_regex_validation_alternation_safe() {
1208 let result = validate_regex_pattern("(foo|bar)");
1209 assert!(
1210 result.is_ok(),
1211 "Alternation without quantifier should be allowed"
1212 );
1213 }
1214
1215 #[test]
1216 fn test_regex_validation_alternation_in_character_class() {
1217 let result = validate_regex_pattern("[a|b]+");
1218 assert!(
1219 result.is_ok(),
1220 "Pipe inside character class should be allowed"
1221 );
1222 }
1223
1224 #[test]
1225 fn test_regex_validation_large_repetition_range() {
1226 let result = validate_regex_pattern("a{1,999999}");
1227 assert!(result.is_err(), "Large repetition range should be rejected");
1228 }
1229
1230 #[test]
1231 fn test_regex_validation_safe_repetition_range() {
1232 let result = validate_regex_pattern("a{1,5}");
1233 assert!(result.is_ok(), "Small repetition range should be allowed");
1234 }
1235
1236 #[test]
1237 fn test_regex_validation_exact_max_repetition() {
1238 let result = validate_regex_pattern("a{1000}");
1239 assert!(result.is_ok(), "Repetition at max limit should be allowed");
1240 }
1241
1242 #[test]
1243 fn test_regex_validation_just_over_max_repetition() {
1244 let result = validate_regex_pattern("a{1001}");
1245 assert!(
1246 result.is_err(),
1247 "Repetition over max limit should be rejected"
1248 );
1249 }
1250
1251 #[test]
1252 fn test_regex_validation_open_ended_range() {
1253 let result = validate_regex_pattern("a{100,}");
1254 assert!(result.is_ok(), "Open-ended range should be allowed");
1255 }
1256
1257 #[test]
1258 fn test_regex_validation_combined_attacks() {
1259 assert!(
1260 validate_regex_pattern("(a|ab){1,999999}").is_err(),
1261 "Combined alternation + large range should be rejected"
1262 );
1263 assert!(
1264 validate_regex_pattern("(a+|b+)*").is_err(),
1265 "Alternation with quantified branches in quantified group should be rejected"
1266 );
1267 }
1268
1269 #[test]
1270 fn test_query_cache_basic() {
1271 clear_query_cache();
1272 let query1 = parse_query("kind:function").unwrap();
1273 let query2 = parse_query("kind:function").unwrap();
1274 let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1275 assert!(query1.matches(&test_match));
1276 assert!(query2.matches(&test_match));
1277 }
1278
1279 #[test]
1280 fn test_query_cache_different_queries() {
1281 clear_query_cache();
1282 let query1 = parse_query("kind:function").unwrap();
1283 let query2 = parse_query("kind:method").unwrap();
1284 let func_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1285 let method_match = make_test_match("test", ContextKind::Method, None, vec![], "rust");
1286 assert!(query1.matches(&func_match));
1287 assert!(!query1.matches(&method_match));
1288 assert!(!query2.matches(&func_match));
1289 assert!(query2.matches(&method_match));
1290 }
1291
1292 #[test]
1293 fn test_query_cache_eviction() {
1294 clear_query_cache();
1295 for i in 0..150 {
1296 let query_str = format!("kind:function AND name~=test{i}");
1297 let _ = parse_query(&query_str).unwrap();
1298 }
1299 for i in 0..150 {
1300 let query_str = format!("kind:function AND name~=test{i}");
1301 let result = parse_query(&query_str);
1302 assert!(result.is_ok(), "Query {i} should parse successfully");
1303 }
1304 }
1305
1306 #[test]
1307 fn test_query_cache_with_errors() {
1308 clear_query_cache();
1309 let result1 = parse_query("invalid~syntax");
1310 assert!(result1.is_err());
1311 let result2 = parse_query("invalid~syntax");
1312 assert!(result2.is_err());
1313 let result3 = parse_query("kind:function");
1314 assert!(result3.is_ok());
1315 }
1316
1317 #[test]
1318 fn test_query_cache_clear() {
1319 clear_query_cache();
1320 let _ = parse_query("kind:function").unwrap();
1321 clear_query_cache();
1322 let query = parse_query("kind:function").unwrap();
1323 let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1324 assert!(query.matches(&test_match));
1325 }
1326
1327 #[test]
1328 fn test_query_cache_thread_safety() {
1329 use std::thread;
1330 clear_query_cache();
1331 let handles: Vec<_> = (0..10)
1332 .map(|_| {
1333 thread::spawn(|| {
1334 let query = parse_query("kind:function").unwrap();
1335 let test_match =
1336 make_test_match("test", ContextKind::Function, None, vec![], "rust");
1337 assert!(query.matches(&test_match));
1338 })
1339 })
1340 .collect();
1341 for handle in handles {
1342 handle.join().unwrap();
1343 }
1344 }
1345
1346 #[test]
1347 fn test_query_cache_complex_queries() {
1348 clear_query_cache();
1349 let complex_queries = vec![
1350 "kind:function AND name~=^test_",
1351 "(kind:method AND parent:class) OR kind:function",
1352 "depth:>3 AND NOT in:TestClass",
1353 "kind:function AND name~=.*helper.* AND depth:<=2",
1354 ];
1355 for query_str in &complex_queries {
1356 let query1 = parse_query(query_str).unwrap();
1357 let query2 = parse_query(query_str).unwrap();
1358 let test_match =
1359 make_test_match("test_func", ContextKind::Function, None, vec![], "rust");
1360 assert_eq!(query1.matches(&test_match), query2.matches(&test_match));
1361 }
1362 }
1363
1364 #[test]
1365 fn test_query_cache_poison_recovery() {
1366 use std::sync::{Arc, Barrier};
1367 use std::thread;
1368 clear_query_cache();
1369 let _ = parse_query("kind:function").unwrap();
1370 let barrier = Arc::new(Barrier::new(5));
1371 let handles: Vec<_> = (0..5)
1372 .map(|_| {
1373 let barrier = Arc::clone(&barrier);
1374 thread::spawn(move || {
1375 barrier.wait();
1376 let query = parse_query("kind:function").unwrap();
1377 let test_match =
1378 make_test_match("test", ContextKind::Function, None, vec![], "rust");
1379 assert!(query.matches(&test_match));
1380 })
1381 })
1382 .collect();
1383 for handle in handles {
1384 handle.join().unwrap();
1385 }
1386 let query = parse_query("kind:function").unwrap();
1387 let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1388 assert!(query.matches(&test_match));
1389 }
1390}