1#![forbid(unsafe_code)]
6use crate::parser::{ClassItem, FinalNode};
7use safe_proc_macro2::{Ident, TokenStream};
8use safe_quote::{format_ident, quote};
9
10#[derive(Clone, PartialOrd, PartialEq)]
11pub enum Predicate {
12 Any,
13 Incl(Vec<ClassItem>),
14 Excl(Vec<ClassItem>),
15}
16impl core::fmt::Debug for Predicate {
17 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
18 match self {
19 Predicate::Any => write!(f, "Empty"),
20 Predicate::Incl(items) => write!(f, "Incl{items:?}"),
21 Predicate::Excl(items) => write!(f, "Excl{items:?}"),
22 }
23 }
24}
25
26#[derive(Clone, PartialOrd, PartialEq)]
35enum OptimizedNode {
36 Byte(Predicate),
37 Seq(Vec<OptimizedNode>),
38 Alt(Vec<OptimizedNode>),
39 Optional(Box<OptimizedNode>),
40 Star(Box<OptimizedNode>),
41 Group(Box<OptimizedNode>),
42}
43impl OptimizedNode {
44 pub fn non_capturing(&self) -> OptimizedNode {
45 match self {
46 OptimizedNode::Byte(_) => self.clone(),
47 OptimizedNode::Seq(nodes) => {
48 OptimizedNode::Seq(nodes.iter().map(OptimizedNode::non_capturing).collect())
49 }
50 OptimizedNode::Alt(nodes) => {
51 OptimizedNode::Alt(nodes.iter().map(OptimizedNode::non_capturing).collect())
52 }
53 OptimizedNode::Optional(node) => {
54 OptimizedNode::Optional(Box::new(node.non_capturing()))
55 }
56 OptimizedNode::Star(node) => OptimizedNode::Star(Box::new(node.non_capturing())),
57 OptimizedNode::Group(node) => node.non_capturing(),
58 }
59 }
60
61 pub fn from_final_node(final_node: &FinalNode) -> Option<Self> {
62 match final_node {
63 FinalNode::AnyByte => Some(OptimizedNode::Byte(Predicate::Any)),
64 FinalNode::Byte(b) => {
65 Some(OptimizedNode::Byte(Predicate::Incl(vec![ClassItem::Byte(
66 *b,
67 )])))
68 }
69 FinalNode::Class(true, items) => {
70 Some(OptimizedNode::Byte(Predicate::Incl(items.clone())))
71 }
72 FinalNode::Class(false, items) => {
73 Some(OptimizedNode::Byte(Predicate::Excl(items.clone())))
74 }
75 FinalNode::Seq(final_nodes) => {
76 let mut nodes: Vec<OptimizedNode> = final_nodes
77 .iter()
78 .filter_map(OptimizedNode::from_final_node)
79 .collect();
80 if nodes.is_empty() {
81 None
82 } else if nodes.len() == 1 {
83 Some(nodes.pop().unwrap())
84 } else {
85 Some(OptimizedNode::Seq(nodes))
86 }
87 }
88 FinalNode::Alt(final_nodes) => {
89 let mut nodes: Vec<OptimizedNode> = final_nodes
90 .iter()
91 .filter_map(OptimizedNode::from_final_node)
92 .collect();
93 if nodes.is_empty() {
94 None
95 } else if nodes.len() == 1 {
96 Some(nodes.pop().unwrap())
97 } else {
98 Some(OptimizedNode::Alt(nodes))
99 }
100 }
101 FinalNode::Repeat(inner_final_node, 0, None) => Some(OptimizedNode::Star(Box::new(
102 OptimizedNode::from_final_node(inner_final_node)?,
103 ))),
104 FinalNode::Repeat(inner_final_node, min, None) => {
105 let node = OptimizedNode::from_final_node(inner_final_node)?;
106 let non_capturing_node = node.non_capturing();
107 let mut src_nodes =
108 core::iter::once(node).chain(core::iter::repeat(non_capturing_node.clone()));
109 let mut nodes = Vec::with_capacity(min + 1);
110 nodes.extend(src_nodes.by_ref().take(*min));
111 nodes.push(OptimizedNode::Star(Box::new(non_capturing_node)));
112 Some(OptimizedNode::Seq(nodes))
113 }
114 FinalNode::Repeat(_node, 0, Some(0)) => None,
115 FinalNode::Repeat(node, 1, Some(1)) => OptimizedNode::from_final_node(node),
116 FinalNode::Repeat(_node, min, Some(max)) if max < min => unreachable!(),
117 FinalNode::Repeat(inner_final_node, min, Some(max)) => {
118 let node = OptimizedNode::from_final_node(inner_final_node)?;
119 let non_capturing_node = node.non_capturing();
120 let mut src_nodes =
121 core::iter::once(node).chain(core::iter::repeat(non_capturing_node));
122 let mut nodes = Vec::with_capacity(*max);
123 nodes.extend(src_nodes.by_ref().take(*min));
124 nodes.extend(
125 src_nodes
126 .map(|node| OptimizedNode::Optional(Box::new(node)))
127 .take(max - min),
128 );
129 Some(OptimizedNode::Seq(nodes))
130 }
131 FinalNode::Group(inner_final_node) => Some(OptimizedNode::Group(Box::new(
132 OptimizedNode::from_final_node(inner_final_node).expect("found empty group"),
133 ))),
134 FinalNode::NonCapturingGroup(inner_final_node) => {
135 Some(OptimizedNode::from_final_node(inner_final_node)?)
136 }
137 }
138 }
139}
140impl core::fmt::Debug for OptimizedNode {
141 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
142 match self {
143 OptimizedNode::Byte(items) => write!(f, "OptimizedNode::Byte({items:?})"),
144 OptimizedNode::Seq(nodes) => write!(f, "OptimizedNode::Seq{nodes:?}"),
145 OptimizedNode::Alt(nodes) => write!(f, "OptimizedNode::Alt{nodes:?}"),
146 OptimizedNode::Optional(node) => write!(f, "OptimizedNode::Optional({node:?})"),
147 OptimizedNode::Star(node) => write!(f, "OptimizedNode::Star({node:?})"),
148 OptimizedNode::Group(node) => write!(f, "OptimizedNode::Group({node:?})"),
149 }
150 }
151}
152
153#[derive(Clone)]
154struct Counter {
155 n: usize,
156}
157impl Counter {
158 pub fn new() -> Self {
159 Self { n: 0 }
160 }
161 pub fn get(&self) -> usize {
162 self.n
163 }
164 pub fn get_and_increment(&mut self) -> usize {
165 let result = self.n;
166 self.n += 1;
167 result
168 }
169}
170#[cfg(test)]
171#[test]
172fn test_counter() {
173 let mut counter = Counter::new();
174 assert_eq!(0, counter.get());
175 assert_eq!(0, counter.get_and_increment());
176 assert_eq!(1, counter.get());
177 assert_eq!(1, counter.get_and_increment());
178 assert_eq!(2, counter.get());
179 assert_eq!(2, counter.get_and_increment());
180 assert_eq!(3, counter.get());
181}
182
183fn byte_and_prev_var_names(n: usize) -> (Ident, Ident) {
184 (format_ident!("b{}", n), format_ident!("prev_b{}", n))
185}
186
187#[derive(Clone, PartialOrd, PartialEq)]
188enum TaggedNode {
189 Byte(Predicate),
190 Seq(Vec<TaggedNode>),
191 Alt(Vec<TaggedNode>),
192 Optional(Box<TaggedNode>),
193 Star(Box<TaggedNode>),
194 Group(usize, Box<TaggedNode>),
195}
196impl TaggedNode {
197 pub fn from_optimized(group_counter: &mut Counter, source: &OptimizedNode) -> Self {
198 match source {
199 OptimizedNode::Byte(predicate) => TaggedNode::Byte(predicate.clone()),
200 OptimizedNode::Seq(nodes) => TaggedNode::Seq(
201 nodes
202 .iter()
203 .map(|node| TaggedNode::from_optimized(group_counter, node))
204 .collect(),
205 ),
206 OptimizedNode::Alt(nodes) => TaggedNode::Alt(
207 nodes
208 .iter()
209 .map(|node| TaggedNode::from_optimized(group_counter, node))
210 .collect(),
211 ),
212 OptimizedNode::Optional(node) => {
213 TaggedNode::Optional(Box::new(TaggedNode::from_optimized(group_counter, node)))
214 }
215 OptimizedNode::Star(node) => {
216 TaggedNode::Star(Box::new(TaggedNode::from_optimized(group_counter, node)))
217 }
218 OptimizedNode::Group(node) => {
219 let this_group = group_counter.get_and_increment();
220 TaggedNode::Group(
221 this_group,
222 Box::new(TaggedNode::from_optimized(group_counter, node)),
223 )
224 }
225 }
226 }
227}
228impl core::fmt::Debug for TaggedNode {
229 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
230 match self {
231 TaggedNode::Byte(predicate) => write!(f, "Byte({predicate:?})"),
232 TaggedNode::Seq(nodes) => write!(f, "Seq({nodes:?})"),
233 TaggedNode::Alt(nodes) => write!(f, "Alt({nodes:?})"),
234 TaggedNode::Optional(node) => write!(f, "Optional({node:?})"),
235 TaggedNode::Star(node) => write!(f, "Star({node:?})"),
236 TaggedNode::Group(group_num, node) => {
237 write!(f, "Group({group_num},{node:?})")
238 }
239 }
240 }
241}
242
243#[allow(clippy::too_many_lines)]
244fn build(
245 var_counter: &mut Counter,
246 num_groups: usize,
247 enclosing_groups: &[usize],
248 statements2_reversed: &mut Vec<TokenStream>,
249 prev_state_expr: &TokenStream,
250 node: &TaggedNode,
251) -> TokenStream {
252 crate::dprintln!("build {:?}", node);
253 let result = match node {
254 TaggedNode::Byte(predicate) => {
255 let var_num = var_counter.get_and_increment();
256 let (var_name, prev_var_name) = byte_and_prev_var_names(var_num);
257 let filter = match predicate {
258 Predicate::Any => quote! {},
259 Predicate::Incl(items) => {
260 let comparisons = items.iter().map(|p| match p {
261 ClassItem::Byte(b) => quote! {*b == #b},
262 ClassItem::ByteRange(x, y) => quote! {(#x ..= #y).contains(b)},
263 });
264 quote! { .filter(|_| { #( #comparisons )||* } ) }
265 }
266 Predicate::Excl(items) => {
267 let comparisons = items.iter().map(|p| match p {
268 ClassItem::Byte(b) => quote! {*b != #b},
269 ClassItem::ByteRange(x, y) => quote! {!(#x ..= #y).contains(b)},
270 });
271 quote! { .filter(|_| { #( #comparisons )&&* } ) }
272 }
273 };
274 let update_groups = if enclosing_groups.is_empty() {
275 quote! {}
276 } else {
277 let mut range_names = Vec::new();
278 let mut range_values = Vec::new();
279 for r in 0..num_groups {
280 let range_name = format_ident!("r{}", r);
281 range_names.push(range_name.clone());
282 range_values.push(if enclosing_groups.contains(&r) {
283 quote! { #range_name .start .. n + 1}
284 } else {
285 quote! { #range_name }
286 });
287 }
288 let extra_comma = if num_groups > 1 {
289 quote! {}
290 } else {
291 quote! {,}
292 };
293 quote! {
294 .map(
295 |( #( #range_names ),* #extra_comma )| ( #( #range_values ),* #extra_comma )
296 )
297 }
298 };
299 statements2_reversed.push(quote! {
300 #var_name = #prev_state_expr .clone() #filter #update_groups ;
301 });
302 quote! { #prev_var_name }
303 }
304 TaggedNode::Seq(inner_nodes) => {
305 assert!(!inner_nodes.is_empty());
306 let mut last_state_expr = prev_state_expr.clone();
307 for node in inner_nodes {
308 last_state_expr = build(
309 var_counter,
310 num_groups,
311 enclosing_groups,
312 statements2_reversed,
313 &last_state_expr,
314 node,
315 );
316 }
317 last_state_expr
318 }
319 TaggedNode::Alt(inner_nodes) => {
320 assert!(!inner_nodes.is_empty());
321 let mut arm_state_exprs: Vec<TokenStream> = Vec::new();
322 for node in inner_nodes {
323 arm_state_exprs.push(build(
324 var_counter,
325 num_groups,
326 enclosing_groups,
327 statements2_reversed,
328 prev_state_expr,
329 node,
330 ));
331 }
332 quote! { None #( .or_else(|| #arm_state_exprs.clone()) )* }
333 }
334 TaggedNode::Optional(inner) => {
335 let node_state_expr = build(
336 var_counter,
337 num_groups,
338 enclosing_groups,
339 statements2_reversed,
340 prev_state_expr,
341 inner,
342 );
343 quote! { #prev_state_expr .clone() .or_else(|| #node_state_expr .clone()) }
344 }
345 TaggedNode::Star(inner) => {
347 let first_expr = build(
348 &mut var_counter.clone(), num_groups,
350 enclosing_groups,
351 &mut Vec::new(), "e! { #prev_state_expr },
353 inner,
354 );
355 let expr = build(
356 var_counter,
357 num_groups,
358 enclosing_groups,
359 statements2_reversed,
360 "e! { #prev_state_expr .clone() .or_else(|| #first_expr .clone()) },
361 inner,
362 );
363 quote! { #prev_state_expr .clone() .or_else(|| #expr .clone()) }
364 }
365 TaggedNode::Group(group_num, inner) => {
366 let inner_enclosing_groups: Vec<usize> = enclosing_groups
367 .iter()
368 .chain(core::iter::once(group_num))
369 .copied()
370 .collect();
371 let inner_prev_state_expr = {
372 let mut range_names = Vec::new();
373 let mut range_values = Vec::new();
374 let extra_comma = if num_groups > 1 {
375 quote! {}
376 } else {
377 quote! {,}
378 };
379 for r in 0..num_groups {
380 let range_name = format_ident!("r{}", r);
381 range_names.push(range_name.clone());
382 range_values.push(if &r == group_num {
383 quote! { n .. n }
384 } else {
385 quote! { #range_name }
386 });
387 }
388 quote! {
389 #prev_state_expr .clone().map(
390 |( #( #range_names ),* #extra_comma )| ( #( #range_values ),* #extra_comma )
391 )
392 }
393 };
394 build(
395 var_counter,
396 num_groups,
397 &inner_enclosing_groups,
398 statements2_reversed,
399 &inner_prev_state_expr,
400 inner,
401 )
402 }
403 };
404 crate::dprintln!("build returning {:?}", result);
405 result
406}
407
408#[must_use]
412#[allow(clippy::too_many_lines)]
413pub fn generate(final_node: &FinalNode) -> safe_proc_macro2::TokenStream {
414 let Some(optimized_node) = OptimizedNode::from_final_node(final_node) else {
415 return quote! {
416 safe_regex::Matcher0::new(|data: &[u8]| {
417 if data.is_empty() {
418 Some(())
419 } else {
420 None
421 }
422 })
423 };
424 };
425 let mut group_counter = Counter::new();
426 let tagged_node = TaggedNode::from_optimized(&mut group_counter, &optimized_node);
427 let num_groups = group_counter.get();
428 let matcher_type_name = format_ident!("Matcher{}", num_groups);
429 let mut statements2_reversed: Vec<TokenStream> = Vec::new();
430 let mut var_counter = Counter::new();
431 let accept_expr = build(
432 &mut var_counter,
433 num_groups,
434 &Vec::new(),
435 &mut statements2_reversed,
436 "e! { start },
437 &tagged_node,
438 );
439 let mut var_names: Vec<Ident> = Vec::new();
440 let mut var_clone_statements: Vec<TokenStream> = Vec::new();
441 for n in 0..var_counter.get() {
442 let (var_name, prev_var_name) = byte_and_prev_var_names(n);
443 var_clone_statements.push(quote! {
444 let #prev_var_name = #var_name .clone() ;
445 });
446 var_names.push(var_name);
447 }
448 let statements2 = statements2_reversed.iter().rev();
449 let give_up_stmt = if var_names.len() == 1 {
450 quote! { #( #var_names .as_ref()? )* ; }
451 } else {
452 quote! {
453 if #( #var_names .is_none() )&&* {
454 return None;
455 }
456 }
457 };
458 let result = if num_groups == 0 {
459 quote! {
460 safe_regex::#matcher_type_name::new(|data: &[u8]| {
461 let mut start = Some(());
462 #( let mut #var_names : Option<()> = None; )*
463 let mut data_iter = data.iter();
464 loop {
465 #( #var_clone_statements )*
466 if let Some(b) = data_iter.next() {
467 #( #statements2 )*
468 start = None;
469 #give_up_stmt
470 } else {
471 return #accept_expr ;
472 }
473 }
474 })
475 }
476 } else {
477 let default_ranges = core::iter::repeat(quote! { usize::MAX..usize::MAX }).take(num_groups);
478 let extra_comma = if num_groups > 1 {
479 quote! {}
480 } else {
481 quote! {,}
482 };
483 let range_types = core::iter::repeat(quote! { core::ops::Range<usize> }).take(num_groups);
484 let range_type = quote! { Option<( #( #range_types ),* #extra_comma )> };
485 let range_names: Vec<Ident> = (0..num_groups).map(|r| format_ident!("r{}", r)).collect();
486 quote! {
487 safe_regex::#matcher_type_name::new(|data: &[u8]| {
488 assert!(data.len() < usize::MAX - 2);
489 let mut start = Some(( #( #default_ranges ),* #extra_comma ));
490 #( let mut #var_names : #range_type = None; )*
491 let mut accept : #range_type = None;
492 let mut data_iter = data.iter();
493 let mut n = 0;
494 loop {
495 #( #var_clone_statements )*
496 accept = #accept_expr .clone() ;
497 if let Some(b) = data_iter.next() {
498 #( #statements2 )*
499 start = None;
500 #give_up_stmt
501 } else {
502 break;
503 }
504 n +=1 ;
505 }
506 accept .map(|( #( #range_names ),* #extra_comma )| {
507 [
508 #(
509 if #range_names.start == usize::MAX || #range_names.end == usize::MAX || #range_names.is_empty() {
510 0..0usize
511 } else {
512 #range_names
513 },
514 )*
515 ]
516 })
517 })
518 }
519 };
520 crate::dprintln!("result={}", result);
521 #[allow(clippy::let_and_return)]
522 result
523}