sochdb_query/grep_executor.rs
1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Grep Lane Executor (Task 5)
19//!
20//! Exact regex / substring search as a first-class retrieval lane, built on the
21//! trigram candidate index. The pipeline is:
22//!
23//! ```text
24//! regex ──► required-literal extraction ──► trigram conjunction
25//! ──► trigram posting intersection (candidate DocIds)
26//! ──► ∩ AllowedSet (filter pushdown BEFORE verification)
27//! ──► regex verification (linear-time, finite-automaton engine)
28//! ──► ranked hits OR candidate gate
29//! ```
30//!
31//! ## Correctness over speed
32//!
33//! Trigram pre-filtering is only ever used when the executor can *prove* the
34//! extracted literals are mandatory (present in every possible match). For any
35//! pattern it cannot prove this for — alternation, groups, character classes,
36//! or no literal run of length ≥ 3 — it falls back to an explicit, bounded
37//! full scan rather than risk a false negative. The full-scan path is capped by
38//! `max_scan`; exceeding the cap is reported as
39//! [`GrepError::DegeneratePattern`] instead of silently returning partial
40//! results.
41//!
42//! ## Verification engine
43//!
44//! Verification uses the `regex` crate, a finite-automaton engine with
45//! guaranteed linear-time matching, so adversarial patterns cannot turn the
46//! verify stage into a catastrophic-backtracking DoS.
47//!
48//! ## Two fusion modes
49//!
50//! Grep produces a *set*, but RRF consumes *ranked lists*. Both shapes are
51//! supported:
52//! - [`GrepMode::Rank`] scores each hit by specificity-weighted, TF-saturated,
53//! length-pivoted relevance (BM25-flavored over the pattern's literal terms)
54//! so it can plug into RRF as a third ranked lane **without** the
55//! short-document / common-term bias of raw match density.
56//! - [`GrepMode::Gate`] returns the matching documents as an
57//! [`AllowedSet`] to intersect into the other lanes (the
58//! "find the function that contains X" cascade), via [`GrepResults::into_allowed_set`].
59
60use regex::Regex;
61
62use crate::candidate_gate::AllowedSet;
63use crate::trigram_index::{DocId, Trigram, TrigramIndex, trigrams_of};
64
65/// Default cap on documents verified by a degenerate (no-trigram) full scan.
66pub const DEFAULT_MAX_SCAN: usize = 100_000;
67
68/// BM25-style term-frequency saturation constant for grep `Rank` scoring.
69/// Bounds the marginal value of additional matches of the same term.
70const GREP_K1: f32 = 1.2;
71
72/// BM25-style length-normalization (pivot) constant for grep `Rank` scoring.
73/// `0.0` disables length normalization; `1.0` applies it fully.
74const GREP_B: f32 = 0.75;
75
76/// How grep results should be consumed by the fusion layer.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum GrepMode {
79 /// Produce a ranked list (for RRF as a third lane).
80 Rank,
81 /// Produce a candidate gate (intersect into the other lanes).
82 Gate,
83}
84
85/// A single grep match.
86#[derive(Debug, Clone, PartialEq)]
87pub struct GrepHit {
88 /// Matching document id.
89 pub doc_id: DocId,
90 /// Rank score (higher is better): specificity-weighted, TF-saturated,
91 /// length-pivoted relevance over the pattern's literal terms.
92 pub score: f32,
93 /// Number of (non-overlapping) matches in the document.
94 pub match_count: usize,
95}
96
97/// The outcome of a grep search.
98#[derive(Debug, Clone)]
99pub struct GrepResults {
100 /// Ranked hits, best first.
101 pub hits: Vec<GrepHit>,
102 /// Whether the trigram index was used (`true`) or a full scan ran (`false`).
103 pub used_index: bool,
104}
105
106impl GrepResults {
107 /// The matching document ids as an [`AllowedSet`] for gate / cascade fusion.
108 pub fn into_allowed_set(self) -> AllowedSet {
109 AllowedSet::from_iter(self.hits.into_iter().map(|h| h.doc_id))
110 }
111}
112
113/// Errors the grep lane can return.
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum GrepError {
116 /// The pattern is not a valid regular expression.
117 InvalidRegex(String),
118 /// The pattern yields no usable trigram and the corpus exceeds the scan
119 /// budget, so it is rejected rather than scanned partially.
120 DegeneratePattern { corpus: usize, max_scan: usize },
121}
122
123impl std::fmt::Display for GrepError {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 GrepError::InvalidRegex(e) => write!(f, "invalid regex: {e}"),
127 GrepError::DegeneratePattern { corpus, max_scan } => write!(
128 f,
129 "degenerate pattern (no indexable literal) over a corpus of {corpus} documents \
130 exceeds the scan budget of {max_scan}"
131 ),
132 }
133 }
134}
135
136impl std::error::Error for GrepError {}
137
138/// The grep executor: plans and runs regex search over a [`TrigramIndex`].
139pub struct GrepExecutor<'a> {
140 index: &'a TrigramIndex,
141 max_scan: usize,
142}
143
144impl<'a> GrepExecutor<'a> {
145 /// Create an executor over `index` with the default scan budget.
146 pub fn new(index: &'a TrigramIndex) -> Self {
147 Self {
148 index,
149 max_scan: DEFAULT_MAX_SCAN,
150 }
151 }
152
153 /// Override the full-scan document budget for degenerate patterns.
154 pub fn with_max_scan(mut self, max_scan: usize) -> Self {
155 self.max_scan = max_scan;
156 self
157 }
158
159 /// Run a grep search.
160 ///
161 /// `allowed` is applied as a pushdown filter **before** regex verification,
162 /// preserving the same `result ⊆ allowed` invariant the other lanes honor.
163 /// `limit` caps the number of returned hits (0 = unlimited).
164 pub fn search(
165 &self,
166 pattern: &str,
167 allowed: &AllowedSet,
168 limit: usize,
169 mode: GrepMode,
170 ) -> Result<GrepResults, GrepError> {
171 let re = Regex::new(pattern).map_err(|e| GrepError::InvalidRegex(e.to_string()))?;
172
173 if allowed.is_empty() {
174 return Ok(GrepResults {
175 hits: Vec::new(),
176 used_index: false,
177 });
178 }
179
180 // ---- Plan: candidate document set (safe superset, no false negatives) ----
181 //
182 // While planning we also capture each term's document-frequency estimate
183 // (its trigram-candidate count) so the ranking stage can compute IDF
184 // without a second pass over the postings.
185 //
186 // A leading whole-pattern inline-flag group like `(?i)` is stripped
187 // *for literal extraction only* (the compiled `re` above keeps the
188 // flag), so a case-insensitive alternation still drives the index and
189 // IDF instead of degrading to a full scan.
190 let extract = strip_leading_inline_flags(pattern);
191 let (terms, is_alternation) = literal_terms(extract);
192 let mut term_df: Vec<(String, usize)> = Vec::new();
193 let (candidates, used_index): (Vec<DocId>, bool) = if terms.is_empty() {
194 // No provably-mandatory literal (complex regex): bounded full scan.
195 if self.index.len() > self.max_scan {
196 return Err(GrepError::DegeneratePattern {
197 corpus: self.index.len(),
198 max_scan: self.max_scan,
199 });
200 }
201 (self.index.documents().map(|(id, _)| id).collect(), false)
202 } else if is_alternation {
203 // Alternation `a|b|c`: a match contains *some* branch, so the
204 // candidate set is the UNION of each branch's trigram candidates
205 // (Cox AND-of-ORs, union form). Every branch is trigram-indexable
206 // here (guaranteed by `literal_alternation`), so this stays a safe
207 // superset and the previously full-scanned `|` patterns now use the
208 // index. Each branch's candidate count doubles as its df estimate.
209 let mut union: Vec<DocId> = Vec::new();
210 for term in &terms {
211 let branch = self.index.candidates(&trigrams_of(term));
212 term_df.push((term.to_lowercase(), branch.len().max(1)));
213 union.extend(branch);
214 }
215 union.sort_unstable();
216 union.dedup();
217 (union, true)
218 } else {
219 // Conjunction of mandatory literals: AND of all their trigrams.
220 let mut trigrams: Vec<Trigram> = Vec::new();
221 for term in &terms {
222 let df = self.index.candidates(&trigrams_of(term)).len().max(1);
223 term_df.push((term.to_lowercase(), df));
224 trigrams.extend(trigrams_of(term));
225 }
226 trigrams.sort_unstable();
227 trigrams.dedup();
228 (self.index.candidates(&trigrams), true)
229 };
230
231 // ---- Gate mode: membership only, no ranking ----
232 if mode == GrepMode::Gate {
233 let mut hits: Vec<GrepHit> = Vec::new();
234 for doc_id in candidates {
235 if !allowed.contains(doc_id) {
236 continue;
237 }
238 if let Some(text) = self.index.doc_text(doc_id) {
239 if re.is_match(text) {
240 hits.push(GrepHit {
241 doc_id,
242 score: 1.0,
243 match_count: 1,
244 });
245 }
246 }
247 }
248 hits.sort_by(|a, b| a.doc_id.cmp(&b.doc_id));
249 if limit > 0 && hits.len() > limit {
250 hits.truncate(limit);
251 }
252 return Ok(GrepResults { hits, used_index });
253 }
254
255 // ---- Rank mode: specificity-weighted, TF-saturated, length-pivoted ----
256 //
257 // The old score was raw match density (`matches / doc_len`), which is
258 // IDF-blind (a hit on a common word counts as much as a rare one),
259 // linear in raw match count (50 hits == 50x one hit), and explodes for
260 // short documents — so it injected noise into RRF. The corrected score
261 // is BM25-flavored over the grep's literal terms:
262 //
263 // idf(t) = ln(1 + (N - df + 0.5)/(df + 0.5)) // term rarity
264 // tf_sat = c / (c + k1) // saturating TF
265 // raw(d) = SUM_t idf(t) * tf_sat(count_t(d))
266 // score(d) = raw(d) / (1 - b + b*len_d/avg_len) // pivoted length
267 //
268 // `df` is estimated index-locally as the trigram-candidate count of the
269 // term (a tight upper bound on its true document frequency), captured
270 // during planning above, so no extra corpus statistics are needed.
271 // Verification still uses the full regex, so the hit *set* is unchanged
272 // — only the ranking improves.
273 let n = self.index.len().max(1) as f32;
274 let term_idf: Vec<(String, f32)> = term_df
275 .iter()
276 .map(|(t, df)| {
277 let dff = *df as f32;
278 let idf = (1.0 + (n - dff + 0.5) / (dff + 0.5)).ln();
279 (t.clone(), idf.max(0.0))
280 })
281 .collect();
282
283 struct Pending {
284 doc_id: DocId,
285 len: f32,
286 raw: f32,
287 match_count: usize,
288 }
289 let mut pending: Vec<Pending> = Vec::new();
290 let mut total_len = 0.0f32;
291 // Reused per-term match-count buffer (alternation path) to avoid a
292 // per-document allocation.
293 let mut counts: Vec<u32> = vec![0; term_idf.len()];
294 for doc_id in candidates {
295 if !allowed.contains(doc_id) {
296 continue;
297 }
298 let Some(text) = self.index.doc_text(doc_id) else {
299 continue;
300 };
301
302 // Single regex pass over the document. For an alternation each match
303 // is exactly one branch literal, so we attribute it to its term in
304 // the SAME pass — no extra per-term substring scans, no allocation.
305 let mut match_count = 0usize;
306 if is_alternation {
307 for c in counts.iter_mut() {
308 *c = 0;
309 }
310 for m in re.find_iter(text) {
311 match_count += 1;
312 let ms = m.as_str();
313 for (i, (term_lc, _)) in term_idf.iter().enumerate() {
314 if eq_ci_ascii(ms, term_lc) {
315 counts[i] += 1;
316 break;
317 }
318 }
319 }
320 } else {
321 match_count = re.find_iter(text).count();
322 }
323 if match_count == 0 {
324 continue;
325 }
326
327 let len = text.chars().count().max(1) as f32;
328 let raw = if term_idf.is_empty() {
329 // Complex pattern with no literal terms to weight: saturate the
330 // raw regex match count so a flood of matches can't dominate.
331 tf_saturate(match_count as f32)
332 } else if is_alternation {
333 // Per-branch counts already attributed in the single pass above.
334 let mut s = 0.0f32;
335 for (i, (_, idf)) in term_idf.iter().enumerate() {
336 if counts[i] > 0 {
337 s += idf * tf_saturate(counts[i] as f32);
338 }
339 }
340 s
341 } else {
342 // Conjunction / complex literal terms (rare): the whole-pattern
343 // matches can't be attributed per term, so scan each mandatory
344 // term once (allocation-free, ASCII case-insensitive).
345 let mut s = 0.0f32;
346 for (term_lc, idf) in &term_idf {
347 let c = count_ci_ascii(text, term_lc);
348 if c > 0 {
349 s += idf * tf_saturate(c as f32);
350 }
351 }
352 s
353 };
354 total_len += len;
355 pending.push(Pending {
356 doc_id,
357 len,
358 raw,
359 match_count,
360 });
361 }
362
363 let avg_len = if pending.is_empty() {
364 1.0
365 } else {
366 (total_len / pending.len() as f32).max(1.0)
367 };
368
369 let mut hits: Vec<GrepHit> = pending
370 .into_iter()
371 .map(|p| {
372 let norm = 1.0 - GREP_B + GREP_B * (p.len / avg_len);
373 GrepHit {
374 doc_id: p.doc_id,
375 score: if norm > 0.0 { p.raw / norm } else { p.raw },
376 match_count: p.match_count,
377 }
378 })
379 .collect();
380
381 // Rank: relevance descending, doc_id ascending as a stable tiebreak.
382 hits.sort_by(|a, b| {
383 b.score
384 .total_cmp(&a.score)
385 .then_with(|| a.doc_id.cmp(&b.doc_id))
386 });
387 if limit > 0 && hits.len() > limit {
388 hits.truncate(limit);
389 }
390
391 Ok(GrepResults { hits, used_index })
392 }
393}
394
395/// BM25-style saturating term frequency: `count / (count + k1)`, in `[0, 1)`.
396fn tf_saturate(count: f32) -> f32 {
397 count / (count + GREP_K1)
398}
399
400/// Count non-overlapping, ASCII case-insensitive occurrences of `needle`
401/// (already lowercased) in `hay`, without allocating a lowercased copy.
402///
403/// Non-ASCII bytes are compared as-is (no Unicode case folding); since this
404/// only feeds the *ranking* signal of documents the full regex already
405/// verified, that approximation never affects correctness.
406fn count_ci_ascii(hay: &str, needle: &str) -> usize {
407 let h = hay.as_bytes();
408 let n = needle.as_bytes();
409 if n.is_empty() || h.len() < n.len() {
410 return 0;
411 }
412 let last = h.len() - n.len();
413 let mut count = 0;
414 let mut i = 0;
415 while i <= last {
416 let mut k = 0;
417 while k < n.len() && h[i + k].to_ascii_lowercase() == n[k] {
418 k += 1;
419 }
420 if k == n.len() {
421 count += 1;
422 i += n.len(); // non-overlapping
423 } else {
424 i += 1;
425 }
426 }
427 count
428}
429
430/// ASCII case-insensitive equality. `b` is assumed already lowercased.
431fn eq_ci_ascii(a: &str, b: &str) -> bool {
432 a.len() == b.len()
433 && a.bytes()
434 .zip(b.bytes())
435 .all(|(x, y)| x.to_ascii_lowercase() == y)
436}
437
438/// Strip a leading whole-pattern inline-flag group (e.g. `(?i)`, `(?ims)`,
439/// `(?i-u)`) so the remainder can be parsed for mandatory literals. Only a pure
440/// flag setter — alphabetic flags plus an optional `-` toggle, immediately
441/// closed by `)` with no `:` scoping — is stripped; scoped groups like
442/// `(?i:...)` are left intact (returns the original pattern). The compiled
443/// regex still carries the flag, so this only affects literal extraction,
444/// never matching semantics.
445fn strip_leading_inline_flags(pattern: &str) -> &str {
446 if let Some(rest) = pattern.strip_prefix("(?") {
447 if let Some(close) = rest.find(')') {
448 let flags = &rest[..close];
449 if !flags.is_empty() && flags.bytes().all(|b| b.is_ascii_alphabetic() || b == b'-') {
450 return &rest[close + 1..];
451 }
452 }
453 }
454 pattern
455}
456
457/// Literal terms used for BOTH trigram planning and specificity scoring,
458/// together with a flag indicating whether they came from a top-level
459/// alternation (union plan) versus a conjunction (AND plan).
460///
461/// - Top-level literal alternation `a|b|c` → `(vec!["a","b","c"], true)`.
462/// - Mandatory-literal conjunction (e.g. `parse.*query`) → `(runs, false)`.
463/// - Anything else (char classes, groups, no ≥3 literal run) → `(vec![], false)`
464/// so the caller falls back to a bounded full scan.
465fn literal_terms(pattern: &str) -> (Vec<String>, bool) {
466 if let Some(branches) = literal_alternation(pattern) {
467 (branches, true)
468 } else if let Some(runs) = required_literals(pattern) {
469 (runs, false)
470 } else {
471 (Vec::new(), false)
472 }
473}
474
475/// If `pattern` is a top-level alternation of plain literals — every `|` is at
476/// the top level (no grouping/classes) and each branch reduces to a single
477/// mandatory literal of length ≥ 3 — return the per-branch literals. Otherwise
478/// `None`.
479///
480/// This is conservative: a branch that is too short or contains a wildcard
481/// (multiple runs) disqualifies the whole alternation, so the union plan it
482/// drives is always a safe trigram superset of the regex's true matches.
483fn literal_alternation(pattern: &str) -> Option<Vec<String>> {
484 if !pattern.contains('|') {
485 return None;
486 }
487 // Any grouping/class could scope a `|`, so only treat `|` as top-level when
488 // none are present.
489 if pattern.contains(['(', ')', '[', ']', '{', '}']) {
490 return None;
491 }
492 let mut branches: Vec<String> = Vec::new();
493 for raw in pattern.split('|') {
494 let lits = required_literals(raw)?;
495 // A clean term branch is exactly one mandatory literal run.
496 if lits.len() != 1 {
497 return None;
498 }
499 branches.push(lits.into_iter().next().unwrap());
500 }
501 if branches.is_empty() {
502 None
503 } else {
504 Some(branches)
505 }
506}
507
508/// Extract the mandatory trigram conjunction for `pattern`, or `None` if the
509/// pattern is too complex to prove a mandatory literal (caller must full-scan).
510///
511/// Safety contract: a returned trigram set is **required** — every document
512/// matching `pattern` contains all of them — so intersecting their postings can
513/// never drop a true match. When that cannot be proven, this returns `None`.
514pub fn required_trigrams(pattern: &str) -> Option<Vec<Trigram>> {
515 let literals = required_literals(pattern)?;
516 let mut trigrams: Vec<Trigram> = Vec::new();
517 for lit in &literals {
518 trigrams.extend(trigrams_of(lit));
519 }
520 if trigrams.is_empty() {
521 return None;
522 }
523 trigrams.sort_unstable();
524 trigrams.dedup();
525 Some(trigrams)
526}
527
528/// Extract literal runs that must appear in every match of `pattern`.
529///
530/// Conservative by design: it bails out (returns `None`) on any construct that
531/// can make a literal optional or contextual — alternation `|`, groups `( )`,
532/// character classes `[ ]`, counted repetition `{ }` — so the only literals it
533/// ever reports are unconditionally mandatory. `*` and `?` make the *preceding*
534/// character optional, so that character is trimmed from its run; `+` keeps it
535/// (one-or-more still requires one). Only runs of length ≥ 3 (trigram-indexable)
536/// are returned.
537fn required_literals(pattern: &str) -> Option<Vec<String>> {
538 let mut runs: Vec<String> = Vec::new();
539 let mut cur = String::new();
540 let mut chars = pattern.chars().peekable();
541
542 while let Some(c) = chars.next() {
543 match c {
544 // Constructs that defeat "mandatory literal" reasoning → full scan.
545 '|' | '(' | ')' | '[' | ']' | '{' | '}' => return None,
546 '\\' => match chars.next() {
547 // Escaped ASCII-alnum is a class (\d, \w, \s, \b, ...): a separator.
548 Some(n) if n.is_ascii_alphanumeric() => flush(&mut cur, &mut runs),
549 // Escaped punctuation is a literal character (\., \+, \\, ...).
550 Some(n) => cur.push(n),
551 None => {}
552 },
553 // `*` / `?`: the preceding char becomes optional → drop it.
554 '*' | '?' => {
555 cur.pop();
556 flush(&mut cur, &mut runs);
557 }
558 // Wildcard / anchors / `+` end the current literal run but keep it.
559 '.' | '^' | '$' | '+' => flush(&mut cur, &mut runs),
560 _ => cur.push(c),
561 }
562 }
563 flush(&mut cur, &mut runs);
564
565 let mandatory: Vec<String> = runs
566 .into_iter()
567 .filter(|r| r.chars().count() >= 3)
568 .collect();
569 if mandatory.is_empty() {
570 None
571 } else {
572 Some(mandatory)
573 }
574}
575
576/// Move a completed literal run into `runs` if non-empty.
577fn flush(cur: &mut String, runs: &mut Vec<String>) {
578 if !cur.is_empty() {
579 runs.push(std::mem::take(cur));
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 fn build_index() -> TrigramIndex {
588 let mut idx = TrigramIndex::new();
589 idx.insert(1, "fn parse_query(input: &str) -> Query");
590 idx.insert(2, "let parser = build();");
591 idx.insert(3, "// completely unrelated comment");
592 idx.insert(4, "error: connection timeout occurred");
593 idx.insert(5, "PARSE_MODE constant");
594 idx
595 }
596
597 #[test]
598 fn test_required_literals_extraction() {
599 // Pure literal → mandatory.
600 assert_eq!(required_literals("parse"), Some(vec!["parse".to_string()]));
601 // Wildcard splits into two mandatory runs.
602 assert_eq!(
603 required_literals("parse.*query"),
604 Some(vec!["parse".to_string(), "query".to_string()])
605 );
606 // Escaped dot is a literal, so the whole thing is one contiguous literal.
607 assert_eq!(
608 required_literals(r"config\.toml"),
609 Some(vec!["config.toml".to_string()])
610 );
611 // `?` drops the optional preceding char: "color"/"colour".
612 assert_eq!(required_literals("colou?r"), Some(vec!["colo".to_string()]));
613 // Alternation / groups / classes → cannot prove a mandatory literal.
614 assert_eq!(required_literals("cat|dog"), None);
615 assert_eq!(required_literals("(foo)bar"), None);
616 assert_eq!(required_literals("a[bc]def"), None);
617 // No literal run of length ≥ 3.
618 assert_eq!(required_literals("a.b"), None);
619 }
620
621 #[test]
622 fn test_grep_substring_uses_index() {
623 let idx = build_index();
624 let exec = GrepExecutor::new(&idx);
625 let res = exec
626 .search("parse", &AllowedSet::All, 0, GrepMode::Rank)
627 .unwrap();
628 assert!(res.used_index, "a pure literal must use the trigram index");
629 let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
630 // Docs 1 (parse_query) and 2 (parser) contain the lowercase substring
631 // "parse"; doc 5 is PARSE (uppercase) and must NOT match a
632 // case-sensitive search; doc 3 is unrelated.
633 assert!(ids.contains(&1));
634 assert!(ids.contains(&2));
635 assert!(!ids.contains(&5));
636 assert!(!ids.contains(&3));
637 }
638
639 #[test]
640 fn test_grep_case_insensitive_pattern() {
641 let idx = build_index();
642 let exec = GrepExecutor::new(&idx);
643 // (?i) makes verification case-insensitive; the trigram pre-filter is a
644 // safe superset, so doc 5 (PARSE) must now appear.
645 let res = exec
646 .search("(?i)parse", &AllowedSet::All, 0, GrepMode::Rank)
647 .unwrap();
648 let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
649 assert!(ids.contains(&5));
650 }
651
652 #[test]
653 fn test_grep_regex_with_wildcard() {
654 let idx = build_index();
655 let exec = GrepExecutor::new(&idx);
656 // Both "parse" and "query" are mandatory; only doc 1 has both.
657 let res = exec
658 .search("parse.*query", &AllowedSet::All, 0, GrepMode::Rank)
659 .unwrap();
660 assert!(res.used_index);
661 let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
662 assert_eq!(ids, vec![1]);
663 }
664
665 #[test]
666 fn test_allowed_set_pushdown() {
667 let idx = build_index();
668 let exec = GrepExecutor::new(&idx);
669 // Restrict to docs {2} — even though doc 1 also matches "parse", the
670 // gate must exclude it: result ⊆ allowed.
671 let allowed = AllowedSet::from_iter([2u64]);
672 let res = exec.search("parse", &allowed, 0, GrepMode::Rank).unwrap();
673 let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
674 assert_eq!(ids, vec![2]);
675 }
676
677 #[test]
678 fn test_gate_mode_to_allowed_set() {
679 let idx = build_index();
680 let exec = GrepExecutor::new(&idx);
681 let res = exec
682 .search("parse", &AllowedSet::All, 0, GrepMode::Gate)
683 .unwrap();
684 let gate = res.into_allowed_set();
685 assert!(gate.contains(1));
686 assert!(gate.contains(2));
687 assert!(!gate.contains(3));
688 }
689
690 #[test]
691 fn test_invalid_regex_errors() {
692 let idx = build_index();
693 let exec = GrepExecutor::new(&idx);
694 let err = exec
695 .search("(unclosed", &AllowedSet::All, 0, GrepMode::Rank)
696 .unwrap_err();
697 assert!(matches!(err, GrepError::InvalidRegex(_)));
698 }
699
700 #[test]
701 fn test_degenerate_pattern_rejected_over_budget() {
702 let idx = build_index();
703 // Budget of 1, corpus of 5, pattern "a." has no indexable trigram.
704 let exec = GrepExecutor::new(&idx).with_max_scan(1);
705 let err = exec
706 .search("a.", &AllowedSet::All, 0, GrepMode::Rank)
707 .unwrap_err();
708 assert!(matches!(err, GrepError::DegeneratePattern { .. }));
709 }
710
711 #[test]
712 fn test_degenerate_pattern_scans_within_budget() {
713 let idx = build_index();
714 // Same degenerate pattern, but the budget covers the corpus → full scan.
715 let exec = GrepExecutor::new(&idx).with_max_scan(1000);
716 let res = exec
717 .search("er.", &AllowedSet::All, 0, GrepMode::Rank)
718 .unwrap();
719 assert!(!res.used_index, "degenerate pattern must full-scan");
720 // "er" followed by any char appears in "parser"/"error"/... — at least
721 // one hit, proving the scan path actually verifies.
722 assert!(!res.hits.is_empty());
723 }
724
725 // ---- Alternation planning (Cox AND-of-ORs, union form) ----
726
727 #[test]
728 fn test_literal_alternation_extraction() {
729 // Clean top-level literal alternation.
730 assert_eq!(
731 literal_alternation("parse|timeout"),
732 Some(vec!["parse".to_string(), "timeout".to_string()])
733 );
734 // Not an alternation.
735 assert_eq!(literal_alternation("parse"), None);
736 // Grouping could scope the `|` → not provably top-level.
737 assert_eq!(literal_alternation("(parse|query)x"), None);
738 // A branch shorter than a trigram disqualifies the whole alternation.
739 assert_eq!(literal_alternation("parse|ab"), None);
740 // A branch with a wildcard is multiple runs → disqualified.
741 assert_eq!(literal_alternation("parse|foo.*bar"), None);
742 }
743
744 #[test]
745 fn test_strip_leading_inline_flags() {
746 // Whole-pattern flag setters are stripped for literal extraction.
747 assert_eq!(
748 strip_leading_inline_flags("(?i)parse|timeout"),
749 "parse|timeout"
750 );
751 assert_eq!(strip_leading_inline_flags("(?ims)parse"), "parse");
752 // Disable-toggle flags (ASCII-only case folding) are also stripped.
753 assert_eq!(strip_leading_inline_flags("(?i-u)parse|x"), "parse|x");
754 // Scoped groups must be left intact (they constrain `|` scope).
755 assert_eq!(strip_leading_inline_flags("(?i:parse|x)y"), "(?i:parse|x)y");
756 // No flag group → returned unchanged.
757 assert_eq!(strip_leading_inline_flags("parse|timeout"), "parse|timeout");
758 assert_eq!(strip_leading_inline_flags("(parse)"), "(parse)");
759 }
760
761 #[test]
762 fn test_case_insensitive_alternation_uses_index() {
763 let idx = build_index();
764 let exec = GrepExecutor::new(&idx);
765 // `(?i)` must still drive the trigram index + union, and now match the
766 // uppercase PARSE in doc 5 that the case-sensitive variant skipped.
767 let res = exec
768 .search("(?i)parse|timeout", &AllowedSet::All, 0, GrepMode::Rank)
769 .unwrap();
770 assert!(
771 res.used_index,
772 "flagged alternation must still use the index"
773 );
774 let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
775 assert!(ids.contains(&1));
776 assert!(ids.contains(&4));
777 assert!(
778 ids.contains(&5),
779 "case-insensitive match must include PARSE"
780 );
781 }
782
783 #[test]
784 fn test_alternation_uses_index_and_unions_branches() {
785 let idx = build_index();
786 let exec = GrepExecutor::new(&idx);
787 // `parse|timeout` previously full-scanned (required_literals bailed on
788 // `|`); now it must use the trigram index and union both branches.
789 let res = exec
790 .search("parse|timeout", &AllowedSet::All, 0, GrepMode::Rank)
791 .unwrap();
792 assert!(res.used_index, "literal alternation must use the index");
793 let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
794 // Docs 1 & 2 contain "parse" (lowercase); doc 4 contains "timeout".
795 assert!(ids.contains(&1));
796 assert!(ids.contains(&2));
797 assert!(ids.contains(&4));
798 // Doc 5 is uppercase PARSE → case-sensitive regex must not match it.
799 assert!(!ids.contains(&5));
800 }
801
802 // ---- Ranking: specificity / saturation / length pivot ----
803
804 #[test]
805 fn test_rank_prefers_rarer_term_over_common_frequent_term() {
806 // "alpha" is common (df = 8); "zeta" is rare (df = 1). A single hit on
807 // the rare term must outrank four hits on the common term — the exact
808 // pathology the old `matches / doc_len` density score got backwards.
809 let mut idx = TrigramIndex::new();
810 idx.insert(1, "alpha alpha alpha alpha");
811 for i in 2..=8u64 {
812 idx.insert(i, "alpha context");
813 }
814 idx.insert(9, "zeta marker present here");
815
816 let exec = GrepExecutor::new(&idx);
817 let res = exec
818 .search("alpha|zeta", &AllowedSet::All, 0, GrepMode::Rank)
819 .unwrap();
820 assert!(res.used_index);
821 // The top-ranked hit is the rare-term document, not the match-stuffed
822 // common-term one.
823 assert_eq!(res.hits.first().map(|h| h.doc_id), Some(9));
824 let score_rare = res.hits.iter().find(|h| h.doc_id == 9).unwrap().score;
825 let score_common = res.hits.iter().find(|h| h.doc_id == 1).unwrap().score;
826 assert!(
827 score_rare > score_common,
828 "rare-term doc {score_rare} must outrank frequent common-term doc {score_common}"
829 );
830 }
831
832 #[test]
833 fn test_rank_saturates_repeated_matches() {
834 // Two docs of equal length hit the same (equally rare) term; one has
835 // many more matches. With length held constant, TF saturation means the
836 // high-count doc scores higher, but far less than linearly.
837 let mut idx = TrigramIndex::new();
838 // 1 match, padded to the same char length as doc 2 (47 chars).
839 idx.insert(1, "zebra xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
840 // 8 matches, 47 chars.
841 idx.insert(2, "zebra zebra zebra zebra zebra zebra zebra zebra");
842 let exec = GrepExecutor::new(&idx);
843 let res = exec
844 .search("zebra", &AllowedSet::All, 0, GrepMode::Rank)
845 .unwrap();
846 let s1 = res.hits.iter().find(|h| h.doc_id == 1).unwrap().score;
847 let s2 = res.hits.iter().find(|h| h.doc_id == 2).unwrap().score;
848 // 8x the matches must score higher, but nowhere near 8x (saturation).
849 assert!(s2 > s1, "more matches should still score higher");
850 assert!(
851 s2 < 4.0 * s1,
852 "saturation must keep 8x matches well under 8x score (got {s2} vs {s1})"
853 );
854 }
855}