Skip to main content

tui_math/
renderer.rs

1//! MathML to Unicode terminal renderer
2
3use crate::mathbox::MathBox;
4use crate::unicode_maps::{get_greek, get_symbol, to_subscript, to_superscript, BRACKETS};
5use latex2mathml::{latex_to_mathml, DisplayStyle};
6use roxmltree::{Document, Node};
7use std::fmt;
8
9/// Errors that can occur during math rendering
10#[derive(Debug)]
11pub enum RenderError {
12    LatexConversion(String),
13    MathMLParse(String),
14    InvalidStructure(String),
15}
16
17impl fmt::Display for RenderError {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            RenderError::LatexConversion(e) => write!(f, "LaTeX conversion error: {}", e),
21            RenderError::MathMLParse(e) => write!(f, "MathML parse error: {}", e),
22            RenderError::InvalidStructure(e) => write!(f, "Invalid math structure: {}", e),
23        }
24    }
25}
26
27impl std::error::Error for RenderError {}
28
29/// Math renderer that converts LaTeX/MathML to Unicode terminal output
30pub struct MathRenderer {
31    use_unicode_scripts: bool,
32}
33
34impl MathRenderer {
35    pub fn new() -> Self {
36        Self {
37            use_unicode_scripts: true,
38        }
39    }
40
41    /// Set whether to use Unicode superscript/subscript characters when possible
42    pub fn use_unicode_scripts(mut self, use_unicode: bool) -> Self {
43        self.use_unicode_scripts = use_unicode;
44        self
45    }
46
47    /// Render LaTeX math to Unicode string
48    pub fn render_latex(&self, latex: &str) -> Result<String, RenderError> {
49        let mathml = latex_to_mathml(latex, DisplayStyle::Inline)
50            .map_err(|e| RenderError::LatexConversion(e.to_string()))?;
51        self.render_mathml(&mathml)
52    }
53
54    /// Render MathML to Unicode string
55    pub fn render_mathml(&self, mathml: &str) -> Result<String, RenderError> {
56        let doc = Document::parse(mathml)
57            .map_err(|e| RenderError::MathMLParse(e.to_string()))?;
58        let root = doc.root_element();
59        let math_box = self.process_element(&root)?;
60        Ok(math_box.to_string())
61    }
62
63    /// Render to MathBox (for advanced usage)
64    pub fn render_to_box(&self, latex: &str) -> Result<MathBox, RenderError> {
65        let mathml = latex_to_mathml(latex, DisplayStyle::Inline)
66            .map_err(|e| RenderError::LatexConversion(e.to_string()))?;
67        let doc = Document::parse(&mathml)
68            .map_err(|e| RenderError::MathMLParse(e.to_string()))?;
69        let root = doc.root_element();
70        self.process_element(&root)
71    }
72
73    fn process_element(&self, node: &Node) -> Result<MathBox, RenderError> {
74        let tag = node.tag_name().name();
75
76        match tag {
77            "math" | "mrow" | "mstyle" | "mpadded" | "mphantom" => {
78                self.process_row(node)
79            }
80            "mi" | "mn" | "mtext" => {
81                self.process_text(node)
82            }
83            "mo" => {
84                self.process_operator(node)
85            }
86            "msup" => {
87                self.process_superscript(node)
88            }
89            "msub" => {
90                self.process_subscript(node)
91            }
92            "msubsup" => {
93                self.process_subsup(node)
94            }
95            "mfrac" => {
96                self.process_fraction(node)
97            }
98            "msqrt" => {
99                self.process_sqrt(node)
100            }
101            "mroot" => {
102                self.process_nthroot(node)
103            }
104            "mover" => {
105                self.process_over(node)
106            }
107            "munder" => {
108                self.process_under(node)
109            }
110            "munderover" => {
111                self.process_underover(node)
112            }
113            "mtable" => {
114                self.process_table(node)
115            }
116            "mtr" => {
117                self.process_table_row(node)
118            }
119            "mtd" => {
120                self.process_row(node)
121            }
122            "mfenced" => {
123                self.process_fenced(node)
124            }
125            "menclose" => {
126                self.process_row(node) // Simplified
127            }
128            "mspace" => {
129                Ok(MathBox::from_text(" "))
130            }
131            "semantics" => {
132                // Process first child only
133                if let Some(child) = node.children().filter(|n| n.is_element()).next() {
134                    self.process_element(&child)
135                } else {
136                    Ok(MathBox::empty(0, 1, 0))
137                }
138            }
139            "annotation" | "annotation-xml" => {
140                // Skip annotations
141                Ok(MathBox::empty(0, 1, 0))
142            }
143            _ => {
144                // Unknown element, try to process children
145                self.process_row(node)
146            }
147        }
148    }
149
150    fn process_row(&self, node: &Node) -> Result<MathBox, RenderError> {
151        self.process_row_inner(node, true)
152    }
153
154    fn process_row_compact(&self, node: &Node) -> Result<MathBox, RenderError> {
155        self.process_row_inner(node, false)
156    }
157
158    fn process_row_inner(&self, node: &Node, add_spacing: bool) -> Result<MathBox, RenderError> {
159        let child_nodes: Vec<_> = node.children().filter(|n| n.is_element()).collect();
160
161        if child_nodes.is_empty() {
162            let text = self.get_text_content(node);
163            if !text.is_empty() {
164                return Ok(MathBox::from_text(&text));
165            }
166            return Ok(MathBox::empty(0, 1, 0));
167        }
168
169        let mut boxes = Vec::new();
170        let mut prev_multiline = false;
171
172        for (i, child) in child_nodes.iter().enumerate() {
173            let child_box = self.process_element(child)?;
174            let is_multiline = child_box.height > 1;
175
176            // Add spacing between multi-line elements
177            if add_spacing && i > 0 && (prev_multiline || is_multiline) {
178                boxes.push(MathBox::from_text(" "));
179            }
180
181            // Add spacing around binary operators in row context (not in compact mode)
182            if add_spacing && child.tag_name().name() == "mo" {
183                let op = self.get_text_content(child);
184                let is_first = i == 0;
185                let is_binary_op = !is_first && matches!(op.as_str(), "+" | "-" | "±" | "∓");
186                let is_relation = matches!(
187                    op.as_str(),
188                    "=" | "≤" | "≥" | "≠" | "≈" | "≡" | "→" | "⇒" | "⟹" | "×" | "÷" | "·"
189                );
190
191                if is_binary_op || is_relation {
192                    // Don't add extra space if we just added one for multiline
193                    if !prev_multiline && !is_multiline {
194                        boxes.push(MathBox::from_text(" "));
195                    }
196                    boxes.push(child_box);
197                    boxes.push(MathBox::from_text(" "));
198                    prev_multiline = is_multiline;
199                    continue;
200                }
201            }
202            boxes.push(child_box);
203            prev_multiline = is_multiline;
204        }
205
206        Ok(MathBox::concat_horizontal(&boxes))
207    }
208
209    fn process_text(&self, node: &Node) -> Result<MathBox, RenderError> {
210        let text = self.get_text_content(node);
211
212        // Handle Greek letters and special identifiers
213        if let Some(greek) = get_greek(&text) {
214            return Ok(MathBox::from_text(&greek.to_string()));
215        }
216
217        Ok(MathBox::from_text(&text))
218    }
219
220    fn process_operator(&self, node: &Node) -> Result<MathBox, RenderError> {
221        let text = self.get_text_content(node);
222
223        // Handle special operators
224        let rendered = match text.as_str() {
225            "∑" | "∏" | "∫" | "∬" | "∭" | "∮" | "⋃" | "⋂" => {
226                // Big operators - keep as is
227                text
228            }
229            _ => {
230                // Check if it's a LaTeX command
231                if text.starts_with('\\') {
232                    let cmd = &text[1..];
233                    if let Some(sym) = get_symbol(cmd) {
234                        sym.to_string()
235                    } else if let Some(greek) = get_greek(cmd) {
236                        greek.to_string()
237                    } else {
238                        text
239                    }
240                } else {
241                    text
242                }
243            }
244        };
245
246        // Spacing is handled in process_row for context-aware operator spacing
247        Ok(MathBox::from_text(&rendered))
248    }
249
250    fn process_superscript(&self, node: &Node) -> Result<MathBox, RenderError> {
251        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
252        if children.len() != 2 {
253            return Err(RenderError::InvalidStructure(
254                "msup requires exactly 2 children".to_string(),
255            ));
256        }
257
258        let base = self.process_element(&children[0])?;
259        // Use compact mode for superscript content (no spacing around operators)
260        let sup = if children[1].tag_name().name() == "mrow" {
261            self.process_row_compact(&children[1])?
262        } else {
263            self.process_element(&children[1])?
264        };
265
266        // Try Unicode superscript for simple cases
267        if self.use_unicode_scripts && base.height == 1 && sup.height == 1 {
268            let sup_text = sup.to_string();
269            if let Some(unicode_sup) = to_superscript(sup_text.trim()) {
270                let combined = format!("{}{}", base.to_string(), unicode_sup);
271                return Ok(MathBox::from_text(&combined));
272            }
273        }
274
275        // Fall back to 2D rendering
276        let width = base.width + sup.width;
277        let height = base.height + 1;
278        let mut result = MathBox::empty(width, height, base.baseline + 1);
279
280        // Place base at bottom
281        result.blit(&base, 0, 1);
282        // Place superscript at top-right
283        result.blit(&sup, base.width, 0);
284
285        Ok(result)
286    }
287
288    fn process_subscript(&self, node: &Node) -> Result<MathBox, RenderError> {
289        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
290        if children.len() != 2 {
291            return Err(RenderError::InvalidStructure(
292                "msub requires exactly 2 children".to_string(),
293            ));
294        }
295
296        let base = self.process_element(&children[0])?;
297        // Use compact mode for subscript content (no spacing around operators)
298        let sub = if children[1].tag_name().name() == "mrow" {
299            self.process_row_compact(&children[1])?
300        } else {
301            self.process_element(&children[1])?
302        };
303
304        // Try Unicode subscript for simple cases
305        if self.use_unicode_scripts && base.height == 1 && sub.height == 1 {
306            let sub_text = sub.to_string();
307            if let Some(unicode_sub) = to_subscript(sub_text.trim()) {
308                let combined = format!("{}{}", base.to_string(), unicode_sub);
309                return Ok(MathBox::from_text(&combined));
310            }
311        }
312
313        // Fall back to 2D rendering
314        let width = base.width + sub.width;
315        let height = base.height + 1;
316        let mut result = MathBox::empty(width, height, base.baseline);
317
318        // Place base at top
319        result.blit(&base, 0, 0);
320        // Place subscript at bottom-right
321        result.blit(&sub, base.width, base.height);
322
323        Ok(result)
324    }
325
326    fn process_subsup(&self, node: &Node) -> Result<MathBox, RenderError> {
327        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
328        if children.len() != 3 {
329            return Err(RenderError::InvalidStructure(
330                "msubsup requires exactly 3 children".to_string(),
331            ));
332        }
333
334        // Check if base is a big operator (integral, sum, etc.)
335        let base_text = self.get_text_content(&children[0]);
336        let is_big_operator = matches!(
337            base_text.as_str(),
338            "∫" | "∬" | "∭" | "∮" | "∑" | "∏" | "⋃" | "⋂"
339        );
340
341        let base = self.process_element(&children[0])?;
342        let sub = self.process_element(&children[1])?;
343        let sup = self.process_element(&children[2])?;
344
345        // For big operators, stack limits vertically (centered)
346        if is_big_operator {
347            return Ok(MathBox::stack_vertical(&[sup, base, sub]));
348        }
349
350        // Try Unicode scripts for simple cases
351        if self.use_unicode_scripts && base.height == 1 && sub.height == 1 && sup.height == 1 {
352            let sub_text = sub.to_string();
353            let sup_text = sup.to_string();
354            if let (Some(unicode_sub), Some(unicode_sup)) =
355                (to_subscript(sub_text.trim()), to_superscript(sup_text.trim()))
356            {
357                let combined = format!("{}{}{}", base.to_string(), unicode_sub, unicode_sup);
358                return Ok(MathBox::from_text(&combined));
359            }
360        }
361
362        // 2D rendering with both
363        let script_width = sub.width.max(sup.width);
364        let width = base.width + script_width;
365        let height = base.height + 2;
366        let mut result = MathBox::empty(width, height, base.baseline + 1);
367
368        result.blit(&base, 0, 1);
369        result.blit(&sup, base.width, 0);
370        result.blit(&sub, base.width, height - 1);
371
372        Ok(result)
373    }
374
375    fn process_fraction(&self, node: &Node) -> Result<MathBox, RenderError> {
376        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
377        if children.len() != 2 {
378            return Err(RenderError::InvalidStructure(
379                "mfrac requires exactly 2 children".to_string(),
380            ));
381        }
382
383        let num = self.process_element(&children[0])?;
384        let den = self.process_element(&children[1])?;
385
386        let width = num.width.max(den.width);
387        let height = num.height + 1 + den.height;
388        let baseline = num.height;
389
390        let mut result = MathBox::empty(width, height, baseline);
391
392        // Center numerator
393        let num_offset = (width - num.width) / 2;
394        result.blit(&num, num_offset, 0);
395
396        // Draw fraction line using box-drawing character
397        result.fill_row(num.height, '─');
398
399        // Center denominator
400        let den_offset = (width - den.width) / 2;
401        result.blit(&den, den_offset, num.height + 1);
402
403        Ok(result)
404    }
405
406    fn process_sqrt(&self, node: &Node) -> Result<MathBox, RenderError> {
407        let inner = self.process_row(node)?;
408
409        // Simple sqrt rendering: √ followed by content with overline
410        // Layout: ___
411        //        √abc
412
413        if inner.height == 1 {
414            let inner_text = inner.to_string();
415            let inner_width = inner_text.chars().count();
416
417            // Single line: √ + content, with overline above content
418            let width = 1 + inner_width;
419            let height = 2;
420            let mut result = MathBox::empty(width, height, 1);
421
422            // Draw bar above the content (not above √)
423            for x in 1..width {
424                result.set(x, 0, '_');
425            }
426
427            // Draw √ and content
428            result.set(0, 1, '√');
429            for (i, ch) in inner_text.chars().enumerate() {
430                result.set(1 + i, 1, ch);
431            }
432
433            return Ok(result);
434        }
435
436        // Multi-line sqrt - use simple bracket approach
437        let width = inner.width + 1;
438        let height = inner.height + 1;
439        let mut result = MathBox::empty(width, height, inner.baseline + 1);
440
441        // Draw bar
442        for x in 1..width {
443            result.set(x, 0, '_');
444        }
445
446        // Draw √ at the left
447        result.set(0, 1, '√');
448
449        // Place content
450        result.blit(&inner, 1, 1);
451
452        Ok(result)
453    }
454
455    fn process_nthroot(&self, node: &Node) -> Result<MathBox, RenderError> {
456        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
457        if children.len() != 2 {
458            return Err(RenderError::InvalidStructure(
459                "mroot requires exactly 2 children".to_string(),
460            ));
461        }
462
463        let inner = self.process_element(&children[0])?;
464        let index = self.process_element(&children[1])?;
465
466        // Try Unicode superscript for index
467        let index_text = index.to_string();
468        if let Some(unicode_idx) = to_superscript(index_text.trim()) {
469            let text = format!("{}√{}", unicode_idx, inner.to_string());
470            return Ok(MathBox::from_text(&text));
471        }
472
473        // 2D rendering
474        let width = index.width + inner.width + 2;
475        let height = (inner.height + 1).max(index.height);
476        let mut result = MathBox::empty(width, height, height / 2);
477
478        // Place index
479        result.blit(&index, 0, 0);
480
481        // Draw sqrt and content
482        result.set(index.width, height - 1, '√');
483        for x in (index.width + 1)..width {
484            result.set(x, 0, '─');
485        }
486        result.blit(&inner, index.width + 2, 1);
487
488        Ok(result)
489    }
490
491    fn process_over(&self, node: &Node) -> Result<MathBox, RenderError> {
492        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
493        if children.len() != 2 {
494            return Err(RenderError::InvalidStructure(
495                "mover requires exactly 2 children".to_string(),
496            ));
497        }
498
499        let base = self.process_element(&children[0])?;
500        let over = self.process_element(&children[1])?;
501
502        let over_text = over.to_string().trim().to_string();
503
504        // Handle common accents on single-height bases
505        if base.height == 1 {
506            let accent = match over_text.as_str() {
507                "^" | "ˆ" => Some("̂"),  // Combining circumflex
508                "~" | "˜" => Some("̃"),  // Combining tilde
509                "¯" | "-" => Some("̄"),  // Combining macron (bar)
510                "." => Some("̇"),        // Combining dot above
511                ".." | "¨" => Some("̈"), // Combining diaeresis
512                "→" => Some("⃗"),        // Combining right arrow
513                _ => None,
514            };
515            if let Some(combining) = accent {
516                let base_text = base.to_string();
517                let text = format!("{}{}", base_text, combining);
518                return Ok(MathBox::from_text(&text));
519            }
520        }
521
522        // Stack vertically
523        Ok(MathBox::stack_vertical(&[over, base]))
524    }
525
526    fn process_under(&self, node: &Node) -> Result<MathBox, RenderError> {
527        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
528        if children.len() != 2 {
529            return Err(RenderError::InvalidStructure(
530                "munder requires exactly 2 children".to_string(),
531            ));
532        }
533
534        let base_text = self.get_text_content(&children[0]);
535        let base = self.process_element(&children[0])?;
536        let under = self.process_element(&children[1])?;
537
538        // For "lim" and similar operators, render subscript inline
539        if base_text == "lim" || base_text == "max" || base_text == "min" || base_text == "sup" || base_text == "inf" {
540            // Try to convert to Unicode subscript, fallback to parentheses
541            let under_text = under.to_string();
542            let under_trimmed = under_text.trim();
543
544            // Try full Unicode subscript conversion
545            if let Some(subscript) = to_subscript(under_trimmed) {
546                let combined = format!("{}{}", base_text, subscript);
547                return Ok(MathBox::from_text(&combined));
548            }
549
550            // Fallback: use parentheses notation
551            let combined = format!("{}({})", base_text, under_trimmed);
552            return Ok(MathBox::from_text(&combined));
553        }
554
555        // For other elements, stack with baseline at the base element
556        let width = base.width.max(under.width);
557        let height = base.height + under.height;
558        let mut result = MathBox::empty(width, height, base.baseline);
559
560        // Center base
561        let base_offset = (width - base.width) / 2;
562        result.blit(&base, base_offset, 0);
563
564        // Center under below base
565        let under_offset = (width - under.width) / 2;
566        result.blit(&under, under_offset, base.height);
567
568        Ok(result)
569    }
570
571    fn process_underover(&self, node: &Node) -> Result<MathBox, RenderError> {
572        let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
573        if children.len() != 3 {
574            return Err(RenderError::InvalidStructure(
575                "munderover requires exactly 3 children".to_string(),
576            ));
577        }
578
579        let base = self.process_element(&children[0])?;
580        let under = self.process_element(&children[1])?;
581        let over = self.process_element(&children[2])?;
582
583        Ok(MathBox::stack_vertical(&[over, base, under]))
584    }
585
586    fn process_table(&self, node: &Node) -> Result<MathBox, RenderError> {
587        let rows: Vec<Vec<MathBox>> = node
588            .children()
589            .filter(|n| n.is_element() && n.tag_name().name() == "mtr")
590            .map(|row| {
591                row.children()
592                    .filter(|n| n.is_element() && n.tag_name().name() == "mtd")
593                    .map(|cell| self.process_row(&cell))
594                    .collect::<Result<Vec<_>, _>>()
595            })
596            .collect::<Result<Vec<_>, _>>()?;
597
598        if rows.is_empty() {
599            return Ok(MathBox::empty(0, 1, 0));
600        }
601
602        // Calculate column widths and row heights
603        let num_cols = rows.iter().map(|r| r.len()).max().unwrap_or(0);
604        let mut col_widths = vec![0; num_cols];
605        let mut row_heights = vec![0; rows.len()];
606
607        for (i, row) in rows.iter().enumerate() {
608            for (j, cell) in row.iter().enumerate() {
609                col_widths[j] = col_widths[j].max(cell.width);
610                row_heights[i] = row_heights[i].max(cell.height);
611            }
612        }
613
614        // Add spacing
615        let spacing = 2;
616        let total_width: usize = col_widths.iter().sum::<usize>() + spacing * (num_cols.saturating_sub(1));
617        let total_height: usize = row_heights.iter().sum();
618
619        let mut result = MathBox::empty(total_width, total_height, total_height / 2);
620
621        let mut y_pos = 0;
622        for (i, row) in rows.iter().enumerate() {
623            let mut x_pos = 0;
624            for (j, cell) in row.iter().enumerate() {
625                // Center cell in its column
626                let x_offset = (col_widths[j] - cell.width) / 2;
627                result.blit(cell, x_pos + x_offset, y_pos);
628                x_pos += col_widths[j] + spacing;
629            }
630            y_pos += row_heights[i];
631        }
632
633        Ok(result)
634    }
635
636    fn process_table_row(&self, node: &Node) -> Result<MathBox, RenderError> {
637        let cells: Vec<MathBox> = node
638            .children()
639            .filter(|n| n.is_element())
640            .map(|n| self.process_row(&n))
641            .collect::<Result<Vec<_>, _>>()?;
642
643        // Join cells with spacing
644        let spacing = MathBox::from_text("  ");
645        let mut parts = Vec::new();
646        for (i, cell) in cells.into_iter().enumerate() {
647            if i > 0 {
648                parts.push(spacing.clone());
649            }
650            parts.push(cell);
651        }
652
653        Ok(MathBox::concat_horizontal(&parts))
654    }
655
656    fn process_fenced(&self, node: &Node) -> Result<MathBox, RenderError> {
657        let open = node.attribute("open").unwrap_or("(");
658        let close = node.attribute("close").unwrap_or(")");
659
660        let inner = self.process_row(node)?;
661
662        if inner.height <= 1 {
663            // Simple case
664            let text = format!("{}{}{}", open, inner.to_string(), close);
665            return Ok(MathBox::from_text(&text));
666        }
667
668        // Scaled brackets
669        let left_chars = BRACKETS.get_left(open, inner.height);
670        let right_chars = BRACKETS.get_right(close, inner.height);
671
672        let width = 1 + inner.width + 1;
673        let height = inner.height;
674        let mut result = MathBox::empty(width, height, inner.baseline);
675
676        // Draw brackets
677        for (y, &ch) in left_chars.iter().enumerate() {
678            result.set(0, y, ch);
679        }
680        for (y, &ch) in right_chars.iter().enumerate() {
681            result.set(width - 1, y, ch);
682        }
683
684        // Place content
685        result.blit(&inner, 1, 0);
686
687        Ok(result)
688    }
689
690    fn get_text_content(&self, node: &Node) -> String {
691        let mut text = String::new();
692        for child in node.children() {
693            if child.is_text() {
694                text.push_str(child.text().unwrap_or(""));
695            }
696        }
697        text.trim().to_string()
698    }
699}
700
701impl Default for MathRenderer {
702    fn default() -> Self {
703        Self::new()
704    }
705}
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710
711    #[test]
712    fn test_simple_expression() {
713        let renderer = MathRenderer::new();
714        let result = renderer.render_latex("x + y").unwrap();
715        assert!(result.contains('x'));
716        assert!(result.contains('y'));
717    }
718
719    #[test]
720    fn test_superscript() {
721        let renderer = MathRenderer::new();
722        let result = renderer.render_latex("x^2").unwrap();
723        // Should contain Unicode superscript
724        assert!(result.contains('²') || result.contains('2'));
725    }
726
727    #[test]
728    fn test_fraction() {
729        let renderer = MathRenderer::new();
730        let result = renderer.render_latex(r"\frac{a}{b}").unwrap();
731        assert!(result.contains('a'));
732        assert!(result.contains('b'));
733        assert!(result.contains('─'));
734    }
735}