1use regex_syntax::ast::parse::Parser as AstParser;
10use regex_syntax::hir::literal::{ExtractKind, Extractor};
11use regex_syntax::hir::{self, Hir};
12
13use crate::index::trigram::extract_trigrams_from_bytes;
14use crate::search::SearchOptions;
15
16pub type Arm = Vec<[u8; 3]>;
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TrigramPlan {
22 Narrow { arms: Vec<Arm> },
23 FullScan,
24}
25
26impl TrigramPlan {
27 #[must_use]
29 pub fn for_patterns(patterns: &[String], opts: &SearchOptions) -> Self {
30 if opts.invert_match() {
31 return Self::FullScan;
32 }
33 let mut trigram_arms: Vec<Arm> = Vec::new();
34 for p in patterns {
35 let arms = if opts.fixed_strings() {
36 fixed_string_literals(p.as_bytes(), opts.case_insensitive())
37 } else {
38 match plan_pattern(p.as_str(), opts) {
39 Some(a) => a,
40 None => return Self::FullScan,
41 }
42 };
43 for lit in arms {
44 if lit.len() < 3 {
45 return Self::FullScan;
46 }
47 trigram_arms.push(extract_trigrams_from_bytes(&lit));
48 }
49 }
50 if trigram_arms.is_empty() {
51 return Self::FullScan;
52 }
53 Self::Narrow { arms: trigram_arms }
54 }
55}
56
57fn build_configured_hir(pattern: &str, opts: &SearchOptions) -> Option<Hir> {
59 let ast = AstParser::new().parse(pattern).ok()?;
60
61 let mut builder = regex_syntax::hir::translate::TranslatorBuilder::new();
62 builder.unicode(true);
63 if opts.case_insensitive() {
64 builder.case_insensitive(true);
65 }
66 let mut translator = builder.build();
67 let hir = translator.translate(pattern, &ast).ok()?;
68 Some(hir)
69}
70
71fn wrap_word(hir: Hir, unicode: bool) -> Hir {
73 let start_half = if unicode {
74 hir::Look::WordStartHalfUnicode
75 } else {
76 hir::Look::WordStartHalfAscii
77 };
78 let end_half = if unicode {
79 hir::Look::WordEndHalfUnicode
80 } else {
81 hir::Look::WordEndHalfAscii
82 };
83 Hir::concat(vec![Hir::look(start_half), hir, Hir::look(end_half)])
84}
85
86fn wrap_line(hir: Hir) -> Hir {
88 Hir::concat(vec![
89 Hir::look(hir::Look::StartLF),
90 hir,
91 Hir::look(hir::Look::EndLF),
92 ])
93}
94
95fn shape_hir(hir: Hir, opts: &SearchOptions) -> Hir {
97 if opts.line_regexp() {
98 wrap_line(hir)
99 } else if opts.word_regexp() {
100 wrap_word(hir, true)
101 } else {
102 hir
103 }
104}
105
106fn extract_literals(hir: &Hir) -> Vec<Vec<u8>> {
109 let extractor_prefix = Extractor::new();
110 let extractor_suffix = {
111 let mut e = Extractor::new();
112 e.kind(ExtractKind::Suffix);
113 e
114 };
115
116 let seq_prefix = extractor_prefix.extract(hir);
117 let seq_suffix = extractor_suffix.extract(hir);
118
119 let lits_prefix = seq_prefix.literals();
120 let lits_suffix = seq_suffix.literals();
121
122 pick_better_lits(lits_prefix, lits_suffix)
123}
124
125fn pick_better_lits(
127 lits_a: Option<&[regex_syntax::hir::literal::Literal]>,
128 lits_b: Option<&[regex_syntax::hir::literal::Literal]>,
129) -> Vec<Vec<u8>> {
130 fn total_bytes(lits: Option<&[regex_syntax::hir::literal::Literal]>) -> usize {
131 lits.map_or(0, |l| l.iter().map(|lit| lit.as_bytes().len()).sum())
132 }
133
134 let a_count = lits_a.map_or(0, <[regex_syntax::hir::literal::Literal]>::len);
135 let b_count = lits_b.map_or(0, <[regex_syntax::hir::literal::Literal]>::len);
136 let a_has = a_count > 0;
137 let b_has = b_count > 0;
138
139 let lits = match (a_has, b_has) {
140 (true, false) => lits_a,
141 (false, true) => lits_b,
142 (false, false) => return Vec::new(),
143 (true, true) => {
144 let a_total = total_bytes(lits_a);
145 let b_total = total_bytes(lits_b);
146 if a_total >= b_total {
147 lits_a
148 } else {
149 lits_b
150 }
151 }
152 };
153
154 let lits = match lits {
155 Some(l) if !l.is_empty() => l,
156 _ => return Vec::new(),
157 };
158
159 let mut out = Vec::new();
160 for lit in lits {
161 let bytes = lit.as_bytes();
162 if bytes.len() >= 3 {
163 out.push(bytes.to_vec());
164 }
165 }
166 out
167}
168
169fn fixed_string_literals(lit: &[u8], case_insensitive: bool) -> Vec<Vec<u8>> {
171 if case_insensitive {
172 vec![lit.to_ascii_lowercase()]
173 } else {
174 vec![lit.to_vec()]
175 }
176}
177
178fn plan_pattern(pattern: &str, opts: &SearchOptions) -> Option<Vec<Vec<u8>>> {
180 let hir = build_configured_hir(pattern, opts)?;
181 let shaped = shape_hir(hir, opts);
182 let lits = extract_literals(&shaped);
183 if lits.is_empty() {
184 None
185 } else {
186 Some(lits)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::search::{CaseMode, SearchMatchFlags};
194
195 fn narrow(patterns: &[String], opts: &SearchOptions) -> bool {
196 matches!(
197 TrigramPlan::for_patterns(patterns, opts),
198 TrigramPlan::Narrow { .. }
199 )
200 }
201
202 fn full_scan(patterns: &[String], opts: &SearchOptions) -> bool {
203 matches!(
204 TrigramPlan::for_patterns(patterns, opts),
205 TrigramPlan::FullScan
206 )
207 }
208
209 #[test]
210 fn literal_narrows() {
211 assert!(narrow(&["beta".to_string()], &SearchOptions::default()));
212 }
213
214 #[test]
215 fn dot_star_full_scan() {
216 assert!(full_scan(&[".*".to_string()], &SearchOptions::default()));
217 }
218
219 #[test]
220 fn alternation_narrows() {
221 assert!(narrow(&[r"foo|bar".to_string()], &SearchOptions::default()));
222 }
223
224 #[test]
225 fn word_literal_narrows() {
226 let opts = SearchOptions {
227 flags: SearchMatchFlags::WORD_REGEXP,
228 case_mode: CaseMode::Sensitive,
229 max_results: None,
230 };
231 assert!(narrow(&["beta".to_string()], &opts));
232 }
233
234 #[test]
235 fn line_regexp_narrows() {
236 let opts = SearchOptions {
237 flags: SearchMatchFlags::LINE_REGEXP,
238 case_mode: CaseMode::Sensitive,
239 max_results: None,
240 };
241 assert!(narrow(&["beta".to_string()], &opts));
242 }
243
244 #[test]
245 fn case_insensitive_narrows() {
246 let opts = SearchOptions {
247 flags: SearchMatchFlags::empty(),
248 case_mode: CaseMode::Insensitive,
249 max_results: None,
250 };
251 assert!(narrow(&["beta".to_string()], &opts));
252 }
253
254 #[test]
255 fn required_literal_inside_regex_narrows() {
256 assert!(narrow(
257 &["[A-Z]+_RESUME".to_string()],
258 &SearchOptions::default()
259 ));
260 }
261
262 #[test]
263 fn unicode_class_full_scan() {
264 assert!(full_scan(
265 &[r"\p{Greek}".to_string()],
266 &SearchOptions::default()
267 ));
268 }
269
270 #[test]
271 fn no_literal_full_scan() {
272 assert!(full_scan(
273 &[r"\w{5}\s+\w{5}".to_string()],
274 &SearchOptions::default()
275 ));
276 }
277
278 #[test]
279 fn short_literal_full_scan() {
280 assert!(full_scan(&["ab".to_string()], &SearchOptions::default()));
281 }
282
283 #[test]
284 fn fixed_string_narrows() {
285 let opts = SearchOptions {
286 flags: SearchMatchFlags::FIXED_STRINGS,
287 case_mode: CaseMode::Sensitive,
288 max_results: None,
289 };
290 assert!(narrow(&["beta.gamma".to_string()], &opts));
291 }
292}