1use regex_syntax::ast::parse::Parser as AstParser;
10use regex_syntax::hir::literal::{ExtractKind, Extractor};
11use regex_syntax::hir::{self, Hir};
12
13use crate::index::trigram::extract_trigrams_from_bytes;
14use crate::search::SearchOptions;
15
16pub type Arm = Vec<[u8; 3]>;
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum TrigramPlan {
22 Narrow { arms: Vec<Arm> },
23 FullScan,
24}
25
26impl TrigramPlan {
27 #[must_use]
29 pub fn for_patterns(patterns: &[String], opts: &SearchOptions) -> Self {
30 if opts.invert_match() {
31 return Self::FullScan;
32 }
33 let mut trigram_arms: Vec<Arm> = Vec::new();
34 for p in patterns {
35 let arms = if opts.fixed_strings() {
36 fixed_string_literals(p.as_bytes(), opts.case_insensitive())
37 } else {
38 match plan_pattern(p.as_str(), opts) {
39 Some(a) => a,
40 None => return Self::FullScan,
41 }
42 };
43 for lit in arms {
44 if lit.len() < 3 {
45 return Self::FullScan;
46 }
47 trigram_arms.push(extract_trigrams_from_bytes(&lit));
48 }
49 }
50 if trigram_arms.is_empty() {
51 return Self::FullScan;
52 }
53 Self::Narrow { arms: trigram_arms }
54 }
55}
56
57fn build_configured_hir(pattern: &str, opts: &SearchOptions) -> Option<Hir> {
59 let ast = AstParser::new().parse(pattern).ok()?;
60
61 let mut builder = regex_syntax::hir::translate::TranslatorBuilder::new();
62 builder.unicode(true);
63 if opts.case_insensitive() {
64 builder.case_insensitive(true);
65 }
66 let mut translator = builder.build();
67 let hir = translator.translate(pattern, &ast).ok()?;
68 Some(hir)
69}
70
71fn wrap_word(hir: Hir, unicode: bool) -> Hir {
73 let start_half = if unicode {
74 hir::Look::WordStartHalfUnicode
75 } else {
76 hir::Look::WordStartHalfAscii
77 };
78 let end_half = if unicode {
79 hir::Look::WordEndHalfUnicode
80 } else {
81 hir::Look::WordEndHalfAscii
82 };
83 Hir::concat(vec![Hir::look(start_half), hir, Hir::look(end_half)])
84}
85
86fn wrap_line(hir: Hir) -> Hir {
88 Hir::concat(vec![
89 Hir::look(hir::Look::StartLF),
90 hir,
91 Hir::look(hir::Look::EndLF),
92 ])
93}
94
95fn shape_hir(hir: Hir, opts: &SearchOptions) -> Hir {
97 if opts.word_regexp() {
98 wrap_word(hir, true)
99 } else if opts.line_regexp() {
100 wrap_line(hir)
101 } else {
102 hir
103 }
104}
105
106fn extract_literals(hir: &Hir) -> Vec<Vec<u8>> {
109 let extractor_prefix = Extractor::new();
110 let extractor_suffix = {
111 let mut e = Extractor::new();
112 e.kind(ExtractKind::Suffix);
113 e
114 };
115
116 let seq_prefix = extractor_prefix.extract(hir);
117 let seq_suffix = extractor_suffix.extract(hir);
118
119 let lits_prefix = seq_prefix.literals();
120 let lits_suffix = seq_suffix.literals();
121
122 pick_better_lits(lits_prefix, lits_suffix)
123}
124
125fn pick_better_lits(
127 lits_a: Option<&[regex_syntax::hir::literal::Literal]>,
128 lits_b: Option<&[regex_syntax::hir::literal::Literal]>,
129) -> Vec<Vec<u8>> {
130 fn total_bytes(lits: Option<&[regex_syntax::hir::literal::Literal]>) -> usize {
131 lits.map_or(0, |l| l.iter().map(|lit| lit.as_bytes().len()).sum())
132 }
133
134 let a_count = lits_a.map_or(0, <[regex_syntax::hir::literal::Literal]>::len);
135 let b_count = lits_b.map_or(0, <[regex_syntax::hir::literal::Literal]>::len);
136 let a_has = a_count > 0;
137 let b_has = b_count > 0;
138
139 let lits = match (a_has, b_has) {
140 (true, false) => lits_a,
141 (false, true) => lits_b,
142 (false, false) => return Vec::new(),
143 (true, true) => {
144 let a_total = total_bytes(lits_a);
145 let b_total = total_bytes(lits_b);
146 if a_total >= b_total {
147 lits_a
148 } else {
149 lits_b
150 }
151 }
152 };
153
154 let lits = match lits {
155 Some(l) if !l.is_empty() => l,
156 _ => return Vec::new(),
157 };
158
159 let mut out = Vec::new();
160 for lit in lits {
161 let bytes = lit.as_bytes();
162 if bytes.len() >= 3 {
163 out.push(bytes.to_vec());
164 }
165 }
166 out
167}
168
169fn fixed_string_literals(lit: &[u8], case_insensitive: bool) -> Vec<Vec<u8>> {
171 if case_insensitive {
172 vec![lit.to_ascii_lowercase()]
173 } else {
174 vec![lit.to_vec()]
175 }
176}
177
178fn plan_pattern(pattern: &str, opts: &SearchOptions) -> Option<Vec<Vec<u8>>> {
180 let hir = build_configured_hir(pattern, opts)?;
181 let shaped = shape_hir(hir, opts);
182 let lits = extract_literals(&shaped);
183 if lits.is_empty() {
184 None
185 } else {
186 Some(lits)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::search::SearchMatchFlags;
194
195 fn narrow(patterns: &[String], flags: SearchMatchFlags) -> bool {
196 let opts = SearchOptions {
197 flags,
198 max_results: None,
199 };
200 matches!(
201 TrigramPlan::for_patterns(patterns, &opts),
202 TrigramPlan::Narrow { .. }
203 )
204 }
205
206 fn full_scan(patterns: &[String], flags: SearchMatchFlags) -> bool {
207 let opts = SearchOptions {
208 flags,
209 max_results: None,
210 };
211 matches!(
212 TrigramPlan::for_patterns(patterns, &opts),
213 TrigramPlan::FullScan
214 )
215 }
216
217 #[test]
218 fn literal_narrows() {
219 assert!(narrow(&["beta".to_string()], SearchMatchFlags::empty()));
220 }
221
222 #[test]
223 fn dot_star_full_scan() {
224 assert!(full_scan(&[".*".to_string()], SearchMatchFlags::empty()));
225 }
226
227 #[test]
228 fn alternation_narrows() {
229 assert!(narrow(&[r"foo|bar".to_string()], SearchMatchFlags::empty()));
230 }
231
232 #[test]
233 fn word_literal_narrows() {
234 assert!(narrow(&["beta".to_string()], SearchMatchFlags::WORD_REGEXP));
235 }
236
237 #[test]
238 fn line_regexp_narrows() {
239 assert!(narrow(&["beta".to_string()], SearchMatchFlags::LINE_REGEXP));
240 }
241
242 #[test]
243 fn case_insensitive_narrows() {
244 assert!(narrow(
245 &["beta".to_string()],
246 SearchMatchFlags::CASE_INSENSITIVE
247 ));
248 }
249
250 #[test]
251 fn required_literal_inside_regex_narrows() {
252 assert!(narrow(
253 &["[A-Z]+_RESUME".to_string()],
254 SearchMatchFlags::empty()
255 ));
256 }
257
258 #[test]
259 fn unicode_class_full_scan() {
260 assert!(full_scan(
261 &[r"\p{Greek}".to_string()],
262 SearchMatchFlags::empty()
263 ));
264 }
265
266 #[test]
267 fn no_literal_full_scan() {
268 assert!(full_scan(
269 &[r"\w{5}\s+\w{5}".to_string()],
270 SearchMatchFlags::empty()
271 ));
272 }
273
274 #[test]
275 fn short_literal_full_scan() {
276 assert!(full_scan(&["ab".to_string()], SearchMatchFlags::empty()));
277 }
278
279 #[test]
280 fn fixed_string_narrows() {
281 assert!(narrow(
282 &["beta.gamma".to_string()],
283 SearchMatchFlags::FIXED_STRINGS
284 ));
285 }
286}