1use crate::mathbox::MathBox;
4use crate::unicode_maps::{get_greek, get_symbol, to_subscript, to_superscript, BRACKETS};
5use latex2mathml::{latex_to_mathml, DisplayStyle};
6use roxmltree::{Document, Node};
7use std::fmt;
8
9#[derive(Debug)]
11pub enum RenderError {
12 LatexConversion(String),
13 MathMLParse(String),
14 InvalidStructure(String),
15}
16
17impl fmt::Display for RenderError {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 match self {
20 RenderError::LatexConversion(e) => write!(f, "LaTeX conversion error: {}", e),
21 RenderError::MathMLParse(e) => write!(f, "MathML parse error: {}", e),
22 RenderError::InvalidStructure(e) => write!(f, "Invalid math structure: {}", e),
23 }
24 }
25}
26
27impl std::error::Error for RenderError {}
28
29pub struct MathRenderer {
31 use_unicode_scripts: bool,
32}
33
34impl MathRenderer {
35 pub fn new() -> Self {
36 Self {
37 use_unicode_scripts: true,
38 }
39 }
40
41 pub fn use_unicode_scripts(mut self, use_unicode: bool) -> Self {
43 self.use_unicode_scripts = use_unicode;
44 self
45 }
46
47 pub fn render_latex(&self, latex: &str) -> Result<String, RenderError> {
49 let mathml = latex_to_mathml(latex, DisplayStyle::Inline)
50 .map_err(|e| RenderError::LatexConversion(e.to_string()))?;
51 self.render_mathml(&mathml)
52 }
53
54 pub fn render_mathml(&self, mathml: &str) -> Result<String, RenderError> {
56 let doc = Document::parse(mathml)
57 .map_err(|e| RenderError::MathMLParse(e.to_string()))?;
58 let root = doc.root_element();
59 let math_box = self.process_element(&root)?;
60 Ok(math_box.to_string())
61 }
62
63 pub fn render_to_box(&self, latex: &str) -> Result<MathBox, RenderError> {
65 let mathml = latex_to_mathml(latex, DisplayStyle::Inline)
66 .map_err(|e| RenderError::LatexConversion(e.to_string()))?;
67 let doc = Document::parse(&mathml)
68 .map_err(|e| RenderError::MathMLParse(e.to_string()))?;
69 let root = doc.root_element();
70 self.process_element(&root)
71 }
72
73 fn process_element(&self, node: &Node) -> Result<MathBox, RenderError> {
74 let tag = node.tag_name().name();
75
76 match tag {
77 "math" | "mrow" | "mstyle" | "mpadded" | "mphantom" => {
78 self.process_row(node)
79 }
80 "mi" | "mn" | "mtext" => {
81 self.process_text(node)
82 }
83 "mo" => {
84 self.process_operator(node)
85 }
86 "msup" => {
87 self.process_superscript(node)
88 }
89 "msub" => {
90 self.process_subscript(node)
91 }
92 "msubsup" => {
93 self.process_subsup(node)
94 }
95 "mfrac" => {
96 self.process_fraction(node)
97 }
98 "msqrt" => {
99 self.process_sqrt(node)
100 }
101 "mroot" => {
102 self.process_nthroot(node)
103 }
104 "mover" => {
105 self.process_over(node)
106 }
107 "munder" => {
108 self.process_under(node)
109 }
110 "munderover" => {
111 self.process_underover(node)
112 }
113 "mtable" => {
114 self.process_table(node)
115 }
116 "mtr" => {
117 self.process_table_row(node)
118 }
119 "mtd" => {
120 self.process_row(node)
121 }
122 "mfenced" => {
123 self.process_fenced(node)
124 }
125 "menclose" => {
126 self.process_row(node) }
128 "mspace" => {
129 Ok(MathBox::from_text(" "))
130 }
131 "semantics" => {
132 if let Some(child) = node.children().filter(|n| n.is_element()).next() {
134 self.process_element(&child)
135 } else {
136 Ok(MathBox::empty(0, 1, 0))
137 }
138 }
139 "annotation" | "annotation-xml" => {
140 Ok(MathBox::empty(0, 1, 0))
142 }
143 _ => {
144 self.process_row(node)
146 }
147 }
148 }
149
150 fn process_row(&self, node: &Node) -> Result<MathBox, RenderError> {
151 self.process_row_inner(node, true)
152 }
153
154 fn process_row_compact(&self, node: &Node) -> Result<MathBox, RenderError> {
155 self.process_row_inner(node, false)
156 }
157
158 fn process_row_inner(&self, node: &Node, add_spacing: bool) -> Result<MathBox, RenderError> {
159 let child_nodes: Vec<_> = node.children().filter(|n| n.is_element()).collect();
160
161 if child_nodes.is_empty() {
162 let text = self.get_text_content(node);
163 if !text.is_empty() {
164 return Ok(MathBox::from_text(&text));
165 }
166 return Ok(MathBox::empty(0, 1, 0));
167 }
168
169 let mut boxes = Vec::new();
170 let mut prev_multiline = false;
171
172 for (i, child) in child_nodes.iter().enumerate() {
173 let child_box = self.process_element(child)?;
174 let is_multiline = child_box.height > 1;
175
176 if add_spacing && i > 0 && (prev_multiline || is_multiline) {
178 boxes.push(MathBox::from_text(" "));
179 }
180
181 if add_spacing && child.tag_name().name() == "mo" {
183 let op = self.get_text_content(child);
184 let is_first = i == 0;
185 let is_binary_op = !is_first && matches!(op.as_str(), "+" | "-" | "±" | "∓");
186 let is_relation = matches!(
187 op.as_str(),
188 "=" | "≤" | "≥" | "≠" | "≈" | "≡" | "→" | "⇒" | "⟹" | "×" | "÷" | "·"
189 );
190
191 if is_binary_op || is_relation {
192 if !prev_multiline && !is_multiline {
194 boxes.push(MathBox::from_text(" "));
195 }
196 boxes.push(child_box);
197 boxes.push(MathBox::from_text(" "));
198 prev_multiline = is_multiline;
199 continue;
200 }
201 }
202 boxes.push(child_box);
203 prev_multiline = is_multiline;
204 }
205
206 Ok(MathBox::concat_horizontal(&boxes))
207 }
208
209 fn process_text(&self, node: &Node) -> Result<MathBox, RenderError> {
210 let text = self.get_text_content(node);
211
212 if let Some(greek) = get_greek(&text) {
214 return Ok(MathBox::from_text(&greek.to_string()));
215 }
216
217 Ok(MathBox::from_text(&text))
218 }
219
220 fn process_operator(&self, node: &Node) -> Result<MathBox, RenderError> {
221 let text = self.get_text_content(node);
222
223 let rendered = match text.as_str() {
225 "∑" | "∏" | "∫" | "∬" | "∭" | "∮" | "⋃" | "⋂" => {
226 text
228 }
229 _ => {
230 if text.starts_with('\\') {
232 let cmd = &text[1..];
233 if let Some(sym) = get_symbol(cmd) {
234 sym.to_string()
235 } else if let Some(greek) = get_greek(cmd) {
236 greek.to_string()
237 } else {
238 text
239 }
240 } else {
241 text
242 }
243 }
244 };
245
246 Ok(MathBox::from_text(&rendered))
248 }
249
250 fn process_superscript(&self, node: &Node) -> Result<MathBox, RenderError> {
251 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
252 if children.len() != 2 {
253 return Err(RenderError::InvalidStructure(
254 "msup requires exactly 2 children".to_string(),
255 ));
256 }
257
258 let base = self.process_element(&children[0])?;
259 let sup = if children[1].tag_name().name() == "mrow" {
261 self.process_row_compact(&children[1])?
262 } else {
263 self.process_element(&children[1])?
264 };
265
266 if self.use_unicode_scripts && base.height == 1 && sup.height == 1 {
268 let sup_text = sup.to_string();
269 if let Some(unicode_sup) = to_superscript(sup_text.trim()) {
270 let combined = format!("{}{}", base.to_string(), unicode_sup);
271 return Ok(MathBox::from_text(&combined));
272 }
273 }
274
275 let width = base.width + sup.width;
277 let height = base.height + 1;
278 let mut result = MathBox::empty(width, height, base.baseline + 1);
279
280 result.blit(&base, 0, 1);
282 result.blit(&sup, base.width, 0);
284
285 Ok(result)
286 }
287
288 fn process_subscript(&self, node: &Node) -> Result<MathBox, RenderError> {
289 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
290 if children.len() != 2 {
291 return Err(RenderError::InvalidStructure(
292 "msub requires exactly 2 children".to_string(),
293 ));
294 }
295
296 let base = self.process_element(&children[0])?;
297 let sub = if children[1].tag_name().name() == "mrow" {
299 self.process_row_compact(&children[1])?
300 } else {
301 self.process_element(&children[1])?
302 };
303
304 if self.use_unicode_scripts && base.height == 1 && sub.height == 1 {
306 let sub_text = sub.to_string();
307 if let Some(unicode_sub) = to_subscript(sub_text.trim()) {
308 let combined = format!("{}{}", base.to_string(), unicode_sub);
309 return Ok(MathBox::from_text(&combined));
310 }
311 }
312
313 let width = base.width + sub.width;
315 let height = base.height + 1;
316 let mut result = MathBox::empty(width, height, base.baseline);
317
318 result.blit(&base, 0, 0);
320 result.blit(&sub, base.width, base.height);
322
323 Ok(result)
324 }
325
326 fn process_subsup(&self, node: &Node) -> Result<MathBox, RenderError> {
327 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
328 if children.len() != 3 {
329 return Err(RenderError::InvalidStructure(
330 "msubsup requires exactly 3 children".to_string(),
331 ));
332 }
333
334 let base_text = self.get_text_content(&children[0]);
336 let is_big_operator = matches!(
337 base_text.as_str(),
338 "∫" | "∬" | "∭" | "∮" | "∑" | "∏" | "⋃" | "⋂"
339 );
340
341 let base = self.process_element(&children[0])?;
342 let sub = self.process_element(&children[1])?;
343 let sup = self.process_element(&children[2])?;
344
345 if is_big_operator {
347 return Ok(MathBox::stack_vertical(&[sup, base, sub]));
348 }
349
350 if self.use_unicode_scripts && base.height == 1 && sub.height == 1 && sup.height == 1 {
352 let sub_text = sub.to_string();
353 let sup_text = sup.to_string();
354 if let (Some(unicode_sub), Some(unicode_sup)) =
355 (to_subscript(sub_text.trim()), to_superscript(sup_text.trim()))
356 {
357 let combined = format!("{}{}{}", base.to_string(), unicode_sub, unicode_sup);
358 return Ok(MathBox::from_text(&combined));
359 }
360 }
361
362 let script_width = sub.width.max(sup.width);
364 let width = base.width + script_width;
365 let height = base.height + 2;
366 let mut result = MathBox::empty(width, height, base.baseline + 1);
367
368 result.blit(&base, 0, 1);
369 result.blit(&sup, base.width, 0);
370 result.blit(&sub, base.width, height - 1);
371
372 Ok(result)
373 }
374
375 fn process_fraction(&self, node: &Node) -> Result<MathBox, RenderError> {
376 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
377 if children.len() != 2 {
378 return Err(RenderError::InvalidStructure(
379 "mfrac requires exactly 2 children".to_string(),
380 ));
381 }
382
383 let num = self.process_element(&children[0])?;
384 let den = self.process_element(&children[1])?;
385
386 let width = num.width.max(den.width);
387 let height = num.height + 1 + den.height;
388 let baseline = num.height;
389
390 let mut result = MathBox::empty(width, height, baseline);
391
392 let num_offset = (width - num.width) / 2;
394 result.blit(&num, num_offset, 0);
395
396 result.fill_row(num.height, '─');
398
399 let den_offset = (width - den.width) / 2;
401 result.blit(&den, den_offset, num.height + 1);
402
403 Ok(result)
404 }
405
406 fn process_sqrt(&self, node: &Node) -> Result<MathBox, RenderError> {
407 let inner = self.process_row(node)?;
408
409 if inner.height == 1 {
414 let inner_text = inner.to_string();
415 let inner_width = inner_text.chars().count();
416
417 let width = 1 + inner_width;
419 let height = 2;
420 let mut result = MathBox::empty(width, height, 1);
421
422 for x in 1..width {
424 result.set(x, 0, '_');
425 }
426
427 result.set(0, 1, '√');
429 for (i, ch) in inner_text.chars().enumerate() {
430 result.set(1 + i, 1, ch);
431 }
432
433 return Ok(result);
434 }
435
436 let width = inner.width + 1;
438 let height = inner.height + 1;
439 let mut result = MathBox::empty(width, height, inner.baseline + 1);
440
441 for x in 1..width {
443 result.set(x, 0, '_');
444 }
445
446 result.set(0, 1, '√');
448
449 result.blit(&inner, 1, 1);
451
452 Ok(result)
453 }
454
455 fn process_nthroot(&self, node: &Node) -> Result<MathBox, RenderError> {
456 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
457 if children.len() != 2 {
458 return Err(RenderError::InvalidStructure(
459 "mroot requires exactly 2 children".to_string(),
460 ));
461 }
462
463 let inner = self.process_element(&children[0])?;
464 let index = self.process_element(&children[1])?;
465
466 let index_text = index.to_string();
468 if let Some(unicode_idx) = to_superscript(index_text.trim()) {
469 let text = format!("{}√{}", unicode_idx, inner.to_string());
470 return Ok(MathBox::from_text(&text));
471 }
472
473 let width = index.width + inner.width + 2;
475 let height = (inner.height + 1).max(index.height);
476 let mut result = MathBox::empty(width, height, height / 2);
477
478 result.blit(&index, 0, 0);
480
481 result.set(index.width, height - 1, '√');
483 for x in (index.width + 1)..width {
484 result.set(x, 0, '─');
485 }
486 result.blit(&inner, index.width + 2, 1);
487
488 Ok(result)
489 }
490
491 fn process_over(&self, node: &Node) -> Result<MathBox, RenderError> {
492 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
493 if children.len() != 2 {
494 return Err(RenderError::InvalidStructure(
495 "mover requires exactly 2 children".to_string(),
496 ));
497 }
498
499 let base = self.process_element(&children[0])?;
500 let over = self.process_element(&children[1])?;
501
502 let over_text = over.to_string().trim().to_string();
503
504 if base.height == 1 {
506 let accent = match over_text.as_str() {
507 "^" | "ˆ" => Some("̂"), "~" | "˜" => Some("̃"), "¯" | "-" => Some("̄"), "." => Some("̇"), ".." | "¨" => Some("̈"), "→" => Some("⃗"), _ => None,
514 };
515 if let Some(combining) = accent {
516 let base_text = base.to_string();
517 let text = format!("{}{}", base_text, combining);
518 return Ok(MathBox::from_text(&text));
519 }
520 }
521
522 Ok(MathBox::stack_vertical(&[over, base]))
524 }
525
526 fn process_under(&self, node: &Node) -> Result<MathBox, RenderError> {
527 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
528 if children.len() != 2 {
529 return Err(RenderError::InvalidStructure(
530 "munder requires exactly 2 children".to_string(),
531 ));
532 }
533
534 let base_text = self.get_text_content(&children[0]);
535 let base = self.process_element(&children[0])?;
536 let under = self.process_element(&children[1])?;
537
538 if base_text == "lim" || base_text == "max" || base_text == "min" || base_text == "sup" || base_text == "inf" {
540 let under_text = under.to_string();
542 let under_trimmed = under_text.trim();
543
544 if let Some(subscript) = to_subscript(under_trimmed) {
546 let combined = format!("{}{}", base_text, subscript);
547 return Ok(MathBox::from_text(&combined));
548 }
549
550 let combined = format!("{}({})", base_text, under_trimmed);
552 return Ok(MathBox::from_text(&combined));
553 }
554
555 let width = base.width.max(under.width);
557 let height = base.height + under.height;
558 let mut result = MathBox::empty(width, height, base.baseline);
559
560 let base_offset = (width - base.width) / 2;
562 result.blit(&base, base_offset, 0);
563
564 let under_offset = (width - under.width) / 2;
566 result.blit(&under, under_offset, base.height);
567
568 Ok(result)
569 }
570
571 fn process_underover(&self, node: &Node) -> Result<MathBox, RenderError> {
572 let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
573 if children.len() != 3 {
574 return Err(RenderError::InvalidStructure(
575 "munderover requires exactly 3 children".to_string(),
576 ));
577 }
578
579 let base = self.process_element(&children[0])?;
580 let under = self.process_element(&children[1])?;
581 let over = self.process_element(&children[2])?;
582
583 Ok(MathBox::stack_vertical(&[over, base, under]))
584 }
585
586 fn process_table(&self, node: &Node) -> Result<MathBox, RenderError> {
587 let rows: Vec<Vec<MathBox>> = node
588 .children()
589 .filter(|n| n.is_element() && n.tag_name().name() == "mtr")
590 .map(|row| {
591 row.children()
592 .filter(|n| n.is_element() && n.tag_name().name() == "mtd")
593 .map(|cell| self.process_row(&cell))
594 .collect::<Result<Vec<_>, _>>()
595 })
596 .collect::<Result<Vec<_>, _>>()?;
597
598 if rows.is_empty() {
599 return Ok(MathBox::empty(0, 1, 0));
600 }
601
602 let num_cols = rows.iter().map(|r| r.len()).max().unwrap_or(0);
604 let mut col_widths = vec![0; num_cols];
605 let mut row_heights = vec![0; rows.len()];
606
607 for (i, row) in rows.iter().enumerate() {
608 for (j, cell) in row.iter().enumerate() {
609 col_widths[j] = col_widths[j].max(cell.width);
610 row_heights[i] = row_heights[i].max(cell.height);
611 }
612 }
613
614 let spacing = 2;
616 let total_width: usize = col_widths.iter().sum::<usize>() + spacing * (num_cols.saturating_sub(1));
617 let total_height: usize = row_heights.iter().sum();
618
619 let mut result = MathBox::empty(total_width, total_height, total_height / 2);
620
621 let mut y_pos = 0;
622 for (i, row) in rows.iter().enumerate() {
623 let mut x_pos = 0;
624 for (j, cell) in row.iter().enumerate() {
625 let x_offset = (col_widths[j] - cell.width) / 2;
627 result.blit(cell, x_pos + x_offset, y_pos);
628 x_pos += col_widths[j] + spacing;
629 }
630 y_pos += row_heights[i];
631 }
632
633 Ok(result)
634 }
635
636 fn process_table_row(&self, node: &Node) -> Result<MathBox, RenderError> {
637 let cells: Vec<MathBox> = node
638 .children()
639 .filter(|n| n.is_element())
640 .map(|n| self.process_row(&n))
641 .collect::<Result<Vec<_>, _>>()?;
642
643 let spacing = MathBox::from_text(" ");
645 let mut parts = Vec::new();
646 for (i, cell) in cells.into_iter().enumerate() {
647 if i > 0 {
648 parts.push(spacing.clone());
649 }
650 parts.push(cell);
651 }
652
653 Ok(MathBox::concat_horizontal(&parts))
654 }
655
656 fn process_fenced(&self, node: &Node) -> Result<MathBox, RenderError> {
657 let open = node.attribute("open").unwrap_or("(");
658 let close = node.attribute("close").unwrap_or(")");
659
660 let inner = self.process_row(node)?;
661
662 if inner.height <= 1 {
663 let text = format!("{}{}{}", open, inner.to_string(), close);
665 return Ok(MathBox::from_text(&text));
666 }
667
668 let left_chars = BRACKETS.get_left(open, inner.height);
670 let right_chars = BRACKETS.get_right(close, inner.height);
671
672 let width = 1 + inner.width + 1;
673 let height = inner.height;
674 let mut result = MathBox::empty(width, height, inner.baseline);
675
676 for (y, &ch) in left_chars.iter().enumerate() {
678 result.set(0, y, ch);
679 }
680 for (y, &ch) in right_chars.iter().enumerate() {
681 result.set(width - 1, y, ch);
682 }
683
684 result.blit(&inner, 1, 0);
686
687 Ok(result)
688 }
689
690 fn get_text_content(&self, node: &Node) -> String {
691 let mut text = String::new();
692 for child in node.children() {
693 if child.is_text() {
694 text.push_str(child.text().unwrap_or(""));
695 }
696 }
697 text.trim().to_string()
698 }
699}
700
701impl Default for MathRenderer {
702 fn default() -> Self {
703 Self::new()
704 }
705}
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710
711 #[test]
712 fn test_simple_expression() {
713 let renderer = MathRenderer::new();
714 let result = renderer.render_latex("x + y").unwrap();
715 assert!(result.contains('x'));
716 assert!(result.contains('y'));
717 }
718
719 #[test]
720 fn test_superscript() {
721 let renderer = MathRenderer::new();
722 let result = renderer.render_latex("x^2").unwrap();
723 assert!(result.contains('²') || result.contains('2'));
725 }
726
727 #[test]
728 fn test_fraction() {
729 let renderer = MathRenderer::new();
730 let result = renderer.render_latex(r"\frac{a}{b}").unwrap();
731 assert!(result.contains('a'));
732 assert!(result.contains('b'));
733 assert!(result.contains('─'));
734 }
735}