safe_regex_compiler/
generator.rs

1//! Provides a [`generate`](fn.generate.html) function used by the `regex!`
2//! proc macro.
3//!
4//! How-to develop proc macros: <https://github.com/dtolnay/proc-macro-workshop>
5#![forbid(unsafe_code)]
6use crate::parser::{ClassItem, FinalNode};
7use safe_proc_macro2::{Ident, TokenStream};
8use safe_quote::{format_ident, quote};
9
10#[derive(Clone, PartialOrd, PartialEq)]
11pub enum Predicate {
12    Any,
13    Incl(Vec<ClassItem>),
14    Excl(Vec<ClassItem>),
15}
16impl core::fmt::Debug for Predicate {
17    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
18        match self {
19            Predicate::Any => write!(f, "Empty"),
20            Predicate::Incl(items) => write!(f, "Incl{items:?}"),
21            Predicate::Excl(items) => write!(f, "Excl{items:?}"),
22        }
23    }
24}
25
26// TODO(mleonhard) Add more tree simplifications:
27// - Collapse nested Seq into one
28// - Collapse nested Alt into one
29// - Merge peer Bytes in Alt
30// - Deduplicate Empty in Alt
31// - Drop Optional(x) that comes right after Star(x)
32// - Reorder Optional(x),x so the optional comes later
33// - Translate x{2,5} into "xx(x(x(x)?)?)?" rather than "xxx?x?x?"
34#[derive(Clone, PartialOrd, PartialEq)]
35enum OptimizedNode {
36    Byte(Predicate),
37    Seq(Vec<OptimizedNode>),
38    Alt(Vec<OptimizedNode>),
39    Optional(Box<OptimizedNode>),
40    Star(Box<OptimizedNode>),
41    Group(Box<OptimizedNode>),
42}
43impl OptimizedNode {
44    pub fn non_capturing(&self) -> OptimizedNode {
45        match self {
46            OptimizedNode::Byte(_) => self.clone(),
47            OptimizedNode::Seq(nodes) => {
48                OptimizedNode::Seq(nodes.iter().map(OptimizedNode::non_capturing).collect())
49            }
50            OptimizedNode::Alt(nodes) => {
51                OptimizedNode::Alt(nodes.iter().map(OptimizedNode::non_capturing).collect())
52            }
53            OptimizedNode::Optional(node) => {
54                OptimizedNode::Optional(Box::new(node.non_capturing()))
55            }
56            OptimizedNode::Star(node) => OptimizedNode::Star(Box::new(node.non_capturing())),
57            OptimizedNode::Group(node) => node.non_capturing(),
58        }
59    }
60
61    pub fn from_final_node(final_node: &FinalNode) -> Option<Self> {
62        match final_node {
63            FinalNode::AnyByte => Some(OptimizedNode::Byte(Predicate::Any)),
64            FinalNode::Byte(b) => {
65                Some(OptimizedNode::Byte(Predicate::Incl(vec![ClassItem::Byte(
66                    *b,
67                )])))
68            }
69            FinalNode::Class(true, items) => {
70                Some(OptimizedNode::Byte(Predicate::Incl(items.clone())))
71            }
72            FinalNode::Class(false, items) => {
73                Some(OptimizedNode::Byte(Predicate::Excl(items.clone())))
74            }
75            FinalNode::Seq(final_nodes) => {
76                let mut nodes: Vec<OptimizedNode> = final_nodes
77                    .iter()
78                    .filter_map(OptimizedNode::from_final_node)
79                    .collect();
80                if nodes.is_empty() {
81                    None
82                } else if nodes.len() == 1 {
83                    Some(nodes.pop().unwrap())
84                } else {
85                    Some(OptimizedNode::Seq(nodes))
86                }
87            }
88            FinalNode::Alt(final_nodes) => {
89                let mut nodes: Vec<OptimizedNode> = final_nodes
90                    .iter()
91                    .filter_map(OptimizedNode::from_final_node)
92                    .collect();
93                if nodes.is_empty() {
94                    None
95                } else if nodes.len() == 1 {
96                    Some(nodes.pop().unwrap())
97                } else {
98                    Some(OptimizedNode::Alt(nodes))
99                }
100            }
101            FinalNode::Repeat(inner_final_node, 0, None) => Some(OptimizedNode::Star(Box::new(
102                OptimizedNode::from_final_node(inner_final_node)?,
103            ))),
104            FinalNode::Repeat(inner_final_node, min, None) => {
105                let node = OptimizedNode::from_final_node(inner_final_node)?;
106                let non_capturing_node = node.non_capturing();
107                let mut src_nodes =
108                    core::iter::once(node).chain(core::iter::repeat(non_capturing_node.clone()));
109                let mut nodes = Vec::with_capacity(min + 1);
110                nodes.extend(src_nodes.by_ref().take(*min));
111                nodes.push(OptimizedNode::Star(Box::new(non_capturing_node)));
112                Some(OptimizedNode::Seq(nodes))
113            }
114            FinalNode::Repeat(_node, 0, Some(0)) => None,
115            FinalNode::Repeat(node, 1, Some(1)) => OptimizedNode::from_final_node(node),
116            FinalNode::Repeat(_node, min, Some(max)) if max < min => unreachable!(),
117            FinalNode::Repeat(inner_final_node, min, Some(max)) => {
118                let node = OptimizedNode::from_final_node(inner_final_node)?;
119                let non_capturing_node = node.non_capturing();
120                let mut src_nodes =
121                    core::iter::once(node).chain(core::iter::repeat(non_capturing_node));
122                let mut nodes = Vec::with_capacity(*max);
123                nodes.extend(src_nodes.by_ref().take(*min));
124                nodes.extend(
125                    src_nodes
126                        .map(|node| OptimizedNode::Optional(Box::new(node)))
127                        .take(max - min),
128                );
129                Some(OptimizedNode::Seq(nodes))
130            }
131            FinalNode::Group(inner_final_node) => Some(OptimizedNode::Group(Box::new(
132                OptimizedNode::from_final_node(inner_final_node).expect("found empty group"),
133            ))),
134            FinalNode::NonCapturingGroup(inner_final_node) => {
135                Some(OptimizedNode::from_final_node(inner_final_node)?)
136            }
137        }
138    }
139}
140impl core::fmt::Debug for OptimizedNode {
141    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
142        match self {
143            OptimizedNode::Byte(items) => write!(f, "OptimizedNode::Byte({items:?})"),
144            OptimizedNode::Seq(nodes) => write!(f, "OptimizedNode::Seq{nodes:?}"),
145            OptimizedNode::Alt(nodes) => write!(f, "OptimizedNode::Alt{nodes:?}"),
146            OptimizedNode::Optional(node) => write!(f, "OptimizedNode::Optional({node:?})"),
147            OptimizedNode::Star(node) => write!(f, "OptimizedNode::Star({node:?})"),
148            OptimizedNode::Group(node) => write!(f, "OptimizedNode::Group({node:?})"),
149        }
150    }
151}
152
153#[derive(Clone)]
154struct Counter {
155    n: usize,
156}
157impl Counter {
158    pub fn new() -> Self {
159        Self { n: 0 }
160    }
161    pub fn get(&self) -> usize {
162        self.n
163    }
164    pub fn get_and_increment(&mut self) -> usize {
165        let result = self.n;
166        self.n += 1;
167        result
168    }
169}
170#[cfg(test)]
171#[test]
172fn test_counter() {
173    let mut counter = Counter::new();
174    assert_eq!(0, counter.get());
175    assert_eq!(0, counter.get_and_increment());
176    assert_eq!(1, counter.get());
177    assert_eq!(1, counter.get_and_increment());
178    assert_eq!(2, counter.get());
179    assert_eq!(2, counter.get_and_increment());
180    assert_eq!(3, counter.get());
181}
182
183fn byte_and_prev_var_names(n: usize) -> (Ident, Ident) {
184    (format_ident!("b{}", n), format_ident!("prev_b{}", n))
185}
186
187#[derive(Clone, PartialOrd, PartialEq)]
188enum TaggedNode {
189    Byte(Predicate),
190    Seq(Vec<TaggedNode>),
191    Alt(Vec<TaggedNode>),
192    Optional(Box<TaggedNode>),
193    Star(Box<TaggedNode>),
194    Group(usize, Box<TaggedNode>),
195}
196impl TaggedNode {
197    pub fn from_optimized(group_counter: &mut Counter, source: &OptimizedNode) -> Self {
198        match source {
199            OptimizedNode::Byte(predicate) => TaggedNode::Byte(predicate.clone()),
200            OptimizedNode::Seq(nodes) => TaggedNode::Seq(
201                nodes
202                    .iter()
203                    .map(|node| TaggedNode::from_optimized(group_counter, node))
204                    .collect(),
205            ),
206            OptimizedNode::Alt(nodes) => TaggedNode::Alt(
207                nodes
208                    .iter()
209                    .map(|node| TaggedNode::from_optimized(group_counter, node))
210                    .collect(),
211            ),
212            OptimizedNode::Optional(node) => {
213                TaggedNode::Optional(Box::new(TaggedNode::from_optimized(group_counter, node)))
214            }
215            OptimizedNode::Star(node) => {
216                TaggedNode::Star(Box::new(TaggedNode::from_optimized(group_counter, node)))
217            }
218            OptimizedNode::Group(node) => {
219                let this_group = group_counter.get_and_increment();
220                TaggedNode::Group(
221                    this_group,
222                    Box::new(TaggedNode::from_optimized(group_counter, node)),
223                )
224            }
225        }
226    }
227}
228impl core::fmt::Debug for TaggedNode {
229    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
230        match self {
231            TaggedNode::Byte(predicate) => write!(f, "Byte({predicate:?})"),
232            TaggedNode::Seq(nodes) => write!(f, "Seq({nodes:?})"),
233            TaggedNode::Alt(nodes) => write!(f, "Alt({nodes:?})"),
234            TaggedNode::Optional(node) => write!(f, "Optional({node:?})"),
235            TaggedNode::Star(node) => write!(f, "Star({node:?})"),
236            TaggedNode::Group(group_num, node) => {
237                write!(f, "Group({group_num},{node:?})")
238            }
239        }
240    }
241}
242
243#[allow(clippy::too_many_lines)]
244fn build(
245    var_counter: &mut Counter,
246    num_groups: usize,
247    enclosing_groups: &[usize],
248    statements2_reversed: &mut Vec<TokenStream>,
249    prev_state_expr: &TokenStream,
250    node: &TaggedNode,
251) -> TokenStream {
252    crate::dprintln!("build {:?}", node);
253    let result = match node {
254        TaggedNode::Byte(predicate) => {
255            let var_num = var_counter.get_and_increment();
256            let (var_name, prev_var_name) = byte_and_prev_var_names(var_num);
257            let filter = match predicate {
258                Predicate::Any => quote! {},
259                Predicate::Incl(items) => {
260                    let comparisons = items.iter().map(|p| match p {
261                        ClassItem::Byte(b) => quote! {*b == #b},
262                        ClassItem::ByteRange(x, y) => quote! {(#x ..= #y).contains(b)},
263                    });
264                    quote! { .filter(|_| { #( #comparisons )||* } )  }
265                }
266                Predicate::Excl(items) => {
267                    let comparisons = items.iter().map(|p| match p {
268                        ClassItem::Byte(b) => quote! {*b != #b},
269                        ClassItem::ByteRange(x, y) => quote! {!(#x ..= #y).contains(b)},
270                    });
271                    quote! { .filter(|_| { #( #comparisons )&&* } )  }
272                }
273            };
274            let update_groups = if enclosing_groups.is_empty() {
275                quote! {}
276            } else {
277                let mut range_names = Vec::new();
278                let mut range_values = Vec::new();
279                for r in 0..num_groups {
280                    let range_name = format_ident!("r{}", r);
281                    range_names.push(range_name.clone());
282                    range_values.push(if enclosing_groups.contains(&r) {
283                        quote! { #range_name .start .. n + 1}
284                    } else {
285                        quote! { #range_name }
286                    });
287                }
288                let extra_comma = if num_groups > 1 {
289                    quote! {}
290                } else {
291                    quote! {,}
292                };
293                quote! {
294                    .map(
295                        |( #( #range_names ),* #extra_comma )| ( #( #range_values ),* #extra_comma )
296                    )
297                }
298            };
299            statements2_reversed.push(quote! {
300                #var_name = #prev_state_expr .clone() #filter #update_groups ;
301            });
302            quote! { #prev_var_name }
303        }
304        TaggedNode::Seq(inner_nodes) => {
305            assert!(!inner_nodes.is_empty());
306            let mut last_state_expr = prev_state_expr.clone();
307            for node in inner_nodes {
308                last_state_expr = build(
309                    var_counter,
310                    num_groups,
311                    enclosing_groups,
312                    statements2_reversed,
313                    &last_state_expr,
314                    node,
315                );
316            }
317            last_state_expr
318        }
319        TaggedNode::Alt(inner_nodes) => {
320            assert!(!inner_nodes.is_empty());
321            let mut arm_state_exprs: Vec<TokenStream> = Vec::new();
322            for node in inner_nodes {
323                arm_state_exprs.push(build(
324                    var_counter,
325                    num_groups,
326                    enclosing_groups,
327                    statements2_reversed,
328                    prev_state_expr,
329                    node,
330                ));
331            }
332            quote! { None #( .or_else(|| #arm_state_exprs.clone()) )* }
333        }
334        TaggedNode::Optional(inner) => {
335            let node_state_expr = build(
336                var_counter,
337                num_groups,
338                enclosing_groups,
339                statements2_reversed,
340                prev_state_expr,
341                inner,
342            );
343            quote! { #prev_state_expr .clone() .or_else(|| #node_state_expr .clone()) }
344        }
345        // See safe-regex/tests/machine::seq_in_star .
346        TaggedNode::Star(inner) => {
347            let first_expr = build(
348                &mut var_counter.clone(), // <-- discards
349                num_groups,
350                enclosing_groups,
351                &mut Vec::new(), // <-- discards
352                &quote! { #prev_state_expr },
353                inner,
354            );
355            let expr = build(
356                var_counter,
357                num_groups,
358                enclosing_groups,
359                statements2_reversed,
360                &quote! { #prev_state_expr .clone() .or_else(|| #first_expr .clone()) },
361                inner,
362            );
363            quote! { #prev_state_expr .clone() .or_else(|| #expr .clone()) }
364        }
365        TaggedNode::Group(group_num, inner) => {
366            let inner_enclosing_groups: Vec<usize> = enclosing_groups
367                .iter()
368                .chain(core::iter::once(group_num))
369                .copied()
370                .collect();
371            let inner_prev_state_expr = {
372                let mut range_names = Vec::new();
373                let mut range_values = Vec::new();
374                let extra_comma = if num_groups > 1 {
375                    quote! {}
376                } else {
377                    quote! {,}
378                };
379                for r in 0..num_groups {
380                    let range_name = format_ident!("r{}", r);
381                    range_names.push(range_name.clone());
382                    range_values.push(if &r == group_num {
383                        quote! { n .. n }
384                    } else {
385                        quote! { #range_name }
386                    });
387                }
388                quote! {
389                    #prev_state_expr .clone().map(
390                        |( #( #range_names ),* #extra_comma )| ( #( #range_values ),* #extra_comma )
391                    )
392                }
393            };
394            build(
395                var_counter,
396                num_groups,
397                &inner_enclosing_groups,
398                statements2_reversed,
399                &inner_prev_state_expr,
400                inner,
401            )
402        }
403    };
404    crate::dprintln!("build returning {:?}", result);
405    result
406}
407
408/// Generates an enum that implements `parsed_re` and implements the
409/// [`safe_regex::internal::Machine`](https://docs.rs/safe-regex/latest/safe_regex/internal/trait.Machine.html)
410/// trait.
411#[must_use]
412#[allow(clippy::too_many_lines)]
413pub fn generate(final_node: &FinalNode) -> safe_proc_macro2::TokenStream {
414    let Some(optimized_node) = OptimizedNode::from_final_node(final_node) else {
415        return quote! {
416            safe_regex::Matcher0::new(|data: &[u8]| {
417                if data.is_empty() {
418                    Some(())
419                } else {
420                    None
421                }
422            })
423        };
424    };
425    let mut group_counter = Counter::new();
426    let tagged_node = TaggedNode::from_optimized(&mut group_counter, &optimized_node);
427    let num_groups = group_counter.get();
428    let matcher_type_name = format_ident!("Matcher{}", num_groups);
429    let mut statements2_reversed: Vec<TokenStream> = Vec::new();
430    let mut var_counter = Counter::new();
431    let accept_expr = build(
432        &mut var_counter,
433        num_groups,
434        &Vec::new(),
435        &mut statements2_reversed,
436        &quote! { start },
437        &tagged_node,
438    );
439    let mut var_names: Vec<Ident> = Vec::new();
440    let mut var_clone_statements: Vec<TokenStream> = Vec::new();
441    for n in 0..var_counter.get() {
442        let (var_name, prev_var_name) = byte_and_prev_var_names(n);
443        var_clone_statements.push(quote! {
444            let #prev_var_name = #var_name .clone() ;
445        });
446        var_names.push(var_name);
447    }
448    let statements2 = statements2_reversed.iter().rev();
449    let give_up_stmt = if var_names.len() == 1 {
450        quote! { #( #var_names .as_ref()? )* ; }
451    } else {
452        quote! {
453            if #( #var_names .is_none() )&&* {
454                return None;
455            }
456        }
457    };
458    let result = if num_groups == 0 {
459        quote! {
460            safe_regex::#matcher_type_name::new(|data: &[u8]| {
461                let mut start = Some(());
462                #( let mut #var_names : Option<()> = None; )*
463                let mut data_iter = data.iter();
464                loop {
465                    #( #var_clone_statements )*
466                    if let Some(b) = data_iter.next() {
467                        #( #statements2 )*
468                        start = None;
469                        #give_up_stmt
470                    } else {
471                        return #accept_expr ;
472                    }
473                }
474            })
475        }
476    } else {
477        let default_ranges = core::iter::repeat(quote! { usize::MAX..usize::MAX }).take(num_groups);
478        let extra_comma = if num_groups > 1 {
479            quote! {}
480        } else {
481            quote! {,}
482        };
483        let range_types = core::iter::repeat(quote! { core::ops::Range<usize> }).take(num_groups);
484        let range_type = quote! { Option<( #( #range_types ),* #extra_comma )> };
485        let range_names: Vec<Ident> = (0..num_groups).map(|r| format_ident!("r{}", r)).collect();
486        quote! {
487            safe_regex::#matcher_type_name::new(|data: &[u8]| {
488                assert!(data.len() < usize::MAX - 2);
489                let mut start = Some(( #( #default_ranges ),* #extra_comma ));
490                #( let mut #var_names : #range_type = None; )*
491                let mut accept : #range_type = None;
492                let mut data_iter = data.iter();
493                let mut n = 0;
494                loop {
495                    #( #var_clone_statements )*
496                    accept = #accept_expr .clone() ;
497                    if let Some(b) = data_iter.next() {
498                        #( #statements2 )*
499                        start = None;
500                        #give_up_stmt
501                    } else {
502                        break;
503                    }
504                    n +=1 ;
505                }
506                accept .map(|( #( #range_names ),* #extra_comma )| {
507                    [
508                        #(
509                            if #range_names.start == usize::MAX || #range_names.end == usize::MAX || #range_names.is_empty() {
510                                0..0usize
511                            } else {
512                                #range_names
513                            },
514                        )*
515                    ]
516                })
517            })
518        }
519    };
520    crate::dprintln!("result={}", result);
521    #[allow(clippy::let_and_return)]
522    result
523}