Skip to main content

trampoline_parser/
prefix_factoring.rs

1//! Automatic detection and rewriting of exponential backtracking patterns.
2//!
3//! When `Choice` alternatives share a common prefix containing recursive rules,
4//! parsing becomes O(2^n). This module detects such patterns and rewrites them
5//! to factor out the common prefix, achieving O(n) parsing time.
6//!
7//! ## Example
8//!
9//! ```text
10//! // BAD: O(2^n) - shared prefix '(' datum+ is re-parsed on backtrack
11//! Choice([
12//!     Sequence(['(', datum+, '.', datum, ')']),  // dotted_list
13//!     Sequence(['(', datum+, ')']),              // proper_list
14//! ])
15//!
16//! // GOOD: O(n) - prefix parsed once, suffix is optional
17//! Sequence([
18//!     '(',
19//!     datum+,
20//!     Optional(Sequence(['.', datum])),
21//!     ')',
22//! ])
23//! ```
24
25use crate::ir::{Combinator, InfixOp, PostfixOp, PrattDef, PrefixOp, RuleDef, TernaryOp};
26use crate::validation;
27use std::collections::{HashMap, HashSet};
28
29/// Severity of a backtracking issue
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum BacktrackingSeverity {
32    /// No shared prefix between alternatives
33    None,
34    /// Shared prefix exists but contains no recursion (O(k*n) worst case)
35    Linear,
36    /// Shared prefix contains recursive rules (O(2^n) worst case)
37    Exponential,
38}
39
40/// Result of analyzing a Choice node for common prefixes
41#[derive(Debug)]
42pub struct PrefixAnalysis {
43    /// The shared prefix elements
44    pub prefix: Vec<Combinator>,
45    /// What remains after the prefix for each alternative
46    pub suffixes: Vec<Suffix>,
47    /// How severe the backtracking issue is
48    pub severity: BacktrackingSeverity,
49}
50
51/// What remains of an alternative after factoring out the common prefix
52#[derive(Debug, Clone)]
53pub enum Suffix {
54    /// The alternative was exactly the prefix (nothing remains)
55    Empty,
56    /// A single combinator remains
57    Single(Combinator),
58    /// Multiple combinators remain
59    Sequence(Vec<Combinator>),
60}
61
62/// Warning about a backtracking issue in a grammar
63#[derive(Debug)]
64pub struct BacktrackingWarning {
65    /// Name of the rule containing the problematic choice
66    pub rule_name: String,
67    /// Human-readable description of the issue
68    pub description: String,
69    /// Severity of the issue
70    pub severity: BacktrackingSeverity,
71}
72
73// ============================================================================
74// Combinator Equality
75// ============================================================================
76
77/// Check if two combinators are structurally equal.
78///
79/// This is a deep structural comparison. Two combinators may match the same
80/// language but not be structurally equal (e.g., different orderings in choice).
81pub fn combinators_equal(a: &Combinator, b: &Combinator) -> bool {
82    match (a, b) {
83        // Leaf nodes - direct comparison
84        (Combinator::Literal(s1), Combinator::Literal(s2)) => s1 == s2,
85        (Combinator::Char(c1), Combinator::Char(c2)) => c1 == c2,
86        (Combinator::CharClass(cc1), Combinator::CharClass(cc2)) => cc1 == cc2,
87        (Combinator::CharRange(a1, b1), Combinator::CharRange(a2, b2)) => a1 == a2 && b1 == b2,
88        (Combinator::AnyChar, Combinator::AnyChar) => true,
89        (Combinator::Rule(r1), Combinator::Rule(r2)) => r1 == r2,
90
91        // Recursive nodes - compare children
92        (Combinator::Sequence(items1), Combinator::Sequence(items2)) => {
93            items1.len() == items2.len()
94                && items1
95                    .iter()
96                    .zip(items2)
97                    .all(|(a, b)| combinators_equal(a, b))
98        }
99        (Combinator::Choice(items1), Combinator::Choice(items2)) => {
100            items1.len() == items2.len()
101                && items1
102                    .iter()
103                    .zip(items2)
104                    .all(|(a, b)| combinators_equal(a, b))
105        }
106
107        // Single-child wrappers
108        (Combinator::ZeroOrMore(inner1), Combinator::ZeroOrMore(inner2))
109        | (Combinator::OneOrMore(inner1), Combinator::OneOrMore(inner2))
110        | (Combinator::Optional(inner1), Combinator::Optional(inner2))
111        | (Combinator::Skip(inner1), Combinator::Skip(inner2))
112        | (Combinator::Capture(inner1), Combinator::Capture(inner2))
113        | (Combinator::NotFollowedBy(inner1), Combinator::NotFollowedBy(inner2))
114        | (Combinator::FollowedBy(inner1), Combinator::FollowedBy(inner2)) => {
115            combinators_equal(inner1, inner2)
116        }
117
118        // SeparatedBy
119        (
120            Combinator::SeparatedBy {
121                item: i1,
122                separator: s1,
123                trailing: t1,
124            },
125            Combinator::SeparatedBy {
126                item: i2,
127                separator: s2,
128                trailing: t2,
129            },
130        ) => t1 == t2 && combinators_equal(i1, i2) && combinators_equal(s1, s2),
131
132        // Mapped
133        (
134            Combinator::Mapped {
135                inner: i1,
136                mapping: m1,
137            },
138            Combinator::Mapped {
139                inner: i2,
140                mapping: m2,
141            },
142        ) => m1 == m2 && combinators_equal(i1, i2),
143
144        // Pratt - compare all parts
145        (Combinator::Pratt(p1), Combinator::Pratt(p2)) => pratt_equal(p1, p2),
146
147        // Different variants are not equal
148        _ => false,
149    }
150}
151
152fn pratt_equal(p1: &PrattDef, p2: &PrattDef) -> bool {
153    // Compare operands
154    match (p1.operand.as_ref(), p2.operand.as_ref()) {
155        (Some(o1), Some(o2)) => {
156            if !combinators_equal(o1, o2) {
157                return false;
158            }
159        }
160        (None, None) => {}
161        _ => return false,
162    }
163
164    // Compare prefix ops
165    if p1.prefix_ops.len() != p2.prefix_ops.len() {
166        return false;
167    }
168    for (op1, op2) in p1.prefix_ops.iter().zip(&p2.prefix_ops) {
169        if !prefix_op_equal(op1, op2) {
170            return false;
171        }
172    }
173
174    // Compare infix ops
175    if p1.infix_ops.len() != p2.infix_ops.len() {
176        return false;
177    }
178    for (op1, op2) in p1.infix_ops.iter().zip(&p2.infix_ops) {
179        if !infix_op_equal(op1, op2) {
180            return false;
181        }
182    }
183
184    // Compare postfix ops
185    if p1.postfix_ops.len() != p2.postfix_ops.len() {
186        return false;
187    }
188    for (op1, op2) in p1.postfix_ops.iter().zip(&p2.postfix_ops) {
189        if !postfix_op_equal(op1, op2) {
190            return false;
191        }
192    }
193
194    // Compare ternary
195    match (&p1.ternary, &p2.ternary) {
196        (Some(t1), Some(t2)) => ternary_op_equal(t1, t2),
197        (None, None) => true,
198        _ => false,
199    }
200}
201
202fn prefix_op_equal(op1: &PrefixOp, op2: &PrefixOp) -> bool {
203    op1.precedence == op2.precedence
204        && op1.mapping == op2.mapping
205        && combinators_equal(&op1.pattern, &op2.pattern)
206}
207
208fn infix_op_equal(op1: &InfixOp, op2: &InfixOp) -> bool {
209    op1.precedence == op2.precedence
210        && op1.assoc == op2.assoc
211        && op1.mapping == op2.mapping
212        && combinators_equal(&op1.pattern, &op2.pattern)
213}
214
215fn postfix_op_equal(op1: &PostfixOp, op2: &PostfixOp) -> bool {
216    match (op1, op2) {
217        (
218            PostfixOp::Simple {
219                pattern: p1,
220                precedence: prec1,
221                mapping: m1,
222            },
223            PostfixOp::Simple {
224                pattern: p2,
225                precedence: prec2,
226                mapping: m2,
227            },
228        ) => prec1 == prec2 && m1 == m2 && combinators_equal(p1, p2),
229
230        (
231            PostfixOp::Call {
232                open: o1,
233                close: c1,
234                separator: s1,
235                arg_rule: ar1,
236                precedence: prec1,
237                mapping: m1,
238            },
239            PostfixOp::Call {
240                open: o2,
241                close: c2,
242                separator: s2,
243                arg_rule: ar2,
244                precedence: prec2,
245                mapping: m2,
246            },
247        ) => {
248            prec1 == prec2
249                && m1 == m2
250                && ar1 == ar2
251                && combinators_equal(o1, o2)
252                && combinators_equal(c1, c2)
253                && combinators_equal(s1, s2)
254        }
255
256        (
257            PostfixOp::Index {
258                open: o1,
259                close: c1,
260                precedence: prec1,
261                mapping: m1,
262            },
263            PostfixOp::Index {
264                open: o2,
265                close: c2,
266                precedence: prec2,
267                mapping: m2,
268            },
269        ) => prec1 == prec2 && m1 == m2 && combinators_equal(o1, o2) && combinators_equal(c1, c2),
270
271        (
272            PostfixOp::Member {
273                pattern: p1,
274                precedence: prec1,
275                mapping: m1,
276            },
277            PostfixOp::Member {
278                pattern: p2,
279                precedence: prec2,
280                mapping: m2,
281            },
282        ) => prec1 == prec2 && m1 == m2 && combinators_equal(p1, p2),
283
284        _ => false,
285    }
286}
287
288fn ternary_op_equal(t1: &TernaryOp, t2: &TernaryOp) -> bool {
289    t1.precedence == t2.precedence
290        && t1.mapping == t2.mapping
291        && combinators_equal(&t1.first, &t2.first)
292        && combinators_equal(&t1.second, &t2.second)
293}
294
295// ============================================================================
296// Recursion Detection
297// ============================================================================
298
299/// Check if a combinator tree contains rules that can be recursive.
300///
301/// A rule is considered recursive if following rule references leads back
302/// to a rule we've already seen.
303fn contains_recursion(
304    comb: &Combinator,
305    rule_map: &HashMap<&str, &Combinator>,
306    visited: &mut HashSet<String>,
307) -> bool {
308    match comb {
309        Combinator::Rule(name) => {
310            if visited.contains(name) {
311                return true; // Found cycle = recursion
312            }
313            visited.insert(name.clone());
314            if let Some(rule_comb) = rule_map.get(name.as_str()) {
315                contains_recursion(rule_comb, rule_map, visited)
316            } else {
317                false
318            }
319        }
320        Combinator::Sequence(items) | Combinator::Choice(items) => items
321            .iter()
322            .any(|c| contains_recursion(c, rule_map, visited)),
323        Combinator::ZeroOrMore(inner)
324        | Combinator::OneOrMore(inner)
325        | Combinator::Optional(inner)
326        | Combinator::Skip(inner)
327        | Combinator::Capture(inner)
328        | Combinator::NotFollowedBy(inner)
329        | Combinator::FollowedBy(inner)
330        | Combinator::Mapped { inner, .. }
331        | Combinator::Memoize { inner, .. } => contains_recursion(inner, rule_map, visited),
332        Combinator::SeparatedBy {
333            item, separator, ..
334        } => {
335            contains_recursion(item, rule_map, visited)
336                || contains_recursion(separator, rule_map, visited)
337        }
338        Combinator::Pratt(pratt) => {
339            if let Some(ref operand) = *pratt.operand {
340                if contains_recursion(operand, rule_map, visited) {
341                    return true;
342                }
343            }
344            // Check operator patterns too
345            for op in &pratt.prefix_ops {
346                if contains_recursion(&op.pattern, rule_map, visited) {
347                    return true;
348                }
349            }
350            for op in &pratt.infix_ops {
351                if contains_recursion(&op.pattern, rule_map, visited) {
352                    return true;
353                }
354            }
355            false
356        }
357        // Leaf nodes
358        Combinator::Literal(_)
359        | Combinator::Char(_)
360        | Combinator::CharClass(_)
361        | Combinator::CharRange(_, _)
362        | Combinator::AnyChar => false,
363    }
364}
365
366// ============================================================================
367// Common Prefix Detection
368// ============================================================================
369
370/// A view of a combinator as a sequence of elements.
371/// Single items become 1-element sequences.
372struct SequenceView<'a> {
373    items: Vec<&'a Combinator>,
374}
375
376impl<'a> SequenceView<'a> {
377    fn from_combinator(c: &'a Combinator) -> Self {
378        match c {
379            Combinator::Sequence(items) => SequenceView {
380                items: items.iter().collect(),
381            },
382            other => SequenceView { items: vec![other] },
383        }
384    }
385
386    fn len(&self) -> usize {
387        self.items.len()
388    }
389
390    fn get(&self, i: usize) -> Option<&'a Combinator> {
391        self.items.get(i).copied()
392    }
393}
394
395/// Expand a combinator by resolving rule references to their definitions.
396/// Only expands one level - the top-level combinator.
397fn expand_combinator<'a>(
398    comb: &'a Combinator,
399    rule_map: &'a HashMap<&str, &'a Combinator>,
400) -> &'a Combinator {
401    if let Combinator::Rule(name) = comb {
402        if let Some(expanded) = rule_map.get(name.as_str()) {
403            return expanded;
404        }
405    }
406    comb
407}
408
409/// Find the longest common prefix among Choice alternatives.
410///
411/// Returns a `PrefixAnalysis` with:
412/// - The shared prefix elements
413/// - The suffixes (what remains) for each alternative
414/// - The severity of the backtracking issue
415///
416/// This function expands rule references to find hidden shared prefixes.
417/// For example, if `dotted_list` and `proper_list` both start with `'(' datum+`,
418/// this will detect that shared prefix.
419pub fn find_common_prefix(
420    alternatives: &[Combinator],
421    rule_map: &HashMap<&str, &Combinator>,
422) -> PrefixAnalysis {
423    if alternatives.len() < 2 {
424        return PrefixAnalysis {
425            prefix: vec![],
426            suffixes: alternatives
427                .iter()
428                .map(|c| Suffix::Single(c.clone()))
429                .collect(),
430            severity: BacktrackingSeverity::None,
431        };
432    }
433
434    // Expand rule references to get their actual content
435    let expanded: Vec<&Combinator> = alternatives
436        .iter()
437        .map(|c| expand_combinator(c, rule_map))
438        .collect();
439
440    // Normalize to sequence views
441    let views: Vec<SequenceView> = expanded
442        .iter()
443        .map(|c| SequenceView::from_combinator(c))
444        .collect();
445
446    // Find the minimum length across all alternatives
447    let min_len = views.iter().map(|v| v.len()).min().unwrap_or(0);
448
449    // Find LCP length by comparing element-by-element
450    let mut prefix_len = 0;
451    for i in 0..min_len {
452        let first = match views[0].get(i) {
453            Some(c) => c,
454            None => break,
455        };
456
457        let all_equal = views.iter().skip(1).all(|v| {
458            v.get(i)
459                .map(|c| combinators_equal(first, c))
460                .unwrap_or(false)
461        });
462
463        if all_equal {
464            prefix_len = i + 1;
465        } else {
466            break;
467        }
468    }
469
470    if prefix_len == 0 {
471        return PrefixAnalysis {
472            prefix: vec![],
473            suffixes: alternatives
474                .iter()
475                .map(|c| Suffix::Single(c.clone()))
476                .collect(),
477            severity: BacktrackingSeverity::None,
478        };
479    }
480
481    // Extract prefix
482    let prefix: Vec<Combinator> = (0..prefix_len)
483        .filter_map(|i| views[0].get(i).cloned())
484        .collect();
485
486    // Extract suffixes
487    let suffixes: Vec<Suffix> = views
488        .iter()
489        .map(|v| {
490            let remaining: Vec<Combinator> = (prefix_len..v.len())
491                .filter_map(|i| v.get(i).cloned())
492                .collect();
493            match remaining.len() {
494                0 => Suffix::Empty,
495                1 => Suffix::Single(remaining.into_iter().next().unwrap_or_else(|| {
496                    // This shouldn't happen due to the length check
497                    Combinator::Literal(String::new())
498                })),
499                _ => Suffix::Sequence(remaining),
500            }
501        })
502        .collect();
503
504    // Check if prefix consumes input (is non-nullable)
505    let prefix_as_seq = Combinator::Sequence(prefix.clone());
506    let mut visited = HashSet::new();
507    let prefix_is_nullable = validation::is_nullable(&prefix_as_seq, rule_map, &mut visited);
508
509    // Check if prefix contains recursive rules
510    let mut recursion_visited = HashSet::new();
511    let prefix_has_recursion = prefix
512        .iter()
513        .any(|c| contains_recursion(c, rule_map, &mut recursion_visited));
514
515    // Determine severity
516    let severity = if prefix_is_nullable {
517        // Nullable prefix doesn't cause backtracking issues
518        BacktrackingSeverity::None
519    } else if prefix_has_recursion {
520        BacktrackingSeverity::Exponential
521    } else {
522        BacktrackingSeverity::Linear
523    };
524
525    PrefixAnalysis {
526        prefix,
527        suffixes,
528        severity,
529    }
530}
531
532// ============================================================================
533// Transformation
534// ============================================================================
535
536/// Transform a Choice with common prefix into factored form.
537///
538/// Before: `Choice([Seq([A, B, C]), Seq([A, B, D])])`
539/// After:  `Seq([A, B, Choice([C, D])])`
540///
541/// If one suffix is empty, wraps the tail choice in `Optional`.
542pub fn factor_common_prefix(analysis: &PrefixAnalysis) -> Option<Combinator> {
543    if analysis.severity == BacktrackingSeverity::None || analysis.prefix.is_empty() {
544        return None; // Nothing to factor
545    }
546
547    let prefix = &analysis.prefix;
548    let suffixes = &analysis.suffixes;
549
550    // Build the suffix alternatives (non-empty ones)
551    let suffix_alternatives: Vec<Combinator> = suffixes
552        .iter()
553        .filter_map(|s| match s {
554            Suffix::Empty => None,
555            Suffix::Single(c) => Some(c.clone()),
556            Suffix::Sequence(items) => Some(Combinator::Sequence(items.clone())),
557        })
558        .collect();
559
560    let has_empty_suffix = suffixes.iter().any(|s| matches!(s, Suffix::Empty));
561
562    // Build the tail (what comes after the prefix)
563    let tail = if suffix_alternatives.is_empty() {
564        // All suffixes were empty - just return the prefix
565        None
566    } else if suffix_alternatives.len() == 1 && has_empty_suffix {
567        // One real suffix + empty = optional(suffix)
568        Some(Combinator::Optional(Box::new(
569            suffix_alternatives.into_iter().next()?,
570        )))
571    } else if has_empty_suffix {
572        // Multiple suffixes + empty = optional(choice(suffixes))
573        Some(Combinator::Optional(Box::new(Combinator::Choice(
574            suffix_alternatives,
575        ))))
576    } else {
577        // No empty suffix = choice(suffixes)
578        Some(Combinator::Choice(suffix_alternatives))
579    };
580
581    // Combine prefix and tail
582    let mut result_items = prefix.clone();
583    if let Some(tail) = tail {
584        result_items.push(tail);
585    }
586
587    // Return as sequence, or single item if only one element
588    Some(if result_items.len() == 1 {
589        result_items.into_iter().next()?
590    } else {
591        Combinator::Sequence(result_items)
592    })
593}
594
595// ============================================================================
596// Grammar Optimization
597// ============================================================================
598
599/// Recursively optimize all Choice nodes in a combinator tree.
600///
601/// Only transforms choices with `Exponential` severity.
602pub fn optimize_combinator(comb: &Combinator, rule_map: &HashMap<&str, &Combinator>) -> Combinator {
603    match comb {
604        Combinator::Choice(alternatives) => {
605            // First optimize children
606            let optimized_alts: Vec<Combinator> = alternatives
607                .iter()
608                .map(|c| optimize_combinator(c, rule_map))
609                .collect();
610
611            // Then check for common prefix
612            let analysis = find_common_prefix(&optimized_alts, rule_map);
613
614            if analysis.severity == BacktrackingSeverity::Exponential {
615                if let Some(factored) = factor_common_prefix(&analysis) {
616                    return factored;
617                }
618            }
619
620            Combinator::Choice(optimized_alts)
621        }
622        Combinator::Sequence(items) => Combinator::Sequence(
623            items
624                .iter()
625                .map(|c| optimize_combinator(c, rule_map))
626                .collect(),
627        ),
628        Combinator::ZeroOrMore(inner) => {
629            Combinator::ZeroOrMore(Box::new(optimize_combinator(inner, rule_map)))
630        }
631        Combinator::OneOrMore(inner) => {
632            Combinator::OneOrMore(Box::new(optimize_combinator(inner, rule_map)))
633        }
634        Combinator::Optional(inner) => {
635            Combinator::Optional(Box::new(optimize_combinator(inner, rule_map)))
636        }
637        Combinator::Skip(inner) => Combinator::Skip(Box::new(optimize_combinator(inner, rule_map))),
638        Combinator::Capture(inner) => {
639            Combinator::Capture(Box::new(optimize_combinator(inner, rule_map)))
640        }
641        Combinator::NotFollowedBy(inner) => {
642            Combinator::NotFollowedBy(Box::new(optimize_combinator(inner, rule_map)))
643        }
644        Combinator::FollowedBy(inner) => {
645            Combinator::FollowedBy(Box::new(optimize_combinator(inner, rule_map)))
646        }
647        Combinator::Mapped { inner, mapping } => Combinator::Mapped {
648            inner: Box::new(optimize_combinator(inner, rule_map)),
649            mapping: mapping.clone(),
650        },
651        Combinator::Memoize { inner, id } => Combinator::Memoize {
652            inner: Box::new(optimize_combinator(inner, rule_map)),
653            id: *id,
654        },
655        Combinator::SeparatedBy {
656            item,
657            separator,
658            trailing,
659        } => Combinator::SeparatedBy {
660            item: Box::new(optimize_combinator(item, rule_map)),
661            separator: Box::new(optimize_combinator(separator, rule_map)),
662            trailing: *trailing,
663        },
664        Combinator::Pratt(pratt) => {
665            // Optimize operand if present
666            let optimized_operand = pratt
667                .operand
668                .as_ref()
669                .as_ref()
670                .map(|o| optimize_combinator(o, rule_map));
671            Combinator::Pratt(PrattDef {
672                operand: Box::new(optimized_operand),
673                prefix_ops: pratt.prefix_ops.clone(),
674                infix_ops: pratt.infix_ops.clone(),
675                postfix_ops: pratt.postfix_ops.clone(),
676                ternary: pratt.ternary.clone(),
677            })
678        }
679        // Leaf nodes - return unchanged
680        Combinator::Rule(_)
681        | Combinator::Literal(_)
682        | Combinator::Char(_)
683        | Combinator::CharClass(_)
684        | Combinator::CharRange(_, _)
685        | Combinator::AnyChar => comb.clone(),
686    }
687}
688
689// ============================================================================
690// Grammar Analysis
691// ============================================================================
692
693/// Analyze a grammar for backtracking issues.
694///
695/// Returns warnings for each problematic Choice node found.
696pub fn analyze_grammar(rules: &[RuleDef]) -> Vec<BacktrackingWarning> {
697    let rule_map: HashMap<&str, &Combinator> = rules
698        .iter()
699        .map(|r| (r.name.as_str(), &r.combinator))
700        .collect();
701
702    let mut warnings = Vec::new();
703
704    for rule in rules {
705        analyze_combinator_for_backtracking(&rule.name, &rule.combinator, &rule_map, &mut warnings);
706    }
707
708    warnings
709}
710
711fn analyze_combinator_for_backtracking(
712    rule_name: &str,
713    comb: &Combinator,
714    rule_map: &HashMap<&str, &Combinator>,
715    warnings: &mut Vec<BacktrackingWarning>,
716) {
717    match comb {
718        Combinator::Choice(alternatives) => {
719            let analysis = find_common_prefix(alternatives, rule_map);
720
721            if analysis.severity == BacktrackingSeverity::Exponential {
722                warnings.push(BacktrackingWarning {
723                    rule_name: rule_name.to_string(),
724                    description: format!(
725                        "Choice with {} alternatives shares a prefix of {} elements containing recursive rules. \
726                         This causes O(2^n) parsing time. Consider factoring out the common prefix.",
727                        alternatives.len(),
728                        analysis.prefix.len()
729                    ),
730                    severity: analysis.severity,
731                });
732            }
733
734            // Recurse into alternatives
735            for alt in alternatives {
736                analyze_combinator_for_backtracking(rule_name, alt, rule_map, warnings);
737            }
738        }
739        Combinator::Sequence(items) => {
740            for item in items {
741                analyze_combinator_for_backtracking(rule_name, item, rule_map, warnings);
742            }
743        }
744        Combinator::ZeroOrMore(inner)
745        | Combinator::OneOrMore(inner)
746        | Combinator::Optional(inner)
747        | Combinator::Skip(inner)
748        | Combinator::Capture(inner)
749        | Combinator::NotFollowedBy(inner)
750        | Combinator::FollowedBy(inner)
751        | Combinator::Mapped { inner, .. }
752        | Combinator::Memoize { inner, .. } => {
753            analyze_combinator_for_backtracking(rule_name, inner, rule_map, warnings);
754        }
755        Combinator::SeparatedBy {
756            item, separator, ..
757        } => {
758            analyze_combinator_for_backtracking(rule_name, item, rule_map, warnings);
759            analyze_combinator_for_backtracking(rule_name, separator, rule_map, warnings);
760        }
761        Combinator::Pratt(pratt) => {
762            if let Some(ref operand) = *pratt.operand {
763                analyze_combinator_for_backtracking(rule_name, operand, rule_map, warnings);
764            }
765        }
766        // Leaf nodes - nothing to analyze
767        Combinator::Rule(_)
768        | Combinator::Literal(_)
769        | Combinator::Char(_)
770        | Combinator::CharClass(_)
771        | Combinator::CharRange(_, _)
772        | Combinator::AnyChar => {}
773    }
774}
775
776// ============================================================================
777// Memoization Candidate Detection
778// ============================================================================
779
780/// Identify rules that should be memoized to avoid exponential backtracking.
781///
782/// This function analyzes the grammar to find rules that:
783/// 1. Appear at the start of Choice alternatives that share a common prefix
784/// 2. Contain recursion (either directly or through rule references)
785/// 3. Would cause exponential backtracking without memoization
786///
787/// Returns a set of rule names that should be wrapped with `.memoize()`.
788pub fn identify_memoization_candidates(rules: &[RuleDef]) -> HashSet<String> {
789    let rule_map: HashMap<&str, &Combinator> = rules
790        .iter()
791        .map(|r| (r.name.as_str(), &r.combinator))
792        .collect();
793
794    let mut candidates = HashSet::new();
795
796    for rule in rules {
797        find_memoization_candidates_in_combinator(&rule.combinator, &rule_map, &mut candidates);
798    }
799
800    candidates
801}
802
803/// Recursively search a combinator for memoization candidates.
804fn find_memoization_candidates_in_combinator(
805    comb: &Combinator,
806    rule_map: &HashMap<&str, &Combinator>,
807    candidates: &mut HashSet<String>,
808) {
809    match comb {
810        Combinator::Choice(alternatives) => {
811            // Check if this choice has exponential backtracking potential
812            let analysis = find_common_prefix(alternatives, rule_map);
813
814            if analysis.severity == BacktrackingSeverity::Exponential {
815                // Find the first rule reference in the common prefix
816                // That rule (and rules it calls) are memoization candidates
817                for prefix_elem in &analysis.prefix {
818                    collect_rule_references(prefix_elem, candidates);
819                }
820            }
821
822            // Also check for patterns where one alternative starts with another
823            // E.g., Choice([generic_call, identifier]) where generic_call starts with identifier
824            find_overlapping_rule_starts(alternatives, rule_map, candidates);
825
826            // Recurse into alternatives
827            for alt in alternatives {
828                find_memoization_candidates_in_combinator(alt, rule_map, candidates);
829            }
830        }
831        Combinator::Sequence(items) => {
832            for item in items {
833                find_memoization_candidates_in_combinator(item, rule_map, candidates);
834            }
835        }
836        Combinator::ZeroOrMore(inner)
837        | Combinator::OneOrMore(inner)
838        | Combinator::Optional(inner)
839        | Combinator::Skip(inner)
840        | Combinator::Capture(inner)
841        | Combinator::NotFollowedBy(inner)
842        | Combinator::FollowedBy(inner)
843        | Combinator::Mapped { inner, .. }
844        | Combinator::Memoize { inner, .. } => {
845            find_memoization_candidates_in_combinator(inner, rule_map, candidates);
846        }
847        Combinator::SeparatedBy {
848            item, separator, ..
849        } => {
850            find_memoization_candidates_in_combinator(item, rule_map, candidates);
851            find_memoization_candidates_in_combinator(separator, rule_map, candidates);
852        }
853        Combinator::Pratt(pratt) => {
854            if let Some(ref operand) = *pratt.operand {
855                find_memoization_candidates_in_combinator(operand, rule_map, candidates);
856            }
857        }
858        // Leaf nodes - nothing to recurse into
859        Combinator::Rule(_)
860        | Combinator::Literal(_)
861        | Combinator::Char(_)
862        | Combinator::CharClass(_)
863        | Combinator::CharRange(_, _)
864        | Combinator::AnyChar => {}
865    }
866}
867
868/// Find cases where one alternative starts with a rule that another alternative would match.
869///
870/// E.g., Choice([generic_call, identifier]) where generic_call starts with identifier.
871/// In this case, generic_call should be memoized because after matching identifier
872/// and failing to complete generic_call, we'll backtrack and try identifier again.
873fn find_overlapping_rule_starts(
874    alternatives: &[Combinator],
875    rule_map: &HashMap<&str, &Combinator>,
876    candidates: &mut HashSet<String>,
877) {
878    // Collect the first rule reference from each alternative
879    let mut first_rules: Vec<(usize, &str)> = Vec::new();
880    for (idx, alt) in alternatives.iter().enumerate() {
881        if let Some(first_rule) = get_first_rule(alt, rule_map) {
882            first_rules.push((idx, first_rule));
883        }
884    }
885
886    // For each alternative that's a rule reference, check if its expansion
887    // starts with any other alternative's first rule
888    for alt in alternatives {
889        if let Combinator::Rule(rule_name) = alt {
890            if let Some(rule_def) = rule_map.get(rule_name.as_str()) {
891                // Get what this rule starts with
892                if let Some(starts_with) = get_first_rule(rule_def, rule_map) {
893                    // Check if any other alternative is this same rule
894                    for (_, other_first) in &first_rules {
895                        if *other_first == starts_with && *other_first != rule_name.as_str() {
896                            // This rule starts with something another alternative matches
897                            // Mark this rule as a memoization candidate
898                            candidates.insert(rule_name.clone());
899                        }
900                    }
901                }
902            }
903        }
904    }
905}
906
907/// Get the first rule reference at the start of a combinator.
908fn get_first_rule<'a>(
909    comb: &'a Combinator,
910    _rule_map: &HashMap<&str, &'a Combinator>,
911) -> Option<&'a str> {
912    match comb {
913        Combinator::Rule(name) => Some(name.as_str()),
914        Combinator::Sequence(items) if !items.is_empty() => get_first_rule(&items[0], _rule_map),
915        Combinator::Optional(inner) => get_first_rule(inner, _rule_map),
916        Combinator::Skip(inner) => get_first_rule(inner, _rule_map),
917        Combinator::Capture(inner) => get_first_rule(inner, _rule_map),
918        Combinator::Mapped { inner, .. } => get_first_rule(inner, _rule_map),
919        Combinator::Memoize { inner, .. } => get_first_rule(inner, _rule_map),
920        _ => None,
921    }
922}
923
924/// Collect all rule references from a combinator.
925fn collect_rule_references(comb: &Combinator, rules: &mut HashSet<String>) {
926    match comb {
927        Combinator::Rule(name) => {
928            rules.insert(name.clone());
929        }
930        Combinator::Sequence(items) | Combinator::Choice(items) => {
931            for item in items {
932                collect_rule_references(item, rules);
933            }
934        }
935        Combinator::ZeroOrMore(inner)
936        | Combinator::OneOrMore(inner)
937        | Combinator::Optional(inner)
938        | Combinator::Skip(inner)
939        | Combinator::Capture(inner)
940        | Combinator::NotFollowedBy(inner)
941        | Combinator::FollowedBy(inner)
942        | Combinator::Mapped { inner, .. }
943        | Combinator::Memoize { inner, .. } => {
944            collect_rule_references(inner, rules);
945        }
946        Combinator::SeparatedBy {
947            item, separator, ..
948        } => {
949            collect_rule_references(item, rules);
950            collect_rule_references(separator, rules);
951        }
952        Combinator::Pratt(pratt) => {
953            if let Some(ref operand) = *pratt.operand {
954                collect_rule_references(operand, rules);
955            }
956            for op in &pratt.prefix_ops {
957                collect_rule_references(&op.pattern, rules);
958            }
959            for op in &pratt.infix_ops {
960                collect_rule_references(&op.pattern, rules);
961            }
962        }
963        // Leaf nodes
964        Combinator::Literal(_)
965        | Combinator::Char(_)
966        | Combinator::CharClass(_)
967        | Combinator::CharRange(_, _)
968        | Combinator::AnyChar => {}
969    }
970}
971
972#[cfg(test)]
973mod tests {
974    use super::*;
975
976    #[test]
977    fn test_combinators_equal_literals() {
978        assert!(combinators_equal(
979            &Combinator::Literal("foo".to_string()),
980            &Combinator::Literal("foo".to_string())
981        ));
982        assert!(!combinators_equal(
983            &Combinator::Literal("foo".to_string()),
984            &Combinator::Literal("bar".to_string())
985        ));
986    }
987
988    #[test]
989    fn test_combinators_equal_sequences() {
990        let seq1 = Combinator::Sequence(vec![
991            Combinator::Char('('),
992            Combinator::Rule("datum".to_string()),
993        ]);
994        let seq2 = Combinator::Sequence(vec![
995            Combinator::Char('('),
996            Combinator::Rule("datum".to_string()),
997        ]);
998        let seq3 = Combinator::Sequence(vec![
999            Combinator::Char('('),
1000            Combinator::Rule("other".to_string()),
1001        ]);
1002
1003        assert!(combinators_equal(&seq1, &seq2));
1004        assert!(!combinators_equal(&seq1, &seq3));
1005    }
1006
1007    #[test]
1008    fn test_find_common_prefix_simple() {
1009        let rule_map = HashMap::new();
1010
1011        let alternatives = vec![
1012            Combinator::Sequence(vec![
1013                Combinator::Char('('),
1014                Combinator::Char('a'),
1015                Combinator::Char('.'),
1016            ]),
1017            Combinator::Sequence(vec![
1018                Combinator::Char('('),
1019                Combinator::Char('a'),
1020                Combinator::Char(')'),
1021            ]),
1022        ];
1023
1024        let analysis = find_common_prefix(&alternatives, &rule_map);
1025
1026        assert_eq!(analysis.prefix.len(), 2);
1027        assert!(combinators_equal(
1028            &analysis.prefix[0],
1029            &Combinator::Char('(')
1030        ));
1031        assert!(combinators_equal(
1032            &analysis.prefix[1],
1033            &Combinator::Char('a')
1034        ));
1035        assert_eq!(analysis.suffixes.len(), 2);
1036    }
1037
1038    #[test]
1039    fn test_factor_common_prefix() {
1040        let analysis = PrefixAnalysis {
1041            prefix: vec![Combinator::Char('('), Combinator::Char('a')],
1042            suffixes: vec![
1043                Suffix::Single(Combinator::Char('.')),
1044                Suffix::Single(Combinator::Char(')')),
1045            ],
1046            severity: BacktrackingSeverity::Exponential,
1047        };
1048
1049        let factored = factor_common_prefix(&analysis).unwrap();
1050
1051        // Should be: Sequence(['(', 'a', Choice(['.', ')'])])
1052        if let Combinator::Sequence(items) = factored {
1053            assert_eq!(items.len(), 3);
1054            assert!(combinators_equal(&items[0], &Combinator::Char('(')));
1055            assert!(combinators_equal(&items[1], &Combinator::Char('a')));
1056            if let Combinator::Choice(alts) = &items[2] {
1057                assert_eq!(alts.len(), 2);
1058            } else {
1059                panic!("Expected Choice");
1060            }
1061        } else {
1062            panic!("Expected Sequence");
1063        }
1064    }
1065
1066    #[test]
1067    fn test_factor_with_empty_suffix() {
1068        let analysis = PrefixAnalysis {
1069            prefix: vec![Combinator::Char('('), Combinator::Char('a')],
1070            suffixes: vec![
1071                Suffix::Single(Combinator::Char('.')),
1072                Suffix::Empty, // One alternative is just the prefix
1073            ],
1074            severity: BacktrackingSeverity::Exponential,
1075        };
1076
1077        let factored = factor_common_prefix(&analysis).unwrap();
1078
1079        // Should be: Sequence(['(', 'a', Optional('.')])
1080        if let Combinator::Sequence(items) = factored {
1081            assert_eq!(items.len(), 3);
1082            if let Combinator::Optional(_) = &items[2] {
1083                // Good - wrapped in Optional
1084            } else {
1085                panic!("Expected Optional for empty suffix case");
1086            }
1087        } else {
1088            panic!("Expected Sequence");
1089        }
1090    }
1091
1092    #[test]
1093    fn test_identify_memoization_candidates() {
1094        // Create a grammar with potential exponential backtracking
1095        // Similar to the generic_call vs identifier pattern
1096        let rules = vec![
1097            RuleDef {
1098                name: "primary_inner".to_string(),
1099                combinator: Combinator::Choice(vec![
1100                    Combinator::Rule("generic_call".to_string()),
1101                    Combinator::Rule("identifier".to_string()),
1102                ]),
1103            },
1104            RuleDef {
1105                name: "generic_call".to_string(),
1106                combinator: Combinator::Sequence(vec![
1107                    Combinator::Rule("identifier".to_string()),
1108                    Combinator::Rule("type_arguments".to_string()),
1109                    Combinator::Literal("(".to_string()),
1110                    Combinator::Literal(")".to_string()),
1111                ]),
1112            },
1113            RuleDef {
1114                name: "type_arguments".to_string(),
1115                combinator: Combinator::Sequence(vec![
1116                    Combinator::Literal("<".to_string()),
1117                    Combinator::Rule("type".to_string()),
1118                    Combinator::Literal(">".to_string()),
1119                ]),
1120            },
1121            RuleDef {
1122                name: "type".to_string(),
1123                combinator: Combinator::Choice(vec![
1124                    Combinator::Rule("type_reference".to_string()),
1125                    Combinator::Rule("identifier".to_string()),
1126                ]),
1127            },
1128            RuleDef {
1129                name: "type_reference".to_string(),
1130                combinator: Combinator::Sequence(vec![
1131                    Combinator::Rule("identifier".to_string()),
1132                    Combinator::Optional(Box::new(Combinator::Rule("type_arguments".to_string()))),
1133                ]),
1134            },
1135            RuleDef {
1136                name: "identifier".to_string(),
1137                combinator: Combinator::Capture(Box::new(Combinator::OneOrMore(Box::new(
1138                    Combinator::CharClass(crate::ir::CharClass::Alpha),
1139                )))),
1140            },
1141        ];
1142
1143        let candidates = identify_memoization_candidates(&rules);
1144
1145        // The choice in primary_inner has generic_call and identifier as alternatives.
1146        // generic_call contains identifier as its first element, so there's a shared prefix.
1147        // generic_call should be identified as a memoization candidate.
1148        assert!(
1149            candidates.contains("generic_call"),
1150            "Should identify generic_call as memoization candidate, got: {:?}",
1151            candidates
1152        );
1153    }
1154}