Skip to main content

sochdb_query/
grep_executor.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Grep Lane Executor (Task 5)
19//!
20//! Exact regex / substring search as a first-class retrieval lane, built on the
21//! trigram candidate index. The pipeline is:
22//!
23//! ```text
24//! regex ──► required-literal extraction ──► trigram conjunction
25//!       ──► trigram posting intersection (candidate DocIds)
26//!       ──► ∩ AllowedSet            (filter pushdown BEFORE verification)
27//!       ──► regex verification      (linear-time, finite-automaton engine)
28//!       ──► ranked hits  OR  candidate gate
29//! ```
30//!
31//! ## Correctness over speed
32//!
33//! Trigram pre-filtering is only ever used when the executor can *prove* the
34//! extracted literals are mandatory (present in every possible match). For any
35//! pattern it cannot prove this for — alternation, groups, character classes,
36//! or no literal run of length ≥ 3 — it falls back to an explicit, bounded
37//! full scan rather than risk a false negative. The full-scan path is capped by
38//! `max_scan`; exceeding the cap is reported as
39//! [`GrepError::DegeneratePattern`] instead of silently returning partial
40//! results.
41//!
42//! ## Verification engine
43//!
44//! Verification uses the `regex` crate, a finite-automaton engine with
45//! guaranteed linear-time matching, so adversarial patterns cannot turn the
46//! verify stage into a catastrophic-backtracking DoS.
47//!
48//! ## Two fusion modes
49//!
50//! Grep produces a *set*, but RRF consumes *ranked lists*. Both shapes are
51//! supported:
52//! - [`GrepMode::Rank`] scores each hit by specificity-weighted, TF-saturated,
53//!   length-pivoted relevance (BM25-flavored over the pattern's literal terms)
54//!   so it can plug into RRF as a third ranked lane **without** the
55//!   short-document / common-term bias of raw match density.
56//! - [`GrepMode::Gate`] returns the matching documents as an
57//!   [`AllowedSet`] to intersect into the other lanes (the
58//!   "find the function that contains X" cascade), via [`GrepResults::into_allowed_set`].
59
60use regex::Regex;
61
62use crate::candidate_gate::AllowedSet;
63use crate::trigram_index::{DocId, Trigram, TrigramIndex, trigrams_of};
64
65/// Default cap on documents verified by a degenerate (no-trigram) full scan.
66pub const DEFAULT_MAX_SCAN: usize = 100_000;
67
68/// BM25-style term-frequency saturation constant for grep `Rank` scoring.
69/// Bounds the marginal value of additional matches of the same term.
70const GREP_K1: f32 = 1.2;
71
72/// BM25-style length-normalization (pivot) constant for grep `Rank` scoring.
73/// `0.0` disables length normalization; `1.0` applies it fully.
74const GREP_B: f32 = 0.75;
75
76/// How grep results should be consumed by the fusion layer.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum GrepMode {
79    /// Produce a ranked list (for RRF as a third lane).
80    Rank,
81    /// Produce a candidate gate (intersect into the other lanes).
82    Gate,
83}
84
85/// A single grep match.
86#[derive(Debug, Clone, PartialEq)]
87pub struct GrepHit {
88    /// Matching document id.
89    pub doc_id: DocId,
90    /// Rank score (higher is better): specificity-weighted, TF-saturated,
91    /// length-pivoted relevance over the pattern's literal terms.
92    pub score: f32,
93    /// Number of (non-overlapping) matches in the document.
94    pub match_count: usize,
95}
96
97/// The outcome of a grep search.
98#[derive(Debug, Clone)]
99pub struct GrepResults {
100    /// Ranked hits, best first.
101    pub hits: Vec<GrepHit>,
102    /// Whether the trigram index was used (`true`) or a full scan ran (`false`).
103    pub used_index: bool,
104}
105
106impl GrepResults {
107    /// The matching document ids as an [`AllowedSet`] for gate / cascade fusion.
108    pub fn into_allowed_set(self) -> AllowedSet {
109        AllowedSet::from_iter(self.hits.into_iter().map(|h| h.doc_id))
110    }
111}
112
113/// Errors the grep lane can return.
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum GrepError {
116    /// The pattern is not a valid regular expression.
117    InvalidRegex(String),
118    /// The pattern yields no usable trigram and the corpus exceeds the scan
119    /// budget, so it is rejected rather than scanned partially.
120    DegeneratePattern { corpus: usize, max_scan: usize },
121}
122
123impl std::fmt::Display for GrepError {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        match self {
126            GrepError::InvalidRegex(e) => write!(f, "invalid regex: {e}"),
127            GrepError::DegeneratePattern { corpus, max_scan } => write!(
128                f,
129                "degenerate pattern (no indexable literal) over a corpus of {corpus} documents \
130                 exceeds the scan budget of {max_scan}"
131            ),
132        }
133    }
134}
135
136impl std::error::Error for GrepError {}
137
138/// The grep executor: plans and runs regex search over a [`TrigramIndex`].
139pub struct GrepExecutor<'a> {
140    index: &'a TrigramIndex,
141    max_scan: usize,
142}
143
144impl<'a> GrepExecutor<'a> {
145    /// Create an executor over `index` with the default scan budget.
146    pub fn new(index: &'a TrigramIndex) -> Self {
147        Self {
148            index,
149            max_scan: DEFAULT_MAX_SCAN,
150        }
151    }
152
153    /// Override the full-scan document budget for degenerate patterns.
154    pub fn with_max_scan(mut self, max_scan: usize) -> Self {
155        self.max_scan = max_scan;
156        self
157    }
158
159    /// Run a grep search.
160    ///
161    /// `allowed` is applied as a pushdown filter **before** regex verification,
162    /// preserving the same `result ⊆ allowed` invariant the other lanes honor.
163    /// `limit` caps the number of returned hits (0 = unlimited).
164    pub fn search(
165        &self,
166        pattern: &str,
167        allowed: &AllowedSet,
168        limit: usize,
169        mode: GrepMode,
170    ) -> Result<GrepResults, GrepError> {
171        let re = Regex::new(pattern).map_err(|e| GrepError::InvalidRegex(e.to_string()))?;
172
173        if allowed.is_empty() {
174            return Ok(GrepResults {
175                hits: Vec::new(),
176                used_index: false,
177            });
178        }
179
180        // ---- Plan: candidate document set (safe superset, no false negatives) ----
181        //
182        // While planning we also capture each term's document-frequency estimate
183        // (its trigram-candidate count) so the ranking stage can compute IDF
184        // without a second pass over the postings.
185        //
186        // A leading whole-pattern inline-flag group like `(?i)` is stripped
187        // *for literal extraction only* (the compiled `re` above keeps the
188        // flag), so a case-insensitive alternation still drives the index and
189        // IDF instead of degrading to a full scan.
190        let extract = strip_leading_inline_flags(pattern);
191        let (terms, is_alternation) = literal_terms(extract);
192        let mut term_df: Vec<(String, usize)> = Vec::new();
193        let (candidates, used_index): (Vec<DocId>, bool) = if terms.is_empty() {
194            // No provably-mandatory literal (complex regex): bounded full scan.
195            if self.index.len() > self.max_scan {
196                return Err(GrepError::DegeneratePattern {
197                    corpus: self.index.len(),
198                    max_scan: self.max_scan,
199                });
200            }
201            (self.index.documents().map(|(id, _)| id).collect(), false)
202        } else if is_alternation {
203            // Alternation `a|b|c`: a match contains *some* branch, so the
204            // candidate set is the UNION of each branch's trigram candidates
205            // (Cox AND-of-ORs, union form). Every branch is trigram-indexable
206            // here (guaranteed by `literal_alternation`), so this stays a safe
207            // superset and the previously full-scanned `|` patterns now use the
208            // index. Each branch's candidate count doubles as its df estimate.
209            let mut union: Vec<DocId> = Vec::new();
210            for term in &terms {
211                let branch = self.index.candidates(&trigrams_of(term));
212                term_df.push((term.to_lowercase(), branch.len().max(1)));
213                union.extend(branch);
214            }
215            union.sort_unstable();
216            union.dedup();
217            (union, true)
218        } else {
219            // Conjunction of mandatory literals: AND of all their trigrams.
220            let mut trigrams: Vec<Trigram> = Vec::new();
221            for term in &terms {
222                let df = self.index.candidates(&trigrams_of(term)).len().max(1);
223                term_df.push((term.to_lowercase(), df));
224                trigrams.extend(trigrams_of(term));
225            }
226            trigrams.sort_unstable();
227            trigrams.dedup();
228            (self.index.candidates(&trigrams), true)
229        };
230
231        // ---- Gate mode: membership only, no ranking ----
232        if mode == GrepMode::Gate {
233            let mut hits: Vec<GrepHit> = Vec::new();
234            for doc_id in candidates {
235                if !allowed.contains(doc_id) {
236                    continue;
237                }
238                if let Some(text) = self.index.doc_text(doc_id) {
239                    if re.is_match(text) {
240                        hits.push(GrepHit {
241                            doc_id,
242                            score: 1.0,
243                            match_count: 1,
244                        });
245                    }
246                }
247            }
248            hits.sort_by(|a, b| a.doc_id.cmp(&b.doc_id));
249            if limit > 0 && hits.len() > limit {
250                hits.truncate(limit);
251            }
252            return Ok(GrepResults { hits, used_index });
253        }
254
255        // ---- Rank mode: specificity-weighted, TF-saturated, length-pivoted ----
256        //
257        // The old score was raw match density (`matches / doc_len`), which is
258        // IDF-blind (a hit on a common word counts as much as a rare one),
259        // linear in raw match count (50 hits == 50x one hit), and explodes for
260        // short documents — so it injected noise into RRF. The corrected score
261        // is BM25-flavored over the grep's literal terms:
262        //
263        //   idf(t)   = ln(1 + (N - df + 0.5)/(df + 0.5))         // term rarity
264        //   tf_sat   = c / (c + k1)                              // saturating TF
265        //   raw(d)   = SUM_t idf(t) * tf_sat(count_t(d))
266        //   score(d) = raw(d) / (1 - b + b*len_d/avg_len)        // pivoted length
267        //
268        // `df` is estimated index-locally as the trigram-candidate count of the
269        // term (a tight upper bound on its true document frequency), captured
270        // during planning above, so no extra corpus statistics are needed.
271        // Verification still uses the full regex, so the hit *set* is unchanged
272        // — only the ranking improves.
273        let n = self.index.len().max(1) as f32;
274        let term_idf: Vec<(String, f32)> = term_df
275            .iter()
276            .map(|(t, df)| {
277                let dff = *df as f32;
278                let idf = (1.0 + (n - dff + 0.5) / (dff + 0.5)).ln();
279                (t.clone(), idf.max(0.0))
280            })
281            .collect();
282
283        struct Pending {
284            doc_id: DocId,
285            len: f32,
286            raw: f32,
287            match_count: usize,
288        }
289        let mut pending: Vec<Pending> = Vec::new();
290        let mut total_len = 0.0f32;
291        // Reused per-term match-count buffer (alternation path) to avoid a
292        // per-document allocation.
293        let mut counts: Vec<u32> = vec![0; term_idf.len()];
294        for doc_id in candidates {
295            if !allowed.contains(doc_id) {
296                continue;
297            }
298            let Some(text) = self.index.doc_text(doc_id) else {
299                continue;
300            };
301
302            // Single regex pass over the document. For an alternation each match
303            // is exactly one branch literal, so we attribute it to its term in
304            // the SAME pass — no extra per-term substring scans, no allocation.
305            let mut match_count = 0usize;
306            if is_alternation {
307                for c in counts.iter_mut() {
308                    *c = 0;
309                }
310                for m in re.find_iter(text) {
311                    match_count += 1;
312                    let ms = m.as_str();
313                    for (i, (term_lc, _)) in term_idf.iter().enumerate() {
314                        if eq_ci_ascii(ms, term_lc) {
315                            counts[i] += 1;
316                            break;
317                        }
318                    }
319                }
320            } else {
321                match_count = re.find_iter(text).count();
322            }
323            if match_count == 0 {
324                continue;
325            }
326
327            let len = text.chars().count().max(1) as f32;
328            let raw = if term_idf.is_empty() {
329                // Complex pattern with no literal terms to weight: saturate the
330                // raw regex match count so a flood of matches can't dominate.
331                tf_saturate(match_count as f32)
332            } else if is_alternation {
333                // Per-branch counts already attributed in the single pass above.
334                let mut s = 0.0f32;
335                for (i, (_, idf)) in term_idf.iter().enumerate() {
336                    if counts[i] > 0 {
337                        s += idf * tf_saturate(counts[i] as f32);
338                    }
339                }
340                s
341            } else {
342                // Conjunction / complex literal terms (rare): the whole-pattern
343                // matches can't be attributed per term, so scan each mandatory
344                // term once (allocation-free, ASCII case-insensitive).
345                let mut s = 0.0f32;
346                for (term_lc, idf) in &term_idf {
347                    let c = count_ci_ascii(text, term_lc);
348                    if c > 0 {
349                        s += idf * tf_saturate(c as f32);
350                    }
351                }
352                s
353            };
354            total_len += len;
355            pending.push(Pending {
356                doc_id,
357                len,
358                raw,
359                match_count,
360            });
361        }
362
363        let avg_len = if pending.is_empty() {
364            1.0
365        } else {
366            (total_len / pending.len() as f32).max(1.0)
367        };
368
369        let mut hits: Vec<GrepHit> = pending
370            .into_iter()
371            .map(|p| {
372                let norm = 1.0 - GREP_B + GREP_B * (p.len / avg_len);
373                GrepHit {
374                    doc_id: p.doc_id,
375                    score: if norm > 0.0 { p.raw / norm } else { p.raw },
376                    match_count: p.match_count,
377                }
378            })
379            .collect();
380
381        // Rank: relevance descending, doc_id ascending as a stable tiebreak.
382        hits.sort_by(|a, b| {
383            b.score
384                .total_cmp(&a.score)
385                .then_with(|| a.doc_id.cmp(&b.doc_id))
386        });
387        if limit > 0 && hits.len() > limit {
388            hits.truncate(limit);
389        }
390
391        Ok(GrepResults { hits, used_index })
392    }
393}
394
395/// BM25-style saturating term frequency: `count / (count + k1)`, in `[0, 1)`.
396fn tf_saturate(count: f32) -> f32 {
397    count / (count + GREP_K1)
398}
399
400/// Count non-overlapping, ASCII case-insensitive occurrences of `needle`
401/// (already lowercased) in `hay`, without allocating a lowercased copy.
402///
403/// Non-ASCII bytes are compared as-is (no Unicode case folding); since this
404/// only feeds the *ranking* signal of documents the full regex already
405/// verified, that approximation never affects correctness.
406fn count_ci_ascii(hay: &str, needle: &str) -> usize {
407    let h = hay.as_bytes();
408    let n = needle.as_bytes();
409    if n.is_empty() || h.len() < n.len() {
410        return 0;
411    }
412    let last = h.len() - n.len();
413    let mut count = 0;
414    let mut i = 0;
415    while i <= last {
416        let mut k = 0;
417        while k < n.len() && h[i + k].to_ascii_lowercase() == n[k] {
418            k += 1;
419        }
420        if k == n.len() {
421            count += 1;
422            i += n.len(); // non-overlapping
423        } else {
424            i += 1;
425        }
426    }
427    count
428}
429
430/// ASCII case-insensitive equality. `b` is assumed already lowercased.
431fn eq_ci_ascii(a: &str, b: &str) -> bool {
432    a.len() == b.len()
433        && a.bytes()
434            .zip(b.bytes())
435            .all(|(x, y)| x.to_ascii_lowercase() == y)
436}
437
438/// Strip a leading whole-pattern inline-flag group (e.g. `(?i)`, `(?ims)`,
439/// `(?i-u)`) so the remainder can be parsed for mandatory literals. Only a pure
440/// flag setter — alphabetic flags plus an optional `-` toggle, immediately
441/// closed by `)` with no `:` scoping — is stripped; scoped groups like
442/// `(?i:...)` are left intact (returns the original pattern). The compiled
443/// regex still carries the flag, so this only affects literal extraction,
444/// never matching semantics.
445fn strip_leading_inline_flags(pattern: &str) -> &str {
446    if let Some(rest) = pattern.strip_prefix("(?") {
447        if let Some(close) = rest.find(')') {
448            let flags = &rest[..close];
449            if !flags.is_empty() && flags.bytes().all(|b| b.is_ascii_alphabetic() || b == b'-') {
450                return &rest[close + 1..];
451            }
452        }
453    }
454    pattern
455}
456
457/// Literal terms used for BOTH trigram planning and specificity scoring,
458/// together with a flag indicating whether they came from a top-level
459/// alternation (union plan) versus a conjunction (AND plan).
460///
461/// - Top-level literal alternation `a|b|c` → `(vec!["a","b","c"], true)`.
462/// - Mandatory-literal conjunction (e.g. `parse.*query`) → `(runs, false)`.
463/// - Anything else (char classes, groups, no ≥3 literal run) → `(vec![], false)`
464///   so the caller falls back to a bounded full scan.
465fn literal_terms(pattern: &str) -> (Vec<String>, bool) {
466    if let Some(branches) = literal_alternation(pattern) {
467        (branches, true)
468    } else if let Some(runs) = required_literals(pattern) {
469        (runs, false)
470    } else {
471        (Vec::new(), false)
472    }
473}
474
475/// If `pattern` is a top-level alternation of plain literals — every `|` is at
476/// the top level (no grouping/classes) and each branch reduces to a single
477/// mandatory literal of length ≥ 3 — return the per-branch literals. Otherwise
478/// `None`.
479///
480/// This is conservative: a branch that is too short or contains a wildcard
481/// (multiple runs) disqualifies the whole alternation, so the union plan it
482/// drives is always a safe trigram superset of the regex's true matches.
483fn literal_alternation(pattern: &str) -> Option<Vec<String>> {
484    if !pattern.contains('|') {
485        return None;
486    }
487    // Any grouping/class could scope a `|`, so only treat `|` as top-level when
488    // none are present.
489    if pattern.contains(['(', ')', '[', ']', '{', '}']) {
490        return None;
491    }
492    let mut branches: Vec<String> = Vec::new();
493    for raw in pattern.split('|') {
494        let lits = required_literals(raw)?;
495        // A clean term branch is exactly one mandatory literal run.
496        if lits.len() != 1 {
497            return None;
498        }
499        branches.push(lits.into_iter().next().unwrap());
500    }
501    if branches.is_empty() {
502        None
503    } else {
504        Some(branches)
505    }
506}
507
508/// Extract the mandatory trigram conjunction for `pattern`, or `None` if the
509/// pattern is too complex to prove a mandatory literal (caller must full-scan).
510///
511/// Safety contract: a returned trigram set is **required** — every document
512/// matching `pattern` contains all of them — so intersecting their postings can
513/// never drop a true match. When that cannot be proven, this returns `None`.
514pub fn required_trigrams(pattern: &str) -> Option<Vec<Trigram>> {
515    let literals = required_literals(pattern)?;
516    let mut trigrams: Vec<Trigram> = Vec::new();
517    for lit in &literals {
518        trigrams.extend(trigrams_of(lit));
519    }
520    if trigrams.is_empty() {
521        return None;
522    }
523    trigrams.sort_unstable();
524    trigrams.dedup();
525    Some(trigrams)
526}
527
528/// Extract literal runs that must appear in every match of `pattern`.
529///
530/// Conservative by design: it bails out (returns `None`) on any construct that
531/// can make a literal optional or contextual — alternation `|`, groups `( )`,
532/// character classes `[ ]`, counted repetition `{ }` — so the only literals it
533/// ever reports are unconditionally mandatory. `*` and `?` make the *preceding*
534/// character optional, so that character is trimmed from its run; `+` keeps it
535/// (one-or-more still requires one). Only runs of length ≥ 3 (trigram-indexable)
536/// are returned.
537fn required_literals(pattern: &str) -> Option<Vec<String>> {
538    let mut runs: Vec<String> = Vec::new();
539    let mut cur = String::new();
540    let mut chars = pattern.chars().peekable();
541
542    while let Some(c) = chars.next() {
543        match c {
544            // Constructs that defeat "mandatory literal" reasoning → full scan.
545            '|' | '(' | ')' | '[' | ']' | '{' | '}' => return None,
546            '\\' => match chars.next() {
547                // Escaped ASCII-alnum is a class (\d, \w, \s, \b, ...): a separator.
548                Some(n) if n.is_ascii_alphanumeric() => flush(&mut cur, &mut runs),
549                // Escaped punctuation is a literal character (\., \+, \\, ...).
550                Some(n) => cur.push(n),
551                None => {}
552            },
553            // `*` / `?`: the preceding char becomes optional → drop it.
554            '*' | '?' => {
555                cur.pop();
556                flush(&mut cur, &mut runs);
557            }
558            // Wildcard / anchors / `+` end the current literal run but keep it.
559            '.' | '^' | '$' | '+' => flush(&mut cur, &mut runs),
560            _ => cur.push(c),
561        }
562    }
563    flush(&mut cur, &mut runs);
564
565    let mandatory: Vec<String> = runs
566        .into_iter()
567        .filter(|r| r.chars().count() >= 3)
568        .collect();
569    if mandatory.is_empty() {
570        None
571    } else {
572        Some(mandatory)
573    }
574}
575
576/// Move a completed literal run into `runs` if non-empty.
577fn flush(cur: &mut String, runs: &mut Vec<String>) {
578    if !cur.is_empty() {
579        runs.push(std::mem::take(cur));
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    fn build_index() -> TrigramIndex {
588        let mut idx = TrigramIndex::new();
589        idx.insert(1, "fn parse_query(input: &str) -> Query");
590        idx.insert(2, "let parser = build();");
591        idx.insert(3, "// completely unrelated comment");
592        idx.insert(4, "error: connection timeout occurred");
593        idx.insert(5, "PARSE_MODE constant");
594        idx
595    }
596
597    #[test]
598    fn test_required_literals_extraction() {
599        // Pure literal → mandatory.
600        assert_eq!(required_literals("parse"), Some(vec!["parse".to_string()]));
601        // Wildcard splits into two mandatory runs.
602        assert_eq!(
603            required_literals("parse.*query"),
604            Some(vec!["parse".to_string(), "query".to_string()])
605        );
606        // Escaped dot is a literal, so the whole thing is one contiguous literal.
607        assert_eq!(
608            required_literals(r"config\.toml"),
609            Some(vec!["config.toml".to_string()])
610        );
611        // `?` drops the optional preceding char: "color"/"colour".
612        assert_eq!(required_literals("colou?r"), Some(vec!["colo".to_string()]));
613        // Alternation / groups / classes → cannot prove a mandatory literal.
614        assert_eq!(required_literals("cat|dog"), None);
615        assert_eq!(required_literals("(foo)bar"), None);
616        assert_eq!(required_literals("a[bc]def"), None);
617        // No literal run of length ≥ 3.
618        assert_eq!(required_literals("a.b"), None);
619    }
620
621    #[test]
622    fn test_grep_substring_uses_index() {
623        let idx = build_index();
624        let exec = GrepExecutor::new(&idx);
625        let res = exec
626            .search("parse", &AllowedSet::All, 0, GrepMode::Rank)
627            .unwrap();
628        assert!(res.used_index, "a pure literal must use the trigram index");
629        let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
630        // Docs 1 (parse_query) and 2 (parser) contain the lowercase substring
631        // "parse"; doc 5 is PARSE (uppercase) and must NOT match a
632        // case-sensitive search; doc 3 is unrelated.
633        assert!(ids.contains(&1));
634        assert!(ids.contains(&2));
635        assert!(!ids.contains(&5));
636        assert!(!ids.contains(&3));
637    }
638
639    #[test]
640    fn test_grep_case_insensitive_pattern() {
641        let idx = build_index();
642        let exec = GrepExecutor::new(&idx);
643        // (?i) makes verification case-insensitive; the trigram pre-filter is a
644        // safe superset, so doc 5 (PARSE) must now appear.
645        let res = exec
646            .search("(?i)parse", &AllowedSet::All, 0, GrepMode::Rank)
647            .unwrap();
648        let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
649        assert!(ids.contains(&5));
650    }
651
652    #[test]
653    fn test_grep_regex_with_wildcard() {
654        let idx = build_index();
655        let exec = GrepExecutor::new(&idx);
656        // Both "parse" and "query" are mandatory; only doc 1 has both.
657        let res = exec
658            .search("parse.*query", &AllowedSet::All, 0, GrepMode::Rank)
659            .unwrap();
660        assert!(res.used_index);
661        let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
662        assert_eq!(ids, vec![1]);
663    }
664
665    #[test]
666    fn test_allowed_set_pushdown() {
667        let idx = build_index();
668        let exec = GrepExecutor::new(&idx);
669        // Restrict to docs {2} — even though doc 1 also matches "parse", the
670        // gate must exclude it: result ⊆ allowed.
671        let allowed = AllowedSet::from_iter([2u64]);
672        let res = exec.search("parse", &allowed, 0, GrepMode::Rank).unwrap();
673        let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
674        assert_eq!(ids, vec![2]);
675    }
676
677    #[test]
678    fn test_gate_mode_to_allowed_set() {
679        let idx = build_index();
680        let exec = GrepExecutor::new(&idx);
681        let res = exec
682            .search("parse", &AllowedSet::All, 0, GrepMode::Gate)
683            .unwrap();
684        let gate = res.into_allowed_set();
685        assert!(gate.contains(1));
686        assert!(gate.contains(2));
687        assert!(!gate.contains(3));
688    }
689
690    #[test]
691    fn test_invalid_regex_errors() {
692        let idx = build_index();
693        let exec = GrepExecutor::new(&idx);
694        let err = exec
695            .search("(unclosed", &AllowedSet::All, 0, GrepMode::Rank)
696            .unwrap_err();
697        assert!(matches!(err, GrepError::InvalidRegex(_)));
698    }
699
700    #[test]
701    fn test_degenerate_pattern_rejected_over_budget() {
702        let idx = build_index();
703        // Budget of 1, corpus of 5, pattern "a." has no indexable trigram.
704        let exec = GrepExecutor::new(&idx).with_max_scan(1);
705        let err = exec
706            .search("a.", &AllowedSet::All, 0, GrepMode::Rank)
707            .unwrap_err();
708        assert!(matches!(err, GrepError::DegeneratePattern { .. }));
709    }
710
711    #[test]
712    fn test_degenerate_pattern_scans_within_budget() {
713        let idx = build_index();
714        // Same degenerate pattern, but the budget covers the corpus → full scan.
715        let exec = GrepExecutor::new(&idx).with_max_scan(1000);
716        let res = exec
717            .search("er.", &AllowedSet::All, 0, GrepMode::Rank)
718            .unwrap();
719        assert!(!res.used_index, "degenerate pattern must full-scan");
720        // "er" followed by any char appears in "parser"/"error"/... — at least
721        // one hit, proving the scan path actually verifies.
722        assert!(!res.hits.is_empty());
723    }
724
725    // ---- Alternation planning (Cox AND-of-ORs, union form) ----
726
727    #[test]
728    fn test_literal_alternation_extraction() {
729        // Clean top-level literal alternation.
730        assert_eq!(
731            literal_alternation("parse|timeout"),
732            Some(vec!["parse".to_string(), "timeout".to_string()])
733        );
734        // Not an alternation.
735        assert_eq!(literal_alternation("parse"), None);
736        // Grouping could scope the `|` → not provably top-level.
737        assert_eq!(literal_alternation("(parse|query)x"), None);
738        // A branch shorter than a trigram disqualifies the whole alternation.
739        assert_eq!(literal_alternation("parse|ab"), None);
740        // A branch with a wildcard is multiple runs → disqualified.
741        assert_eq!(literal_alternation("parse|foo.*bar"), None);
742    }
743
744    #[test]
745    fn test_strip_leading_inline_flags() {
746        // Whole-pattern flag setters are stripped for literal extraction.
747        assert_eq!(
748            strip_leading_inline_flags("(?i)parse|timeout"),
749            "parse|timeout"
750        );
751        assert_eq!(strip_leading_inline_flags("(?ims)parse"), "parse");
752        // Disable-toggle flags (ASCII-only case folding) are also stripped.
753        assert_eq!(strip_leading_inline_flags("(?i-u)parse|x"), "parse|x");
754        // Scoped groups must be left intact (they constrain `|` scope).
755        assert_eq!(strip_leading_inline_flags("(?i:parse|x)y"), "(?i:parse|x)y");
756        // No flag group → returned unchanged.
757        assert_eq!(strip_leading_inline_flags("parse|timeout"), "parse|timeout");
758        assert_eq!(strip_leading_inline_flags("(parse)"), "(parse)");
759    }
760
761    #[test]
762    fn test_case_insensitive_alternation_uses_index() {
763        let idx = build_index();
764        let exec = GrepExecutor::new(&idx);
765        // `(?i)` must still drive the trigram index + union, and now match the
766        // uppercase PARSE in doc 5 that the case-sensitive variant skipped.
767        let res = exec
768            .search("(?i)parse|timeout", &AllowedSet::All, 0, GrepMode::Rank)
769            .unwrap();
770        assert!(
771            res.used_index,
772            "flagged alternation must still use the index"
773        );
774        let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
775        assert!(ids.contains(&1));
776        assert!(ids.contains(&4));
777        assert!(
778            ids.contains(&5),
779            "case-insensitive match must include PARSE"
780        );
781    }
782
783    #[test]
784    fn test_alternation_uses_index_and_unions_branches() {
785        let idx = build_index();
786        let exec = GrepExecutor::new(&idx);
787        // `parse|timeout` previously full-scanned (required_literals bailed on
788        // `|`); now it must use the trigram index and union both branches.
789        let res = exec
790            .search("parse|timeout", &AllowedSet::All, 0, GrepMode::Rank)
791            .unwrap();
792        assert!(res.used_index, "literal alternation must use the index");
793        let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
794        // Docs 1 & 2 contain "parse" (lowercase); doc 4 contains "timeout".
795        assert!(ids.contains(&1));
796        assert!(ids.contains(&2));
797        assert!(ids.contains(&4));
798        // Doc 5 is uppercase PARSE → case-sensitive regex must not match it.
799        assert!(!ids.contains(&5));
800    }
801
802    // ---- Ranking: specificity / saturation / length pivot ----
803
804    #[test]
805    fn test_rank_prefers_rarer_term_over_common_frequent_term() {
806        // "alpha" is common (df = 8); "zeta" is rare (df = 1). A single hit on
807        // the rare term must outrank four hits on the common term — the exact
808        // pathology the old `matches / doc_len` density score got backwards.
809        let mut idx = TrigramIndex::new();
810        idx.insert(1, "alpha alpha alpha alpha");
811        for i in 2..=8u64 {
812            idx.insert(i, "alpha context");
813        }
814        idx.insert(9, "zeta marker present here");
815
816        let exec = GrepExecutor::new(&idx);
817        let res = exec
818            .search("alpha|zeta", &AllowedSet::All, 0, GrepMode::Rank)
819            .unwrap();
820        assert!(res.used_index);
821        // The top-ranked hit is the rare-term document, not the match-stuffed
822        // common-term one.
823        assert_eq!(res.hits.first().map(|h| h.doc_id), Some(9));
824        let score_rare = res.hits.iter().find(|h| h.doc_id == 9).unwrap().score;
825        let score_common = res.hits.iter().find(|h| h.doc_id == 1).unwrap().score;
826        assert!(
827            score_rare > score_common,
828            "rare-term doc {score_rare} must outrank frequent common-term doc {score_common}"
829        );
830    }
831
832    #[test]
833    fn test_rank_saturates_repeated_matches() {
834        // Two docs of equal length hit the same (equally rare) term; one has
835        // many more matches. With length held constant, TF saturation means the
836        // high-count doc scores higher, but far less than linearly.
837        let mut idx = TrigramIndex::new();
838        // 1 match, padded to the same char length as doc 2 (47 chars).
839        idx.insert(1, "zebra xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
840        // 8 matches, 47 chars.
841        idx.insert(2, "zebra zebra zebra zebra zebra zebra zebra zebra");
842        let exec = GrepExecutor::new(&idx);
843        let res = exec
844            .search("zebra", &AllowedSet::All, 0, GrepMode::Rank)
845            .unwrap();
846        let s1 = res.hits.iter().find(|h| h.doc_id == 1).unwrap().score;
847        let s2 = res.hits.iter().find(|h| h.doc_id == 2).unwrap().score;
848        // 8x the matches must score higher, but nowhere near 8x (saturation).
849        assert!(s2 > s1, "more matches should still score higher");
850        assert!(
851            s2 < 4.0 * s1,
852            "saturation must keep 8x matches well under 8x score (got {s2} vs {s1})"
853        );
854    }
855}