1use crate::{Plugin, ProcessResult};
16use regex::Regex;
17use streamdown_config::ComputedStyle;
18use streamdown_core::state::ParseState;
19use std::collections::HashMap;
20use std::sync::LazyLock;
21
22pub struct LatexPlugin {
24 in_block: bool,
26 buffer: String,
28}
29
30impl LatexPlugin {
31 pub fn new() -> Self {
33 Self {
34 in_block: false,
35 buffer: String::new(),
36 }
37 }
38}
39
40impl Default for LatexPlugin {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl Plugin for LatexPlugin {
47 fn name(&self) -> &str {
48 "latex"
49 }
50
51 fn process_line(
52 &mut self,
53 line: &str,
54 _state: &ParseState,
55 _style: &ComputedStyle,
56 ) -> Option<ProcessResult> {
57 if !self.in_block && line.contains('$') && !line.contains("$$") {
59 let converted = convert_inline_math(line);
61 if converted != line {
62 return Some(ProcessResult::Lines(vec![converted]));
63 }
64 }
65
66 if !self.in_block {
68 if let Some(idx) = line.find("$$") {
69 self.in_block = true;
70 self.buffer.clear();
71
72 let after = &line[idx + 2..];
74
75 if let Some(end_idx) = after.find("$$") {
77 self.in_block = false;
79 let expr = &after[..end_idx];
80 let converted = latex_to_unicode(expr);
81 return Some(ProcessResult::Lines(vec![converted]));
82 }
83
84 self.buffer.push_str(after);
86 return Some(ProcessResult::Continue);
87 }
88 return None;
89 }
90
91 if let Some(idx) = line.find("$$") {
93 self.in_block = false;
95 self.buffer.push_str(&line[..idx]);
96
97 let converted = latex_to_unicode(&self.buffer);
98 self.buffer.clear();
99
100 return Some(ProcessResult::Lines(vec![converted]));
101 }
102
103 if !self.buffer.is_empty() {
105 self.buffer.push(' ');
106 }
107 self.buffer.push_str(line);
108 Some(ProcessResult::Continue)
109 }
110
111 fn flush(&mut self) -> Option<Vec<String>> {
112 if self.buffer.is_empty() {
113 return None;
114 }
115
116 let result = std::mem::take(&mut self.buffer);
118 self.in_block = false;
119 Some(vec![format!("$$ {} (incomplete)", result)])
120 }
121
122 fn reset(&mut self) {
123 self.in_block = false;
124 self.buffer.clear();
125 }
126
127 fn is_active(&self) -> bool {
128 self.in_block
129 }
130
131 fn priority(&self) -> i32 {
132 10 }
134}
135
136fn convert_inline_math(line: &str) -> String {
138 static INLINE_RE: LazyLock<Regex> =
139 LazyLock::new(|| Regex::new(r"\$([^$]+)\$").unwrap());
140
141 INLINE_RE
142 .replace_all(line, |caps: ®ex::Captures| latex_to_unicode(&caps[1]))
143 .to_string()
144}
145
146pub fn latex_to_unicode(latex: &str) -> String {
148 let mut result = latex.to_string();
149
150 result = convert_commands(&result);
152 result = convert_fractions(&result);
153 result = convert_subscripts(&result);
154 result = convert_superscripts(&result);
155 result = cleanup(&result);
156
157 result
158}
159
160static GREEK_LETTERS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
162 let mut m = HashMap::new();
163 m.insert("alpha", "α");
165 m.insert("beta", "β");
166 m.insert("gamma", "γ");
167 m.insert("delta", "δ");
168 m.insert("epsilon", "ε");
169 m.insert("varepsilon", "ε");
170 m.insert("zeta", "ζ");
171 m.insert("eta", "η");
172 m.insert("theta", "θ");
173 m.insert("vartheta", "ϑ");
174 m.insert("iota", "ι");
175 m.insert("kappa", "κ");
176 m.insert("lambda", "λ");
177 m.insert("mu", "μ");
178 m.insert("nu", "ν");
179 m.insert("xi", "ξ");
180 m.insert("omicron", "ο");
181 m.insert("pi", "π");
182 m.insert("varpi", "ϖ");
183 m.insert("rho", "ρ");
184 m.insert("varrho", "ϱ");
185 m.insert("sigma", "σ");
186 m.insert("varsigma", "ς");
187 m.insert("tau", "τ");
188 m.insert("upsilon", "υ");
189 m.insert("phi", "φ");
190 m.insert("varphi", "ϕ");
191 m.insert("chi", "χ");
192 m.insert("psi", "ψ");
193 m.insert("omega", "ω");
194 m.insert("Gamma", "Γ");
196 m.insert("Delta", "Δ");
197 m.insert("Theta", "Θ");
198 m.insert("Lambda", "Λ");
199 m.insert("Xi", "Ξ");
200 m.insert("Pi", "Π");
201 m.insert("Sigma", "Σ");
202 m.insert("Upsilon", "Υ");
203 m.insert("Phi", "Φ");
204 m.insert("Psi", "Ψ");
205 m.insert("Omega", "Ω");
206 m
207});
208
209static OPERATORS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
211 let mut m = HashMap::new();
212 m.insert("sum", "Σ");
213 m.insert("prod", "Π");
214 m.insert("int", "∫");
215 m.insert("iint", "∬");
216 m.insert("iiint", "∭");
217 m.insert("oint", "∮");
218 m.insert("partial", "∂");
219 m.insert("nabla", "∇");
220 m.insert("sqrt", "√");
221 m.insert("cbrt", "∛");
222 m.insert("times", "×");
223 m.insert("div", "÷");
224 m.insert("cdot", "·");
225 m.insert("ast", "∗");
226 m.insert("star", "⋆");
227 m.insert("circ", "∘");
228 m.insert("bullet", "•");
229 m.insert("oplus", "⊕");
230 m.insert("ominus", "⊖");
231 m.insert("otimes", "⊗");
232 m.insert("oslash", "⊘");
233 m.insert("odot", "⊙");
234 m
235});
236
237static RELATIONS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
239 let mut m = HashMap::new();
240 m.insert("leq", "≤");
241 m.insert("le", "≤");
242 m.insert("geq", "≥");
243 m.insert("ge", "≥");
244 m.insert("neq", "≠");
245 m.insert("ne", "≠");
246 m.insert("approx", "≈");
247 m.insert("equiv", "≡");
248 m.insert("sim", "∼");
249 m.insert("simeq", "≃");
250 m.insert("cong", "≅");
251 m.insert("propto", "∝");
252 m.insert("ll", "≪");
253 m.insert("gg", "≫");
254 m.insert("subset", "⊂");
255 m.insert("supset", "⊃");
256 m.insert("subseteq", "⊆");
257 m.insert("supseteq", "⊇");
258 m.insert("in", "∈");
259 m.insert("notin", "∉");
260 m.insert("ni", "∋");
261 m.insert("forall", "∀");
262 m.insert("exists", "∃");
263 m.insert("nexists", "∄");
264 m
265});
266
267static SYMBOLS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
269 let mut m = HashMap::new();
270 m.insert("infty", "∞");
271 m.insert("pm", "±");
272 m.insert("mp", "∓");
273 m.insert("to", "→");
274 m.insert("rightarrow", "→");
275 m.insert("leftarrow", "←");
276 m.insert("leftrightarrow", "↔");
277 m.insert("Rightarrow", "⇒");
278 m.insert("Leftarrow", "⇐");
279 m.insert("Leftrightarrow", "⇔");
280 m.insert("uparrow", "↑");
281 m.insert("downarrow", "↓");
282 m.insert("mapsto", "↦");
283 m.insert("ldots", "…");
284 m.insert("cdots", "⋯");
285 m.insert("vdots", "⋮");
286 m.insert("ddots", "⋱");
287 m.insert("therefore", "∴");
288 m.insert("because", "∵");
289 m.insert("angle", "∠");
290 m.insert("perp", "⊥");
291 m.insert("parallel", "∥");
292 m.insert("triangle", "△");
293 m.insert("square", "□");
294 m.insert("diamond", "◇");
295 m.insert("emptyset", "∅");
296 m.insert("varnothing", "∅");
297 m.insert("neg", "¬");
298 m.insert("lnot", "¬");
299 m.insert("land", "∧");
300 m.insert("wedge", "∧");
301 m.insert("lor", "∨");
302 m.insert("vee", "∨");
303 m.insert("cap", "∩");
304 m.insert("cup", "∪");
305 m.insert("setminus", "∖");
306 m.insert("aleph", "ℵ");
307 m.insert("hbar", "ℏ");
308 m.insert("ell", "ℓ");
309 m.insert("Re", "ℜ");
310 m.insert("Im", "ℑ");
311 m.insert("wp", "℘");
312 m.insert("prime", "′");
313 m.insert("degree", "°");
314 m
315});
316
317static SUBSCRIPT_DIGITS: LazyLock<HashMap<char, char>> = LazyLock::new(|| {
319 let mut m = HashMap::new();
320 m.insert('0', '₀');
321 m.insert('1', '₁');
322 m.insert('2', '₂');
323 m.insert('3', '₃');
324 m.insert('4', '₄');
325 m.insert('5', '₅');
326 m.insert('6', '₆');
327 m.insert('7', '₇');
328 m.insert('8', '₈');
329 m.insert('9', '₉');
330 m.insert('+', '₊');
331 m.insert('-', '₋');
332 m.insert('=', '₌');
333 m.insert('(', '₍');
334 m.insert(')', '₎');
335 m.insert('a', 'ₐ');
336 m.insert('e', 'ₑ');
337 m.insert('h', 'ₕ');
338 m.insert('i', 'ᵢ');
339 m.insert('j', 'ⱼ');
340 m.insert('k', 'ₖ');
341 m.insert('l', 'ₗ');
342 m.insert('m', 'ₘ');
343 m.insert('n', 'ₙ');
344 m.insert('o', 'ₒ');
345 m.insert('p', 'ₚ');
346 m.insert('r', 'ᵣ');
347 m.insert('s', 'ₛ');
348 m.insert('t', 'ₜ');
349 m.insert('u', 'ᵤ');
350 m.insert('v', 'ᵥ');
351 m.insert('x', 'ₓ');
352 m
353});
354
355static SUPERSCRIPT_CHARS: LazyLock<HashMap<char, char>> = LazyLock::new(|| {
357 let mut m = HashMap::new();
358 m.insert('0', '⁰');
359 m.insert('1', '¹');
360 m.insert('2', '²');
361 m.insert('3', '³');
362 m.insert('4', '⁴');
363 m.insert('5', '⁵');
364 m.insert('6', '⁶');
365 m.insert('7', '⁷');
366 m.insert('8', '⁸');
367 m.insert('9', '⁹');
368 m.insert('+', '⁺');
369 m.insert('-', '⁻');
370 m.insert('=', '⁼');
371 m.insert('(', '⁽');
372 m.insert(')', '⁾');
373 m.insert('a', 'ᵃ');
374 m.insert('b', 'ᵇ');
375 m.insert('c', 'ᶜ');
376 m.insert('d', 'ᵈ');
377 m.insert('e', 'ᵉ');
378 m.insert('f', 'ᶠ');
379 m.insert('g', 'ᵍ');
380 m.insert('h', 'ʰ');
381 m.insert('i', 'ⁱ');
382 m.insert('j', 'ʲ');
383 m.insert('k', 'ᵏ');
384 m.insert('l', 'ˡ');
385 m.insert('m', 'ᵐ');
386 m.insert('n', 'ⁿ');
387 m.insert('o', 'ᵒ');
388 m.insert('p', 'ᵖ');
389 m.insert('r', 'ʳ');
390 m.insert('s', 'ˢ');
391 m.insert('t', 'ᵗ');
392 m.insert('u', 'ᵘ');
393 m.insert('v', 'ᵛ');
394 m.insert('w', 'ʷ');
395 m.insert('x', 'ˣ');
396 m.insert('y', 'ʸ');
397 m.insert('z', 'ᶻ');
398 m
399});
400
401fn convert_commands(input: &str) -> String {
403 static CMD_RE: LazyLock<Regex> =
404 LazyLock::new(|| Regex::new(r"\\([a-zA-Z]+)").unwrap());
405
406 CMD_RE
407 .replace_all(input, |caps: ®ex::Captures| {
408 let cmd = &caps[1];
409
410 if let Some(s) = GREEK_LETTERS.get(cmd) {
412 return (*s).to_string();
413 }
414 if let Some(s) = OPERATORS.get(cmd) {
415 return (*s).to_string();
416 }
417 if let Some(s) = RELATIONS.get(cmd) {
418 return (*s).to_string();
419 }
420 if let Some(s) = SYMBOLS.get(cmd) {
421 return (*s).to_string();
422 }
423
424 format!("\\{}", cmd)
426 })
427 .to_string()
428}
429
430fn convert_fractions(input: &str) -> String {
432 static FRAC_RE: LazyLock<Regex> =
433 LazyLock::new(|| Regex::new(r"\\frac\{([^}]*)\}\{([^}]*)\}").unwrap());
434
435 FRAC_RE
436 .replace_all(input, |caps: ®ex::Captures| {
437 let num = &caps[1];
438 let den = &caps[2];
439 format!("({}/{})", num, den)
440 })
441 .to_string()
442}
443
444fn convert_subscripts(input: &str) -> String {
446 static BRACED_SUB_RE: LazyLock<Regex> =
448 LazyLock::new(|| Regex::new(r"_\{([^}]+)\}").unwrap());
449
450 let result = BRACED_SUB_RE
451 .replace_all(input, |caps: ®ex::Captures| {
452 let content = &caps[1];
453 to_subscript(content)
454 })
455 .to_string();
456
457 static SINGLE_SUB_RE: LazyLock<Regex> =
459 LazyLock::new(|| Regex::new(r"_([0-9a-z])").unwrap());
460
461 SINGLE_SUB_RE
462 .replace_all(&result, |caps: ®ex::Captures| {
463 let c = caps[1].chars().next().unwrap();
464 SUBSCRIPT_DIGITS.get(&c).map(|&s| s.to_string()).unwrap_or_else(|| format!("_{}", c))
465 })
466 .to_string()
467}
468
469fn convert_superscripts(input: &str) -> String {
471 static BRACED_SUP_RE: LazyLock<Regex> =
473 LazyLock::new(|| Regex::new(r"\^\{([^}]+)\}").unwrap());
474
475 let result = BRACED_SUP_RE
476 .replace_all(input, |caps: ®ex::Captures| {
477 let content = &caps[1];
478 to_superscript(content)
479 })
480 .to_string();
481
482 static SINGLE_SUP_RE: LazyLock<Regex> =
484 LazyLock::new(|| Regex::new(r"\^([0-9a-z])").unwrap());
485
486 SINGLE_SUP_RE
487 .replace_all(&result, |caps: ®ex::Captures| {
488 let c = caps[1].chars().next().unwrap();
489 SUPERSCRIPT_CHARS.get(&c).map(|&s| s.to_string()).unwrap_or_else(|| format!("^{}", c))
490 })
491 .to_string()
492}
493
494fn to_subscript(s: &str) -> String {
496 s.chars()
497 .map(|c| SUBSCRIPT_DIGITS.get(&c).copied().unwrap_or(c))
498 .collect()
499}
500
501fn to_superscript(s: &str) -> String {
503 s.chars()
504 .map(|c| SUPERSCRIPT_CHARS.get(&c).copied().unwrap_or(c))
505 .collect()
506}
507
508fn cleanup(input: &str) -> String {
510 input
512 .replace("{ ", "")
513 .replace(" }", "")
514 .replace("{}", "")
515 .trim()
516 .to_string()
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_greek_letters() {
525 assert_eq!(latex_to_unicode(r"\alpha + \beta"), "α + β");
526 assert_eq!(latex_to_unicode(r"\Gamma\Delta"), "ΓΔ");
527 assert_eq!(latex_to_unicode(r"\pi r^2"), "π r²");
528 }
529
530 #[test]
531 fn test_operators() {
532 assert_eq!(latex_to_unicode(r"\sum x"), "Σ x");
533 assert_eq!(latex_to_unicode(r"\int f(x) dx"), "∫ f(x) dx");
534 let result = latex_to_unicode(r"\prod_{i=1}");
536 assert!(result.contains("Π")); assert!(result.contains("₁")); }
539
540 #[test]
541 fn test_relations() {
542 assert_eq!(latex_to_unicode(r"x \leq y"), "x ≤ y");
543 assert_eq!(latex_to_unicode(r"a \neq b"), "a ≠ b");
544 assert_eq!(latex_to_unicode(r"A \subset B"), "A ⊂ B");
545 }
546
547 #[test]
548 fn test_symbols() {
549 assert_eq!(latex_to_unicode(r"\infty"), "∞");
550 assert_eq!(latex_to_unicode(r"\pm 1"), "± 1");
551 assert_eq!(latex_to_unicode(r"x \to y"), "x → y");
552 }
553
554 #[test]
555 fn test_subscripts() {
556 assert_eq!(latex_to_unicode("x_1"), "x₁");
557 assert_eq!(latex_to_unicode("x_{12}"), "x₁₂");
558 assert_eq!(latex_to_unicode("a_n"), "aₙ");
559 }
560
561 #[test]
562 fn test_superscripts() {
563 assert_eq!(latex_to_unicode("x^2"), "x²");
564 assert_eq!(latex_to_unicode("x^{10}"), "x¹⁰");
565 assert_eq!(latex_to_unicode("e^x"), "eˣ");
566 }
567
568 #[test]
569 fn test_fractions() {
570 assert_eq!(latex_to_unicode(r"\frac{a}{b}"), "(a/b)");
571 assert_eq!(latex_to_unicode(r"\frac{1}{2}"), "(1/2)");
572 }
573
574 #[test]
575 fn test_complex_expression() {
576 let input = r"E = mc^2";
577 assert_eq!(latex_to_unicode(input), "E = mc²");
578
579 let input = r"\sum_{i=1}^n x_i";
580 let result = latex_to_unicode(input);
581 assert!(result.contains("Σ")); assert!(result.contains("ᵢ") || result.contains("i")); }
585
586 #[test]
587 fn test_inline_math() {
588 assert_eq!(convert_inline_math("The value $x^2$ is"), "The value x² is");
589 assert_eq!(
590 convert_inline_math("We have $\\alpha$ and $\\beta$"),
591 "We have α and β"
592 );
593 }
594
595 #[test]
596 fn test_latex_plugin_single_line() {
597 let mut plugin = LatexPlugin::new();
598 let state = ParseState::new();
599 let style = ComputedStyle::default();
600
601 let result = plugin.process_line("$$E = mc^2$$", &state, &style);
602 assert!(matches!(result, Some(ProcessResult::Lines(_))));
603 if let Some(ProcessResult::Lines(lines)) = result {
604 assert_eq!(lines.len(), 1);
605 assert!(lines[0].contains("E = mc²"));
606 }
607 }
608
609 #[test]
610 fn test_latex_plugin_multiline() {
611 let mut plugin = LatexPlugin::new();
612 let state = ParseState::new();
613 let style = ComputedStyle::default();
614
615 let result = plugin.process_line("$$\\sum_{i=1}^n", &state, &style);
617 assert!(matches!(result, Some(ProcessResult::Continue)));
618
619 let result = plugin.process_line("x_i$$", &state, &style);
621 assert!(matches!(result, Some(ProcessResult::Lines(_))));
622 if let Some(ProcessResult::Lines(lines)) = result {
623 assert!(lines[0].contains("Σ"));
624 }
625 }
626
627 #[test]
628 fn test_latex_plugin_inline() {
629 let mut plugin = LatexPlugin::new();
630 let state = ParseState::new();
631 let style = ComputedStyle::default();
632
633 let result = plugin.process_line("The value $x^2$ is important", &state, &style);
634 assert!(matches!(result, Some(ProcessResult::Lines(_))));
635 if let Some(ProcessResult::Lines(lines)) = result {
636 assert!(lines[0].contains("x²"));
637 }
638 }
639
640 #[test]
641 fn test_latex_plugin_no_match() {
642 let mut plugin = LatexPlugin::new();
643 let state = ParseState::new();
644 let style = ComputedStyle::default();
645
646 let result = plugin.process_line("Normal text without math", &state, &style);
647 assert!(result.is_none());
648 }
649
650 #[test]
651 fn test_latex_plugin_flush() {
652 let mut plugin = LatexPlugin::new();
653 let state = ParseState::new();
654 let style = ComputedStyle::default();
655
656 plugin.process_line("$$x^2 + y^2", &state, &style);
658
659 let result = plugin.flush();
661 assert!(result.is_some());
662 }
663
664 #[test]
665 fn test_latex_plugin_reset() {
666 let mut plugin = LatexPlugin::new();
667 let state = ParseState::new();
668 let style = ComputedStyle::default();
669
670 plugin.process_line("$$x^2", &state, &style);
671 assert!(plugin.is_active());
672
673 plugin.reset();
674 assert!(!plugin.is_active());
675 }
676}