Skip to main content

trustformers_tokenizers/
math_tokenizer.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7/// Types of mathematical tokens
8#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum MathTokenType {
10    /// Numbers (integers, floats, scientific notation)
11    Number,
12    /// Variables (single letters, subscripted/superscripted)
13    Variable,
14    /// Mathematical operators (+, -, *, /, ^, etc.)
15    Operator,
16    /// Mathematical functions (sin, cos, log, etc.)
17    Function,
18    /// Greek letters (α, β, γ, etc.)
19    GreekLetter,
20    /// Mathematical constants (π, e, ∞, etc.)
21    Constant,
22    /// Delimiters (parentheses, brackets, braces)
23    Delimiter,
24    /// LaTeX commands (\frac, \sum, \int, etc.)
25    LaTeXCommand,
26    /// Subscripts and superscripts
27    Script,
28    /// Mathematical symbols (∈, ∀, ∃, etc.)
29    Symbol,
30    /// Units (m, kg, s, etc.)
31    Unit,
32    /// Text within math (for labels, etc.)
33    Text,
34    /// Whitespace
35    Whitespace,
36    /// Unknown/other
37    Unknown,
38}
39
40/// A mathematical token with position and type information
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct MathToken {
43    /// The token text
44    pub text: String,
45    /// Token type
46    pub token_type: MathTokenType,
47    /// Start position in original text
48    pub start: usize,
49    /// End position in original text
50    pub end: usize,
51    /// Token ID (assigned during tokenization)
52    pub id: Option<u32>,
53    /// LaTeX representation if different from text
54    pub latex: Option<String>,
55    /// Mathematical meaning/description
56    pub meaning: Option<String>,
57}
58
59impl MathToken {
60    /// Create a new math token
61    pub fn new(text: String, token_type: MathTokenType, start: usize, end: usize) -> Self {
62        Self {
63            text,
64            token_type,
65            start,
66            end,
67            id: None,
68            latex: None,
69            meaning: None,
70        }
71    }
72
73    /// Set the token ID
74    pub fn with_id(mut self, id: u32) -> Self {
75        self.id = Some(id);
76        self
77    }
78
79    /// Set the LaTeX representation
80    pub fn with_latex(mut self, latex: String) -> Self {
81        self.latex = Some(latex);
82        self
83    }
84
85    /// Set the meaning
86    pub fn with_meaning(mut self, meaning: String) -> Self {
87        self.meaning = Some(meaning);
88        self
89    }
90}
91
92/// Configuration for math tokenizer
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct MathTokenizerConfig {
95    /// Whether to preserve whitespace tokens
96    pub preserve_whitespace: bool,
97    /// Whether to recognize LaTeX commands
98    pub recognize_latex: bool,
99    /// Whether to recognize scientific notation
100    pub recognize_scientific_notation: bool,
101    /// Whether to handle subscripts and superscripts
102    pub handle_scripts: bool,
103    /// Whether to recognize units
104    pub recognize_units: bool,
105    /// Maximum token length
106    pub max_token_length: usize,
107    /// Custom function names to recognize
108    pub custom_functions: HashSet<String>,
109    /// Custom constants to recognize
110    pub custom_constants: HashMap<String, String>, // symbol -> meaning
111}
112
113impl Default for MathTokenizerConfig {
114    fn default() -> Self {
115        Self {
116            preserve_whitespace: false,
117            recognize_latex: true,
118            recognize_scientific_notation: true,
119            handle_scripts: true,
120            recognize_units: true,
121            max_token_length: 50,
122            custom_functions: HashSet::new(),
123            custom_constants: HashMap::new(),
124        }
125    }
126}
127
128/// Mathematical formula tokenizer
129pub struct MathTokenizer {
130    config: MathTokenizerConfig,
131    /// Regular expressions for tokenization
132    number_regex: Regex,
133    scientific_regex: Regex,
134    latex_command_regex: Regex,
135    greek_letters: HashSet<String>,
136    math_functions: HashSet<String>,
137    math_constants: HashMap<String, String>,
138    math_operators: HashSet<String>,
139    math_symbols: HashMap<String, String>,
140    units: HashSet<String>,
141    /// Vocabulary mappings
142    token_to_id: HashMap<String, u32>,
143    id_to_token: HashMap<u32, String>,
144    next_id: u32,
145}
146
147impl MathTokenizer {
148    /// Create a new math tokenizer with default configuration
149    pub fn new() -> Result<Self> {
150        Self::with_config(MathTokenizerConfig::default())
151    }
152
153    /// Create a new math tokenizer with custom configuration
154    pub fn with_config(config: MathTokenizerConfig) -> Result<Self> {
155        let number_regex = Regex::new(r"^\d+\.?\d*$").expect("valid regex");
156        let scientific_regex = Regex::new(r"^\d+\.?\d*[eE][+-]?\d+$").expect("valid regex");
157        let latex_command_regex = Regex::new(r"^\\[a-zA-Z]+$").expect("valid regex");
158
159        // Build Greek letters set
160        let greek_letters = [
161            "α", "β", "γ", "δ", "ε", "ζ", "η", "θ", "ι", "κ", "λ", "μ", "ν", "ξ", "ο", "π", "ρ",
162            "σ", "τ", "υ", "φ", "χ", "ψ", "ω", "Α", "Β", "Γ", "Δ", "Ε", "Ζ", "Η", "Θ", "Ι", "Κ",
163            "Λ", "Μ", "Ν", "Ξ", "Ο", "Π", "Ρ", "Σ", "Τ", "Υ", "Φ", "Χ", "Ψ", "Ω",
164        ]
165        .iter()
166        .map(|s| s.to_string())
167        .collect();
168
169        // Build math functions set
170        let mut math_functions = [
171            "sin", "cos", "tan", "sec", "csc", "cot", "arcsin", "arccos", "arctan", "asin", "acos",
172            "atan", "sinh", "cosh", "tanh", "sech", "csch", "coth", "log", "ln", "lg", "exp",
173            "sqrt", "abs", "sgn", "min", "max", "gcd", "lcm", "floor", "ceil", "det", "tr", "rank",
174            "dim", "span", "ker", "im",
175        ]
176        .iter()
177        .map(|s| s.to_string())
178        .collect::<HashSet<_>>();
179
180        // Add custom functions
181        math_functions.extend(config.custom_functions.clone());
182
183        // Build math constants
184        let mut math_constants = HashMap::new();
185        math_constants.insert("π".to_string(), "pi".to_string());
186        math_constants.insert("e".to_string(), "euler_number".to_string());
187        math_constants.insert("∞".to_string(), "infinity".to_string());
188        math_constants.insert("i".to_string(), "imaginary_unit".to_string());
189        math_constants.insert("φ".to_string(), "golden_ratio".to_string());
190        math_constants.insert("γ".to_string(), "euler_gamma".to_string());
191
192        // Add custom constants
193        math_constants.extend(config.custom_constants.clone());
194
195        // Build operators set
196        let math_operators = [
197            "+", "-", "*", "×", "·", "/", "÷", "^", "=", "≠", "≈", "≡", "<", ">", "≤", "≥", "≪",
198            "≫", "±", "∓", "∝", "∼", "≅",
199        ]
200        .iter()
201        .map(|s| s.to_string())
202        .collect();
203
204        // Build symbols map
205        let mut math_symbols = HashMap::new();
206        math_symbols.insert("∈".to_string(), "element_of".to_string());
207        math_symbols.insert("∉".to_string(), "not_element_of".to_string());
208        math_symbols.insert("⊂".to_string(), "subset".to_string());
209        math_symbols.insert("⊃".to_string(), "superset".to_string());
210        math_symbols.insert("⊆".to_string(), "subset_equal".to_string());
211        math_symbols.insert("⊇".to_string(), "superset_equal".to_string());
212        math_symbols.insert("∪".to_string(), "union".to_string());
213        math_symbols.insert("∩".to_string(), "intersection".to_string());
214        math_symbols.insert("∅".to_string(), "empty_set".to_string());
215        math_symbols.insert("∀".to_string(), "for_all".to_string());
216        math_symbols.insert("∃".to_string(), "exists".to_string());
217        math_symbols.insert("∄".to_string(), "not_exists".to_string());
218        math_symbols.insert("∇".to_string(), "nabla".to_string());
219        math_symbols.insert("∂".to_string(), "partial".to_string());
220        math_symbols.insert("∫".to_string(), "integral".to_string());
221        math_symbols.insert("∬".to_string(), "double_integral".to_string());
222        math_symbols.insert("∭".to_string(), "triple_integral".to_string());
223        math_symbols.insert("∮".to_string(), "contour_integral".to_string());
224        math_symbols.insert("∑".to_string(), "sum".to_string());
225        math_symbols.insert("∏".to_string(), "product".to_string());
226        math_symbols.insert("→".to_string(), "right_arrow".to_string());
227        math_symbols.insert("←".to_string(), "left_arrow".to_string());
228        math_symbols.insert("↔".to_string(), "double_arrow".to_string());
229        math_symbols.insert("⇒".to_string(), "implies".to_string());
230        math_symbols.insert("⇔".to_string(), "if_and_only_if".to_string());
231        math_symbols.insert("√".to_string(), "square_root".to_string());
232
233        // Build units set
234        let units = [
235            // Length
236            "m", "cm", "mm", "km", "in", "ft", "yd", "mi", // Mass
237            "g", "kg", "mg", "lb", "oz", // Time
238            "s", "ms", "min", "h", "hr", "day", "yr", // Temperature
239            "K", "°C", "°F", // Energy
240            "J", "kJ", "cal", "kcal", "eV", "keV", "MeV", "GeV", // Power
241            "W", "kW", "MW", "hp", // Frequency
242            "Hz", "kHz", "MHz", "GHz", // Voltage
243            "V", "mV", "kV", // Current
244            "A", "mA", "μA", // Resistance
245            "Ω", "kΩ", "MΩ",
246        ]
247        .iter()
248        .map(|s| s.to_string())
249        .collect();
250
251        let mut tokenizer = Self {
252            config,
253            number_regex,
254            scientific_regex,
255            latex_command_regex,
256            greek_letters,
257            math_functions,
258            math_constants,
259            math_operators,
260            math_symbols,
261            units,
262            token_to_id: HashMap::new(),
263            id_to_token: HashMap::new(),
264            next_id: 1, // Reserve 0 for special tokens
265        };
266
267        // Initialize vocabulary with all mathematical tokens
268        tokenizer.initialize_vocabulary();
269        Ok(tokenizer)
270    }
271
272    /// Initialize vocabulary with all mathematical tokens
273    fn initialize_vocabulary(&mut self) {
274        // Collect data to avoid borrowing issues
275        let greek_letters: Vec<String> = self.greek_letters.iter().cloned().collect();
276        let math_functions: Vec<String> = self.math_functions.iter().cloned().collect();
277        let math_constants: Vec<String> = self.math_constants.keys().cloned().collect();
278        let math_operators: Vec<String> = self.math_operators.iter().cloned().collect();
279        let math_symbols: Vec<String> = self.math_symbols.keys().cloned().collect();
280        let units: Vec<String> = self.units.iter().cloned().collect();
281
282        // Add Greek letters
283        for letter in &greek_letters {
284            self.add_token_with_type(letter, "GreekLetter");
285        }
286
287        // Add mathematical functions
288        for function in &math_functions {
289            self.add_token_with_type(function, "Function");
290        }
291
292        // Add mathematical constants
293        for constant in &math_constants {
294            self.add_token_with_type(constant, "Constant");
295        }
296
297        // Add mathematical operators
298        for operator in &math_operators {
299            self.add_token_with_type(operator, "Operator");
300        }
301
302        // Add mathematical symbols
303        for symbol in &math_symbols {
304            self.add_token_with_type(symbol, "Symbol");
305        }
306
307        // Add units
308        for unit in &units {
309            self.add_token_with_type(unit, "Unit");
310        }
311
312        // Add common variables
313        for var in ["x", "y", "z", "a", "b", "c", "n", "m", "k", "i", "j"] {
314            self.add_token_with_type(var, "Variable");
315        }
316
317        // Add numbers 0-9
318        for num in 0..10 {
319            self.add_token_with_type(&num.to_string(), "Number");
320        }
321
322        // Add common punctuation and delimiters
323        for punct in ["(", ")", "[", "]", "{", "}", ",", ".", "!", " "] {
324            self.add_token_with_type(punct, "Punctuation");
325        }
326    }
327
328    /// Add a token with type information to the vocabulary
329    fn add_token_with_type(&mut self, token: &str, token_type: &str) {
330        let key = format!("{}:{}", token, token_type);
331        if !self.token_to_id.contains_key(&key) {
332            let id = self.next_id;
333            self.token_to_id.insert(key.clone(), id);
334            self.id_to_token.insert(id, key);
335            self.next_id += 1;
336        }
337    }
338
339    /// Tokenize mathematical text into MathTokens
340    pub fn tokenize_math(&mut self, text: &str) -> Result<Vec<MathToken>> {
341        let mut tokens = Vec::new();
342        let mut chars = text.char_indices().peekable();
343
344        while let Some((pos, ch)) = chars.next() {
345            let start_pos = pos;
346
347            // Skip whitespace unless configured to preserve it
348            if ch.is_whitespace() {
349                if self.config.preserve_whitespace {
350                    let mut whitespace = String::new();
351                    whitespace.push(ch);
352
353                    // Collect consecutive whitespace
354                    while let Some(&(_, next_ch)) = chars.peek() {
355                        if next_ch.is_whitespace() {
356                            chars.next();
357                            whitespace.push(next_ch);
358                        } else {
359                            break;
360                        }
361                    }
362
363                    let end_pos = start_pos + whitespace.len();
364                    tokens.push(MathToken::new(
365                        whitespace,
366                        MathTokenType::Whitespace,
367                        start_pos,
368                        end_pos,
369                    ));
370                }
371                continue;
372            }
373
374            // Try to match different token types
375            if let Some(token) = self.try_match_number(&mut chars, ch, start_pos)? {
376                tokens.push(token);
377            } else if let Some(token) = self.try_match_latex_command(&mut chars, ch, start_pos)? {
378                tokens.push(token);
379            } else if let Some(token) = self.try_match_function(&mut chars, ch, start_pos)? {
380                tokens.push(token);
381            } else if let Some(token) =
382                self.try_match_multi_char_symbol(&mut chars, ch, start_pos)?
383            {
384                tokens.push(token);
385            } else {
386                // Single character token
387                let token_text = ch.to_string();
388                let token_type = self.classify_single_char(&token_text);
389                let end_pos = start_pos + ch.len_utf8();
390
391                tokens.push(MathToken::new(token_text, token_type, start_pos, end_pos));
392            }
393        }
394
395        Ok(tokens)
396    }
397
398    /// Try to match a number (including scientific notation)
399    fn try_match_number(
400        &self,
401        chars: &mut std::iter::Peekable<std::str::CharIndices>,
402        first_char: char,
403        start_pos: usize,
404    ) -> Result<Option<MathToken>> {
405        if !first_char.is_ascii_digit() && first_char != '.' {
406            return Ok(None);
407        }
408
409        let mut number = String::new();
410        number.push(first_char);
411        let mut current_pos = start_pos + first_char.len_utf8();
412
413        // Collect digits and decimal point
414        while let Some(&(_, ch)) = chars.peek() {
415            if ch.is_ascii_digit() || ch == '.' {
416                chars.next();
417                number.push(ch);
418                current_pos += ch.len_utf8();
419            } else {
420                break;
421            }
422        }
423
424        // Check for scientific notation
425        if let Some(&(_, ch)) = chars.peek() {
426            if ch == 'e' || ch == 'E' {
427                let mut temp_number = number.clone();
428                temp_number.push(ch);
429                chars.next();
430                current_pos += ch.len_utf8();
431
432                // Check for optional sign
433                if let Some(&(_, sign_ch)) = chars.peek() {
434                    if sign_ch == '+' || sign_ch == '-' {
435                        chars.next();
436                        temp_number.push(sign_ch);
437                        current_pos += sign_ch.len_utf8();
438                    }
439                }
440
441                // Must have digits after E
442                let mut has_exponent_digits = false;
443                while let Some(&(_, ch)) = chars.peek() {
444                    if ch.is_ascii_digit() {
445                        chars.next();
446                        temp_number.push(ch);
447                        current_pos += ch.len_utf8();
448                        has_exponent_digits = true;
449                    } else {
450                        break;
451                    }
452                }
453
454                if has_exponent_digits {
455                    number = temp_number;
456                }
457            }
458        }
459
460        let token_type = if (self.config.recognize_scientific_notation
461            && self.scientific_regex.is_match(&number))
462            || self.number_regex.is_match(&number)
463        {
464            MathTokenType::Number
465        } else {
466            MathTokenType::Unknown
467        };
468
469        Ok(Some(MathToken::new(
470            number,
471            token_type,
472            start_pos,
473            current_pos,
474        )))
475    }
476
477    /// Try to match a LaTeX command
478    fn try_match_latex_command(
479        &self,
480        chars: &mut std::iter::Peekable<std::str::CharIndices>,
481        first_char: char,
482        start_pos: usize,
483    ) -> Result<Option<MathToken>> {
484        if !self.config.recognize_latex || first_char != '\\' {
485            return Ok(None);
486        }
487
488        let mut command = String::new();
489        command.push(first_char);
490        let mut current_pos = start_pos + first_char.len_utf8();
491
492        // Collect alphabetic characters
493        while let Some(&(_, ch)) = chars.peek() {
494            if ch.is_ascii_alphabetic() {
495                chars.next();
496                command.push(ch);
497                current_pos += ch.len_utf8();
498            } else {
499                break;
500            }
501        }
502
503        if command.len() > 1 && self.latex_command_regex.is_match(&command) {
504            Ok(Some(MathToken::new(
505                command,
506                MathTokenType::LaTeXCommand,
507                start_pos,
508                current_pos,
509            )))
510        } else {
511            Ok(None)
512        }
513    }
514
515    /// Try to match a mathematical function
516    fn try_match_function(
517        &self,
518        chars: &mut std::iter::Peekable<std::str::CharIndices>,
519        first_char: char,
520        start_pos: usize,
521    ) -> Result<Option<MathToken>> {
522        if !first_char.is_ascii_alphabetic() {
523            return Ok(None);
524        }
525
526        let mut function = String::new();
527        function.push(first_char);
528        let mut current_pos = start_pos + first_char.len_utf8();
529
530        // Look ahead to build potential function name
531        let saved_chars = chars.clone();
532        while let Some(&(_, ch)) = chars.peek() {
533            if ch.is_ascii_alphabetic() {
534                chars.next();
535                function.push(ch);
536                current_pos += ch.len_utf8();
537            } else {
538                break;
539            }
540        }
541
542        if self.math_functions.contains(&function) {
543            Ok(Some(MathToken::new(
544                function,
545                MathTokenType::Function,
546                start_pos,
547                current_pos,
548            )))
549        } else {
550            // Restore chars iterator and try shorter match
551            *chars = saved_chars;
552            Ok(None)
553        }
554    }
555
556    /// Try to match multi-character symbols
557    fn try_match_multi_char_symbol(
558        &self,
559        chars: &mut std::iter::Peekable<std::str::CharIndices>,
560        first_char: char,
561        start_pos: usize,
562    ) -> Result<Option<MathToken>> {
563        // Look for 2-3 character symbols first
564        let mut symbol = String::new();
565        symbol.push(first_char);
566        let mut current_pos = start_pos + first_char.len_utf8();
567
568        // Try to match longer symbols first
569        let saved_chars = chars.clone();
570        for _ in 0..2 {
571            if let Some(&(_, ch)) = chars.peek() {
572                let temp_symbol = format!("{}{}", symbol, ch);
573                if self.math_symbols.contains_key(&temp_symbol)
574                    || self.math_operators.contains(&temp_symbol)
575                {
576                    chars.next();
577                    symbol = temp_symbol;
578                    current_pos += ch.len_utf8();
579                } else {
580                    break;
581                }
582            } else {
583                break;
584            }
585        }
586
587        if symbol.chars().count() > 1 {
588            let token_type = if self.math_symbols.contains_key(&symbol) {
589                MathTokenType::Symbol
590            } else if self.math_operators.contains(&symbol) {
591                MathTokenType::Operator
592            } else {
593                MathTokenType::Unknown
594            };
595            Ok(Some(MathToken::new(
596                symbol,
597                token_type,
598                start_pos,
599                current_pos,
600            )))
601        } else {
602            *chars = saved_chars;
603            Ok(None)
604        }
605    }
606
607    /// Classify a single character token
608    fn classify_single_char(&self, ch: &str) -> MathTokenType {
609        if self.greek_letters.contains(ch) {
610            MathTokenType::GreekLetter
611        } else if self.math_constants.contains_key(ch) {
612            MathTokenType::Constant
613        } else if self.math_operators.contains(ch) {
614            MathTokenType::Operator
615        } else if self.math_symbols.contains_key(ch) {
616            MathTokenType::Symbol
617        } else if self.units.contains(ch) {
618            MathTokenType::Unit
619        } else if matches!(ch, "(" | ")" | "[" | "]" | "{" | "}") {
620            MathTokenType::Delimiter
621        } else if ch.chars().all(|c| c.is_ascii_alphabetic()) {
622            MathTokenType::Variable
623        } else {
624            MathTokenType::Unknown
625        }
626    }
627
628    /// Assign IDs to tokens and build vocabulary
629    fn assign_token_ids(&mut self, tokens: &mut [MathToken]) {
630        for token in tokens {
631            let token_key = format!("{}:{:?}", token.text, token.token_type);
632
633            if let Some(&id) = self.token_to_id.get(&token_key) {
634                token.id = Some(id);
635            } else {
636                let id = self.next_id;
637                self.next_id += 1;
638
639                self.token_to_id.insert(token_key.clone(), id);
640                self.id_to_token.insert(id, token_key);
641                token.id = Some(id);
642            }
643        }
644    }
645
646    /// Convert MathTokens to standard TokenizedInput
647    pub fn math_tokens_to_input(&mut self, mut tokens: Vec<MathToken>) -> TokenizedInput {
648        self.assign_token_ids(&mut tokens);
649
650        let input_ids: Vec<u32> = tokens.iter().filter_map(|t| t.id).collect();
651
652        let token_strings: Vec<String> = tokens.iter().map(|t| t.text.clone()).collect();
653
654        let _offsets: Vec<(u32, u32)> =
655            tokens.iter().map(|t| (t.start as u32, t.end as u32)).collect();
656
657        TokenizedInput {
658            input_ids,
659            attention_mask: vec![1u8; token_strings.len()],
660            token_type_ids: None,
661            special_tokens_mask: None,
662            offset_mapping: None,
663            overflowing_tokens: None,
664        }
665    }
666
667    /// Get mathematical analysis of tokenized text
668    pub fn analyze_math(&self, tokens: &[MathToken]) -> MathAnalysis {
669        let mut analysis = MathAnalysis::new();
670
671        for token in tokens {
672            analysis.total_tokens += 1;
673
674            match token.token_type {
675                MathTokenType::Number => analysis.numbers += 1,
676                MathTokenType::Variable => analysis.variables += 1,
677                MathTokenType::Operator => analysis.operators += 1,
678                MathTokenType::Function => analysis.functions += 1,
679                MathTokenType::GreekLetter => analysis.greek_letters += 1,
680                MathTokenType::Constant => analysis.constants += 1,
681                MathTokenType::Delimiter => analysis.delimiters += 1,
682                MathTokenType::LaTeXCommand => analysis.latex_commands += 1,
683                MathTokenType::Script => analysis.scripts += 1,
684                MathTokenType::Symbol => analysis.symbols += 1,
685                MathTokenType::Unit => analysis.units += 1,
686                MathTokenType::Text => analysis.text_tokens += 1,
687                MathTokenType::Whitespace => analysis.whitespace += 1,
688                MathTokenType::Unknown => analysis.unknown += 1,
689            }
690
691            // Track unique tokens
692            analysis.unique_tokens.insert(token.text.clone());
693
694            // Track function names
695            if token.token_type == MathTokenType::Function {
696                *analysis.function_frequency.entry(token.text.clone()).or_insert(0) += 1;
697            }
698
699            // Track operators
700            if token.token_type == MathTokenType::Operator {
701                *analysis.operator_frequency.entry(token.text.clone()).or_insert(0) += 1;
702            }
703        }
704
705        analysis.unique_token_count = analysis.unique_tokens.len();
706        analysis
707    }
708
709    /// Get vocabulary statistics
710    pub fn vocab_stats(&self) -> HashMap<String, usize> {
711        let mut stats = HashMap::new();
712        stats.insert("total_tokens".to_string(), self.token_to_id.len());
713        stats.insert("next_id".to_string(), self.next_id as usize);
714
715        // Count by token type
716        let mut type_counts: HashMap<MathTokenType, usize> = HashMap::new();
717        for key in self.token_to_id.keys() {
718            if let Some(type_str) = key.split(':').nth(1) {
719                if let Ok(token_type) =
720                    serde_json::from_str::<MathTokenType>(&format!("\"{}\"", type_str))
721                {
722                    *type_counts.entry(token_type).or_insert(0) += 1;
723                }
724            }
725        }
726
727        for (token_type, count) in type_counts {
728            stats.insert(format!("{:?}_tokens", token_type).to_lowercase(), count);
729        }
730
731        stats
732    }
733}
734
735impl Default for MathTokenizer {
736    fn default() -> Self {
737        Self::new().expect("MathTokenizer::new() should not fail with default config")
738    }
739}
740
741impl Tokenizer for MathTokenizer {
742    fn encode(&self, text: &str) -> Result<TokenizedInput> {
743        let mut tokenizer = self.clone();
744        let tokens = tokenizer.tokenize_math(text)?;
745        Ok(tokenizer.math_tokens_to_input(tokens))
746    }
747
748    fn decode(&self, token_ids: &[u32]) -> Result<String> {
749        let tokens: std::result::Result<Vec<String>, TrustformersError> = token_ids
750            .iter()
751            .map(|&id| {
752                self.id_to_token
753                    .get(&id)
754                    .and_then(|key| key.split(':').next())
755                    .map(|s| s.to_string())
756                    .ok_or_else(|| TrustformersError::other(format!("Unknown token ID: {}", id)))
757            })
758            .collect();
759
760        Ok(tokens?.join(" "))
761    }
762
763    fn get_vocab(&self) -> HashMap<String, u32> {
764        self.token_to_id
765            .iter()
766            .map(|(key, &id)| {
767                let token = key.split(':').next().unwrap_or(key).to_string();
768                (token, id)
769            })
770            .collect()
771    }
772
773    fn token_to_id(&self, token: &str) -> Option<u32> {
774        // Try exact match first
775        self.token_to_id.get(token).copied().or_else(|| {
776            // Try with different token types
777            for token_type in [
778                MathTokenType::Number,
779                MathTokenType::Variable,
780                MathTokenType::Operator,
781                MathTokenType::Function,
782                MathTokenType::GreekLetter,
783                MathTokenType::Constant,
784            ] {
785                let key = format!("{}:{:?}", token, token_type);
786                if let Some(&id) = self.token_to_id.get(&key) {
787                    return Some(id);
788                }
789            }
790            None
791        })
792    }
793
794    fn id_to_token(&self, id: u32) -> Option<String> {
795        self.id_to_token
796            .get(&id)
797            .and_then(|key| key.split(':').next())
798            .map(|s| s.to_string())
799    }
800
801    fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
802        let mut tokenizer = self.clone();
803        let tokens_a = tokenizer.tokenize_math(text_a)?;
804        let tokens_b = tokenizer.tokenize_math(text_b)?;
805
806        let mut combined_tokens = tokens_a;
807        combined_tokens.push(MathToken {
808            text: "[SEP]".to_string(),
809            token_type: MathTokenType::Symbol,
810            start: 0,
811            end: 5,
812            id: None,
813            latex: None,
814            meaning: None,
815        });
816        combined_tokens.extend(tokens_b);
817
818        Ok(tokenizer.math_tokens_to_input(combined_tokens))
819    }
820
821    fn vocab_size(&self) -> usize {
822        self.token_to_id.len()
823    }
824}
825
826// Make MathTokenizer cloneable for the Tokenizer trait
827impl Clone for MathTokenizer {
828    fn clone(&self) -> Self {
829        Self {
830            config: self.config.clone(),
831            number_regex: Regex::new(r"^\d+\.?\d*$").expect("valid regex"),
832            scientific_regex: Regex::new(r"^\d+\.?\d*[eE][+-]?\d+$").expect("valid regex"),
833            latex_command_regex: Regex::new(r"^\\[a-zA-Z]+$").expect("valid regex"),
834            greek_letters: self.greek_letters.clone(),
835            math_functions: self.math_functions.clone(),
836            math_constants: self.math_constants.clone(),
837            math_operators: self.math_operators.clone(),
838            math_symbols: self.math_symbols.clone(),
839            units: self.units.clone(),
840            token_to_id: self.token_to_id.clone(),
841            id_to_token: self.id_to_token.clone(),
842            next_id: self.next_id,
843        }
844    }
845}
846
847/// Analysis results for mathematical text
848#[derive(Debug, Clone, Serialize, Deserialize)]
849pub struct MathAnalysis {
850    pub total_tokens: usize,
851    pub unique_token_count: usize,
852    pub numbers: usize,
853    pub variables: usize,
854    pub operators: usize,
855    pub functions: usize,
856    pub greek_letters: usize,
857    pub constants: usize,
858    pub delimiters: usize,
859    pub latex_commands: usize,
860    pub scripts: usize,
861    pub symbols: usize,
862    pub units: usize,
863    pub text_tokens: usize,
864    pub whitespace: usize,
865    pub unknown: usize,
866    pub unique_tokens: HashSet<String>,
867    pub function_frequency: HashMap<String, usize>,
868    pub operator_frequency: HashMap<String, usize>,
869}
870
871impl MathAnalysis {
872    fn new() -> Self {
873        Self {
874            total_tokens: 0,
875            unique_token_count: 0,
876            numbers: 0,
877            variables: 0,
878            operators: 0,
879            functions: 0,
880            greek_letters: 0,
881            constants: 0,
882            delimiters: 0,
883            latex_commands: 0,
884            scripts: 0,
885            symbols: 0,
886            units: 0,
887            text_tokens: 0,
888            whitespace: 0,
889            unknown: 0,
890            unique_tokens: HashSet::new(),
891            function_frequency: HashMap::new(),
892            operator_frequency: HashMap::new(),
893        }
894    }
895
896    /// Get the most common functions
897    pub fn top_functions(&self, n: usize) -> Vec<(String, usize)> {
898        let mut functions: Vec<(String, usize)> =
899            self.function_frequency.iter().map(|(k, &v)| (k.clone(), v)).collect();
900        functions.sort_by_key(|item| std::cmp::Reverse(item.1));
901        functions.into_iter().take(n).collect()
902    }
903
904    /// Get the most common operators
905    pub fn top_operators(&self, n: usize) -> Vec<(String, usize)> {
906        let mut operators: Vec<(String, usize)> =
907            self.operator_frequency.iter().map(|(k, &v)| (k.clone(), v)).collect();
908        operators.sort_by_key(|item| std::cmp::Reverse(item.1));
909        operators.into_iter().take(n).collect()
910    }
911
912    /// Calculate complexity score based on token diversity
913    pub fn complexity_score(&self) -> f64 {
914        if self.total_tokens == 0 {
915            return 0.0;
916        }
917
918        let _type_diversity = (self.functions + self.symbols + self.latex_commands) as f64;
919        let token_diversity = self.unique_token_count as f64 / self.total_tokens as f64;
920
921        // Combine different factors
922        // Weight mathematical complexity higher
923        let advanced_math_score =
924            (self.functions * 3 + self.symbols * 2 + self.latex_commands * 3) as f64;
925        let greek_constants_score = (self.greek_letters + self.constants) as f64;
926        let operator_complexity = self.operators as f64 * 0.5;
927
928        advanced_math_score + greek_constants_score + operator_complexity + (token_diversity * 2.0)
929    }
930}
931
932#[cfg(test)]
933mod tests {
934    use super::*;
935
936    #[test]
937    fn test_math_tokenizer_creation() {
938        let tokenizer = MathTokenizer::new().expect("Construction failed");
939        assert!(tokenizer.math_functions.contains("sin"));
940        assert!(tokenizer.greek_letters.contains("π"));
941        assert!(tokenizer.math_operators.contains("+"));
942    }
943
944    #[test]
945    fn test_number_tokenization() {
946        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
947        let tokens = tokenizer.tokenize_math("123 3.14 2e10").expect("Operation failed in test");
948
949        assert_eq!(tokens.len(), 3);
950        assert_eq!(tokens[0].text, "123");
951        assert_eq!(tokens[0].token_type, MathTokenType::Number);
952        assert_eq!(tokens[1].text, "3.14");
953        assert_eq!(tokens[1].token_type, MathTokenType::Number);
954        assert_eq!(tokens[2].text, "2e10");
955        assert_eq!(tokens[2].token_type, MathTokenType::Number);
956    }
957
958    #[test]
959    fn test_function_tokenization() {
960        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
961        let tokens = tokenizer
962            .tokenize_math("sin(x) cos(θ) log(n)")
963            .expect("Operation failed in test");
964
965        let function_tokens: Vec<&MathToken> =
966            tokens.iter().filter(|t| t.token_type == MathTokenType::Function).collect();
967
968        assert_eq!(function_tokens.len(), 3);
969        assert_eq!(function_tokens[0].text, "sin");
970        assert_eq!(function_tokens[1].text, "cos");
971        assert_eq!(function_tokens[2].text, "log");
972    }
973
974    #[test]
975    fn test_latex_commands() {
976        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
977        let tokens = tokenizer
978            .tokenize_math("\\frac{x}{y} \\sum_{i=1}^n")
979            .expect("Operation failed in test");
980
981        let latex_tokens: Vec<&MathToken> =
982            tokens.iter().filter(|t| t.token_type == MathTokenType::LaTeXCommand).collect();
983
984        assert_eq!(latex_tokens.len(), 2);
985        assert_eq!(latex_tokens[0].text, "\\frac");
986        assert_eq!(latex_tokens[1].text, "\\sum");
987    }
988
989    #[test]
990    fn test_greek_letters() {
991        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
992        let tokens = tokenizer.tokenize_math("α + β = γ").expect("Operation failed in test");
993
994        let greek_tokens: Vec<&MathToken> =
995            tokens.iter().filter(|t| t.token_type == MathTokenType::GreekLetter).collect();
996
997        assert_eq!(greek_tokens.len(), 3);
998        assert_eq!(greek_tokens[0].text, "α");
999        assert_eq!(greek_tokens[1].text, "β");
1000        assert_eq!(greek_tokens[2].text, "γ");
1001    }
1002
1003    #[test]
1004    fn test_operators_and_symbols() {
1005        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1006        let tokens = tokenizer.tokenize_math("x ∈ ℝ, x ≥ 0").expect("Operation failed in test");
1007
1008        let symbol_tokens: Vec<&MathToken> = tokens
1009            .iter()
1010            .filter(|t| {
1011                matches!(
1012                    t.token_type,
1013                    MathTokenType::Symbol | MathTokenType::Operator
1014                )
1015            })
1016            .collect();
1017
1018        assert!(symbol_tokens.len() >= 2);
1019    }
1020
1021    #[test]
1022    fn test_tokenizer_interface() {
1023        let tokenizer = MathTokenizer::new().expect("Construction failed");
1024
1025        let encoded = tokenizer.encode("x^2 + y^2 = r^2").expect("Encoding failed");
1026        assert!(!encoded.input_ids.is_empty());
1027        // Test that we can get tokens back from IDs
1028        let tokens: Vec<String> = encoded
1029            .input_ids
1030            .iter()
1031            .map(|&id| tokenizer.id_to_token(id).unwrap_or_else(|| format!("UNK_{}", id)))
1032            .collect();
1033        assert!(!tokens.is_empty());
1034
1035        // Test that we can get token mappings
1036        let vocab = tokenizer.get_vocab();
1037        assert!(!vocab.is_empty());
1038    }
1039
1040    #[test]
1041    fn test_math_analysis() {
1042        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1043        let tokens = tokenizer
1044            .tokenize_math("sin(x) + cos(y) = 1")
1045            .expect("Operation failed in test");
1046        let analysis = tokenizer.analyze_math(&tokens);
1047
1048        assert!(analysis.total_tokens > 0);
1049        assert!(analysis.functions >= 2); // sin, cos
1050        assert!(analysis.operators >= 2); // +, =
1051        assert!(analysis.variables >= 2); // x, y
1052        assert!(analysis.numbers >= 1); // 1
1053    }
1054
1055    #[test]
1056    fn test_custom_config() {
1057        let mut config = MathTokenizerConfig::default();
1058        config.custom_functions.insert("myFunc".to_string());
1059        config.preserve_whitespace = true;
1060
1061        let mut tokenizer = MathTokenizer::with_config(config).expect("Operation failed in test");
1062        let tokens = tokenizer.tokenize_math("myFunc (x)").expect("Operation failed in test");
1063
1064        // Should have function, whitespace, delimiter, variable, delimiter
1065        assert_eq!(tokens.len(), 5);
1066        assert_eq!(tokens[0].token_type, MathTokenType::Function);
1067        assert_eq!(tokens[1].token_type, MathTokenType::Whitespace);
1068    }
1069
1070    #[test]
1071    fn test_scientific_notation() {
1072        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1073        let tokens =
1074            tokenizer.tokenize_math("6.022e23 1.602E-19").expect("Operation failed in test");
1075
1076        assert_eq!(tokens.len(), 2);
1077        assert_eq!(tokens[0].text, "6.022e23");
1078        assert_eq!(tokens[0].token_type, MathTokenType::Number);
1079        assert_eq!(tokens[1].text, "1.602E-19");
1080        assert_eq!(tokens[1].token_type, MathTokenType::Number);
1081    }
1082
1083    #[test]
1084    fn test_complexity_analysis() {
1085        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1086
1087        // Simple expression
1088        let simple_tokens = tokenizer.tokenize_math("x + 1").expect("Operation failed in test");
1089        let simple_analysis = tokenizer.analyze_math(&simple_tokens);
1090
1091        // Complex expression
1092        let complex_tokens = tokenizer
1093            .tokenize_math("∫₀^∞ e^(-x²) dx = √π/2")
1094            .expect("Operation failed in test");
1095        let complex_analysis = tokenizer.analyze_math(&complex_tokens);
1096
1097        assert!(complex_analysis.complexity_score() > simple_analysis.complexity_score());
1098    }
1099
1100    #[test]
1101    fn test_vocab_stats() {
1102        let mut tokenizer = MathTokenizer::new().expect("Construction failed");
1103        tokenizer.tokenize_math("sin(x) + cos(y)").expect("Operation failed in test");
1104
1105        let stats = tokenizer.vocab_stats();
1106        assert!(stats.contains_key("total_tokens"));
1107        assert!(stats["total_tokens"] > 0);
1108    }
1109}