Skip to main content

texform_transform/rewrite/
rule_context.rs

1//! Rule context and typed node views for rule matching.
2//!
3//! This module provides [`RuleContext`], the runtime context object passed to
4//! [`RewriteRule::apply()`](super::rule::RewriteRule::apply) during AST
5//! transformation. It bundles mutable AST access with knowledge-base lookups,
6//! validation helpers, and statistics tracking.
7//!
8//! It also defines a family of read-only *view* structs ([`CommandView`],
9//! [`InfixView`], [`DeclarativeView`], [`EnvironmentView`]) that the
10//! `match_*` helpers extract from AST nodes. Rules operate on these views
11//! instead of pattern-matching raw [`Node`] variants directly, which keeps
12//! rule implementations concise and type-safe.
13
14use std::ops::Deref;
15
16use crate::ast::{ArgumentKind, ArgumentSlot, ArgumentValue, Ast, Node, NodeId};
17use crate::knowledge::{KnowledgeBase, lookup_command_node_name, lookup_environment_node_name};
18use crate::parse::ContentMode;
19use crate::rewrite::rule::RuleKey;
20use crate::rewrite::{RewriteReport, RuleError};
21use texform_knowledge::specs::{
22    ActiveCharacterRecord, ActiveCommandRecord, ActiveEnvironmentRecord, BuiltinCommandRecord,
23    BuiltinEnvironmentRecord,
24};
25
26/// A read-only view of a prefix command node for use in rule matching.
27#[derive(Clone, Copy)]
28pub struct CommandView<'a> {
29    /// The command name without the leading backslash.
30    pub name: &'a str,
31    /// The explicit argument slots parsed for this command.
32    pub args: &'a [ArgumentSlot],
33}
34
35impl CommandView<'_> {
36    /// Returns the command subject used in transform diagnostics, such as `\frac`.
37    pub fn subject(&self) -> String {
38        format!(r"\{}", self.name)
39    }
40}
41
42/// A read-only view of an infix command node for use in rule matching.
43#[derive(Clone, Copy)]
44pub struct InfixView<'a> {
45    /// The command name without the leading backslash.
46    pub name: &'a str,
47    /// The explicit argument slots parsed for this command.
48    pub args: &'a [ArgumentSlot],
49    /// The left operand subtree collected by the parser.
50    pub left: NodeId,
51    /// The right operand subtree collected by the parser.
52    pub right: NodeId,
53}
54
55impl InfixView<'_> {
56    /// Returns the infix command subject used in transform diagnostics, such as `\over`.
57    pub fn subject(&self) -> String {
58        format!(r"\{}", self.name)
59    }
60}
61
62/// A read-only view of a declarative command node for use in rule matching.
63#[derive(Clone, Copy)]
64pub struct DeclarativeView<'a> {
65    /// The command name without the leading backslash.
66    pub name: &'a str,
67    /// The explicit argument slots parsed for this command.
68    pub args: &'a [ArgumentSlot],
69}
70
71/// A read-only view of an environment node for use in rule matching.
72#[derive(Clone, Copy)]
73pub struct EnvironmentView<'a> {
74    /// The environment name (as it appears between `\begin{…}` and `\end{…}`).
75    pub name: &'a str,
76    /// The explicit argument slots parsed for this environment.
77    pub args: &'a [ArgumentSlot],
78    /// The body subtree between `\begin` and `\end`.
79    pub body: NodeId,
80}
81
82/// The runtime context object passed to [`RewriteRule::apply()`](super::rule::RewriteRule::apply).
83///
84/// It bundles mutable AST access with knowledge-base lookups, node-shape
85/// validation helpers, and statistics tracking. Rules receive a mutable
86/// reference to this context and use it both to inspect the current tree
87/// and to record replacement nodes.
88///
89/// `ast` is intentionally public because many transforms need unrestricted
90/// structural mutation, not just a narrow helper surface. The tradeoff is that
91/// rules can also violate AST invariants if they misuse low-level operations,
92/// so debug builds re-run [`Ast::assert_invariants()`](crate::ast::Ast::assert_invariants)
93/// after every successful rewrite. Knowledge-base access, transform-context
94/// queries, and report mutation stay mediated through methods because those interactions are
95/// semantic rather than structural.
96pub struct RuleContext<'a> {
97    /// Mutable access to the AST being transformed.
98    ///
99    /// This field stays public so rules can perform bespoke tree surgery when
100    /// helper functions are not expressive enough.
101    pub ast: &'a mut Ast,
102    math_kb: &'a KnowledgeBase,
103    text_kb: &'a KnowledgeBase,
104    report: &'a mut RewriteReport,
105}
106
107/// A read-only scoped context bound to a rule key for diagnostics and slot extraction.
108pub struct RuleScopedContext<'cx, 'ctx> {
109    cx: &'cx RuleContext<'ctx>,
110    rule: RuleKey,
111}
112
113impl<'cx, 'ctx> Deref for RuleScopedContext<'cx, 'ctx> {
114    type Target = RuleContext<'ctx>;
115
116    fn deref(&self) -> &Self::Target {
117        self.cx
118    }
119}
120
121impl RuleScopedContext<'_, '_> {
122    /// Creates an [`InvalidNodeShape`](RuleError::InvalidNodeShape) error for the bound rule.
123    pub fn invalid_shape(&self, message: impl Into<String>) -> RuleError {
124        self.cx.invalid_shape(self.rule, message)
125    }
126
127    /// Creates a [`MissingMetadata`](RuleError::MissingMetadata) error for the bound rule.
128    pub fn missing_metadata(&self, name: impl Into<String>) -> RuleError {
129        self.cx.missing_metadata(self.rule, name)
130    }
131
132    /// Returns `Ok(())` when `condition` is true, or an invalid-shape error otherwise.
133    pub fn ensure_shape(
134        &self,
135        condition: bool,
136        message: impl Into<String>,
137    ) -> Result<(), RuleError> {
138        self.cx.ensure_shape(condition, self.rule, message)
139    }
140
141    /// Asserts that `args` has exactly `expected` slots, returning an error that names `subject` on mismatch.
142    pub fn expect_arg_len(
143        &self,
144        args: &[ArgumentSlot],
145        expected: usize,
146        subject: &str,
147    ) -> Result<(), RuleError> {
148        self.cx.expect_arg_len(self.rule, args, expected, subject)
149    }
150
151    /// Shorthand for [`expect_arg_len`](Self::expect_arg_len) with `expected = 0`.
152    pub fn expect_no_args(&self, args: &[ArgumentSlot], subject: &str) -> Result<(), RuleError> {
153        self.cx.expect_no_args(self.rule, args, subject)
154    }
155
156    /// Extracts a boolean star argument from a parsed star slot.
157    pub fn star_arg_value(&self, slot: &ArgumentSlot, subject: &str) -> Result<bool, RuleError> {
158        match slot {
159            Some(arg) if arg.kind == ArgumentKind::Star => match arg.value {
160                ArgumentValue::Boolean(value) => Ok(value),
161                _ => {
162                    Err(self
163                        .invalid_shape(format!("{subject} star slot should carry a boolean value")))
164                }
165            },
166            _ => Err(self.invalid_shape(format!("{subject} should carry a star slot"))),
167        }
168    }
169
170    /// Extracts an optional math-content argument.
171    pub fn optional_math_content(
172        &self,
173        slot: &ArgumentSlot,
174        subject: &str,
175        label: &str,
176    ) -> Result<Option<NodeId>, RuleError> {
177        match slot {
178            None => Ok(None),
179            Some(arg) if arg.kind == ArgumentKind::Optional => match arg.value {
180                ArgumentValue::MathContent(node_id) => Ok(Some(node_id)),
181                _ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
182            },
183            _ => Err(self.invalid_shape(format!(
184                "{subject} {label} should be an optional math argument"
185            ))),
186        }
187    }
188
189    /// Extracts an optional braced-group math-content argument.
190    pub fn optional_group_math_content(
191        &self,
192        slot: &ArgumentSlot,
193        subject: &str,
194        label: &str,
195    ) -> Result<Option<NodeId>, RuleError> {
196        match slot {
197            None => Ok(None),
198            Some(arg) if arg.kind == ArgumentKind::Group => match arg.value {
199                ArgumentValue::MathContent(node_id) => Ok(Some(node_id)),
200                _ => Err(self
201                    .invalid_shape(format!("{subject} optional {label} should be math content"))),
202            },
203            _ => Err(self.invalid_shape(format!(
204                "{subject} optional {label} should be a braced group"
205            ))),
206        }
207    }
208
209    /// Extracts a mandatory math-content argument.
210    pub fn mandatory_math_content(
211        &self,
212        slot: &ArgumentSlot,
213        subject: &str,
214        label: &str,
215    ) -> Result<NodeId, RuleError> {
216        match slot {
217            Some(arg) if arg.kind == ArgumentKind::Mandatory => match arg.value {
218                ArgumentValue::MathContent(node_id) => Ok(node_id),
219                _ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
220            },
221            _ => Err(self.invalid_shape(format!(
222                "{subject} {label} should be a mandatory math argument"
223            ))),
224        }
225    }
226
227    /// Extracts a math-content argument that may be either mandatory or a braced group.
228    pub fn mandatory_or_group_math_content(
229        &self,
230        slot: &ArgumentSlot,
231        subject: &str,
232        label: &str,
233    ) -> Result<NodeId, RuleError> {
234        match slot {
235            Some(arg) if matches!(arg.kind, ArgumentKind::Mandatory | ArgumentKind::Group) => {
236                match arg.value {
237                    ArgumentValue::MathContent(node_id) => Ok(node_id),
238                    _ => {
239                        Err(self.invalid_shape(format!("{subject} {label} should be math content")))
240                    }
241                }
242            }
243            _ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
244        }
245    }
246}
247
248impl<'a> RuleContext<'a> {
249    pub fn new(
250        ast: &'a mut Ast,
251        math_kb: &'a KnowledgeBase,
252        text_kb: &'a KnowledgeBase,
253        report: &'a mut RewriteReport,
254    ) -> Self {
255        Self {
256            ast,
257            math_kb,
258            text_kb,
259            report,
260        }
261    }
262
263    fn kb_for(&self, mode: ContentMode) -> &'a KnowledgeBase {
264        match mode {
265            ContentMode::Math => self.math_kb,
266            ContentMode::Text => self.text_kb,
267        }
268    }
269
270    /// Returns a lightweight context that binds diagnostics and slot extraction to one rule.
271    pub fn for_rule(&self, rule: RuleKey) -> RuleScopedContext<'_, 'a> {
272        RuleScopedContext { cx: self, rule }
273    }
274
275    pub fn knows_command_name(&self, name: &str) -> bool {
276        self.lookup_command(name, ContentMode::Math).is_some()
277            || self.lookup_command(name, ContentMode::Text).is_some()
278    }
279
280    pub fn knows_env_name(&self, name: &str) -> bool {
281        self.lookup_env(name, ContentMode::Math).is_some()
282            || self.lookup_env(name, ContentMode::Text).is_some()
283    }
284
285    pub fn command_has_tag(&self, name: &str, tag: &str) -> bool {
286        self.lookup_command(name, ContentMode::Math)
287            .is_some_and(|record| record.tags.contains(&tag))
288            || self
289                .lookup_command(name, ContentMode::Text)
290                .is_some_and(|record| record.tags.contains(&tag))
291    }
292
293    pub fn env_has_tag(&self, name: &str, tag: &str) -> bool {
294        self.lookup_env(name, ContentMode::Math)
295            .is_some_and(|record| record.tags.contains(&tag))
296            || self
297                .lookup_env(name, ContentMode::Text)
298                .is_some_and(|record| record.tags.contains(&tag))
299    }
300
301    /// Looks up the active command record for the node at `node_id` by extracting its name from the AST.
302    pub fn active_command(&self, node_id: NodeId) -> Option<&ActiveCommandRecord> {
303        let name = lookup_command_node_name(self.ast.node(node_id))?;
304        self.lookup_command(name, ContentMode::Math)
305            .or_else(|| self.lookup_command(name, ContentMode::Text))
306    }
307
308    /// Looks up the active environment record for the node at `node_id` by extracting its name from the AST.
309    pub fn active_env(&self, node_id: NodeId) -> Option<&ActiveEnvironmentRecord> {
310        let name = lookup_environment_node_name(self.ast.node(node_id))?;
311        self.lookup_env(name, ContentMode::Math)
312            .or_else(|| self.lookup_env(name, ContentMode::Text))
313    }
314
315    /// Looks up a command record by name directly in the selected knowledge-base lane.
316    pub fn lookup_command(&self, name: &str, mode: ContentMode) -> Option<&ActiveCommandRecord> {
317        self.kb_for(mode).lookup_command(name)
318    }
319
320    /// Looks up a character record by name directly in the selected knowledge-base lane.
321    pub fn lookup_character(
322        &self,
323        name: &str,
324        mode: ContentMode,
325    ) -> Option<&ActiveCharacterRecord> {
326        self.kb_for(mode).lookup_character(name)
327    }
328
329    /// Looks up an environment record by name directly in the selected knowledge-base lane.
330    pub fn lookup_env(&self, name: &str, mode: ContentMode) -> Option<&ActiveEnvironmentRecord> {
331        self.kb_for(mode).lookup_env(name)
332    }
333
334    /// Records that a rule was successfully applied, incrementing its count in the report.
335    pub fn mark_rule_applied(&mut self, key: RuleKey) {
336        self.report.mark_rule_applied(key);
337    }
338
339    /// Records that a rule was attempted after consumed target matching but made no change.
340    pub fn mark_rule_skipped(&mut self, key: RuleKey) {
341        self.report.mark_rule_skipped(key);
342    }
343
344    /// Records the total number of fixed-point iterations the engine performed.
345    pub fn record_iteration(&mut self, iterations: usize) {
346        self.report.record_iteration(iterations);
347    }
348
349    /// Returns the AST node for the given identifier.
350    pub fn node(&self, node_id: NodeId) -> &Node {
351        self.ast.node(node_id)
352    }
353
354    /// Creates an [`InvalidNodeShape`](RuleError::InvalidNodeShape) error for the given rule.
355    pub fn invalid_shape(&self, _rule: RuleKey, message: impl Into<String>) -> RuleError {
356        RuleError::InvalidNodeShape {
357            message: message.into(),
358        }
359    }
360
361    /// Creates a [`MissingMetadata`](RuleError::MissingMetadata) error for the given rule.
362    pub fn missing_metadata(&self, _rule: RuleKey, name: impl Into<String>) -> RuleError {
363        RuleError::MissingMetadata { name: name.into() }
364    }
365
366    /// Returns `Ok(())` when `condition` is true, or an [`InvalidNodeShape`](RuleError::InvalidNodeShape) error otherwise.
367    pub fn ensure_shape(
368        &self,
369        condition: bool,
370        rule: RuleKey,
371        message: impl Into<String>,
372    ) -> Result<(), RuleError> {
373        if condition {
374            Ok(())
375        } else {
376            Err(self.invalid_shape(rule, message))
377        }
378    }
379
380    /// Asserts that `args` has exactly `expected` slots, returning an error that names `subject` on mismatch.
381    pub fn expect_arg_len(
382        &self,
383        rule: RuleKey,
384        args: &[ArgumentSlot],
385        expected: usize,
386        subject: &str,
387    ) -> Result<(), RuleError> {
388        self.ensure_shape(
389            args.len() == expected,
390            rule,
391            format!(
392                "{subject} should carry exactly {expected} explicit argument slots, got {}",
393                args.len()
394            ),
395        )
396    }
397
398    /// Shorthand for [`expect_arg_len`](Self::expect_arg_len) with `expected = 0`.
399    pub fn expect_no_args(
400        &self,
401        rule: RuleKey,
402        args: &[ArgumentSlot],
403        subject: &str,
404    ) -> Result<(), RuleError> {
405        self.expect_arg_len(rule, args, 0, subject)
406    }
407
408    /// Tries to extract a [`CommandView`] from the node, returning `None` if it is not a matching prefix command.
409    pub fn match_command(
410        &self,
411        node_id: NodeId,
412        record: &'static BuiltinCommandRecord,
413    ) -> Option<CommandView<'_>> {
414        match self.ast.node(node_id) {
415            Node::Command { name, args, .. } if name == record.name => Some(CommandView {
416                name: name.as_str(),
417                args: args.as_slice(),
418            }),
419            _ => None,
420        }
421    }
422
423    /// Tries to extract an [`InfixView`] from the node, returning `None` if it is not a matching infix command.
424    pub fn match_infix(
425        &self,
426        node_id: NodeId,
427        record: &'static BuiltinCommandRecord,
428    ) -> Option<InfixView<'_>> {
429        match self.ast.node(node_id) {
430            Node::Infix {
431                name,
432                args,
433                left,
434                right,
435            } if name == record.name => Some(InfixView {
436                name: name.as_str(),
437                args: args.as_slice(),
438                left: *left,
439                right: *right,
440            }),
441            _ => None,
442        }
443    }
444
445    /// Tries to extract a [`DeclarativeView`] from the node, returning `None` if it is not a matching declarative command.
446    pub fn match_declarative(
447        &self,
448        node_id: NodeId,
449        record: &'static BuiltinCommandRecord,
450    ) -> Option<DeclarativeView<'_>> {
451        match self.ast.node(node_id) {
452            Node::Declarative { name, args } if name == record.name => Some(DeclarativeView {
453                name: name.as_str(),
454                args: args.as_slice(),
455            }),
456            _ => None,
457        }
458    }
459
460    /// Tries to extract an [`EnvironmentView`] from the node, returning `None` if it is not a matching environment.
461    pub fn match_environment(
462        &self,
463        node_id: NodeId,
464        record: &'static BuiltinEnvironmentRecord,
465    ) -> Option<EnvironmentView<'_>> {
466        match self.ast.node(node_id) {
467            Node::Environment {
468                name, args, body, ..
469            } if name == record.name => Some(EnvironmentView {
470                name: name.as_str(),
471                args: args.as_slice(),
472                body: *body,
473            }),
474            _ => None,
475        }
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::ast::Argument;
483    use crate::parse::ParseContext;
484    use crate::rewrite::{PackageName, RewriteReport, RuleKey};
485
486    const TEST_RULE: RuleKey = RuleKey {
487        package: PackageName::Base,
488        name: "rule-context-test",
489    };
490
491    #[test]
492    fn extracts_common_prefix_argument_shapes() {
493        let parse_ctx = ParseContext::from_packages(&["base"]);
494        let mut report = RewriteReport::default();
495        let mut ast = Ast::new();
496        let required = ast.new_node(Node::Char('x'));
497        let optional = ast.new_node(Node::Char('2'));
498        let grouped = ast.new_node(Node::Char('t'));
499        let cx = RuleContext::new(
500            &mut ast,
501            parse_ctx.math_kb(),
502            parse_ctx.text_kb(),
503            &mut report,
504        );
505
506        let star = Some(Argument {
507            kind: ArgumentKind::Star,
508            value: ArgumentValue::Boolean(true),
509        });
510        let required = Some(Argument {
511            kind: ArgumentKind::Mandatory,
512            value: ArgumentValue::MathContent(required),
513        });
514        let optional = Some(Argument {
515            kind: ArgumentKind::Optional,
516            value: ArgumentValue::MathContent(optional),
517        });
518        let grouped = Some(Argument {
519            kind: ArgumentKind::Group,
520            value: ArgumentValue::MathContent(grouped),
521        });
522
523        assert!(
524            cx.for_rule(TEST_RULE)
525                .star_arg_value(&star, r"\example")
526                .unwrap()
527        );
528        assert_eq!(
529            cx.for_rule(TEST_RULE)
530                .mandatory_math_content(&required, r"\example", "argument")
531                .unwrap(),
532            required
533                .as_ref()
534                .and_then(|arg| match arg.value {
535                    ArgumentValue::MathContent(id) => Some(id),
536                    _ => None,
537                })
538                .unwrap()
539        );
540        assert_eq!(
541            cx.for_rule(TEST_RULE)
542                .optional_math_content(&optional, r"\example", "order")
543                .unwrap(),
544            optional.as_ref().and_then(|arg| match arg.value {
545                ArgumentValue::MathContent(id) => Some(id),
546                _ => None,
547            })
548        );
549        assert_eq!(
550            cx.for_rule(TEST_RULE)
551                .optional_group_math_content(&grouped, r"\example", "denominator")
552                .unwrap(),
553            grouped.as_ref().and_then(|arg| match arg.value {
554                ArgumentValue::MathContent(id) => Some(id),
555                _ => None,
556            })
557        );
558        assert_eq!(
559            cx.for_rule(TEST_RULE)
560                .mandatory_or_group_math_content(&grouped, r"\example", "argument")
561                .unwrap(),
562            grouped
563                .as_ref()
564                .and_then(|arg| match arg.value {
565                    ArgumentValue::MathContent(id) => Some(id),
566                    _ => None,
567                })
568                .unwrap()
569        );
570        assert_eq!(
571            cx.for_rule(TEST_RULE)
572                .optional_math_content(&None, r"\example", "order")
573                .unwrap(),
574            None
575        );
576    }
577}