1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum MathTokenType {
10 Number,
12 Variable,
14 Operator,
16 Function,
18 GreekLetter,
20 Constant,
22 Delimiter,
24 LaTeXCommand,
26 Script,
28 Symbol,
30 Unit,
32 Text,
34 Whitespace,
36 Unknown,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct MathToken {
43 pub text: String,
45 pub token_type: MathTokenType,
47 pub start: usize,
49 pub end: usize,
51 pub id: Option<u32>,
53 pub latex: Option<String>,
55 pub meaning: Option<String>,
57}
58
59impl MathToken {
60 pub fn new(text: String, token_type: MathTokenType, start: usize, end: usize) -> Self {
62 Self {
63 text,
64 token_type,
65 start,
66 end,
67 id: None,
68 latex: None,
69 meaning: None,
70 }
71 }
72
73 pub fn with_id(mut self, id: u32) -> Self {
75 self.id = Some(id);
76 self
77 }
78
79 pub fn with_latex(mut self, latex: String) -> Self {
81 self.latex = Some(latex);
82 self
83 }
84
85 pub fn with_meaning(mut self, meaning: String) -> Self {
87 self.meaning = Some(meaning);
88 self
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct MathTokenizerConfig {
95 pub preserve_whitespace: bool,
97 pub recognize_latex: bool,
99 pub recognize_scientific_notation: bool,
101 pub handle_scripts: bool,
103 pub recognize_units: bool,
105 pub max_token_length: usize,
107 pub custom_functions: HashSet<String>,
109 pub custom_constants: HashMap<String, String>, }
112
113impl Default for MathTokenizerConfig {
114 fn default() -> Self {
115 Self {
116 preserve_whitespace: false,
117 recognize_latex: true,
118 recognize_scientific_notation: true,
119 handle_scripts: true,
120 recognize_units: true,
121 max_token_length: 50,
122 custom_functions: HashSet::new(),
123 custom_constants: HashMap::new(),
124 }
125 }
126}
127
128pub struct MathTokenizer {
130 config: MathTokenizerConfig,
131 number_regex: Regex,
133 scientific_regex: Regex,
134 latex_command_regex: Regex,
135 greek_letters: HashSet<String>,
136 math_functions: HashSet<String>,
137 math_constants: HashMap<String, String>,
138 math_operators: HashSet<String>,
139 math_symbols: HashMap<String, String>,
140 units: HashSet<String>,
141 token_to_id: HashMap<String, u32>,
143 id_to_token: HashMap<u32, String>,
144 next_id: u32,
145}
146
147impl MathTokenizer {
148 pub fn new() -> Result<Self> {
150 Self::with_config(MathTokenizerConfig::default())
151 }
152
153 pub fn with_config(config: MathTokenizerConfig) -> Result<Self> {
155 let number_regex = Regex::new(r"^\d+\.?\d*$").expect("valid regex");
156 let scientific_regex = Regex::new(r"^\d+\.?\d*[eE][+-]?\d+$").expect("valid regex");
157 let latex_command_regex = Regex::new(r"^\\[a-zA-Z]+$").expect("valid regex");
158
159 let greek_letters = [
161 "α", "β", "γ", "δ", "ε", "ζ", "η", "θ", "ι", "κ", "λ", "μ", "ν", "ξ", "ο", "π", "ρ",
162 "σ", "τ", "υ", "φ", "χ", "ψ", "ω", "Α", "Β", "Γ", "Δ", "Ε", "Ζ", "Η", "Θ", "Ι", "Κ",
163 "Λ", "Μ", "Ν", "Ξ", "Ο", "Π", "Ρ", "Σ", "Τ", "Υ", "Φ", "Χ", "Ψ", "Ω",
164 ]
165 .iter()
166 .map(|s| s.to_string())
167 .collect();
168
169 let mut math_functions = [
171 "sin", "cos", "tan", "sec", "csc", "cot", "arcsin", "arccos", "arctan", "asin", "acos",
172 "atan", "sinh", "cosh", "tanh", "sech", "csch", "coth", "log", "ln", "lg", "exp",
173 "sqrt", "abs", "sgn", "min", "max", "gcd", "lcm", "floor", "ceil", "det", "tr", "rank",
174 "dim", "span", "ker", "im",
175 ]
176 .iter()
177 .map(|s| s.to_string())
178 .collect::<HashSet<_>>();
179
180 math_functions.extend(config.custom_functions.clone());
182
183 let mut math_constants = HashMap::new();
185 math_constants.insert("π".to_string(), "pi".to_string());
186 math_constants.insert("e".to_string(), "euler_number".to_string());
187 math_constants.insert("∞".to_string(), "infinity".to_string());
188 math_constants.insert("i".to_string(), "imaginary_unit".to_string());
189 math_constants.insert("φ".to_string(), "golden_ratio".to_string());
190 math_constants.insert("γ".to_string(), "euler_gamma".to_string());
191
192 math_constants.extend(config.custom_constants.clone());
194
195 let math_operators = [
197 "+", "-", "*", "×", "·", "/", "÷", "^", "=", "≠", "≈", "≡", "<", ">", "≤", "≥", "≪",
198 "≫", "±", "∓", "∝", "∼", "≅",
199 ]
200 .iter()
201 .map(|s| s.to_string())
202 .collect();
203
204 let mut math_symbols = HashMap::new();
206 math_symbols.insert("∈".to_string(), "element_of".to_string());
207 math_symbols.insert("∉".to_string(), "not_element_of".to_string());
208 math_symbols.insert("⊂".to_string(), "subset".to_string());
209 math_symbols.insert("⊃".to_string(), "superset".to_string());
210 math_symbols.insert("⊆".to_string(), "subset_equal".to_string());
211 math_symbols.insert("⊇".to_string(), "superset_equal".to_string());
212 math_symbols.insert("∪".to_string(), "union".to_string());
213 math_symbols.insert("∩".to_string(), "intersection".to_string());
214 math_symbols.insert("∅".to_string(), "empty_set".to_string());
215 math_symbols.insert("∀".to_string(), "for_all".to_string());
216 math_symbols.insert("∃".to_string(), "exists".to_string());
217 math_symbols.insert("∄".to_string(), "not_exists".to_string());
218 math_symbols.insert("∇".to_string(), "nabla".to_string());
219 math_symbols.insert("∂".to_string(), "partial".to_string());
220 math_symbols.insert("∫".to_string(), "integral".to_string());
221 math_symbols.insert("∬".to_string(), "double_integral".to_string());
222 math_symbols.insert("∭".to_string(), "triple_integral".to_string());
223 math_symbols.insert("∮".to_string(), "contour_integral".to_string());
224 math_symbols.insert("∑".to_string(), "sum".to_string());
225 math_symbols.insert("∏".to_string(), "product".to_string());
226 math_symbols.insert("→".to_string(), "right_arrow".to_string());
227 math_symbols.insert("←".to_string(), "left_arrow".to_string());
228 math_symbols.insert("↔".to_string(), "double_arrow".to_string());
229 math_symbols.insert("⇒".to_string(), "implies".to_string());
230 math_symbols.insert("⇔".to_string(), "if_and_only_if".to_string());
231 math_symbols.insert("√".to_string(), "square_root".to_string());
232
233 let units = [
235 "m", "cm", "mm", "km", "in", "ft", "yd", "mi", "g", "kg", "mg", "lb", "oz", "s", "ms", "min", "h", "hr", "day", "yr", "K", "°C", "°F", "J", "kJ", "cal", "kcal", "eV", "keV", "MeV", "GeV", "W", "kW", "MW", "hp", "Hz", "kHz", "MHz", "GHz", "V", "mV", "kV", "A", "mA", "μA", "Ω", "kΩ", "MΩ",
246 ]
247 .iter()
248 .map(|s| s.to_string())
249 .collect();
250
251 let mut tokenizer = Self {
252 config,
253 number_regex,
254 scientific_regex,
255 latex_command_regex,
256 greek_letters,
257 math_functions,
258 math_constants,
259 math_operators,
260 math_symbols,
261 units,
262 token_to_id: HashMap::new(),
263 id_to_token: HashMap::new(),
264 next_id: 1, };
266
267 tokenizer.initialize_vocabulary();
269 Ok(tokenizer)
270 }
271
272 fn initialize_vocabulary(&mut self) {
274 let greek_letters: Vec<String> = self.greek_letters.iter().cloned().collect();
276 let math_functions: Vec<String> = self.math_functions.iter().cloned().collect();
277 let math_constants: Vec<String> = self.math_constants.keys().cloned().collect();
278 let math_operators: Vec<String> = self.math_operators.iter().cloned().collect();
279 let math_symbols: Vec<String> = self.math_symbols.keys().cloned().collect();
280 let units: Vec<String> = self.units.iter().cloned().collect();
281
282 for letter in &greek_letters {
284 self.add_token_with_type(letter, "GreekLetter");
285 }
286
287 for function in &math_functions {
289 self.add_token_with_type(function, "Function");
290 }
291
292 for constant in &math_constants {
294 self.add_token_with_type(constant, "Constant");
295 }
296
297 for operator in &math_operators {
299 self.add_token_with_type(operator, "Operator");
300 }
301
302 for symbol in &math_symbols {
304 self.add_token_with_type(symbol, "Symbol");
305 }
306
307 for unit in &units {
309 self.add_token_with_type(unit, "Unit");
310 }
311
312 for var in ["x", "y", "z", "a", "b", "c", "n", "m", "k", "i", "j"] {
314 self.add_token_with_type(var, "Variable");
315 }
316
317 for num in 0..10 {
319 self.add_token_with_type(&num.to_string(), "Number");
320 }
321
322 for punct in ["(", ")", "[", "]", "{", "}", ",", ".", "!", " "] {
324 self.add_token_with_type(punct, "Punctuation");
325 }
326 }
327
328 fn add_token_with_type(&mut self, token: &str, token_type: &str) {
330 let key = format!("{}:{}", token, token_type);
331 if !self.token_to_id.contains_key(&key) {
332 let id = self.next_id;
333 self.token_to_id.insert(key.clone(), id);
334 self.id_to_token.insert(id, key);
335 self.next_id += 1;
336 }
337 }
338
339 pub fn tokenize_math(&mut self, text: &str) -> Result<Vec<MathToken>> {
341 let mut tokens = Vec::new();
342 let mut chars = text.char_indices().peekable();
343
344 while let Some((pos, ch)) = chars.next() {
345 let start_pos = pos;
346
347 if ch.is_whitespace() {
349 if self.config.preserve_whitespace {
350 let mut whitespace = String::new();
351 whitespace.push(ch);
352
353 while let Some(&(_, next_ch)) = chars.peek() {
355 if next_ch.is_whitespace() {
356 chars.next();
357 whitespace.push(next_ch);
358 } else {
359 break;
360 }
361 }
362
363 let end_pos = start_pos + whitespace.len();
364 tokens.push(MathToken::new(
365 whitespace,
366 MathTokenType::Whitespace,
367 start_pos,
368 end_pos,
369 ));
370 }
371 continue;
372 }
373
374 if let Some(token) = self.try_match_number(&mut chars, ch, start_pos)? {
376 tokens.push(token);
377 } else if let Some(token) = self.try_match_latex_command(&mut chars, ch, start_pos)? {
378 tokens.push(token);
379 } else if let Some(token) = self.try_match_function(&mut chars, ch, start_pos)? {
380 tokens.push(token);
381 } else if let Some(token) =
382 self.try_match_multi_char_symbol(&mut chars, ch, start_pos)?
383 {
384 tokens.push(token);
385 } else {
386 let token_text = ch.to_string();
388 let token_type = self.classify_single_char(&token_text);
389 let end_pos = start_pos + ch.len_utf8();
390
391 tokens.push(MathToken::new(token_text, token_type, start_pos, end_pos));
392 }
393 }
394
395 Ok(tokens)
396 }
397
398 fn try_match_number(
400 &self,
401 chars: &mut std::iter::Peekable<std::str::CharIndices>,
402 first_char: char,
403 start_pos: usize,
404 ) -> Result<Option<MathToken>> {
405 if !first_char.is_ascii_digit() && first_char != '.' {
406 return Ok(None);
407 }
408
409 let mut number = String::new();
410 number.push(first_char);
411 let mut current_pos = start_pos + first_char.len_utf8();
412
413 while let Some(&(_, ch)) = chars.peek() {
415 if ch.is_ascii_digit() || ch == '.' {
416 chars.next();
417 number.push(ch);
418 current_pos += ch.len_utf8();
419 } else {
420 break;
421 }
422 }
423
424 if let Some(&(_, ch)) = chars.peek() {
426 if ch == 'e' || ch == 'E' {
427 let mut temp_number = number.clone();
428 temp_number.push(ch);
429 chars.next();
430 current_pos += ch.len_utf8();
431
432 if let Some(&(_, sign_ch)) = chars.peek() {
434 if sign_ch == '+' || sign_ch == '-' {
435 chars.next();
436 temp_number.push(sign_ch);
437 current_pos += sign_ch.len_utf8();
438 }
439 }
440
441 let mut has_exponent_digits = false;
443 while let Some(&(_, ch)) = chars.peek() {
444 if ch.is_ascii_digit() {
445 chars.next();
446 temp_number.push(ch);
447 current_pos += ch.len_utf8();
448 has_exponent_digits = true;
449 } else {
450 break;
451 }
452 }
453
454 if has_exponent_digits {
455 number = temp_number;
456 }
457 }
458 }
459
460 let token_type = if (self.config.recognize_scientific_notation
461 && self.scientific_regex.is_match(&number))
462 || self.number_regex.is_match(&number)
463 {
464 MathTokenType::Number
465 } else {
466 MathTokenType::Unknown
467 };
468
469 Ok(Some(MathToken::new(
470 number,
471 token_type,
472 start_pos,
473 current_pos,
474 )))
475 }
476
477 fn try_match_latex_command(
479 &self,
480 chars: &mut std::iter::Peekable<std::str::CharIndices>,
481 first_char: char,
482 start_pos: usize,
483 ) -> Result<Option<MathToken>> {
484 if !self.config.recognize_latex || first_char != '\\' {
485 return Ok(None);
486 }
487
488 let mut command = String::new();
489 command.push(first_char);
490 let mut current_pos = start_pos + first_char.len_utf8();
491
492 while let Some(&(_, ch)) = chars.peek() {
494 if ch.is_ascii_alphabetic() {
495 chars.next();
496 command.push(ch);
497 current_pos += ch.len_utf8();
498 } else {
499 break;
500 }
501 }
502
503 if command.len() > 1 && self.latex_command_regex.is_match(&command) {
504 Ok(Some(MathToken::new(
505 command,
506 MathTokenType::LaTeXCommand,
507 start_pos,
508 current_pos,
509 )))
510 } else {
511 Ok(None)
512 }
513 }
514
515 fn try_match_function(
517 &self,
518 chars: &mut std::iter::Peekable<std::str::CharIndices>,
519 first_char: char,
520 start_pos: usize,
521 ) -> Result<Option<MathToken>> {
522 if !first_char.is_ascii_alphabetic() {
523 return Ok(None);
524 }
525
526 let mut function = String::new();
527 function.push(first_char);
528 let mut current_pos = start_pos + first_char.len_utf8();
529
530 let saved_chars = chars.clone();
532 while let Some(&(_, ch)) = chars.peek() {
533 if ch.is_ascii_alphabetic() {
534 chars.next();
535 function.push(ch);
536 current_pos += ch.len_utf8();
537 } else {
538 break;
539 }
540 }
541
542 if self.math_functions.contains(&function) {
543 Ok(Some(MathToken::new(
544 function,
545 MathTokenType::Function,
546 start_pos,
547 current_pos,
548 )))
549 } else {
550 *chars = saved_chars;
552 Ok(None)
553 }
554 }
555
556 fn try_match_multi_char_symbol(
558 &self,
559 chars: &mut std::iter::Peekable<std::str::CharIndices>,
560 first_char: char,
561 start_pos: usize,
562 ) -> Result<Option<MathToken>> {
563 let mut symbol = String::new();
565 symbol.push(first_char);
566 let mut current_pos = start_pos + first_char.len_utf8();
567
568 let saved_chars = chars.clone();
570 for _ in 0..2 {
571 if let Some(&(_, ch)) = chars.peek() {
572 let temp_symbol = format!("{}{}", symbol, ch);
573 if self.math_symbols.contains_key(&temp_symbol)
574 || self.math_operators.contains(&temp_symbol)
575 {
576 chars.next();
577 symbol = temp_symbol;
578 current_pos += ch.len_utf8();
579 } else {
580 break;
581 }
582 } else {
583 break;
584 }
585 }
586
587 if symbol.chars().count() > 1 {
588 let token_type = if self.math_symbols.contains_key(&symbol) {
589 MathTokenType::Symbol
590 } else if self.math_operators.contains(&symbol) {
591 MathTokenType::Operator
592 } else {
593 MathTokenType::Unknown
594 };
595 Ok(Some(MathToken::new(
596 symbol,
597 token_type,
598 start_pos,
599 current_pos,
600 )))
601 } else {
602 *chars = saved_chars;
603 Ok(None)
604 }
605 }
606
607 fn classify_single_char(&self, ch: &str) -> MathTokenType {
609 if self.greek_letters.contains(ch) {
610 MathTokenType::GreekLetter
611 } else if self.math_constants.contains_key(ch) {
612 MathTokenType::Constant
613 } else if self.math_operators.contains(ch) {
614 MathTokenType::Operator
615 } else if self.math_symbols.contains_key(ch) {
616 MathTokenType::Symbol
617 } else if self.units.contains(ch) {
618 MathTokenType::Unit
619 } else if matches!(ch, "(" | ")" | "[" | "]" | "{" | "}") {
620 MathTokenType::Delimiter
621 } else if ch.chars().all(|c| c.is_ascii_alphabetic()) {
622 MathTokenType::Variable
623 } else {
624 MathTokenType::Unknown
625 }
626 }
627
628 fn assign_token_ids(&mut self, tokens: &mut [MathToken]) {
630 for token in tokens {
631 let token_key = format!("{}:{:?}", token.text, token.token_type);
632
633 if let Some(&id) = self.token_to_id.get(&token_key) {
634 token.id = Some(id);
635 } else {
636 let id = self.next_id;
637 self.next_id += 1;
638
639 self.token_to_id.insert(token_key.clone(), id);
640 self.id_to_token.insert(id, token_key);
641 token.id = Some(id);
642 }
643 }
644 }
645
646 pub fn math_tokens_to_input(&mut self, mut tokens: Vec<MathToken>) -> TokenizedInput {
648 self.assign_token_ids(&mut tokens);
649
650 let input_ids: Vec<u32> = tokens.iter().filter_map(|t| t.id).collect();
651
652 let token_strings: Vec<String> = tokens.iter().map(|t| t.text.clone()).collect();
653
654 let _offsets: Vec<(u32, u32)> =
655 tokens.iter().map(|t| (t.start as u32, t.end as u32)).collect();
656
657 TokenizedInput {
658 input_ids,
659 attention_mask: vec![1u8; token_strings.len()],
660 token_type_ids: None,
661 special_tokens_mask: None,
662 offset_mapping: None,
663 overflowing_tokens: None,
664 }
665 }
666
667 pub fn analyze_math(&self, tokens: &[MathToken]) -> MathAnalysis {
669 let mut analysis = MathAnalysis::new();
670
671 for token in tokens {
672 analysis.total_tokens += 1;
673
674 match token.token_type {
675 MathTokenType::Number => analysis.numbers += 1,
676 MathTokenType::Variable => analysis.variables += 1,
677 MathTokenType::Operator => analysis.operators += 1,
678 MathTokenType::Function => analysis.functions += 1,
679 MathTokenType::GreekLetter => analysis.greek_letters += 1,
680 MathTokenType::Constant => analysis.constants += 1,
681 MathTokenType::Delimiter => analysis.delimiters += 1,
682 MathTokenType::LaTeXCommand => analysis.latex_commands += 1,
683 MathTokenType::Script => analysis.scripts += 1,
684 MathTokenType::Symbol => analysis.symbols += 1,
685 MathTokenType::Unit => analysis.units += 1,
686 MathTokenType::Text => analysis.text_tokens += 1,
687 MathTokenType::Whitespace => analysis.whitespace += 1,
688 MathTokenType::Unknown => analysis.unknown += 1,
689 }
690
691 analysis.unique_tokens.insert(token.text.clone());
693
694 if token.token_type == MathTokenType::Function {
696 *analysis.function_frequency.entry(token.text.clone()).or_insert(0) += 1;
697 }
698
699 if token.token_type == MathTokenType::Operator {
701 *analysis.operator_frequency.entry(token.text.clone()).or_insert(0) += 1;
702 }
703 }
704
705 analysis.unique_token_count = analysis.unique_tokens.len();
706 analysis
707 }
708
709 pub fn vocab_stats(&self) -> HashMap<String, usize> {
711 let mut stats = HashMap::new();
712 stats.insert("total_tokens".to_string(), self.token_to_id.len());
713 stats.insert("next_id".to_string(), self.next_id as usize);
714
715 let mut type_counts: HashMap<MathTokenType, usize> = HashMap::new();
717 for key in self.token_to_id.keys() {
718 if let Some(type_str) = key.split(':').nth(1) {
719 if let Ok(token_type) =
720 serde_json::from_str::<MathTokenType>(&format!("\"{}\"", type_str))
721 {
722 *type_counts.entry(token_type).or_insert(0) += 1;
723 }
724 }
725 }
726
727 for (token_type, count) in type_counts {
728 stats.insert(format!("{:?}_tokens", token_type).to_lowercase(), count);
729 }
730
731 stats
732 }
733}
734
735impl Default for MathTokenizer {
736 fn default() -> Self {
737 Self::new().expect("MathTokenizer::new() should not fail with default config")
738 }
739}
740
741impl Tokenizer for MathTokenizer {
742 fn encode(&self, text: &str) -> Result<TokenizedInput> {
743 let mut tokenizer = self.clone();
744 let tokens = tokenizer.tokenize_math(text)?;
745 Ok(tokenizer.math_tokens_to_input(tokens))
746 }
747
748 fn decode(&self, token_ids: &[u32]) -> Result<String> {
749 let tokens: std::result::Result<Vec<String>, TrustformersError> = token_ids
750 .iter()
751 .map(|&id| {
752 self.id_to_token
753 .get(&id)
754 .and_then(|key| key.split(':').next())
755 .map(|s| s.to_string())
756 .ok_or_else(|| TrustformersError::other(format!("Unknown token ID: {}", id)))
757 })
758 .collect();
759
760 Ok(tokens?.join(" "))
761 }
762
763 fn get_vocab(&self) -> HashMap<String, u32> {
764 self.token_to_id
765 .iter()
766 .map(|(key, &id)| {
767 let token = key.split(':').next().unwrap_or(key).to_string();
768 (token, id)
769 })
770 .collect()
771 }
772
773 fn token_to_id(&self, token: &str) -> Option<u32> {
774 self.token_to_id.get(token).copied().or_else(|| {
776 for token_type in [
778 MathTokenType::Number,
779 MathTokenType::Variable,
780 MathTokenType::Operator,
781 MathTokenType::Function,
782 MathTokenType::GreekLetter,
783 MathTokenType::Constant,
784 ] {
785 let key = format!("{}:{:?}", token, token_type);
786 if let Some(&id) = self.token_to_id.get(&key) {
787 return Some(id);
788 }
789 }
790 None
791 })
792 }
793
794 fn id_to_token(&self, id: u32) -> Option<String> {
795 self.id_to_token
796 .get(&id)
797 .and_then(|key| key.split(':').next())
798 .map(|s| s.to_string())
799 }
800
801 fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
802 let mut tokenizer = self.clone();
803 let tokens_a = tokenizer.tokenize_math(text_a)?;
804 let tokens_b = tokenizer.tokenize_math(text_b)?;
805
806 let mut combined_tokens = tokens_a;
807 combined_tokens.push(MathToken {
808 text: "[SEP]".to_string(),
809 token_type: MathTokenType::Symbol,
810 start: 0,
811 end: 5,
812 id: None,
813 latex: None,
814 meaning: None,
815 });
816 combined_tokens.extend(tokens_b);
817
818 Ok(tokenizer.math_tokens_to_input(combined_tokens))
819 }
820
821 fn vocab_size(&self) -> usize {
822 self.token_to_id.len()
823 }
824}
825
826impl Clone for MathTokenizer {
828 fn clone(&self) -> Self {
829 Self {
830 config: self.config.clone(),
831 number_regex: Regex::new(r"^\d+\.?\d*$").expect("valid regex"),
832 scientific_regex: Regex::new(r"^\d+\.?\d*[eE][+-]?\d+$").expect("valid regex"),
833 latex_command_regex: Regex::new(r"^\\[a-zA-Z]+$").expect("valid regex"),
834 greek_letters: self.greek_letters.clone(),
835 math_functions: self.math_functions.clone(),
836 math_constants: self.math_constants.clone(),
837 math_operators: self.math_operators.clone(),
838 math_symbols: self.math_symbols.clone(),
839 units: self.units.clone(),
840 token_to_id: self.token_to_id.clone(),
841 id_to_token: self.id_to_token.clone(),
842 next_id: self.next_id,
843 }
844 }
845}
846
847#[derive(Debug, Clone, Serialize, Deserialize)]
849pub struct MathAnalysis {
850 pub total_tokens: usize,
851 pub unique_token_count: usize,
852 pub numbers: usize,
853 pub variables: usize,
854 pub operators: usize,
855 pub functions: usize,
856 pub greek_letters: usize,
857 pub constants: usize,
858 pub delimiters: usize,
859 pub latex_commands: usize,
860 pub scripts: usize,
861 pub symbols: usize,
862 pub units: usize,
863 pub text_tokens: usize,
864 pub whitespace: usize,
865 pub unknown: usize,
866 pub unique_tokens: HashSet<String>,
867 pub function_frequency: HashMap<String, usize>,
868 pub operator_frequency: HashMap<String, usize>,
869}
870
871impl MathAnalysis {
872 fn new() -> Self {
873 Self {
874 total_tokens: 0,
875 unique_token_count: 0,
876 numbers: 0,
877 variables: 0,
878 operators: 0,
879 functions: 0,
880 greek_letters: 0,
881 constants: 0,
882 delimiters: 0,
883 latex_commands: 0,
884 scripts: 0,
885 symbols: 0,
886 units: 0,
887 text_tokens: 0,
888 whitespace: 0,
889 unknown: 0,
890 unique_tokens: HashSet::new(),
891 function_frequency: HashMap::new(),
892 operator_frequency: HashMap::new(),
893 }
894 }
895
896 pub fn top_functions(&self, n: usize) -> Vec<(String, usize)> {
898 let mut functions: Vec<(String, usize)> =
899 self.function_frequency.iter().map(|(k, &v)| (k.clone(), v)).collect();
900 functions.sort_by_key(|item| std::cmp::Reverse(item.1));
901 functions.into_iter().take(n).collect()
902 }
903
904 pub fn top_operators(&self, n: usize) -> Vec<(String, usize)> {
906 let mut operators: Vec<(String, usize)> =
907 self.operator_frequency.iter().map(|(k, &v)| (k.clone(), v)).collect();
908 operators.sort_by_key(|item| std::cmp::Reverse(item.1));
909 operators.into_iter().take(n).collect()
910 }
911
912 pub fn complexity_score(&self) -> f64 {
914 if self.total_tokens == 0 {
915 return 0.0;
916 }
917
918 let _type_diversity = (self.functions + self.symbols + self.latex_commands) as f64;
919 let token_diversity = self.unique_token_count as f64 / self.total_tokens as f64;
920
921 let advanced_math_score =
924 (self.functions * 3 + self.symbols * 2 + self.latex_commands * 3) as f64;
925 let greek_constants_score = (self.greek_letters + self.constants) as f64;
926 let operator_complexity = self.operators as f64 * 0.5;
927
928 advanced_math_score + greek_constants_score + operator_complexity + (token_diversity * 2.0)
929 }
930}
931
932#[cfg(test)]
933mod tests {
934 use super::*;
935
936 #[test]
937 fn test_math_tokenizer_creation() {
938 let tokenizer = MathTokenizer::new().expect("Construction failed");
939 assert!(tokenizer.math_functions.contains("sin"));
940 assert!(tokenizer.greek_letters.contains("π"));
941 assert!(tokenizer.math_operators.contains("+"));
942 }
943
944 #[test]
945 fn test_number_tokenization() {
946 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
947 let tokens = tokenizer.tokenize_math("123 3.14 2e10").expect("Operation failed in test");
948
949 assert_eq!(tokens.len(), 3);
950 assert_eq!(tokens[0].text, "123");
951 assert_eq!(tokens[0].token_type, MathTokenType::Number);
952 assert_eq!(tokens[1].text, "3.14");
953 assert_eq!(tokens[1].token_type, MathTokenType::Number);
954 assert_eq!(tokens[2].text, "2e10");
955 assert_eq!(tokens[2].token_type, MathTokenType::Number);
956 }
957
958 #[test]
959 fn test_function_tokenization() {
960 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
961 let tokens = tokenizer
962 .tokenize_math("sin(x) cos(θ) log(n)")
963 .expect("Operation failed in test");
964
965 let function_tokens: Vec<&MathToken> =
966 tokens.iter().filter(|t| t.token_type == MathTokenType::Function).collect();
967
968 assert_eq!(function_tokens.len(), 3);
969 assert_eq!(function_tokens[0].text, "sin");
970 assert_eq!(function_tokens[1].text, "cos");
971 assert_eq!(function_tokens[2].text, "log");
972 }
973
974 #[test]
975 fn test_latex_commands() {
976 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
977 let tokens = tokenizer
978 .tokenize_math("\\frac{x}{y} \\sum_{i=1}^n")
979 .expect("Operation failed in test");
980
981 let latex_tokens: Vec<&MathToken> =
982 tokens.iter().filter(|t| t.token_type == MathTokenType::LaTeXCommand).collect();
983
984 assert_eq!(latex_tokens.len(), 2);
985 assert_eq!(latex_tokens[0].text, "\\frac");
986 assert_eq!(latex_tokens[1].text, "\\sum");
987 }
988
989 #[test]
990 fn test_greek_letters() {
991 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
992 let tokens = tokenizer.tokenize_math("α + β = γ").expect("Operation failed in test");
993
994 let greek_tokens: Vec<&MathToken> =
995 tokens.iter().filter(|t| t.token_type == MathTokenType::GreekLetter).collect();
996
997 assert_eq!(greek_tokens.len(), 3);
998 assert_eq!(greek_tokens[0].text, "α");
999 assert_eq!(greek_tokens[1].text, "β");
1000 assert_eq!(greek_tokens[2].text, "γ");
1001 }
1002
1003 #[test]
1004 fn test_operators_and_symbols() {
1005 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1006 let tokens = tokenizer.tokenize_math("x ∈ ℝ, x ≥ 0").expect("Operation failed in test");
1007
1008 let symbol_tokens: Vec<&MathToken> = tokens
1009 .iter()
1010 .filter(|t| {
1011 matches!(
1012 t.token_type,
1013 MathTokenType::Symbol | MathTokenType::Operator
1014 )
1015 })
1016 .collect();
1017
1018 assert!(symbol_tokens.len() >= 2);
1019 }
1020
1021 #[test]
1022 fn test_tokenizer_interface() {
1023 let tokenizer = MathTokenizer::new().expect("Construction failed");
1024
1025 let encoded = tokenizer.encode("x^2 + y^2 = r^2").expect("Encoding failed");
1026 assert!(!encoded.input_ids.is_empty());
1027 let tokens: Vec<String> = encoded
1029 .input_ids
1030 .iter()
1031 .map(|&id| tokenizer.id_to_token(id).unwrap_or_else(|| format!("UNK_{}", id)))
1032 .collect();
1033 assert!(!tokens.is_empty());
1034
1035 let vocab = tokenizer.get_vocab();
1037 assert!(!vocab.is_empty());
1038 }
1039
1040 #[test]
1041 fn test_math_analysis() {
1042 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1043 let tokens = tokenizer
1044 .tokenize_math("sin(x) + cos(y) = 1")
1045 .expect("Operation failed in test");
1046 let analysis = tokenizer.analyze_math(&tokens);
1047
1048 assert!(analysis.total_tokens > 0);
1049 assert!(analysis.functions >= 2); assert!(analysis.operators >= 2); assert!(analysis.variables >= 2); assert!(analysis.numbers >= 1); }
1054
1055 #[test]
1056 fn test_custom_config() {
1057 let mut config = MathTokenizerConfig::default();
1058 config.custom_functions.insert("myFunc".to_string());
1059 config.preserve_whitespace = true;
1060
1061 let mut tokenizer = MathTokenizer::with_config(config).expect("Operation failed in test");
1062 let tokens = tokenizer.tokenize_math("myFunc (x)").expect("Operation failed in test");
1063
1064 assert_eq!(tokens.len(), 5);
1066 assert_eq!(tokens[0].token_type, MathTokenType::Function);
1067 assert_eq!(tokens[1].token_type, MathTokenType::Whitespace);
1068 }
1069
1070 #[test]
1071 fn test_scientific_notation() {
1072 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1073 let tokens =
1074 tokenizer.tokenize_math("6.022e23 1.602E-19").expect("Operation failed in test");
1075
1076 assert_eq!(tokens.len(), 2);
1077 assert_eq!(tokens[0].text, "6.022e23");
1078 assert_eq!(tokens[0].token_type, MathTokenType::Number);
1079 assert_eq!(tokens[1].text, "1.602E-19");
1080 assert_eq!(tokens[1].token_type, MathTokenType::Number);
1081 }
1082
1083 #[test]
1084 fn test_complexity_analysis() {
1085 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1086
1087 let simple_tokens = tokenizer.tokenize_math("x + 1").expect("Operation failed in test");
1089 let simple_analysis = tokenizer.analyze_math(&simple_tokens);
1090
1091 let complex_tokens = tokenizer
1093 .tokenize_math("∫₀^∞ e^(-x²) dx = √π/2")
1094 .expect("Operation failed in test");
1095 let complex_analysis = tokenizer.analyze_math(&complex_tokens);
1096
1097 assert!(complex_analysis.complexity_score() > simple_analysis.complexity_score());
1098 }
1099
1100 #[test]
1101 fn test_vocab_stats() {
1102 let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1103 tokenizer.tokenize_math("sin(x) + cos(y)").expect("Operation failed in test");
1104
1105 let stats = tokenizer.vocab_stats();
1106 assert!(stats.contains_key("total_tokens"));
1107 assert!(stats["total_tokens"] > 0);
1108 }
1109}