1use std::ops::Deref;
15
16use crate::ast::{ArgumentKind, ArgumentSlot, ArgumentValue, Ast, Node, NodeId};
17use crate::knowledge::{KnowledgeBase, lookup_command_node_name, lookup_environment_node_name};
18use crate::parse::ContentMode;
19use crate::rewrite::rule::RuleKey;
20use crate::rewrite::{RewriteReport, RuleError};
21use texform_knowledge::specs::{
22 ActiveCharacterRecord, ActiveCommandRecord, ActiveEnvironmentRecord, BuiltinCommandRecord,
23 BuiltinEnvironmentRecord,
24};
25
26#[derive(Clone, Copy)]
28pub struct CommandView<'a> {
29 pub name: &'a str,
31 pub args: &'a [ArgumentSlot],
33}
34
35impl CommandView<'_> {
36 pub fn subject(&self) -> String {
38 format!(r"\{}", self.name)
39 }
40}
41
42#[derive(Clone, Copy)]
44pub struct InfixView<'a> {
45 pub name: &'a str,
47 pub args: &'a [ArgumentSlot],
49 pub left: NodeId,
51 pub right: NodeId,
53}
54
55impl InfixView<'_> {
56 pub fn subject(&self) -> String {
58 format!(r"\{}", self.name)
59 }
60}
61
62#[derive(Clone, Copy)]
64pub struct DeclarativeView<'a> {
65 pub name: &'a str,
67 pub args: &'a [ArgumentSlot],
69}
70
71#[derive(Clone, Copy)]
73pub struct EnvironmentView<'a> {
74 pub name: &'a str,
76 pub args: &'a [ArgumentSlot],
78 pub body: NodeId,
80}
81
82pub struct RuleContext<'a> {
97 pub ast: &'a mut Ast,
102 math_kb: &'a KnowledgeBase,
103 text_kb: &'a KnowledgeBase,
104 report: &'a mut RewriteReport,
105}
106
107pub struct RuleScopedContext<'cx, 'ctx> {
109 cx: &'cx RuleContext<'ctx>,
110 rule: RuleKey,
111}
112
113impl<'cx, 'ctx> Deref for RuleScopedContext<'cx, 'ctx> {
114 type Target = RuleContext<'ctx>;
115
116 fn deref(&self) -> &Self::Target {
117 self.cx
118 }
119}
120
121impl RuleScopedContext<'_, '_> {
122 pub fn invalid_shape(&self, message: impl Into<String>) -> RuleError {
124 self.cx.invalid_shape(self.rule, message)
125 }
126
127 pub fn missing_metadata(&self, name: impl Into<String>) -> RuleError {
129 self.cx.missing_metadata(self.rule, name)
130 }
131
132 pub fn ensure_shape(
134 &self,
135 condition: bool,
136 message: impl Into<String>,
137 ) -> Result<(), RuleError> {
138 self.cx.ensure_shape(condition, self.rule, message)
139 }
140
141 pub fn expect_arg_len(
143 &self,
144 args: &[ArgumentSlot],
145 expected: usize,
146 subject: &str,
147 ) -> Result<(), RuleError> {
148 self.cx.expect_arg_len(self.rule, args, expected, subject)
149 }
150
151 pub fn expect_no_args(&self, args: &[ArgumentSlot], subject: &str) -> Result<(), RuleError> {
153 self.cx.expect_no_args(self.rule, args, subject)
154 }
155
156 pub fn star_arg_value(&self, slot: &ArgumentSlot, subject: &str) -> Result<bool, RuleError> {
158 match slot {
159 Some(arg) if arg.kind == ArgumentKind::Star => match arg.value {
160 ArgumentValue::Boolean(value) => Ok(value),
161 _ => {
162 Err(self
163 .invalid_shape(format!("{subject} star slot should carry a boolean value")))
164 }
165 },
166 _ => Err(self.invalid_shape(format!("{subject} should carry a star slot"))),
167 }
168 }
169
170 pub fn optional_math_content(
172 &self,
173 slot: &ArgumentSlot,
174 subject: &str,
175 label: &str,
176 ) -> Result<Option<NodeId>, RuleError> {
177 match slot {
178 None => Ok(None),
179 Some(arg) if arg.kind == ArgumentKind::Optional => match arg.value {
180 ArgumentValue::MathContent(node_id) => Ok(Some(node_id)),
181 _ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
182 },
183 _ => Err(self.invalid_shape(format!(
184 "{subject} {label} should be an optional math argument"
185 ))),
186 }
187 }
188
189 pub fn optional_group_math_content(
191 &self,
192 slot: &ArgumentSlot,
193 subject: &str,
194 label: &str,
195 ) -> Result<Option<NodeId>, RuleError> {
196 match slot {
197 None => Ok(None),
198 Some(arg) if arg.kind == ArgumentKind::Group => match arg.value {
199 ArgumentValue::MathContent(node_id) => Ok(Some(node_id)),
200 _ => Err(self
201 .invalid_shape(format!("{subject} optional {label} should be math content"))),
202 },
203 _ => Err(self.invalid_shape(format!(
204 "{subject} optional {label} should be a braced group"
205 ))),
206 }
207 }
208
209 pub fn mandatory_math_content(
211 &self,
212 slot: &ArgumentSlot,
213 subject: &str,
214 label: &str,
215 ) -> Result<NodeId, RuleError> {
216 match slot {
217 Some(arg) if arg.kind == ArgumentKind::Mandatory => match arg.value {
218 ArgumentValue::MathContent(node_id) => Ok(node_id),
219 _ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
220 },
221 _ => Err(self.invalid_shape(format!(
222 "{subject} {label} should be a mandatory math argument"
223 ))),
224 }
225 }
226
227 pub fn mandatory_or_group_math_content(
229 &self,
230 slot: &ArgumentSlot,
231 subject: &str,
232 label: &str,
233 ) -> Result<NodeId, RuleError> {
234 match slot {
235 Some(arg) if matches!(arg.kind, ArgumentKind::Mandatory | ArgumentKind::Group) => {
236 match arg.value {
237 ArgumentValue::MathContent(node_id) => Ok(node_id),
238 _ => {
239 Err(self.invalid_shape(format!("{subject} {label} should be math content")))
240 }
241 }
242 }
243 _ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
244 }
245 }
246}
247
248impl<'a> RuleContext<'a> {
249 pub fn new(
250 ast: &'a mut Ast,
251 math_kb: &'a KnowledgeBase,
252 text_kb: &'a KnowledgeBase,
253 report: &'a mut RewriteReport,
254 ) -> Self {
255 Self {
256 ast,
257 math_kb,
258 text_kb,
259 report,
260 }
261 }
262
263 fn kb_for(&self, mode: ContentMode) -> &'a KnowledgeBase {
264 match mode {
265 ContentMode::Math => self.math_kb,
266 ContentMode::Text => self.text_kb,
267 }
268 }
269
270 pub fn for_rule(&self, rule: RuleKey) -> RuleScopedContext<'_, 'a> {
272 RuleScopedContext { cx: self, rule }
273 }
274
275 pub fn knows_command_name(&self, name: &str) -> bool {
276 self.lookup_command(name, ContentMode::Math).is_some()
277 || self.lookup_command(name, ContentMode::Text).is_some()
278 }
279
280 pub fn knows_env_name(&self, name: &str) -> bool {
281 self.lookup_env(name, ContentMode::Math).is_some()
282 || self.lookup_env(name, ContentMode::Text).is_some()
283 }
284
285 pub fn command_has_tag(&self, name: &str, tag: &str) -> bool {
286 self.lookup_command(name, ContentMode::Math)
287 .is_some_and(|record| record.tags.contains(&tag))
288 || self
289 .lookup_command(name, ContentMode::Text)
290 .is_some_and(|record| record.tags.contains(&tag))
291 }
292
293 pub fn env_has_tag(&self, name: &str, tag: &str) -> bool {
294 self.lookup_env(name, ContentMode::Math)
295 .is_some_and(|record| record.tags.contains(&tag))
296 || self
297 .lookup_env(name, ContentMode::Text)
298 .is_some_and(|record| record.tags.contains(&tag))
299 }
300
301 pub fn active_command(&self, node_id: NodeId) -> Option<&ActiveCommandRecord> {
303 let name = lookup_command_node_name(self.ast.node(node_id))?;
304 self.lookup_command(name, ContentMode::Math)
305 .or_else(|| self.lookup_command(name, ContentMode::Text))
306 }
307
308 pub fn active_env(&self, node_id: NodeId) -> Option<&ActiveEnvironmentRecord> {
310 let name = lookup_environment_node_name(self.ast.node(node_id))?;
311 self.lookup_env(name, ContentMode::Math)
312 .or_else(|| self.lookup_env(name, ContentMode::Text))
313 }
314
315 pub fn lookup_command(&self, name: &str, mode: ContentMode) -> Option<&ActiveCommandRecord> {
317 self.kb_for(mode).lookup_command(name)
318 }
319
320 pub fn lookup_character(
322 &self,
323 name: &str,
324 mode: ContentMode,
325 ) -> Option<&ActiveCharacterRecord> {
326 self.kb_for(mode).lookup_character(name)
327 }
328
329 pub fn lookup_env(&self, name: &str, mode: ContentMode) -> Option<&ActiveEnvironmentRecord> {
331 self.kb_for(mode).lookup_env(name)
332 }
333
334 pub fn mark_rule_applied(&mut self, key: RuleKey) {
336 self.report.mark_rule_applied(key);
337 }
338
339 pub fn mark_rule_skipped(&mut self, key: RuleKey) {
341 self.report.mark_rule_skipped(key);
342 }
343
344 pub fn record_iteration(&mut self, iterations: usize) {
346 self.report.record_iteration(iterations);
347 }
348
349 pub fn node(&self, node_id: NodeId) -> &Node {
351 self.ast.node(node_id)
352 }
353
354 pub fn invalid_shape(&self, _rule: RuleKey, message: impl Into<String>) -> RuleError {
356 RuleError::InvalidNodeShape {
357 message: message.into(),
358 }
359 }
360
361 pub fn missing_metadata(&self, _rule: RuleKey, name: impl Into<String>) -> RuleError {
363 RuleError::MissingMetadata { name: name.into() }
364 }
365
366 pub fn ensure_shape(
368 &self,
369 condition: bool,
370 rule: RuleKey,
371 message: impl Into<String>,
372 ) -> Result<(), RuleError> {
373 if condition {
374 Ok(())
375 } else {
376 Err(self.invalid_shape(rule, message))
377 }
378 }
379
380 pub fn expect_arg_len(
382 &self,
383 rule: RuleKey,
384 args: &[ArgumentSlot],
385 expected: usize,
386 subject: &str,
387 ) -> Result<(), RuleError> {
388 self.ensure_shape(
389 args.len() == expected,
390 rule,
391 format!(
392 "{subject} should carry exactly {expected} explicit argument slots, got {}",
393 args.len()
394 ),
395 )
396 }
397
398 pub fn expect_no_args(
400 &self,
401 rule: RuleKey,
402 args: &[ArgumentSlot],
403 subject: &str,
404 ) -> Result<(), RuleError> {
405 self.expect_arg_len(rule, args, 0, subject)
406 }
407
408 pub fn match_command(
410 &self,
411 node_id: NodeId,
412 record: &'static BuiltinCommandRecord,
413 ) -> Option<CommandView<'_>> {
414 match self.ast.node(node_id) {
415 Node::Command { name, args, .. } if name == record.name => Some(CommandView {
416 name: name.as_str(),
417 args: args.as_slice(),
418 }),
419 _ => None,
420 }
421 }
422
423 pub fn match_infix(
425 &self,
426 node_id: NodeId,
427 record: &'static BuiltinCommandRecord,
428 ) -> Option<InfixView<'_>> {
429 match self.ast.node(node_id) {
430 Node::Infix {
431 name,
432 args,
433 left,
434 right,
435 } if name == record.name => Some(InfixView {
436 name: name.as_str(),
437 args: args.as_slice(),
438 left: *left,
439 right: *right,
440 }),
441 _ => None,
442 }
443 }
444
445 pub fn match_declarative(
447 &self,
448 node_id: NodeId,
449 record: &'static BuiltinCommandRecord,
450 ) -> Option<DeclarativeView<'_>> {
451 match self.ast.node(node_id) {
452 Node::Declarative { name, args } if name == record.name => Some(DeclarativeView {
453 name: name.as_str(),
454 args: args.as_slice(),
455 }),
456 _ => None,
457 }
458 }
459
460 pub fn match_environment(
462 &self,
463 node_id: NodeId,
464 record: &'static BuiltinEnvironmentRecord,
465 ) -> Option<EnvironmentView<'_>> {
466 match self.ast.node(node_id) {
467 Node::Environment {
468 name, args, body, ..
469 } if name == record.name => Some(EnvironmentView {
470 name: name.as_str(),
471 args: args.as_slice(),
472 body: *body,
473 }),
474 _ => None,
475 }
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use crate::ast::Argument;
483 use crate::parse::ParseContext;
484 use crate::rewrite::{PackageName, RewriteReport, RuleKey};
485
486 const TEST_RULE: RuleKey = RuleKey {
487 package: PackageName::Base,
488 name: "rule-context-test",
489 };
490
491 #[test]
492 fn extracts_common_prefix_argument_shapes() {
493 let parse_ctx = ParseContext::from_packages(&["base"]);
494 let mut report = RewriteReport::default();
495 let mut ast = Ast::new();
496 let required = ast.new_node(Node::Char('x'));
497 let optional = ast.new_node(Node::Char('2'));
498 let grouped = ast.new_node(Node::Char('t'));
499 let cx = RuleContext::new(
500 &mut ast,
501 parse_ctx.math_kb(),
502 parse_ctx.text_kb(),
503 &mut report,
504 );
505
506 let star = Some(Argument {
507 kind: ArgumentKind::Star,
508 value: ArgumentValue::Boolean(true),
509 });
510 let required = Some(Argument {
511 kind: ArgumentKind::Mandatory,
512 value: ArgumentValue::MathContent(required),
513 });
514 let optional = Some(Argument {
515 kind: ArgumentKind::Optional,
516 value: ArgumentValue::MathContent(optional),
517 });
518 let grouped = Some(Argument {
519 kind: ArgumentKind::Group,
520 value: ArgumentValue::MathContent(grouped),
521 });
522
523 assert!(
524 cx.for_rule(TEST_RULE)
525 .star_arg_value(&star, r"\example")
526 .unwrap()
527 );
528 assert_eq!(
529 cx.for_rule(TEST_RULE)
530 .mandatory_math_content(&required, r"\example", "argument")
531 .unwrap(),
532 required
533 .as_ref()
534 .and_then(|arg| match arg.value {
535 ArgumentValue::MathContent(id) => Some(id),
536 _ => None,
537 })
538 .unwrap()
539 );
540 assert_eq!(
541 cx.for_rule(TEST_RULE)
542 .optional_math_content(&optional, r"\example", "order")
543 .unwrap(),
544 optional.as_ref().and_then(|arg| match arg.value {
545 ArgumentValue::MathContent(id) => Some(id),
546 _ => None,
547 })
548 );
549 assert_eq!(
550 cx.for_rule(TEST_RULE)
551 .optional_group_math_content(&grouped, r"\example", "denominator")
552 .unwrap(),
553 grouped.as_ref().and_then(|arg| match arg.value {
554 ArgumentValue::MathContent(id) => Some(id),
555 _ => None,
556 })
557 );
558 assert_eq!(
559 cx.for_rule(TEST_RULE)
560 .mandatory_or_group_math_content(&grouped, r"\example", "argument")
561 .unwrap(),
562 grouped
563 .as_ref()
564 .and_then(|arg| match arg.value {
565 ArgumentValue::MathContent(id) => Some(id),
566 _ => None,
567 })
568 .unwrap()
569 );
570 assert_eq!(
571 cx.for_rule(TEST_RULE)
572 .optional_math_content(&None, r"\example", "order")
573 .unwrap(),
574 None
575 );
576 }
577}