Skip to main content

sqry_core/ast/
query.rs

1//! Query parser and evaluator for AST queries
2//!
3//! Implements the query language for searching code by AST structure.
4//!
5//! # Query Language
6//!
7//! The query language supports 7 predicate types:
8//! - `kind:TYPE` - Match symbol kind
9//! - `name~=REGEX` - Match name with regex
10//! - `parent:TYPE` - Match parent kind
11//! - `in:NAME` - Match if inside ancestor
12//! - `depth:N` or `depth:>N` - Match nesting depth
13//! - `path:TEXT` - Match symbol path
14//! - `lang:LANG` - Match language
15//!
16//! Boolean operators: AND, OR, NOT with parentheses for grouping.
17//!
18//! # Example
19//!
20//! ```ignore
21//! let query = parse_query("kind:function AND name~=^test_")?;
22//! let matches: Vec<_> = contextual_matches.into_iter()
23//!     .filter(|m| query.matches(m))
24//!     .collect();
25//! ```
26
27use super::error::{AstQueryError, Result};
28use super::types::{ContextKind, ContextualMatch};
29use lru::LruCache;
30use regex::Regex;
31use std::num::NonZeroUsize;
32use std::sync::Mutex;
33
34/// Maximum allowed length for regex patterns (security limit)
35const MAX_REGEX_LENGTH: usize = 1000;
36
37/// Maximum value for repetition ranges (e.g., {n,m})
38const MAX_REPETITION_COUNT: usize = 1000;
39
40/// Validate regex pattern for security (prevent `DoS` via catastrophic backtracking)
41fn validate_regex_pattern(pattern: &str) -> Result<()> {
42    RegexSafetyScanner::new(pattern)?.scan()
43}
44
45struct RegexSafetyScanner<'a> {
46    pattern: &'a str,
47    chars: Vec<char>,
48    pos: usize,
49    state: RegexScanState,
50}
51
52impl<'a> RegexSafetyScanner<'a> {
53    fn new(pattern: &'a str) -> Result<Self> {
54        if pattern.len() > MAX_REGEX_LENGTH {
55            return Err(invalid_regex(
56                pattern,
57                &format!(
58                    "Regex pattern too long (max {} chars, got {})",
59                    MAX_REGEX_LENGTH,
60                    pattern.len()
61                ),
62            ));
63        }
64
65        Ok(Self {
66            pattern,
67            chars: pattern.chars().collect(),
68            pos: 0,
69            state: RegexScanState::default(),
70        })
71    }
72
73    fn scan(mut self) -> Result<()> {
74        while self.pos < self.chars.len() {
75            let ch = self.chars[self.pos];
76
77            if self.state.consume_escape(ch) {
78                self.pos += 1;
79                continue;
80            }
81
82            if self.state.update_bracket_state(ch) || self.state.in_brackets {
83                self.pos += 1;
84                continue;
85            }
86
87            // Check for alternation inside quantified groups: (a|b)+ or (a|b)*
88            if self.is_quantified_group() {
89                self.check_group_safety()?;
90            }
91
92            // Check for repetition ranges {n,m} and validate bounds
93            if self.handle_repetition_range()? {
94                continue;
95            }
96
97            // Check for simple nested repetition (a*+, a+*, a**, a++, etc.)
98            if self.is_nested_repetition() {
99                return Err(invalid_regex(
100                    self.pattern,
101                    "Regex pattern contains directly nested repetition operators",
102                ));
103            }
104
105            self.pos += 1;
106        }
107
108        // Note: Consecutive quantifiers (e.g., .*.+, \w*\w*) are NOT validated.
109        // Rust's regex crate is immune to catastrophic backtracking.
110        Ok(())
111    }
112
113    fn is_quantified_group(&self) -> bool {
114        if self.chars[self.pos] != ')' || self.pos + 1 >= self.chars.len() {
115            return false;
116        }
117        matches!(self.chars[self.pos + 1], '+' | '*' | '?' | '{')
118    }
119
120    fn check_group_safety(&self) -> Result<()> {
121        if self.has_alternation_in_group() {
122            return Err(invalid_regex(
123                self.pattern,
124                "Regex pattern contains alternation inside quantified group like (a|b)+ which can cause exponential backtracking",
125            ));
126        }
127
128        if self.has_quantifier_in_group() {
129            return Err(invalid_regex(
130                self.pattern,
131                "Regex pattern contains nested quantifiers like (x+)+ or (x*)* which can cause catastrophic backtracking",
132            ));
133        }
134
135        Ok(())
136    }
137
138    fn handle_repetition_range(&mut self) -> Result<bool> {
139        if self.chars[self.pos] != '{' {
140            return Ok(false);
141        }
142
143        if let Some(end_pos) = self.find_matching_brace() {
144            let range_str: String = self.chars[self.pos + 1..end_pos].iter().collect();
145            validate_repetition_range(&range_str, self.pattern)?;
146            self.pos = end_pos + 1;
147            Ok(true)
148        } else {
149            Ok(false)
150        }
151    }
152
153    fn is_nested_repetition(&self) -> bool {
154        if self.pos + 1 >= self.chars.len() {
155            return false;
156        }
157        let ch = self.chars[self.pos];
158        let next_ch = self.chars[self.pos + 1];
159        matches!(ch, '*' | '+' | '?') && matches!(next_ch, '*' | '+' | '?')
160    }
161
162    fn has_alternation_in_group(&self) -> bool {
163        self.group_contains(false, |ch| ch == '|')
164    }
165
166    fn has_quantifier_in_group(&self) -> bool {
167        self.group_contains(true, |ch| matches!(ch, '+' | '*' | '?'))
168    }
169
170    fn group_contains<F>(&self, escape_sensitive: bool, predicate: F) -> bool
171    where
172        F: Fn(char) -> bool,
173    {
174        let Some(start) = self.find_group_start(escape_sensitive) else {
175            return false;
176        };
177
178        let scan_start = start + 1;
179        if scan_start >= self.pos {
180            return false;
181        }
182
183        self.chars[scan_start..self.pos]
184            .iter()
185            .copied()
186            .any(predicate)
187    }
188
189    fn find_group_start(&self, escape_sensitive: bool) -> Option<usize> {
190        let mut depth = 0;
191        let mut i = self.pos;
192        let mut escape_next = false;
193
194        while i > 0 {
195            i -= 1;
196            let ch = self.chars[i];
197
198            if escape_sensitive {
199                if escape_next {
200                    escape_next = false;
201                    continue;
202                }
203
204                if ch == '\\' {
205                    escape_next = true;
206                    continue;
207                }
208            }
209
210            match ch {
211                ')' => depth += 1,
212                '(' => {
213                    if depth == 0 {
214                        return Some(i);
215                    }
216                    depth -= 1;
217                }
218                _ => {}
219            }
220        }
221
222        None
223    }
224
225    fn find_matching_brace(&self) -> Option<usize> {
226        for i in self.pos + 1..self.chars.len() {
227            if self.chars[i] == '}' {
228                return Some(i);
229            }
230            // Only digits and comma allowed in range
231            if !self.chars[i].is_numeric() && self.chars[i] != ',' {
232                return None;
233            }
234        }
235        None
236    }
237}
238
239#[derive(Default)]
240struct RegexScanState {
241    escape_next: bool,
242    in_brackets: bool,
243}
244
245impl RegexScanState {
246    fn consume_escape(&mut self, ch: char) -> bool {
247        if self.escape_next {
248            self.escape_next = false;
249            return true;
250        }
251        if ch == '\\' {
252            self.escape_next = true;
253            return true;
254        }
255        false
256    }
257
258    fn update_bracket_state(&mut self, ch: char) -> bool {
259        match ch {
260            '[' => {
261                self.in_brackets = true;
262                true
263            }
264            ']' => {
265                self.in_brackets = false;
266                true
267            }
268            _ => false,
269        }
270    }
271}
272
273fn invalid_regex(pattern: &str, message: &str) -> AstQueryError {
274    AstQueryError::InvalidRegex {
275        pattern: pattern.to_string(),
276        source: regex::Error::Syntax(message.to_string()),
277    }
278}
279
280/// Validate repetition range values
281fn validate_repetition_range(range_str: &str, pattern: &str) -> Result<()> {
282    let parts: Vec<&str> = range_str.split(',').collect();
283
284    for part in &parts {
285        if part.is_empty() {
286            continue;
287        }
288
289        match part.trim().parse::<usize>() {
290            Ok(count) => {
291                if count > MAX_REPETITION_COUNT {
292                    return Err(AstQueryError::InvalidRegex {
293                        pattern: pattern.to_string(),
294                        source: regex::Error::Syntax(format!(
295                            "Repetition count {count} exceeds maximum of {MAX_REPETITION_COUNT}"
296                        )),
297                    });
298                }
299            }
300            Err(_) => {
301                return Err(AstQueryError::InvalidRegex {
302                    pattern: pattern.to_string(),
303                    source: regex::Error::Syntax(format!("Invalid repetition range: {range_str}")),
304                });
305            }
306        }
307    }
308
309    Ok(())
310}
311
312/// Depth comparison operator
313#[derive(Debug, Clone, PartialEq, Eq)]
314pub enum DepthOp {
315    /// Equal (depth:3)
316    Eq(usize),
317    /// Greater than (depth:>3)
318    Gt(usize),
319    /// Less than (depth:<3)
320    Lt(usize),
321    /// Greater than or equal (depth:>=3)
322    Gte(usize),
323    /// Less than or equal (depth:<=3)
324    Lte(usize),
325}
326
327impl DepthOp {
328    /// Check if a depth value matches this operator
329    #[must_use]
330    pub fn matches(&self, depth: usize) -> bool {
331        match self {
332            DepthOp::Eq(n) => depth == *n,
333            DepthOp::Gt(n) => depth > *n,
334            DepthOp::Lt(n) => depth < *n,
335            DepthOp::Gte(n) => depth >= *n,
336            DepthOp::Lte(n) => depth <= *n,
337        }
338    }
339}
340
341/// AST query predicate
342#[derive(Debug, Clone)]
343pub enum AstPredicate {
344    /// Match symbol kind (kind:function)
345    Kind(ContextKind),
346    /// Match name with regex (name~=pattern)
347    NameRegex(Regex),
348    /// Match parent kind (parent:class)
349    Parent(ContextKind),
350    /// Match if inside ancestor (in:MyClass)
351    In(String),
352    /// Match nesting depth (depth:3 or depth:>3)
353    Depth(DepthOp),
354    /// Match symbol path (`path:module::function`)
355    Path(String),
356    /// Match language (lang:rust)
357    Lang(String),
358}
359
360impl AstPredicate {
361    /// Check if a contextual match satisfies this predicate
362    #[must_use]
363    pub fn matches(&self, ctx_match: &ContextualMatch) -> bool {
364        match self {
365            AstPredicate::Kind(kind) => ctx_match.context.kind == *kind,
366            AstPredicate::NameRegex(regex) => regex.is_match(&ctx_match.name),
367            AstPredicate::Parent(kind) => {
368                if let Some(ref parent) = ctx_match.context.parent {
369                    parent.kind == *kind
370                } else {
371                    false
372                }
373            }
374            AstPredicate::In(name) => {
375                // Check if any ancestor has this name
376                ctx_match.context.ancestors.iter().any(|a| a.name == *name)
377                    || ctx_match
378                        .context
379                        .parent
380                        .as_ref()
381                        .is_some_and(|p| p.name == *name)
382            }
383            AstPredicate::Depth(op) => op.matches(ctx_match.context.depth()),
384            AstPredicate::Path(path) => ctx_match.context.path().contains(path),
385            AstPredicate::Lang(lang) => ctx_match.language == *lang,
386        }
387    }
388}
389
390/// AST query expression (Boolean combination of predicates)
391#[derive(Debug, Clone)]
392pub enum AstExpr {
393    /// Single predicate
394    Predicate(AstPredicate),
395    /// Logical AND
396    And(Box<AstExpr>, Box<AstExpr>),
397    /// Logical OR
398    Or(Box<AstExpr>, Box<AstExpr>),
399    /// Logical NOT
400    Not(Box<AstExpr>),
401}
402
403impl AstExpr {
404    /// Check if a contextual match satisfies this expression
405    #[must_use]
406    pub fn matches(&self, ctx_match: &ContextualMatch) -> bool {
407        match self {
408            AstExpr::Predicate(pred) => pred.matches(ctx_match),
409            AstExpr::And(left, right) => left.matches(ctx_match) && right.matches(ctx_match),
410            AstExpr::Or(left, right) => left.matches(ctx_match) || right.matches(ctx_match),
411            AstExpr::Not(expr) => !expr.matches(ctx_match),
412        }
413    }
414}
415
416/// Maximum number of cached queries (LRU eviction beyond this)
417const QUERY_CACHE_SIZE: usize = 100;
418
419/// Global LRU cache for parsed queries
420static QUERY_CACHE: std::sync::LazyLock<Mutex<LruCache<String, AstExpr>>> =
421    std::sync::LazyLock::new(|| {
422        Mutex::new(LruCache::new(NonZeroUsize::new(QUERY_CACHE_SIZE).unwrap()))
423    });
424
425/// Parse a query string into an AST expression
426///
427/// # Errors
428///
429/// Returns [`AstQueryError`] when the input contains unknown predicates,
430/// invalid operators, malformed regex patterns, or invalid depth expressions.
431pub fn parse_query(input: &str) -> Result<AstExpr> {
432    // Check cache first (handle poisoned lock gracefully)
433    {
434        match QUERY_CACHE.lock() {
435            Ok(mut cache) => {
436                if let Some(cached) = cache.get(input) {
437                    return Ok(cached.clone());
438                }
439            }
440            Err(poisoned) => {
441                let mut cache = poisoned.into_inner();
442                if let Some(cached) = cache.get(input) {
443                    return Ok(cached.clone());
444                }
445            }
446        }
447    }
448
449    // Parse if not cached
450    let mut parser = QueryParser::new(input);
451    let expr = parser.parse_expr()?;
452
453    // Cache the result (handle poisoned lock gracefully)
454    {
455        match QUERY_CACHE.lock() {
456            Ok(mut cache) => {
457                cache.put(input.to_string(), expr.clone());
458            }
459            Err(poisoned) => {
460                let mut cache = poisoned.into_inner();
461                cache.put(input.to_string(), expr.clone());
462            }
463        }
464    }
465
466    Ok(expr)
467}
468
469/// Clear the query cache (useful for testing)
470#[cfg(test)]
471#[allow(deprecated)]
472pub fn clear_query_cache() {
473    match QUERY_CACHE.lock() {
474        Ok(mut cache) => cache.clear(),
475        Err(poisoned) => poisoned.into_inner().clear(),
476    }
477}
478
479/// Tokenizer for query strings
480#[derive(Debug, Clone, PartialEq, Eq)]
481enum Token {
482    Identifier(String),
483    Colon,
484    TildeEquals, // ~=
485    LParen,
486    RParen,
487    And,
488    Or,
489    Not,
490    Eof,
491}
492
493/// Query parser
494struct QueryParser {
495    chars: Vec<char>,
496    pos: usize,
497    current_token: Token,
498}
499
500impl QueryParser {
501    fn new(input: &str) -> Self {
502        let mut parser = Self {
503            chars: input.chars().collect(),
504            pos: 0,
505            current_token: Token::Eof,
506        };
507        parser.current_token = parser.next_token();
508        parser
509    }
510
511    fn parse_expr(&mut self) -> Result<AstExpr> {
512        self.parse_or()
513    }
514
515    fn parse_or(&mut self) -> Result<AstExpr> {
516        let mut left = self.parse_and()?;
517
518        while matches!(self.current_token, Token::Or) {
519            self.advance();
520            let right = self.parse_and()?;
521            left = AstExpr::Or(Box::new(left), Box::new(right));
522        }
523
524        Ok(left)
525    }
526
527    fn parse_and(&mut self) -> Result<AstExpr> {
528        let mut left = self.parse_not()?;
529
530        while matches!(self.current_token, Token::And) {
531            self.advance();
532            let right = self.parse_not()?;
533            left = AstExpr::And(Box::new(left), Box::new(right));
534        }
535
536        Ok(left)
537    }
538
539    fn parse_not(&mut self) -> Result<AstExpr> {
540        if matches!(self.current_token, Token::Not) {
541            self.advance();
542            let expr = self.parse_not()?;
543            Ok(AstExpr::Not(Box::new(expr)))
544        } else {
545            self.parse_primary()
546        }
547    }
548
549    fn parse_primary(&mut self) -> Result<AstExpr> {
550        if matches!(self.current_token, Token::LParen) {
551            self.advance();
552            let expr = self.parse_expr()?;
553            if !matches!(self.current_token, Token::RParen) {
554                return Err(AstQueryError::UnexpectedToken {
555                    expected: ")".to_string(),
556                    actual: format!("{:?}", self.current_token),
557                });
558            }
559            self.advance();
560            return Ok(expr);
561        }
562
563        self.parse_predicate()
564    }
565
566    fn parse_predicate(&mut self) -> Result<AstExpr> {
567        let name = self.expect_identifier("predicate name")?;
568        self.advance();
569
570        let op = self.expect_operator()?;
571        self.advance();
572
573        let value = self.parse_predicate_value(&name, &op)?;
574
575        Self::create_predicate(&name, op, value)
576    }
577
578    fn expect_identifier(&self, expected: &str) -> Result<String> {
579        if let Token::Identifier(name) = &self.current_token {
580            Ok(name.clone())
581        } else {
582            Err(AstQueryError::UnexpectedToken {
583                expected: expected.to_string(),
584                actual: format!("{:?}", self.current_token),
585            })
586        }
587    }
588
589    fn expect_operator(&self) -> Result<Token> {
590        let op = self.current_token.clone();
591        if !matches!(op, Token::Colon | Token::TildeEquals) {
592            return Err(AstQueryError::UnexpectedToken {
593                expected: ":' or '~='".to_string(),
594                actual: format!("{op:?}"),
595            });
596        }
597        Ok(op)
598    }
599
600    fn parse_predicate_value(&mut self, name: &str, op: &Token) -> Result<String> {
601        if name == "name" && matches!(op, Token::TildeEquals) {
602            self.read_regex_value()
603        } else if let Token::Identifier(val) = &self.current_token {
604            let v = val.clone();
605            self.advance();
606            Ok(v)
607        } else {
608            Err(AstQueryError::UnexpectedToken {
609                expected: "value".to_string(),
610                actual: format!("{:?}", self.current_token),
611            })
612        }
613    }
614
615    fn create_predicate(name: &str, op: Token, value: String) -> Result<AstExpr> {
616        let predicate = match (name, op) {
617            ("kind", Token::Colon) => AstPredicate::Kind(Self::parse_context_kind(&value)?),
618            ("name", Token::TildeEquals) => {
619                validate_regex_pattern(&value)?;
620                let regex = Regex::new(&value).map_err(|e| AstQueryError::InvalidRegex {
621                    pattern: value.clone(),
622                    source: e,
623                })?;
624                AstPredicate::NameRegex(regex)
625            }
626            ("parent", Token::Colon) => AstPredicate::Parent(Self::parse_context_kind(&value)?),
627            ("in", Token::Colon) => AstPredicate::In(value),
628            ("depth", Token::Colon) => AstPredicate::Depth(Self::parse_depth_op(&value)?),
629            ("path", Token::Colon) => AstPredicate::Path(value),
630            ("lang", Token::Colon) => AstPredicate::Lang(value),
631            (pred, _) => {
632                return Err(AstQueryError::UnknownPredicate {
633                    predicate: pred.to_string(),
634                });
635            }
636        };
637
638        Ok(AstExpr::Predicate(predicate))
639    }
640
641    fn parse_context_kind(s: &str) -> Result<ContextKind> {
642        match s {
643            "function" => Ok(ContextKind::Function),
644            "method" => Ok(ContextKind::Method),
645            "class" => Ok(ContextKind::Class),
646            "struct" => Ok(ContextKind::Struct),
647            "interface" => Ok(ContextKind::Interface),
648            "enum" => Ok(ContextKind::Enum),
649            "trait" => Ok(ContextKind::Trait),
650            "module" => Ok(ContextKind::Module),
651            "constant" => Ok(ContextKind::Constant),
652            "variable" => Ok(ContextKind::Variable),
653            "type" | "typealias" => Ok(ContextKind::TypeAlias),
654            "impl" => Ok(ContextKind::Impl),
655            _ => Err(AstQueryError::ParseError(format!(
656                "Unknown context kind: {s}"
657            ))),
658        }
659    }
660
661    fn parse_depth_op(s: &str) -> Result<DepthOp> {
662        if let Some(rest) = s.strip_prefix(">=") {
663            Self::parse_depth_num(rest, DepthOp::Gte, s)
664        } else if let Some(rest) = s.strip_prefix("<=") {
665            Self::parse_depth_num(rest, DepthOp::Lte, s)
666        } else if let Some(rest) = s.strip_prefix('>') {
667            Self::parse_depth_num(rest, DepthOp::Gt, s)
668        } else if let Some(rest) = s.strip_prefix('<') {
669            Self::parse_depth_num(rest, DepthOp::Lt, s)
670        } else {
671            Self::parse_depth_num(s, DepthOp::Eq, s)
672        }
673    }
674
675    fn parse_depth_num<F>(num_str: &str, ctor: F, original: &str) -> Result<DepthOp>
676    where
677        F: FnOnce(usize) -> DepthOp,
678    {
679        num_str
680            .parse()
681            .map(ctor)
682            .map_err(|_| AstQueryError::InvalidDepth {
683                value: original.to_string(),
684            })
685    }
686
687    fn read_regex_value(&mut self) -> Result<String> {
688        let mut value = String::new();
689        let mut paren_depth = 0;
690
691        loop {
692            match &self.current_token {
693                Token::Identifier(s) => {
694                    value.push_str(s);
695                    self.advance();
696                }
697                Token::LParen => {
698                    value.push('(');
699                    paren_depth += 1;
700                    self.advance();
701                }
702                Token::RParen => {
703                    if paren_depth == 0 {
704                        break;
705                    }
706                    value.push(')');
707                    paren_depth -= 1;
708                    self.advance();
709                }
710                Token::And | Token::Or | Token::Eof => {
711                    break;
712                }
713                _ => {
714                    return Err(AstQueryError::UnexpectedToken {
715                        expected: "regex value".to_string(),
716                        actual: format!("{:?}", self.current_token),
717                    });
718                }
719            }
720        }
721
722        if value.is_empty() {
723            return Err(AstQueryError::UnexpectedToken {
724                expected: "regex value".to_string(),
725                actual: "empty".to_string(),
726            });
727        }
728
729        Ok(value)
730    }
731
732    fn advance(&mut self) {
733        self.current_token = self.next_token();
734    }
735
736    fn next_token(&mut self) -> Token {
737        self.skip_whitespace();
738
739        if self.pos >= self.chars.len() {
740            return Token::Eof;
741        }
742
743        let ch = self.chars[self.pos];
744
745        match ch {
746            '(' => self.consume_char(Token::LParen),
747            ')' => self.consume_char(Token::RParen),
748            ':' => self.consume_char(Token::Colon),
749            '~' => self.scan_tilde(),
750            _ if Self::is_identifier_char(ch) => self.read_identifier(),
751            _ => self.consume_char(Token::Identifier(ch.to_string())),
752        }
753    }
754
755    fn consume_char(&mut self, token: Token) -> Token {
756        self.pos += 1;
757        token
758    }
759
760    fn scan_tilde(&mut self) -> Token {
761        if self.peek_char(1) == Some('=') {
762            self.pos += 2;
763            Token::TildeEquals
764        } else {
765            self.pos += 1;
766            Token::Identifier("~".to_string())
767        }
768    }
769
770    fn is_identifier_char(ch: char) -> bool {
771        ch.is_alphanumeric()
772            || matches!(
773                ch,
774                '_' | '>'
775                    | '<'
776                    | '='
777                    | '^'
778                    | '$'
779                    | '.'
780                    | '/'
781                    | '-'
782                    | '*'
783                    | '+'
784                    | '?'
785                    | '['
786                    | ']'
787                    | '{'
788                    | '}'
789                    | '|'
790                    | '\\'
791                    | ','
792            )
793    }
794
795    fn read_identifier(&mut self) -> Token {
796        let start = self.pos;
797        while self.pos < self.chars.len() {
798            let ch = self.chars[self.pos];
799            if Self::is_identifier_char(ch) {
800                self.pos += 1;
801            } else {
802                break;
803            }
804        }
805
806        let text: String = self.chars[start..self.pos].iter().collect();
807        match text.to_uppercase().as_str() {
808            "AND" => Token::And,
809            "OR" => Token::Or,
810            "NOT" => Token::Not,
811            _ => Token::Identifier(text),
812        }
813    }
814
815    fn skip_whitespace(&mut self) {
816        while let Some(ch) = self.current_char() {
817            if ch.is_whitespace() {
818                self.pos += 1;
819            } else {
820                break;
821            }
822        }
823    }
824
825    fn current_char(&self) -> Option<char> {
826        self.chars.get(self.pos).copied()
827    }
828
829    fn peek_char(&self, offset: usize) -> Option<char> {
830        self.chars.get(self.pos + offset).copied()
831    }
832}
833
834#[cfg(test)]
835mod tests {
836    use super::*;
837    use std::path::PathBuf;
838
839    fn make_test_match(
840        name: &str,
841        kind: ContextKind,
842        parent: Option<(&str, ContextKind)>,
843        ancestors: Vec<(&str, ContextKind)>,
844        lang: &str,
845    ) -> ContextualMatch {
846        let immediate =
847            super::super::types::ContextItem::new(name.to_string(), kind, 1, 10, 0, 100);
848
849        let parent_item = parent.map(|(pname, pkind)| {
850            super::super::types::ContextItem::new(pname.to_string(), pkind, 1, 20, 0, 200)
851        });
852
853        let ancestor_items: Vec<_> = ancestors
854            .into_iter()
855            .map(|(aname, akind)| {
856                super::super::types::ContextItem::new(aname.to_string(), akind, 1, 30, 0, 300)
857            })
858            .collect();
859
860        let context = super::super::types::Context::new(immediate, parent_item, ancestor_items);
861
862        let location = super::super::types::ContextualMatchLocation::new(
863            PathBuf::from("test.rs"),
864            1,
865            0,
866            10,
867            1,
868        );
869        ContextualMatch::new(name.to_string(), location, context, lang.to_string())
870    }
871
872    #[test]
873    fn test_parse_kind_predicate() {
874        let query = parse_query("kind:function").unwrap();
875        let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
876        assert!(query.matches(&test_match));
877    }
878
879    #[test]
880    fn test_parse_name_regex() {
881        let query = parse_query("name~=test").unwrap();
882        let test_match = make_test_match("test_func", ContextKind::Function, None, vec![], "rust");
883        let other_match =
884            make_test_match("other_func", ContextKind::Function, None, vec![], "rust");
885        assert!(query.matches(&test_match), "Should match test_func");
886        assert!(!query.matches(&other_match), "Should not match other_func");
887    }
888
889    #[test]
890    fn test_parse_parent_predicate() {
891        let query = parse_query("parent:class").unwrap();
892        let test_match = make_test_match(
893            "method",
894            ContextKind::Method,
895            Some(("MyClass", ContextKind::Class)),
896            vec![],
897            "rust",
898        );
899        let no_parent = make_test_match("func", ContextKind::Function, None, vec![], "rust");
900        assert!(query.matches(&test_match));
901        assert!(!query.matches(&no_parent));
902    }
903
904    #[test]
905    fn test_parse_in_predicate() {
906        let query = parse_query("in:MyClass").unwrap();
907        let test_match = make_test_match(
908            "method",
909            ContextKind::Method,
910            Some(("InnerClass", ContextKind::Class)),
911            vec![("MyClass", ContextKind::Class)],
912            "rust",
913        );
914        let not_in = make_test_match("func", ContextKind::Function, None, vec![], "rust");
915        assert!(query.matches(&test_match));
916        assert!(!query.matches(&not_in));
917    }
918
919    #[test]
920    fn test_parse_depth_operators() {
921        let eq_query = parse_query("depth:2").unwrap();
922        let greater_than_query = parse_query("depth:>1").unwrap();
923        let less_than_query = parse_query("depth:<3").unwrap();
924        let greater_equal_query = parse_query("depth:>=2").unwrap();
925        let less_equal_query = parse_query("depth:<=2").unwrap();
926
927        let depth_2 = make_test_match(
928            "method",
929            ContextKind::Method,
930            Some(("MyClass", ContextKind::Class)),
931            vec![],
932            "rust",
933        );
934
935        assert!(eq_query.matches(&depth_2));
936        assert!(greater_than_query.matches(&depth_2));
937        assert!(less_than_query.matches(&depth_2));
938        assert!(greater_equal_query.matches(&depth_2));
939        assert!(less_equal_query.matches(&depth_2));
940    }
941
942    #[test]
943    fn test_parse_path_predicate() {
944        let query = parse_query("path:MyClass").unwrap();
945        let test_match = make_test_match(
946            "method",
947            ContextKind::Method,
948            Some(("MyClass", ContextKind::Class)),
949            vec![],
950            "rust",
951        );
952        let no_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
953        assert!(query.matches(&test_match));
954        assert!(!query.matches(&no_match));
955    }
956
957    #[test]
958    fn test_parse_lang_predicate() {
959        let query = parse_query("lang:rust").unwrap();
960        let rust_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
961        let js_match = make_test_match("func", ContextKind::Function, None, vec![], "javascript");
962        assert!(query.matches(&rust_match));
963        assert!(!query.matches(&js_match));
964    }
965
966    #[test]
967    fn test_parse_and_expression() {
968        let query = parse_query("kind:method AND parent:class").unwrap();
969        let match_both = make_test_match(
970            "method",
971            ContextKind::Method,
972            Some(("MyClass", ContextKind::Class)),
973            vec![],
974            "rust",
975        );
976        let match_kind_only = make_test_match("method", ContextKind::Method, None, vec![], "rust");
977        assert!(query.matches(&match_both));
978        assert!(!query.matches(&match_kind_only));
979    }
980
981    #[test]
982    fn test_parse_or_expression() {
983        let query = parse_query("kind:function OR kind:method").unwrap();
984        let func_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
985        let method_match = make_test_match("method", ContextKind::Method, None, vec![], "rust");
986        let class_match = make_test_match("MyClass", ContextKind::Class, None, vec![], "rust");
987        assert!(query.matches(&func_match));
988        assert!(query.matches(&method_match));
989        assert!(!query.matches(&class_match));
990    }
991
992    #[test]
993    fn test_parse_not_expression() {
994        let query = parse_query("NOT kind:class").unwrap();
995        let func_match = make_test_match("func", ContextKind::Function, None, vec![], "rust");
996        let class_match = make_test_match("MyClass", ContextKind::Class, None, vec![], "rust");
997        assert!(query.matches(&func_match));
998        assert!(!query.matches(&class_match));
999    }
1000
1001    #[test]
1002    fn test_parse_parentheses() {
1003        let query = parse_query("(kind:method AND parent:class) OR kind:function").unwrap();
1004        let method_in_class = make_test_match(
1005            "method",
1006            ContextKind::Method,
1007            Some(("MyClass", ContextKind::Class)),
1008            vec![],
1009            "rust",
1010        );
1011        let func = make_test_match("func", ContextKind::Function, None, vec![], "rust");
1012        let method_no_parent = make_test_match("method", ContextKind::Method, None, vec![], "rust");
1013        assert!(query.matches(&method_in_class));
1014        assert!(query.matches(&func));
1015        assert!(!query.matches(&method_no_parent));
1016    }
1017
1018    #[test]
1019    fn test_parse_complex_query() {
1020        let query = parse_query("kind:method AND depth:>0 AND NOT in:TestClass").unwrap();
1021        let matching = make_test_match(
1022            "method",
1023            ContextKind::Method,
1024            Some(("MyClass", ContextKind::Class)),
1025            vec![],
1026            "rust",
1027        );
1028        let in_test_class = make_test_match(
1029            "method",
1030            ContextKind::Method,
1031            Some(("TestClass", ContextKind::Class)),
1032            vec![],
1033            "rust",
1034        );
1035        assert!(query.matches(&matching));
1036        assert!(!query.matches(&in_test_class));
1037    }
1038
1039    #[test]
1040    fn test_parse_error_unknown_predicate() {
1041        let result = parse_query("unknown:value");
1042        assert!(matches!(
1043            result,
1044            Err(AstQueryError::UnknownPredicate { .. })
1045        ));
1046    }
1047
1048    #[test]
1049    fn test_parse_error_invalid_regex() {
1050        let result = parse_query("name~=[invalid");
1051        assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1052    }
1053
1054    #[test]
1055    fn test_parse_error_invalid_depth() {
1056        let result = parse_query("depth:abc");
1057        assert!(matches!(result, Err(AstQueryError::InvalidDepth { .. })));
1058    }
1059
1060    #[test]
1061    fn test_parse_error_missing_rparen() {
1062        let result = parse_query("(kind:function");
1063        assert!(matches!(result, Err(AstQueryError::UnexpectedToken { .. })));
1064    }
1065
1066    #[test]
1067    fn test_depth_op_matches() {
1068        assert!(DepthOp::Eq(3).matches(3));
1069        assert!(!DepthOp::Eq(3).matches(2));
1070
1071        assert!(DepthOp::Gt(2).matches(3));
1072        assert!(!DepthOp::Gt(2).matches(2));
1073
1074        assert!(DepthOp::Lt(3).matches(2));
1075        assert!(!DepthOp::Lt(3).matches(3));
1076
1077        assert!(DepthOp::Gte(2).matches(2));
1078        assert!(DepthOp::Gte(2).matches(3));
1079        assert!(!DepthOp::Gte(2).matches(1));
1080
1081        assert!(DepthOp::Lte(3).matches(3));
1082        assert!(DepthOp::Lte(3).matches(2));
1083        assert!(!DepthOp::Lte(3).matches(4));
1084    }
1085
1086    #[test]
1087    fn test_case_insensitive_keywords() {
1088        let query1 = parse_query("kind:function and name~=test").unwrap();
1089        let query2 = parse_query("kind:function AND name~=test").unwrap();
1090        let query3 = parse_query("kind:function AnD name~=test").unwrap();
1091        let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1092        assert!(query1.matches(&test_match));
1093        assert!(query2.matches(&test_match));
1094        assert!(query3.matches(&test_match));
1095    }
1096
1097    #[test]
1098    fn test_parse_unicode_identifier() {
1099        let unicode_identifier = "\u{6a21}\u{5757}\u{540d}";
1100        let query = parse_query(&format!("in:{unicode_identifier}")).unwrap();
1101        let test_match = make_test_match(
1102            "method",
1103            ContextKind::Method,
1104            Some((unicode_identifier, ContextKind::Class)),
1105            vec![],
1106            "rust",
1107        );
1108        assert!(query.matches(&test_match));
1109    }
1110
1111    #[test]
1112    fn test_regex_validation_too_long() {
1113        let long_pattern = "a".repeat(1001);
1114        let query_str = format!("name~={long_pattern}");
1115        let result = parse_query(&query_str);
1116        assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1117        if let Err(AstQueryError::InvalidRegex { pattern, .. }) = result {
1118            assert_eq!(pattern.len(), 1001);
1119        }
1120    }
1121
1122    #[test]
1123    fn test_regex_validation_catastrophic_backtracking_nested_plus() {
1124        let result = validate_regex_pattern("(a+)+");
1125        assert!(result.is_err(), "Nested quantifiers should be rejected");
1126    }
1127
1128    #[test]
1129    fn test_regex_validation_catastrophic_backtracking_nested_star() {
1130        let result = validate_regex_pattern("(a*)*");
1131        assert!(result.is_err(), "Nested quantifiers should be rejected");
1132    }
1133
1134    #[test]
1135    fn test_regex_validation_nested_repetition_star_plus() {
1136        let result = parse_query("name~=a*+");
1137        assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1138    }
1139
1140    #[test]
1141    fn test_regex_validation_nested_repetition_plus_star() {
1142        let result = parse_query("name~=a+*");
1143        assert!(matches!(result, Err(AstQueryError::InvalidRegex { .. })));
1144    }
1145
1146    #[test]
1147    fn test_regex_validation_safe_patterns_allowed() {
1148        let safe_patterns = vec![
1149            "name~=test",
1150            "name~=^test_.*",
1151            "name~=foo|bar",
1152            "name~=[a-z]+",
1153            "name~=\\w+",
1154            "name~=test{1,5}",
1155        ];
1156        for pattern in safe_patterns {
1157            let result = parse_query(pattern);
1158            assert!(result.is_ok(), "Safe pattern should be allowed: {pattern}");
1159        }
1160    }
1161
1162    #[test]
1163    fn test_regex_validation_reasonable_length_allowed() {
1164        let pattern = "a".repeat(999);
1165        let query_str = format!("name~={pattern}");
1166        let result = parse_query(&query_str);
1167        assert!(result.is_ok(), "999-char pattern should be allowed");
1168    }
1169
1170    #[test]
1171    fn test_validate_regex_pattern_directly() {
1172        assert!(validate_regex_pattern("test").is_ok());
1173        assert!(
1174            validate_regex_pattern("[a-z]+").is_ok(),
1175            "Character class with quantifier should be allowed"
1176        );
1177        assert!(
1178            validate_regex_pattern("(a+)+").is_err(),
1179            "Nested quantifiers should be rejected"
1180        );
1181        assert!(
1182            validate_regex_pattern("(a*)*").is_err(),
1183            "Nested quantifiers should be rejected"
1184        );
1185        assert!(
1186            validate_regex_pattern("a*+").is_err(),
1187            "Direct nested quantifiers should be rejected"
1188        );
1189    }
1190
1191    #[test]
1192    fn test_regex_validation_alternation_explosion_simple() {
1193        let result = validate_regex_pattern("(a|ab)*");
1194        assert!(
1195            result.is_err(),
1196            "Alternation in quantified group should be rejected"
1197        );
1198    }
1199
1200    #[test]
1201    fn test_regex_validation_alternation_explosion_plus() {
1202        let result = validate_regex_pattern("(foo|bar)+");
1203        assert!(result.is_err(), "Alternation with + should be rejected");
1204    }
1205
1206    #[test]
1207    fn test_regex_validation_alternation_safe() {
1208        let result = validate_regex_pattern("(foo|bar)");
1209        assert!(
1210            result.is_ok(),
1211            "Alternation without quantifier should be allowed"
1212        );
1213    }
1214
1215    #[test]
1216    fn test_regex_validation_alternation_in_character_class() {
1217        let result = validate_regex_pattern("[a|b]+");
1218        assert!(
1219            result.is_ok(),
1220            "Pipe inside character class should be allowed"
1221        );
1222    }
1223
1224    #[test]
1225    fn test_regex_validation_large_repetition_range() {
1226        let result = validate_regex_pattern("a{1,999999}");
1227        assert!(result.is_err(), "Large repetition range should be rejected");
1228    }
1229
1230    #[test]
1231    fn test_regex_validation_safe_repetition_range() {
1232        let result = validate_regex_pattern("a{1,5}");
1233        assert!(result.is_ok(), "Small repetition range should be allowed");
1234    }
1235
1236    #[test]
1237    fn test_regex_validation_exact_max_repetition() {
1238        let result = validate_regex_pattern("a{1000}");
1239        assert!(result.is_ok(), "Repetition at max limit should be allowed");
1240    }
1241
1242    #[test]
1243    fn test_regex_validation_just_over_max_repetition() {
1244        let result = validate_regex_pattern("a{1001}");
1245        assert!(
1246            result.is_err(),
1247            "Repetition over max limit should be rejected"
1248        );
1249    }
1250
1251    #[test]
1252    fn test_regex_validation_open_ended_range() {
1253        let result = validate_regex_pattern("a{100,}");
1254        assert!(result.is_ok(), "Open-ended range should be allowed");
1255    }
1256
1257    #[test]
1258    fn test_regex_validation_combined_attacks() {
1259        assert!(
1260            validate_regex_pattern("(a|ab){1,999999}").is_err(),
1261            "Combined alternation + large range should be rejected"
1262        );
1263        assert!(
1264            validate_regex_pattern("(a+|b+)*").is_err(),
1265            "Alternation with quantified branches in quantified group should be rejected"
1266        );
1267    }
1268
1269    #[test]
1270    fn test_query_cache_basic() {
1271        clear_query_cache();
1272        let query1 = parse_query("kind:function").unwrap();
1273        let query2 = parse_query("kind:function").unwrap();
1274        let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1275        assert!(query1.matches(&test_match));
1276        assert!(query2.matches(&test_match));
1277    }
1278
1279    #[test]
1280    fn test_query_cache_different_queries() {
1281        clear_query_cache();
1282        let query1 = parse_query("kind:function").unwrap();
1283        let query2 = parse_query("kind:method").unwrap();
1284        let func_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1285        let method_match = make_test_match("test", ContextKind::Method, None, vec![], "rust");
1286        assert!(query1.matches(&func_match));
1287        assert!(!query1.matches(&method_match));
1288        assert!(!query2.matches(&func_match));
1289        assert!(query2.matches(&method_match));
1290    }
1291
1292    #[test]
1293    fn test_query_cache_eviction() {
1294        clear_query_cache();
1295        for i in 0..150 {
1296            let query_str = format!("kind:function AND name~=test{i}");
1297            let _ = parse_query(&query_str).unwrap();
1298        }
1299        for i in 0..150 {
1300            let query_str = format!("kind:function AND name~=test{i}");
1301            let result = parse_query(&query_str);
1302            assert!(result.is_ok(), "Query {i} should parse successfully");
1303        }
1304    }
1305
1306    #[test]
1307    fn test_query_cache_with_errors() {
1308        clear_query_cache();
1309        let result1 = parse_query("invalid~syntax");
1310        assert!(result1.is_err());
1311        let result2 = parse_query("invalid~syntax");
1312        assert!(result2.is_err());
1313        let result3 = parse_query("kind:function");
1314        assert!(result3.is_ok());
1315    }
1316
1317    #[test]
1318    fn test_query_cache_clear() {
1319        clear_query_cache();
1320        let _ = parse_query("kind:function").unwrap();
1321        clear_query_cache();
1322        let query = parse_query("kind:function").unwrap();
1323        let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1324        assert!(query.matches(&test_match));
1325    }
1326
1327    #[test]
1328    fn test_query_cache_thread_safety() {
1329        use std::thread;
1330        clear_query_cache();
1331        let handles: Vec<_> = (0..10)
1332            .map(|_| {
1333                thread::spawn(|| {
1334                    let query = parse_query("kind:function").unwrap();
1335                    let test_match =
1336                        make_test_match("test", ContextKind::Function, None, vec![], "rust");
1337                    assert!(query.matches(&test_match));
1338                })
1339            })
1340            .collect();
1341        for handle in handles {
1342            handle.join().unwrap();
1343        }
1344    }
1345
1346    #[test]
1347    fn test_query_cache_complex_queries() {
1348        clear_query_cache();
1349        let complex_queries = vec![
1350            "kind:function AND name~=^test_",
1351            "(kind:method AND parent:class) OR kind:function",
1352            "depth:>3 AND NOT in:TestClass",
1353            "kind:function AND name~=.*helper.* AND depth:<=2",
1354        ];
1355        for query_str in &complex_queries {
1356            let query1 = parse_query(query_str).unwrap();
1357            let query2 = parse_query(query_str).unwrap();
1358            let test_match =
1359                make_test_match("test_func", ContextKind::Function, None, vec![], "rust");
1360            assert_eq!(query1.matches(&test_match), query2.matches(&test_match));
1361        }
1362    }
1363
1364    #[test]
1365    fn test_query_cache_poison_recovery() {
1366        use std::sync::{Arc, Barrier};
1367        use std::thread;
1368        clear_query_cache();
1369        let _ = parse_query("kind:function").unwrap();
1370        let barrier = Arc::new(Barrier::new(5));
1371        let handles: Vec<_> = (0..5)
1372            .map(|_| {
1373                let barrier = Arc::clone(&barrier);
1374                thread::spawn(move || {
1375                    barrier.wait();
1376                    let query = parse_query("kind:function").unwrap();
1377                    let test_match =
1378                        make_test_match("test", ContextKind::Function, None, vec![], "rust");
1379                    assert!(query.matches(&test_match));
1380                })
1381            })
1382            .collect();
1383        for handle in handles {
1384            handle.join().unwrap();
1385        }
1386        let query = parse_query("kind:function").unwrap();
1387        let test_match = make_test_match("test", ContextKind::Function, None, vec![], "rust");
1388        assert!(query.matches(&test_match));
1389    }
1390}