Skip to main content

trustformers_tokenizers/
code_tokenizer.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet};
3use trustformers_core::errors::Result;
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6/// Supported programming languages
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub enum Language {
9    Rust,
10    Python,
11    JavaScript,
12    TypeScript,
13    Java,
14    CSharp,
15    CPlusPlus,
16    C,
17    Go,
18    Ruby,
19    PHP,
20    Swift,
21    Kotlin,
22    Scala,
23    Haskell,
24    Clojure,
25    SQL,
26    HTML,
27    CSS,
28    JSON,
29    XML,
30    YAML,
31    Markdown,
32    Shell,
33    PowerShell,
34    R,
35    Matlab,
36}
37
38impl Language {
39    /// Get file extensions for the language
40    pub fn extensions(&self) -> &'static [&'static str] {
41        match self {
42            Language::Rust => &["rs"],
43            Language::Python => &["py", "pyx", "pyi", "pyw"],
44            Language::JavaScript => &["js", "jsx", "mjs", "cjs"],
45            Language::TypeScript => &["ts", "tsx", "d.ts"],
46            Language::Java => &["java"],
47            Language::CSharp => &["cs"],
48            Language::CPlusPlus => &["cpp", "cxx", "cc", "hpp", "hxx", "hh"],
49            Language::C => &["c", "h"],
50            Language::Go => &["go"],
51            Language::Ruby => &["rb", "rbx", "rjs", "gemspec"],
52            Language::PHP => &["php", "phtml", "php3", "php4", "php5"],
53            Language::Swift => &["swift"],
54            Language::Kotlin => &["kt", "kts"],
55            Language::Scala => &["scala", "sc"],
56            Language::Haskell => &["hs", "lhs"],
57            Language::Clojure => &["clj", "cljs", "cljc", "edn"],
58            Language::SQL => &["sql"],
59            Language::HTML => &["html", "htm", "xhtml"],
60            Language::CSS => &["css", "scss", "sass", "less"],
61            Language::JSON => &["json", "jsonl", "ndjson"],
62            Language::XML => &["xml", "xsd", "xsl", "xslt"],
63            Language::YAML => &["yaml", "yml"],
64            Language::Markdown => &["md", "markdown", "mdown", "mkd"],
65            Language::Shell => &["sh", "bash", "zsh", "fish"],
66            Language::PowerShell => &["ps1", "psm1", "psd1"],
67            Language::R => &["r", "R"],
68            Language::Matlab => &["m"],
69        }
70    }
71
72    /// Get keywords for the language
73    pub fn keywords(&self) -> &'static [&'static str] {
74        match self {
75            Language::Rust => &[
76                "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false",
77                "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut",
78                "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
79                "true", "type", "unsafe", "use", "where", "while", "async", "await", "dyn",
80            ],
81            Language::Python => &[
82                "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
83                "continue", "def", "del", "elif", "else", "except", "finally", "for", "from",
84                "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass",
85                "raise", "return", "try", "while", "with", "yield",
86            ],
87            Language::JavaScript | Language::TypeScript => &[
88                "break",
89                "case",
90                "catch",
91                "class",
92                "const",
93                "continue",
94                "debugger",
95                "default",
96                "delete",
97                "do",
98                "else",
99                "export",
100                "extends",
101                "false",
102                "finally",
103                "for",
104                "function",
105                "if",
106                "import",
107                "in",
108                "instanceof",
109                "new",
110                "null",
111                "return",
112                "super",
113                "switch",
114                "this",
115                "throw",
116                "true",
117                "try",
118                "typeof",
119                "var",
120                "void",
121                "while",
122                "with",
123                "yield",
124                "let",
125                "static",
126                "enum",
127                "implements",
128                "package",
129                "protected",
130                "interface",
131                "private",
132                "public",
133                "async",
134                "await",
135            ],
136            Language::Java => &[
137                "abstract",
138                "assert",
139                "boolean",
140                "break",
141                "byte",
142                "case",
143                "catch",
144                "char",
145                "class",
146                "const",
147                "continue",
148                "default",
149                "do",
150                "double",
151                "else",
152                "enum",
153                "extends",
154                "final",
155                "finally",
156                "float",
157                "for",
158                "goto",
159                "if",
160                "implements",
161                "import",
162                "instanceof",
163                "int",
164                "interface",
165                "long",
166                "native",
167                "new",
168                "package",
169                "private",
170                "protected",
171                "public",
172                "return",
173                "short",
174                "static",
175                "strictfp",
176                "super",
177                "switch",
178                "synchronized",
179                "this",
180                "throw",
181                "throws",
182                "transient",
183                "try",
184                "void",
185                "volatile",
186                "while",
187            ],
188            Language::CSharp => &[
189                "abstract",
190                "as",
191                "base",
192                "bool",
193                "break",
194                "byte",
195                "case",
196                "catch",
197                "char",
198                "checked",
199                "class",
200                "const",
201                "continue",
202                "decimal",
203                "default",
204                "delegate",
205                "do",
206                "double",
207                "else",
208                "enum",
209                "event",
210                "explicit",
211                "extern",
212                "false",
213                "finally",
214                "fixed",
215                "float",
216                "for",
217                "foreach",
218                "goto",
219                "if",
220                "implicit",
221                "in",
222                "int",
223                "interface",
224                "internal",
225                "is",
226                "lock",
227                "long",
228                "namespace",
229                "new",
230                "null",
231                "object",
232                "operator",
233                "out",
234                "override",
235                "params",
236                "private",
237                "protected",
238                "public",
239                "readonly",
240                "ref",
241                "return",
242                "sbyte",
243                "sealed",
244                "short",
245                "sizeof",
246                "stackalloc",
247                "static",
248                "string",
249                "struct",
250                "switch",
251                "this",
252                "throw",
253                "true",
254                "try",
255                "typeof",
256                "uint",
257                "ulong",
258                "unchecked",
259                "unsafe",
260                "ushort",
261                "using",
262                "virtual",
263                "void",
264                "volatile",
265                "while",
266            ],
267            Language::Go => &[
268                "break",
269                "case",
270                "chan",
271                "const",
272                "continue",
273                "default",
274                "defer",
275                "else",
276                "fallthrough",
277                "for",
278                "func",
279                "go",
280                "goto",
281                "if",
282                "import",
283                "interface",
284                "map",
285                "package",
286                "range",
287                "return",
288                "select",
289                "struct",
290                "switch",
291                "type",
292                "var",
293            ],
294            _ => &[], // Add more languages as needed
295        }
296    }
297
298    /// Get comment patterns for the language
299    pub fn comment_patterns(&self) -> CommentPatterns {
300        match self {
301            Language::Rust
302            | Language::JavaScript
303            | Language::TypeScript
304            | Language::Java
305            | Language::CSharp
306            | Language::CPlusPlus
307            | Language::Go
308            | Language::Swift
309            | Language::Kotlin
310            | Language::Scala => CommentPatterns {
311                line_comment: Some("//"),
312                block_comment: Some(("/*", "*/")),
313                doc_comment: Some("///"),
314            },
315            Language::Python | Language::Ruby | Language::Shell => CommentPatterns {
316                line_comment: Some("#"),
317                block_comment: None,
318                doc_comment: Some("#"),
319            },
320            Language::C => CommentPatterns {
321                line_comment: None,
322                block_comment: Some(("/*", "*/")),
323                doc_comment: None,
324            },
325            Language::HTML | Language::XML => CommentPatterns {
326                line_comment: None,
327                block_comment: Some(("<!--", "-->")),
328                doc_comment: None,
329            },
330            Language::CSS => CommentPatterns {
331                line_comment: None,
332                block_comment: Some(("/*", "*/")),
333                doc_comment: None,
334            },
335            Language::SQL => CommentPatterns {
336                line_comment: Some("--"),
337                block_comment: Some(("/*", "*/")),
338                doc_comment: None,
339            },
340            Language::Haskell => CommentPatterns {
341                line_comment: Some("--"),
342                block_comment: Some(("{-", "-}")),
343                doc_comment: Some("-- |"),
344            },
345            _ => CommentPatterns {
346                line_comment: None,
347                block_comment: None,
348                doc_comment: None,
349            },
350        }
351    }
352
353    /// Detect language from file extension
354    pub fn from_extension(ext: &str) -> Option<Language> {
355        let ext = ext.to_lowercase();
356        [
357            Language::Rust,
358            Language::Python,
359            Language::JavaScript,
360            Language::TypeScript,
361            Language::Java,
362            Language::CSharp,
363            Language::CPlusPlus,
364            Language::C,
365            Language::Go,
366            Language::Ruby,
367            Language::PHP,
368            Language::Swift,
369            Language::Kotlin,
370            Language::Scala,
371            Language::Haskell,
372            Language::Clojure,
373            Language::SQL,
374            Language::HTML,
375            Language::CSS,
376            Language::JSON,
377            Language::XML,
378            Language::YAML,
379            Language::Markdown,
380            Language::Shell,
381            Language::PowerShell,
382            Language::R,
383            Language::Matlab,
384        ]
385        .into_iter()
386        .find(|&lang| lang.extensions().contains(&ext.as_str()))
387    }
388}
389
390/// Comment patterns for a language
391#[derive(Debug, Clone)]
392pub struct CommentPatterns {
393    pub line_comment: Option<&'static str>,
394    pub block_comment: Option<(&'static str, &'static str)>,
395    pub doc_comment: Option<&'static str>,
396}
397
398/// Token types for code
399#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
400pub enum CodeTokenType {
401    Keyword,
402    Identifier,
403    Literal(LiteralType),
404    Operator,
405    Punctuation,
406    Comment,
407    Whitespace,
408    String,
409    Number,
410    Unknown,
411}
412
413/// Types of literals
414#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
415pub enum LiteralType {
416    String,
417    Character,
418    Integer,
419    Float,
420    Boolean,
421    Null,
422}
423
424/// A code token with type information
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct CodeToken {
427    pub text: String,
428    pub token_type: CodeTokenType,
429    pub position: TokenPosition,
430    pub language: Language,
431}
432
433/// Position information for a token
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct TokenPosition {
436    pub line: usize,
437    pub column: usize,
438    pub start_offset: usize,
439    pub end_offset: usize,
440}
441
442/// Configuration for code tokenization
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct CodeTokenizerConfig {
445    pub language: Option<Language>,
446    pub preserve_whitespace: bool,
447    pub preserve_comments: bool,
448    pub include_position_info: bool,
449    pub normalize_identifiers: bool,
450    pub max_token_length: Option<usize>,
451    pub custom_keywords: Option<HashSet<String>>,
452}
453
454impl Default for CodeTokenizerConfig {
455    fn default() -> Self {
456        Self {
457            language: None,
458            preserve_whitespace: false,
459            preserve_comments: true,
460            include_position_info: false,
461            normalize_identifiers: false,
462            max_token_length: Some(128),
463            custom_keywords: None,
464        }
465    }
466}
467
468/// Code tokenizer implementation
469pub struct CodeTokenizer {
470    config: CodeTokenizerConfig,
471    keywords: HashSet<String>,
472    token_to_id: HashMap<String, u32>,
473    id_to_token: HashMap<u32, String>,
474    special_tokens: HashMap<String, u32>,
475}
476
477impl CodeTokenizer {
478    /// Create a new code tokenizer
479    pub fn new(config: CodeTokenizerConfig) -> Self {
480        let mut tokenizer = Self {
481            config,
482            keywords: HashSet::new(),
483            token_to_id: HashMap::new(),
484            id_to_token: HashMap::new(),
485            special_tokens: HashMap::new(),
486        };
487
488        tokenizer.initialize_vocabulary();
489        tokenizer
490    }
491
492    /// Create tokenizer for a specific language
493    pub fn for_language(language: Language) -> Self {
494        let config = CodeTokenizerConfig {
495            language: Some(language),
496            ..Default::default()
497        };
498        Self::new(config)
499    }
500
501    /// Initialize vocabulary with common tokens
502    fn initialize_vocabulary(&mut self) {
503        let mut next_id = 0u32;
504
505        // Add special tokens
506        for special in &[
507            "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]",
508        ] {
509            self.add_token(special, &mut next_id);
510            self.special_tokens.insert(special.to_string(), next_id - 1);
511        }
512
513        // Add language keywords
514        if let Some(language) = self.config.language {
515            for keyword in language.keywords() {
516                self.keywords.insert(keyword.to_string());
517                self.add_token(keyword, &mut next_id);
518            }
519        }
520
521        // Add custom keywords
522        if let Some(custom_keywords) = &self.config.custom_keywords {
523            let keywords_to_add: Vec<String> = custom_keywords.iter().cloned().collect();
524            for keyword in keywords_to_add {
525                self.keywords.insert(keyword.clone());
526                self.add_token(&keyword, &mut next_id);
527            }
528        }
529
530        // Add common operators and punctuation
531        for op in &[
532            "+", "-", "*", "/", "%", "=", "==", "!=", "<", ">", "<=", ">=", "&&", "||", "!", "&",
533            "|", "^", "~", "<<", ">>", "++", "--", "+=", "-=", "*=", "/=", "%=", "(", ")", "[",
534            "]", "{", "}", ";", ",", ".", ":", "::", "->", "=>", "?",
535        ] {
536            self.add_token(op, &mut next_id);
537        }
538
539        // Add common literals
540        for literal in &["true", "false", "null", "undefined", "nil", "None"] {
541            self.add_token(literal, &mut next_id);
542        }
543    }
544
545    /// Add a token to the vocabulary
546    fn add_token(&mut self, token: &str, next_id: &mut u32) {
547        if !self.token_to_id.contains_key(token) {
548            self.token_to_id.insert(token.to_string(), *next_id);
549            self.id_to_token.insert(*next_id, token.to_string());
550            *next_id += 1;
551        }
552    }
553
554    /// Tokenize code into structured tokens
555    pub fn tokenize_code(&self, code: &str) -> Result<Vec<CodeToken>> {
556        let language = self.config.language.unwrap_or(Language::JavaScript);
557        let comment_patterns = language.comment_patterns();
558
559        let mut tokens = Vec::new();
560        let mut current_line = 1;
561        let mut current_column = 1;
562        let mut char_indices = code.char_indices().peekable();
563
564        while let Some((start_offset, ch)) = char_indices.next() {
565            let token_start_line = current_line;
566            let token_start_column = current_column;
567
568            // Update position
569            if ch == '\n' {
570                current_line += 1;
571                current_column = 1;
572            } else {
573                current_column += 1;
574            }
575
576            // Skip whitespace (unless preserving)
577            if ch.is_whitespace() {
578                if self.config.preserve_whitespace {
579                    let (text, end_offset) =
580                        self.consume_whitespace(&mut char_indices, start_offset, ch);
581                    tokens.push(CodeToken {
582                        text,
583                        token_type: CodeTokenType::Whitespace,
584                        position: TokenPosition {
585                            line: token_start_line,
586                            column: token_start_column,
587                            start_offset,
588                            end_offset,
589                        },
590                        language,
591                    });
592                }
593                continue;
594            }
595
596            // Handle comments
597            if let Some(token) = self.try_parse_comment(
598                &mut char_indices,
599                start_offset,
600                ch,
601                &comment_patterns,
602                token_start_line,
603                token_start_column,
604                language,
605            )? {
606                if self.config.preserve_comments {
607                    tokens.push(token);
608                }
609                continue;
610            }
611
612            // Handle string literals
613            if ch == '"'
614                || ch == '\''
615                || (ch == '`' && matches!(language, Language::JavaScript | Language::TypeScript))
616            {
617                let token = self.parse_string_literal(
618                    &mut char_indices,
619                    start_offset,
620                    ch,
621                    token_start_line,
622                    token_start_column,
623                    language,
624                )?;
625                tokens.push(token);
626                continue;
627            }
628
629            // Handle numeric literals
630            if ch.is_ascii_digit()
631                || (ch == '.'
632                    && char_indices.peek().map(|(_, c)| c.is_ascii_digit()).unwrap_or(false))
633            {
634                let token = self.parse_numeric_literal(
635                    &mut char_indices,
636                    start_offset,
637                    ch,
638                    token_start_line,
639                    token_start_column,
640                    language,
641                )?;
642                tokens.push(token);
643                continue;
644            }
645
646            // Handle identifiers and keywords
647            if ch.is_alphabetic() || ch == '_' || ch == '$' {
648                let token = self.parse_identifier(
649                    &mut char_indices,
650                    start_offset,
651                    ch,
652                    token_start_line,
653                    token_start_column,
654                    language,
655                )?;
656                tokens.push(token);
657                continue;
658            }
659
660            // Handle operators and punctuation
661            let token = self.parse_operator_or_punctuation(
662                &mut char_indices,
663                start_offset,
664                ch,
665                token_start_line,
666                token_start_column,
667                language,
668            )?;
669            tokens.push(token);
670        }
671
672        Ok(tokens)
673    }
674
675    /// Consume whitespace characters
676    fn consume_whitespace(
677        &self,
678        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
679        start_offset: usize,
680        first_char: char,
681    ) -> (String, usize) {
682        let mut text = String::new();
683        text.push(first_char);
684        let mut end_offset = start_offset;
685
686        while let Some((offset, ch)) = char_indices.peek() {
687            if ch.is_whitespace() {
688                text.push(*ch);
689                end_offset = *offset;
690                char_indices.next();
691            } else {
692                break;
693            }
694        }
695
696        (text, end_offset)
697    }
698
699    /// Try to parse a comment
700    fn try_parse_comment(
701        &self,
702        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
703        start_offset: usize,
704        first_char: char,
705        patterns: &CommentPatterns,
706        token_start_line: usize,
707        token_start_column: usize,
708        language: Language,
709    ) -> Result<Option<CodeToken>> {
710        // Check for line comments
711        if let Some(line_comment) = patterns.line_comment {
712            if first_char
713                == line_comment.chars().next().expect("line_comment pattern must be non-empty")
714            {
715                if let Some(token) = self.try_parse_line_comment(
716                    char_indices,
717                    start_offset,
718                    line_comment,
719                    token_start_line,
720                    token_start_column,
721                    language,
722                )? {
723                    return Ok(Some(token));
724                }
725            }
726        }
727
728        // Check for block comments
729        if let Some((start_delim, end_delim)) = patterns.block_comment {
730            if first_char
731                == start_delim
732                    .chars()
733                    .next()
734                    .expect("block comment start delimiter must be non-empty")
735            {
736                if let Some(token) = self.try_parse_block_comment(
737                    char_indices,
738                    start_offset,
739                    start_delim,
740                    end_delim,
741                    token_start_line,
742                    token_start_column,
743                    language,
744                )? {
745                    return Ok(Some(token));
746                }
747            }
748        }
749
750        Ok(None)
751    }
752
753    /// Parse a line comment
754    fn try_parse_line_comment(
755        &self,
756        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
757        start_offset: usize,
758        comment_start: &str,
759        token_start_line: usize,
760        token_start_column: usize,
761        language: Language,
762    ) -> Result<Option<CodeToken>> {
763        let mut text = String::new();
764        text.push_str(comment_start);
765
766        // Skip the remaining characters of the comment start
767        for _ in 1..comment_start.len() {
768            if let Some((_, ch)) = char_indices.next() {
769                text.push(ch);
770            }
771        }
772
773        // Read until end of line
774        let mut end_offset = start_offset;
775        while let Some((offset, ch)) = char_indices.peek() {
776            if *ch == '\n' {
777                break;
778            }
779            text.push(*ch);
780            end_offset = *offset;
781            char_indices.next();
782        }
783
784        Ok(Some(CodeToken {
785            text,
786            token_type: CodeTokenType::Comment,
787            position: TokenPosition {
788                line: token_start_line,
789                column: token_start_column,
790                start_offset,
791                end_offset,
792            },
793            language,
794        }))
795    }
796
797    /// Parse a block comment
798    fn try_parse_block_comment(
799        &self,
800        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
801        start_offset: usize,
802        start_delim: &str,
803        end_delim: &str,
804        token_start_line: usize,
805        token_start_column: usize,
806        language: Language,
807    ) -> Result<Option<CodeToken>> {
808        let mut text = String::new();
809        text.push_str(start_delim);
810
811        // Skip the remaining characters of the start delimiter
812        for _ in 1..start_delim.len() {
813            if let Some((_, ch)) = char_indices.next() {
814                text.push(ch);
815            }
816        }
817
818        // Read until end delimiter
819        let mut end_offset = start_offset;
820        let end_chars: Vec<char> = end_delim.chars().collect();
821        let mut buffer = Vec::new();
822
823        for (offset, ch) in char_indices.by_ref() {
824            text.push(ch);
825            end_offset = offset;
826            buffer.push(ch);
827
828            // Keep only the last few characters needed to match end delimiter
829            if buffer.len() > end_chars.len() {
830                buffer.remove(0);
831            }
832
833            // Check if we've found the end delimiter
834            if buffer.len() == end_chars.len() && buffer == end_chars {
835                break;
836            }
837        }
838
839        Ok(Some(CodeToken {
840            text,
841            token_type: CodeTokenType::Comment,
842            position: TokenPosition {
843                line: token_start_line,
844                column: token_start_column,
845                start_offset,
846                end_offset,
847            },
848            language,
849        }))
850    }
851
852    /// Parse a string literal
853    fn parse_string_literal(
854        &self,
855        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
856        start_offset: usize,
857        quote_char: char,
858        token_start_line: usize,
859        token_start_column: usize,
860        language: Language,
861    ) -> Result<CodeToken> {
862        let mut text = String::new();
863        text.push(quote_char);
864        let mut end_offset = start_offset;
865        let mut escaped = false;
866
867        for (offset, ch) in char_indices.by_ref() {
868            text.push(ch);
869            end_offset = offset;
870
871            if escaped {
872                escaped = false;
873                continue;
874            }
875
876            if ch == '\\' {
877                escaped = true;
878                continue;
879            }
880
881            if ch == quote_char {
882                break;
883            }
884        }
885
886        Ok(CodeToken {
887            text,
888            token_type: CodeTokenType::String,
889            position: TokenPosition {
890                line: token_start_line,
891                column: token_start_column,
892                start_offset,
893                end_offset,
894            },
895            language,
896        })
897    }
898
899    /// Parse a numeric literal
900    fn parse_numeric_literal(
901        &self,
902        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
903        start_offset: usize,
904        first_char: char,
905        token_start_line: usize,
906        token_start_column: usize,
907        language: Language,
908    ) -> Result<CodeToken> {
909        let mut text = String::new();
910        text.push(first_char);
911        let mut end_offset = start_offset;
912        let mut has_dot = first_char == '.';
913
914        while let Some((offset, ch)) = char_indices.peek() {
915            if ch.is_ascii_digit()
916                || (*ch == '.' && !has_dot)
917                || (*ch == 'e' || *ch == 'E')
918                || (*ch == 'x' || *ch == 'X')
919                || (*ch == '_')
920                || ch.is_ascii_hexdigit()
921            {
922                if *ch == '.' {
923                    has_dot = true;
924                }
925                text.push(*ch);
926                end_offset = *offset;
927                char_indices.next();
928            } else {
929                break;
930            }
931        }
932
933        Ok(CodeToken {
934            text,
935            token_type: CodeTokenType::Number,
936            position: TokenPosition {
937                line: token_start_line,
938                column: token_start_column,
939                start_offset,
940                end_offset,
941            },
942            language,
943        })
944    }
945
946    /// Parse an identifier or keyword
947    fn parse_identifier(
948        &self,
949        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
950        start_offset: usize,
951        first_char: char,
952        token_start_line: usize,
953        token_start_column: usize,
954        language: Language,
955    ) -> Result<CodeToken> {
956        let mut text = String::new();
957        text.push(first_char);
958        let mut end_offset = start_offset;
959
960        while let Some((offset, ch)) = char_indices.peek() {
961            if ch.is_alphanumeric() || *ch == '_' || *ch == '$' {
962                text.push(*ch);
963                end_offset = *offset;
964                char_indices.next();
965            } else {
966                break;
967            }
968        }
969
970        let token_type = if self.keywords.contains(&text) {
971            CodeTokenType::Keyword
972        } else {
973            CodeTokenType::Identifier
974        };
975
976        Ok(CodeToken {
977            text,
978            token_type,
979            position: TokenPosition {
980                line: token_start_line,
981                column: token_start_column,
982                start_offset,
983                end_offset,
984            },
985            language,
986        })
987    }
988
989    /// Parse an operator or punctuation
990    fn parse_operator_or_punctuation(
991        &self,
992        char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
993        start_offset: usize,
994        first_char: char,
995        token_start_line: usize,
996        token_start_column: usize,
997        language: Language,
998    ) -> Result<CodeToken> {
999        let mut text = String::new();
1000        text.push(first_char);
1001        let mut end_offset = start_offset;
1002
1003        // Try to form multi-character operators
1004        let operators = [
1005            "==", "!=", "<=", ">=", "&&", "||", "++", "--", "+=", "-=", "*=", "/=", "%=", "<<",
1006            ">>", "::", "->", "=>", "**", "//", "...", "..", ":=", "<=>",
1007        ];
1008
1009        for op in &operators {
1010            if op.starts_with(first_char) && op.len() > 1 {
1011                let chars = op.chars().skip(1);
1012                let mut matched = true;
1013                let mut lookahead = Vec::new();
1014
1015                for expected_char in chars {
1016                    if let Some((offset, ch)) = char_indices.peek() {
1017                        if *ch == expected_char {
1018                            lookahead.push((*offset, *ch));
1019                            char_indices.next();
1020                        } else {
1021                            matched = false;
1022                            break;
1023                        }
1024                    } else {
1025                        matched = false;
1026                        break;
1027                    }
1028                }
1029
1030                if matched {
1031                    text = op.to_string();
1032                    if let Some((offset, _)) = lookahead.last() {
1033                        end_offset = *offset;
1034                    }
1035                    break;
1036                } else {
1037                    // Put back the consumed characters
1038                    for (_, _ch) in lookahead.into_iter().rev() {
1039                        // Note: This is a simplified approach. In a real implementation,
1040                        // you'd need a more sophisticated way to put back characters.
1041                    }
1042                }
1043            }
1044        }
1045
1046        let token_type = match first_char {
1047            '(' | ')' | '[' | ']' | '{' | '}' | ';' | ',' | '.' | ':' => CodeTokenType::Punctuation,
1048            _ => CodeTokenType::Operator,
1049        };
1050
1051        Ok(CodeToken {
1052            text,
1053            token_type,
1054            position: TokenPosition {
1055                line: token_start_line,
1056                column: token_start_column,
1057                start_offset,
1058                end_offset,
1059            },
1060            language,
1061        })
1062    }
1063
1064    /// Get or create token ID
1065    #[allow(dead_code)]
1066    fn get_or_create_token_id(&mut self, token: &str) -> u32 {
1067        if let Some(&id) = self.token_to_id.get(token) {
1068            id
1069        } else {
1070            let id = self.token_to_id.len() as u32;
1071            self.token_to_id.insert(token.to_string(), id);
1072            self.id_to_token.insert(id, token.to_string());
1073            id
1074        }
1075    }
1076
1077    /// Get vocabulary size
1078    pub fn vocab_size(&self) -> usize {
1079        self.token_to_id.len()
1080    }
1081
1082    /// Get token ID
1083    pub fn token_to_id(&self, token: &str) -> Option<u32> {
1084        self.token_to_id.get(token).copied()
1085    }
1086
1087    /// Get token from ID
1088    pub fn id_to_token(&self, id: u32) -> Option<String> {
1089        self.id_to_token.get(&id).cloned()
1090    }
1091}
1092
1093impl Tokenizer for CodeTokenizer {
1094    fn encode(&self, text: &str) -> Result<TokenizedInput> {
1095        let code_tokens = self.tokenize_code(text)?;
1096        let mut input_ids = Vec::new();
1097
1098        for token in code_tokens {
1099            let token_text = if self.config.normalize_identifiers
1100                && token.token_type == CodeTokenType::Identifier
1101            {
1102                "[IDENTIFIER]".to_string()
1103            } else {
1104                token.text
1105            };
1106
1107            if let Some(id) = self.token_to_id(&token_text) {
1108                input_ids.push(id);
1109            } else if let Some(&unk_id) = self.special_tokens.get("[UNK]") {
1110                input_ids.push(unk_id);
1111            }
1112        }
1113
1114        let attention_mask = vec![1u8; input_ids.len()];
1115
1116        Ok(TokenizedInput {
1117            input_ids,
1118            attention_mask,
1119            token_type_ids: None,
1120            special_tokens_mask: None,
1121            offset_mapping: None,
1122            overflowing_tokens: None,
1123        })
1124    }
1125
1126    fn decode(&self, ids: &[u32]) -> Result<String> {
1127        let tokens: Vec<String> = ids.iter().filter_map(|&id| self.id_to_token(id)).collect();
1128        Ok(tokens.join(" "))
1129    }
1130
1131    fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
1132        let combined = format!("{}\n{}", text_a, text_b);
1133        self.encode(&combined)
1134    }
1135
1136    fn vocab_size(&self) -> usize {
1137        self.token_to_id.len()
1138    }
1139
1140    fn get_vocab(&self) -> HashMap<String, u32> {
1141        self.token_to_id.clone()
1142    }
1143
1144    fn token_to_id(&self, token: &str) -> Option<u32> {
1145        self.token_to_id.get(token).copied()
1146    }
1147
1148    fn id_to_token(&self, id: u32) -> Option<String> {
1149        self.id_to_token.get(&id).cloned()
1150    }
1151}
1152
1153#[cfg(test)]
1154mod tests {
1155    use super::*;
1156
1157    #[test]
1158    fn test_language_detection() {
1159        assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
1160        assert_eq!(Language::from_extension("py"), Some(Language::Python));
1161        assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
1162        assert_eq!(Language::from_extension("unknown"), None);
1163    }
1164
1165    #[test]
1166    fn test_rust_tokenization() {
1167        let tokenizer = CodeTokenizer::for_language(Language::Rust);
1168        let code = "fn main() { let x = 42; }";
1169        let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1170
1171        assert!(!tokens.is_empty());
1172
1173        // Check for keywords
1174        let fn_token = tokens.iter().find(|t| t.text == "fn").expect("Operation failed in test");
1175        assert_eq!(fn_token.token_type, CodeTokenType::Keyword);
1176
1177        let let_token = tokens.iter().find(|t| t.text == "let").expect("Operation failed in test");
1178        assert_eq!(let_token.token_type, CodeTokenType::Keyword);
1179    }
1180
1181    #[test]
1182    fn test_string_literal_parsing() {
1183        let tokenizer = CodeTokenizer::for_language(Language::JavaScript);
1184        let code = r#"let name = "Hello \"World\"";"#;
1185        let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1186
1187        let string_token = tokens
1188            .iter()
1189            .find(|t| t.token_type == CodeTokenType::String)
1190            .expect("Operation failed in test");
1191        assert!(string_token.text.starts_with('"'));
1192        assert!(string_token.text.ends_with('"'));
1193    }
1194
1195    #[test]
1196    fn test_comment_parsing() {
1197        let config = CodeTokenizerConfig {
1198            language: Some(Language::Rust),
1199            preserve_comments: true,
1200            ..Default::default()
1201        };
1202        let tokenizer = CodeTokenizer::new(config);
1203        let code = "// This is a comment\nfn main() {}";
1204        let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1205
1206        let comment_token = tokens
1207            .iter()
1208            .find(|t| t.token_type == CodeTokenType::Comment)
1209            .expect("Operation failed in test");
1210        assert!(comment_token.text.starts_with("//"));
1211    }
1212
1213    #[test]
1214    fn test_numeric_literals() {
1215        let tokenizer = CodeTokenizer::for_language(Language::Python);
1216        let code = "x = 42; y = 3.14; z = 0xFF;";
1217        let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1218
1219        let numeric_tokens: Vec<_> =
1220            tokens.iter().filter(|t| t.token_type == CodeTokenType::Number).collect();
1221
1222        assert!(numeric_tokens.len() >= 3);
1223    }
1224
1225    #[test]
1226    fn test_code_tokenizer_encode() {
1227        let tokenizer = CodeTokenizer::for_language(Language::Python);
1228        let code = "def hello(): return 42";
1229        let result = tokenizer.encode(code).expect("Encoding failed");
1230
1231        assert!(!result.input_ids.is_empty());
1232        assert_eq!(result.input_ids.len(), result.attention_mask.len());
1233    }
1234}