1use crate::ast::{ArgumentValue, Ast, ContentMode, GroupKind, Node, NodeId, ParentLink, Slot};
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub struct FlattenGroupsConfig {
7 pub enabled: bool,
8 pub preserve_group_containing_declarative_command: bool,
12 pub preserve_group_in_script_base_slot: bool,
15 pub preserve_group_inside_env_body: bool,
18 pub preserve_group_containing_infix: bool,
21 pub preserve_group_adjacent_to_command_like: bool,
24 pub preserve_group_as_argument_of_command: bool,
28 pub preserve_group_after_scripted_command_like: bool,
31 pub preserve_empty_group: bool,
34 pub preserve_group_with_lone_atom_spacing_char: bool,
37 pub preserve_group_starting_with_atom_spacing_char: bool,
40 pub preserve_group_containing_delimited_pair: bool,
43}
44
45impl FlattenGroupsConfig {
46 pub const STRICT: Self = Self {
48 enabled: true,
49 preserve_group_containing_declarative_command: true,
50 preserve_group_in_script_base_slot: true,
51 preserve_group_inside_env_body: true,
52 preserve_group_containing_infix: true,
53 preserve_group_adjacent_to_command_like: true,
54 preserve_group_as_argument_of_command: true,
55 preserve_group_after_scripted_command_like: true,
56 preserve_empty_group: true,
57 preserve_group_with_lone_atom_spacing_char: true,
58 preserve_group_starting_with_atom_spacing_char: true,
59 preserve_group_containing_delimited_pair: true,
60 };
61 pub const STRUCTURAL_ONLY: Self = Self {
63 enabled: true,
64 preserve_group_containing_declarative_command: true,
65 preserve_group_in_script_base_slot: true,
66 preserve_group_inside_env_body: true,
67 preserve_group_containing_infix: true,
68 preserve_group_adjacent_to_command_like: false,
69 preserve_group_as_argument_of_command: false,
70 preserve_group_after_scripted_command_like: false,
71 preserve_empty_group: false,
72 preserve_group_with_lone_atom_spacing_char: false,
73 preserve_group_starting_with_atom_spacing_char: false,
74 preserve_group_containing_delimited_pair: false,
75 };
76 pub const ENABLED: Self = Self::STRICT;
77 pub const DISABLED: Self = Self {
78 enabled: false,
79 ..Self::STRICT
80 };
81 pub const DEFAULTS: Self = Self::STRICT;
82}
83
84#[derive(Clone, Debug, Default, PartialEq, Eq)]
85pub struct FlattenGroupsReport {
86 pub actions: FlattenGroupsActionCounts,
87 pub guards: FlattenGroupsGuardCounts,
88}
89
90#[derive(Clone, Debug, Default, PartialEq, Eq)]
91pub struct FlattenGroupsActionCounts {
92 pub removed_empty: usize,
93 pub replaced_single_child: usize,
94 pub inlined_multi_child: usize,
95 pub unwrapped_slot: usize,
96}
97
98#[derive(Clone, Debug, Default, PartialEq, Eq)]
99pub struct FlattenGroupsGuardCounts {
100 pub preserve_group_containing_declarative_command: usize,
101 pub preserve_group_in_script_base_slot: usize,
102 pub preserve_group_inside_env_body: usize,
103 pub preserve_group_containing_infix: usize,
104 pub preserve_group_adjacent_to_command_like: usize,
105 pub preserve_group_as_argument_of_command: usize,
106 pub preserve_group_after_scripted_command_like: usize,
107 pub preserve_empty_group: usize,
108 pub preserve_group_with_lone_atom_spacing_char: usize,
109 pub preserve_group_starting_with_atom_spacing_char: usize,
110 pub preserve_group_containing_delimited_pair: usize,
111}
112
113pub fn run(ast: &mut Ast, config: &FlattenGroupsConfig, report: &mut FlattenGroupsReport) {
114 if !config.enabled {
115 return;
116 }
117
118 visit(ast, ast.root(), false, config, report);
119}
120
121#[derive(Clone, Copy, Debug, Default)]
122struct SubtreeFlags {
123 has_declarative: bool,
124 has_infix: bool,
125 has_delimited: bool,
126}
127
128fn visit(
129 ast: &mut Ast,
130 node: NodeId,
131 in_env_body: bool,
132 config: &FlattenGroupsConfig,
133 report: &mut FlattenGroupsReport,
134) -> SubtreeFlags {
135 let edges = ast.edges(node);
136 let mut flags = SubtreeFlags {
137 has_declarative: matches!(ast.node(node), Node::Declarative { .. }),
138 has_infix: matches!(ast.node(node), Node::Infix { .. }),
139 has_delimited: matches!(
140 ast.node(node),
141 Node::Group {
142 kind: GroupKind::Delimited { .. },
143 ..
144 }
145 ),
146 };
147 for (child, slot) in edges {
148 if ast.contains(child) {
149 let child_flags = visit(
150 ast,
151 child,
152 in_env_body || slot == Slot::EnvBody,
153 config,
154 report,
155 );
156 flags.has_declarative |= child_flags.has_declarative;
157 flags.has_infix |= child_flags.has_infix;
158 flags.has_delimited |= child_flags.has_delimited;
159 }
160 }
161
162 if ast.contains(node) {
163 try_unwrap(ast, node, flags, in_env_body, config, report);
164 }
165
166 flags
167}
168
169fn try_unwrap(
170 ast: &mut Ast,
171 node: NodeId,
172 flags: SubtreeFlags,
173 in_env_body: bool,
174 config: &FlattenGroupsConfig,
175 report: &mut FlattenGroupsReport,
176) {
177 let (kind, mode, child_count) = match ast.node(node) {
178 Node::Group {
179 kind,
180 mode,
181 children,
182 } => (kind.clone(), *mode, children.len()),
183 _ => return,
184 };
185 if !matches!(kind, GroupKind::Explicit | GroupKind::Implicit) {
186 return;
187 }
188 if config.preserve_group_containing_declarative_command && flags.has_declarative {
189 report.guards.preserve_group_containing_declarative_command += 1;
190 return;
191 }
192 if config.preserve_group_inside_env_body && in_env_body {
193 report.guards.preserve_group_inside_env_body += 1;
194 return;
195 }
196
197 let Some(link) = ast.parent(node) else {
198 return;
199 };
200 if !slot_can_unwrap(link.slot, child_count) {
201 return;
202 }
203 if matches!(link.slot, Slot::GroupChild(_))
204 && config.preserve_group_containing_infix
205 && flags.has_infix
206 {
207 report.guards.preserve_group_containing_infix += 1;
208 return;
209 }
210 if matches!(link.slot, Slot::GroupChild(_))
211 && config.preserve_group_containing_delimited_pair
212 && flags.has_delimited
213 {
214 report.guards.preserve_group_containing_delimited_pair += 1;
215 return;
216 }
217 if let Slot::GroupChild(index) = link.slot
218 && config.preserve_group_adjacent_to_command_like
219 {
220 let command_contact = group_child_touches_command(
221 ast,
222 node,
223 link.parent,
224 index,
225 config.preserve_group_after_scripted_command_like,
226 );
227 if command_contact.touches_command {
228 report.guards.preserve_group_adjacent_to_command_like += 1;
229 if command_contact.used_scripted_base {
230 report.guards.preserve_group_after_scripted_command_like += 1;
231 }
232 return;
233 }
234 }
235 let children = ast.children(node);
236 let first_is_atom = children
237 .first()
238 .is_some_and(|child| is_atom_spacing_char(ast, *child));
239 if matches!(link.slot, Slot::GroupChild(_)) {
240 if config.preserve_empty_group && child_count == 0 {
241 report.guards.preserve_empty_group += 1;
242 return;
243 }
244 if config.preserve_group_with_lone_atom_spacing_char && child_count == 1 && first_is_atom {
245 report.guards.preserve_group_with_lone_atom_spacing_char += 1;
246 return;
247 }
248 if config.preserve_group_starting_with_atom_spacing_char && child_count > 1 && first_is_atom
249 {
250 report.guards.preserve_group_starting_with_atom_spacing_char += 1;
251 return;
252 }
253 }
254 if matches!(link.slot, Slot::ScriptBase)
255 && config.preserve_group_with_lone_atom_spacing_char
256 && child_count == 1
257 && first_is_atom
258 {
259 report.guards.preserve_group_with_lone_atom_spacing_char += 1;
260 return;
261 }
262 if matches!(link.slot, Slot::Argument(_))
263 && config.preserve_group_as_argument_of_command
264 && group_as_argument_of_command_needs_boundary(ast, node)
265 {
266 report.guards.preserve_group_as_argument_of_command += 1;
267 return;
268 }
269
270 let Some(parent_mode) = context_mode(ast, link) else {
271 return;
272 };
273 if mode != parent_mode {
274 return;
275 }
276
277 if matches!(link.slot, Slot::ScriptBase)
278 && config.preserve_group_in_script_base_slot
279 && !is_atomic_base(ast, ast.children(node)[0])
280 {
281 report.guards.preserve_group_in_script_base_slot += 1;
282 return;
283 }
284
285 match link.slot {
286 Slot::GroupChild(index) => unwrap_group_child(ast, node, link.parent, index, report),
287 Slot::Argument(_)
288 | Slot::ScriptBase
289 | Slot::ScriptSub
290 | Slot::ScriptSup
291 | Slot::InfixLeft
292 | Slot::InfixRight => redirect_single_child_slot(ast, node, report),
293 Slot::EnvBody => {}
294 }
295}
296
297fn slot_can_unwrap(slot: Slot, child_count: usize) -> bool {
298 match slot {
299 Slot::GroupChild(_) => true,
300 Slot::Argument(_)
301 | Slot::ScriptBase
302 | Slot::ScriptSub
303 | Slot::ScriptSup
304 | Slot::InfixLeft
305 | Slot::InfixRight => child_count == 1,
306 Slot::EnvBody => false,
307 }
308}
309
310#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
311struct CommandContact {
312 touches_command: bool,
313 used_scripted_base: bool,
314}
315
316fn group_child_touches_command(
317 ast: &Ast,
318 node: NodeId,
319 parent: NodeId,
320 index: usize,
321 include_scripted: bool,
322) -> CommandContact {
323 let previous = index
324 .checked_sub(1)
325 .and_then(|previous| ast.children(parent).get(previous).copied());
326 let first_child = ast.children(node).first().copied();
327
328 command_contact_for_node(ast, previous, include_scripted).merge(command_contact_for_node(
329 ast,
330 first_child,
331 include_scripted,
332 ))
333}
334
335impl CommandContact {
336 fn merge(self, other: Self) -> Self {
337 Self {
338 touches_command: self.touches_command || other.touches_command,
339 used_scripted_base: self.used_scripted_base || other.used_scripted_base,
340 }
341 }
342}
343
344fn command_contact_for_node(
345 ast: &Ast,
346 node: Option<NodeId>,
347 include_scripted: bool,
348) -> CommandContact {
349 let Some(node) = node else {
350 return CommandContact::default();
351 };
352 if is_command_like(ast, node, false) {
353 return CommandContact {
354 touches_command: true,
355 used_scripted_base: false,
356 };
357 }
358 if include_scripted && is_command_like(ast, node, true) {
359 return CommandContact {
360 touches_command: true,
361 used_scripted_base: true,
362 };
363 }
364 CommandContact::default()
365}
366
367fn is_atom_spacing_char(ast: &Ast, node: NodeId) -> bool {
368 matches!(
369 ast.node(node),
370 Node::Char(
371 '=' | '<' | '>' | '+' | '-' | ',' | ':' | ';' | '.' | '/' | '*' | '!' | '?' | '|' | 'ยท'
372 )
373 )
374}
375
376fn is_command_like(ast: &Ast, node: NodeId, include_scripted: bool) -> bool {
377 match ast.node(node) {
378 Node::Command { .. } | Node::Declarative { .. } => true,
379 Node::Scripted { base, .. } if include_scripted => is_command_like(ast, *base, true),
380 _ => false,
381 }
382}
383
384fn is_atomic_base(ast: &Ast, node: NodeId) -> bool {
385 match ast.node(node) {
386 Node::Char(_) | Node::Prime { .. } => true,
387 Node::Command { name, args, .. } => {
388 args.iter().all(Option::is_none)
389 && !subtree_has_scripted(ast, node)
390 && !is_script_placement_sensitive_command(name)
391 }
392 _ => false,
393 }
394}
395
396fn group_as_argument_of_command_needs_boundary(ast: &Ast, node: NodeId) -> bool {
397 let children = ast.children(node);
398 if children.len() != 1 {
399 return false;
400 }
401 subtree_has_command_like(ast, children[0])
402}
403
404fn subtree_has_command_like(ast: &Ast, node: NodeId) -> bool {
405 if is_command_like(ast, node, false) {
406 return true;
407 }
408 ast.edges(node)
409 .into_iter()
410 .any(|(child, _)| subtree_has_command_like(ast, child))
411}
412
413fn is_script_placement_sensitive_command(name: &str) -> bool {
414 matches!(
415 name,
416 "arccos"
417 | "arcsin"
418 | "arctan"
419 | "arg"
420 | "bigcap"
421 | "bigcup"
422 | "bigodot"
423 | "bigoplus"
424 | "bigotimes"
425 | "bigsqcup"
426 | "bigtriangledown"
427 | "bigtriangleup"
428 | "biguplus"
429 | "bigvee"
430 | "bigwedge"
431 | "cos"
432 | "cosh"
433 | "cot"
434 | "coth"
435 | "csc"
436 | "deg"
437 | "det"
438 | "dim"
439 | "exp"
440 | "gcd"
441 | "hom"
442 | "inf"
443 | "int"
444 | "ker"
445 | "lg"
446 | "lim"
447 | "liminf"
448 | "limsup"
449 | "ln"
450 | "log"
451 | "max"
452 | "min"
453 | "operatorname"
454 | "Pr"
455 | "prod"
456 | "sec"
457 | "sin"
458 | "sinh"
459 | "sup"
460 | "sum"
461 | "tan"
462 | "tanh"
463 )
464}
465
466fn subtree_has_scripted(ast: &Ast, node: NodeId) -> bool {
467 if matches!(ast.node(node), Node::Scripted { .. }) {
468 return true;
469 }
470 ast.edges(node)
471 .into_iter()
472 .any(|(child, _)| subtree_has_scripted(ast, child))
473}
474
475fn context_mode(ast: &Ast, link: ParentLink) -> Option<ContentMode> {
476 match link.slot {
477 Slot::GroupChild(_) => match ast.node(link.parent) {
478 Node::Root { mode, .. } | Node::Group { mode, .. } => Some(*mode),
479 _ => None,
480 },
481 Slot::Argument(index) => argument_slot_mode(ast, link.parent, index),
482 Slot::ScriptBase
483 | Slot::ScriptSub
484 | Slot::ScriptSup
485 | Slot::InfixLeft
486 | Slot::InfixRight => Some(ContentMode::Math),
487 Slot::EnvBody => None,
488 }
489}
490
491fn argument_slot_mode(ast: &Ast, parent: NodeId, index: usize) -> Option<ContentMode> {
492 let argument = ast.arg_slots(parent).get(index)?.as_ref()?;
493 match argument.value {
494 ArgumentValue::MathContent(_) => Some(ContentMode::Math),
495 ArgumentValue::TextContent(_) => Some(ContentMode::Text),
496 _ => None,
497 }
498}
499
500fn unwrap_group_child(
501 ast: &mut Ast,
502 node: NodeId,
503 parent: NodeId,
504 index: usize,
505 report: &mut FlattenGroupsReport,
506) {
507 let child_count = ast.children(node).len();
508 let children = ast.detach_children_range(node, 0..child_count);
509 let mut parent_children = ast.children(parent).to_vec();
510 assert_eq!(
511 parent_children.get(index),
512 Some(&node),
513 "group child index must match parent link"
514 );
515
516 parent_children.splice(index..index + 1, children);
517 ast.replace_children(parent, parent_children);
518 ast.remove_detached(node);
519
520 match child_count {
521 0 => report.actions.removed_empty += 1,
522 1 => report.actions.replaced_single_child += 1,
523 _ => report.actions.inlined_multi_child += 1,
524 }
525}
526
527fn redirect_single_child_slot(ast: &mut Ast, node: NodeId, report: &mut FlattenGroupsReport) {
528 let mut children = ast.detach_children_range(node, 0..1);
529 let child = children
530 .pop()
531 .expect("single-child slot unwrap requires one child");
532 ast.replace_content_child(node, child);
533 ast.remove_detached(node);
534 report.actions.unwrapped_slot += 1;
535}