1use rust_latex_parser::{AccentKind, EqNode, MathFontKind, MatrixKind};
4
5use crate::renderer::MathRenderer;
6
7pub struct LatexRenderer;
9
10impl MathRenderer for LatexRenderer {
11 type Output = String;
12
13 fn render(&self, node: &EqNode) -> String {
14 node_to_latex(node)
15 }
16}
17
18fn node_to_latex(node: &EqNode) -> String {
19 match node {
20 EqNode::Text(s) => latex_escape_text(s),
21 EqNode::Space(pts) => space_to_latex(*pts),
22 EqNode::Seq(children) => children.iter().map(node_to_latex).collect(),
23 EqNode::Frac(num, den) => {
24 format!(r"\frac{{{}}}{{{}}}", node_to_latex(num), node_to_latex(den))
25 }
26 EqNode::Sup(base, sup) => {
27 format!("{}^{{{}}}", node_to_latex(base), node_to_latex(sup))
28 }
29 EqNode::Sub(base, sub) => {
30 format!("{}_{{{}}} ", node_to_latex(base), node_to_latex(sub))
31 }
32 EqNode::SupSub(base, sup, sub) => {
33 format!(
34 "{}^{{{}}}_{{{}}}",
35 node_to_latex(base),
36 node_to_latex(sup),
37 node_to_latex(sub)
38 )
39 }
40 EqNode::Sqrt(body) => format!(r"\sqrt{{{}}}", node_to_latex(body)),
41 EqNode::BigOp {
42 symbol,
43 lower,
44 upper,
45 } => {
46 let sym = unicode_to_latex_op(symbol);
47 let mut s = sym;
48 if let Some(lo) = lower {
49 s.push_str(&format!("_{{{}}}", node_to_latex(lo)));
50 }
51 if let Some(up) = upper {
52 s.push_str(&format!("^{{{}}}", node_to_latex(up)));
53 }
54 s
55 }
56 EqNode::Accent(body, kind) => {
57 let cmd = match kind {
58 AccentKind::Hat => r"\hat",
59 AccentKind::Bar => r"\overline",
60 AccentKind::Dot => r"\dot",
61 AccentKind::DoubleDot => r"\ddot",
62 AccentKind::Tilde => r"\tilde",
63 AccentKind::Vec => r"\vec",
64 };
65 format!("{}{{{}}}", cmd, node_to_latex(body))
66 }
67 EqNode::Limit { name, lower } => {
68 let latex_name = format!(r"\{}", name);
69 if let Some(lo) = lower {
70 format!("{}_{{{}}}", latex_name, node_to_latex(lo))
71 } else {
72 latex_name
73 }
74 }
75 EqNode::TextBlock(s) => format!(r"\text{{{}}}", s),
76 EqNode::MathFont { kind, content } => {
77 let cmd = match kind {
78 MathFontKind::Bold => r"\mathbf",
79 MathFontKind::Blackboard => r"\mathbb",
80 MathFontKind::Calligraphic => r"\mathcal",
81 MathFontKind::Roman => r"\mathrm",
82 MathFontKind::Fraktur => r"\mathfrak",
83 MathFontKind::SansSerif => r"\mathsf",
84 MathFontKind::Monospace => r"\mathtt",
85 };
86 format!("{}{{{}}}", cmd, node_to_latex(content))
87 }
88 EqNode::Delimited {
89 left,
90 right,
91 content,
92 } => {
93 format!(
94 r"\left{} {} \right{}",
95 latex_delim(left),
96 node_to_latex(content),
97 latex_delim(right)
98 )
99 }
100 EqNode::Matrix { kind, rows } => {
101 let env = match kind {
102 MatrixKind::Plain => "matrix",
103 MatrixKind::Paren => "pmatrix",
104 MatrixKind::Bracket => "bmatrix",
105 MatrixKind::Brace => "Bmatrix",
106 MatrixKind::VBar => "vmatrix",
107 MatrixKind::DoubleVBar => "Vmatrix",
108 };
109 let rows_str: Vec<String> = rows
110 .iter()
111 .map(|row| {
112 row.iter()
113 .map(node_to_latex)
114 .collect::<Vec<_>>()
115 .join(" & ")
116 })
117 .collect();
118 format!(
119 r"\begin{{{}}} {} \end{{{}}}",
120 env,
121 rows_str.join(r" \\ "),
122 env
123 )
124 }
125 EqNode::Cases { rows } => {
126 let rows_str: Vec<String> = rows
127 .iter()
128 .map(|(val, cond)| {
129 if let Some(c) = cond {
130 format!("{} & {}", node_to_latex(val), node_to_latex(c))
131 } else {
132 node_to_latex(val)
133 }
134 })
135 .collect();
136 format!(r"\begin{{cases}} {} \end{{cases}}", rows_str.join(r" \\ "))
137 }
138 EqNode::Binom(top, bottom) => {
139 format!(
140 r"\binom{{{}}}{{{}}}",
141 node_to_latex(top),
142 node_to_latex(bottom)
143 )
144 }
145 EqNode::Brace {
146 content,
147 label,
148 over,
149 } => {
150 let cmd = if *over { r"\overbrace" } else { r"\underbrace" };
151 let mut s = format!("{}{{{}}}", cmd, node_to_latex(content));
152 if let Some(lbl) = label {
153 if *over {
154 s.push_str(&format!("^{{{}}}", node_to_latex(lbl)));
155 } else {
156 s.push_str(&format!("_{{{}}}", node_to_latex(lbl)));
157 }
158 }
159 s
160 }
161 EqNode::StackRel {
162 base,
163 annotation,
164 over,
165 } => {
166 let cmd = if *over { r"\overset" } else { r"\underset" };
167 format!(
168 "{}{{{}}}{{{}}}",
169 cmd,
170 node_to_latex(annotation),
171 node_to_latex(base)
172 )
173 }
174 }
175}
176
177fn latex_escape_text(s: &str) -> String {
179 let mut result = String::new();
181 for ch in s.chars() {
182 match ch {
183 'α' => result.push_str(r"\alpha "),
184 'β' => result.push_str(r"\beta "),
185 'γ' => result.push_str(r"\gamma "),
186 'δ' => result.push_str(r"\delta "),
187 'ε' => result.push_str(r"\epsilon "),
188 'ζ' => result.push_str(r"\zeta "),
189 'η' => result.push_str(r"\eta "),
190 'θ' => result.push_str(r"\theta "),
191 'ι' => result.push_str(r"\iota "),
192 'κ' => result.push_str(r"\kappa "),
193 'λ' => result.push_str(r"\lambda "),
194 'μ' => result.push_str(r"\mu "),
195 'ν' => result.push_str(r"\nu "),
196 'ξ' => result.push_str(r"\xi "),
197 'π' => result.push_str(r"\pi "),
198 'ρ' => result.push_str(r"\rho "),
199 'σ' => result.push_str(r"\sigma "),
200 'τ' => result.push_str(r"\tau "),
201 'υ' => result.push_str(r"\upsilon "),
202 'φ' => result.push_str(r"\phi "),
203 'χ' => result.push_str(r"\chi "),
204 'ψ' => result.push_str(r"\psi "),
205 'ω' => result.push_str(r"\omega "),
206 '∞' => result.push_str(r"\infty "),
207 '∑' => result.push_str(r"\sum "),
208 '∏' => result.push_str(r"\prod "),
209 '∫' => result.push_str(r"\int "),
210 '±' => result.push_str(r"\pm "),
211 '·' => result.push_str(r"\cdot "),
212 '→' => result.push_str(r"\rightarrow "),
213 '←' => result.push_str(r"\leftarrow "),
214 '≤' => result.push_str(r"\leq "),
215 '≥' => result.push_str(r"\geq "),
216 '≠' => result.push_str(r"\neq "),
217 '∈' => result.push_str(r"\in "),
218 '∀' => result.push_str(r"\forall "),
219 '∃' => result.push_str(r"\exists "),
220 '∂' => result.push_str(r"\partial "),
221 '∇' => result.push_str(r"\nabla "),
222 _ => result.push(ch),
223 }
224 }
225 result
226}
227
228fn space_to_latex(pts: f32) -> String {
229 if pts < 0.0 {
230 r"\!".to_string()
231 } else if pts < 3.0 {
232 r"\,".to_string()
233 } else if pts < 5.0 {
234 r"\;".to_string()
235 } else if pts >= 18.0 {
236 r"\quad ".to_string()
237 } else {
238 " ".to_string()
239 }
240}
241
242fn unicode_to_latex_op(symbol: &str) -> String {
243 match symbol {
244 "∑" => r"\sum".to_string(),
245 "∏" => r"\prod".to_string(),
246 "∫" => r"\int".to_string(),
247 "∬" => r"\iint".to_string(),
248 "∮" => r"\oint".to_string(),
249 "⋃" => r"\bigcup".to_string(),
250 "⋂" => r"\bigcap".to_string(),
251 "⊕" => r"\bigoplus".to_string(),
252 "⊗" => r"\bigotimes".to_string(),
253 _ => symbol.to_string(),
254 }
255}
256
257fn latex_delim(d: &str) -> String {
258 match d {
259 "." => ".".to_string(),
260 "(" | ")" | "[" | "]" | "|" => d.to_string(),
261 "{" => r"\{".to_string(),
262 "}" => r"\}".to_string(),
263 "‖" => r"\|".to_string(),
264 _ => d.to_string(),
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::renderer::MathRenderer;
272 use rust_latex_parser::parse_equation;
273
274 #[test]
275 fn test_simple_fraction_roundtrip() {
276 let renderer = LatexRenderer;
277 let ast = parse_equation(r"\frac{a}{b}");
278 let latex = renderer.render(&ast);
279 assert!(latex.contains(r"\frac"));
280 assert!(latex.contains('a'));
281 assert!(latex.contains('b'));
282 }
283
284 #[test]
285 fn test_superscript_roundtrip() {
286 let renderer = LatexRenderer;
287 let ast = parse_equation(r"x^2");
288 let latex = renderer.render(&ast);
289 assert!(latex.contains("x^"));
290 assert!(latex.contains('2'));
291 }
292
293 #[test]
294 fn test_matrix_roundtrip() {
295 let renderer = LatexRenderer;
296 let ast = parse_equation(r"\begin{pmatrix} a & b \\ c & d \end{pmatrix}");
297 let latex = renderer.render(&ast);
298 assert!(latex.contains("pmatrix"));
299 assert!(latex.contains('&'));
300 }
301}