1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet};
3use trustformers_core::errors::Result;
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub enum Language {
9 Rust,
10 Python,
11 JavaScript,
12 TypeScript,
13 Java,
14 CSharp,
15 CPlusPlus,
16 C,
17 Go,
18 Ruby,
19 PHP,
20 Swift,
21 Kotlin,
22 Scala,
23 Haskell,
24 Clojure,
25 SQL,
26 HTML,
27 CSS,
28 JSON,
29 XML,
30 YAML,
31 Markdown,
32 Shell,
33 PowerShell,
34 R,
35 Matlab,
36}
37
38impl Language {
39 pub fn extensions(&self) -> &'static [&'static str] {
41 match self {
42 Language::Rust => &["rs"],
43 Language::Python => &["py", "pyx", "pyi", "pyw"],
44 Language::JavaScript => &["js", "jsx", "mjs", "cjs"],
45 Language::TypeScript => &["ts", "tsx", "d.ts"],
46 Language::Java => &["java"],
47 Language::CSharp => &["cs"],
48 Language::CPlusPlus => &["cpp", "cxx", "cc", "hpp", "hxx", "hh"],
49 Language::C => &["c", "h"],
50 Language::Go => &["go"],
51 Language::Ruby => &["rb", "rbx", "rjs", "gemspec"],
52 Language::PHP => &["php", "phtml", "php3", "php4", "php5"],
53 Language::Swift => &["swift"],
54 Language::Kotlin => &["kt", "kts"],
55 Language::Scala => &["scala", "sc"],
56 Language::Haskell => &["hs", "lhs"],
57 Language::Clojure => &["clj", "cljs", "cljc", "edn"],
58 Language::SQL => &["sql"],
59 Language::HTML => &["html", "htm", "xhtml"],
60 Language::CSS => &["css", "scss", "sass", "less"],
61 Language::JSON => &["json", "jsonl", "ndjson"],
62 Language::XML => &["xml", "xsd", "xsl", "xslt"],
63 Language::YAML => &["yaml", "yml"],
64 Language::Markdown => &["md", "markdown", "mdown", "mkd"],
65 Language::Shell => &["sh", "bash", "zsh", "fish"],
66 Language::PowerShell => &["ps1", "psm1", "psd1"],
67 Language::R => &["r", "R"],
68 Language::Matlab => &["m"],
69 }
70 }
71
72 pub fn keywords(&self) -> &'static [&'static str] {
74 match self {
75 Language::Rust => &[
76 "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false",
77 "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut",
78 "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
79 "true", "type", "unsafe", "use", "where", "while", "async", "await", "dyn",
80 ],
81 Language::Python => &[
82 "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
83 "continue", "def", "del", "elif", "else", "except", "finally", "for", "from",
84 "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass",
85 "raise", "return", "try", "while", "with", "yield",
86 ],
87 Language::JavaScript | Language::TypeScript => &[
88 "break",
89 "case",
90 "catch",
91 "class",
92 "const",
93 "continue",
94 "debugger",
95 "default",
96 "delete",
97 "do",
98 "else",
99 "export",
100 "extends",
101 "false",
102 "finally",
103 "for",
104 "function",
105 "if",
106 "import",
107 "in",
108 "instanceof",
109 "new",
110 "null",
111 "return",
112 "super",
113 "switch",
114 "this",
115 "throw",
116 "true",
117 "try",
118 "typeof",
119 "var",
120 "void",
121 "while",
122 "with",
123 "yield",
124 "let",
125 "static",
126 "enum",
127 "implements",
128 "package",
129 "protected",
130 "interface",
131 "private",
132 "public",
133 "async",
134 "await",
135 ],
136 Language::Java => &[
137 "abstract",
138 "assert",
139 "boolean",
140 "break",
141 "byte",
142 "case",
143 "catch",
144 "char",
145 "class",
146 "const",
147 "continue",
148 "default",
149 "do",
150 "double",
151 "else",
152 "enum",
153 "extends",
154 "final",
155 "finally",
156 "float",
157 "for",
158 "goto",
159 "if",
160 "implements",
161 "import",
162 "instanceof",
163 "int",
164 "interface",
165 "long",
166 "native",
167 "new",
168 "package",
169 "private",
170 "protected",
171 "public",
172 "return",
173 "short",
174 "static",
175 "strictfp",
176 "super",
177 "switch",
178 "synchronized",
179 "this",
180 "throw",
181 "throws",
182 "transient",
183 "try",
184 "void",
185 "volatile",
186 "while",
187 ],
188 Language::CSharp => &[
189 "abstract",
190 "as",
191 "base",
192 "bool",
193 "break",
194 "byte",
195 "case",
196 "catch",
197 "char",
198 "checked",
199 "class",
200 "const",
201 "continue",
202 "decimal",
203 "default",
204 "delegate",
205 "do",
206 "double",
207 "else",
208 "enum",
209 "event",
210 "explicit",
211 "extern",
212 "false",
213 "finally",
214 "fixed",
215 "float",
216 "for",
217 "foreach",
218 "goto",
219 "if",
220 "implicit",
221 "in",
222 "int",
223 "interface",
224 "internal",
225 "is",
226 "lock",
227 "long",
228 "namespace",
229 "new",
230 "null",
231 "object",
232 "operator",
233 "out",
234 "override",
235 "params",
236 "private",
237 "protected",
238 "public",
239 "readonly",
240 "ref",
241 "return",
242 "sbyte",
243 "sealed",
244 "short",
245 "sizeof",
246 "stackalloc",
247 "static",
248 "string",
249 "struct",
250 "switch",
251 "this",
252 "throw",
253 "true",
254 "try",
255 "typeof",
256 "uint",
257 "ulong",
258 "unchecked",
259 "unsafe",
260 "ushort",
261 "using",
262 "virtual",
263 "void",
264 "volatile",
265 "while",
266 ],
267 Language::Go => &[
268 "break",
269 "case",
270 "chan",
271 "const",
272 "continue",
273 "default",
274 "defer",
275 "else",
276 "fallthrough",
277 "for",
278 "func",
279 "go",
280 "goto",
281 "if",
282 "import",
283 "interface",
284 "map",
285 "package",
286 "range",
287 "return",
288 "select",
289 "struct",
290 "switch",
291 "type",
292 "var",
293 ],
294 _ => &[], }
296 }
297
298 pub fn comment_patterns(&self) -> CommentPatterns {
300 match self {
301 Language::Rust
302 | Language::JavaScript
303 | Language::TypeScript
304 | Language::Java
305 | Language::CSharp
306 | Language::CPlusPlus
307 | Language::Go
308 | Language::Swift
309 | Language::Kotlin
310 | Language::Scala => CommentPatterns {
311 line_comment: Some("//"),
312 block_comment: Some(("/*", "*/")),
313 doc_comment: Some("///"),
314 },
315 Language::Python | Language::Ruby | Language::Shell => CommentPatterns {
316 line_comment: Some("#"),
317 block_comment: None,
318 doc_comment: Some("#"),
319 },
320 Language::C => CommentPatterns {
321 line_comment: None,
322 block_comment: Some(("/*", "*/")),
323 doc_comment: None,
324 },
325 Language::HTML | Language::XML => CommentPatterns {
326 line_comment: None,
327 block_comment: Some(("<!--", "-->")),
328 doc_comment: None,
329 },
330 Language::CSS => CommentPatterns {
331 line_comment: None,
332 block_comment: Some(("/*", "*/")),
333 doc_comment: None,
334 },
335 Language::SQL => CommentPatterns {
336 line_comment: Some("--"),
337 block_comment: Some(("/*", "*/")),
338 doc_comment: None,
339 },
340 Language::Haskell => CommentPatterns {
341 line_comment: Some("--"),
342 block_comment: Some(("{-", "-}")),
343 doc_comment: Some("-- |"),
344 },
345 _ => CommentPatterns {
346 line_comment: None,
347 block_comment: None,
348 doc_comment: None,
349 },
350 }
351 }
352
353 pub fn from_extension(ext: &str) -> Option<Language> {
355 let ext = ext.to_lowercase();
356 [
357 Language::Rust,
358 Language::Python,
359 Language::JavaScript,
360 Language::TypeScript,
361 Language::Java,
362 Language::CSharp,
363 Language::CPlusPlus,
364 Language::C,
365 Language::Go,
366 Language::Ruby,
367 Language::PHP,
368 Language::Swift,
369 Language::Kotlin,
370 Language::Scala,
371 Language::Haskell,
372 Language::Clojure,
373 Language::SQL,
374 Language::HTML,
375 Language::CSS,
376 Language::JSON,
377 Language::XML,
378 Language::YAML,
379 Language::Markdown,
380 Language::Shell,
381 Language::PowerShell,
382 Language::R,
383 Language::Matlab,
384 ]
385 .into_iter()
386 .find(|&lang| lang.extensions().contains(&ext.as_str()))
387 }
388}
389
390#[derive(Debug, Clone)]
392pub struct CommentPatterns {
393 pub line_comment: Option<&'static str>,
394 pub block_comment: Option<(&'static str, &'static str)>,
395 pub doc_comment: Option<&'static str>,
396}
397
398#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
400pub enum CodeTokenType {
401 Keyword,
402 Identifier,
403 Literal(LiteralType),
404 Operator,
405 Punctuation,
406 Comment,
407 Whitespace,
408 String,
409 Number,
410 Unknown,
411}
412
413#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
415pub enum LiteralType {
416 String,
417 Character,
418 Integer,
419 Float,
420 Boolean,
421 Null,
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct CodeToken {
427 pub text: String,
428 pub token_type: CodeTokenType,
429 pub position: TokenPosition,
430 pub language: Language,
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct TokenPosition {
436 pub line: usize,
437 pub column: usize,
438 pub start_offset: usize,
439 pub end_offset: usize,
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct CodeTokenizerConfig {
445 pub language: Option<Language>,
446 pub preserve_whitespace: bool,
447 pub preserve_comments: bool,
448 pub include_position_info: bool,
449 pub normalize_identifiers: bool,
450 pub max_token_length: Option<usize>,
451 pub custom_keywords: Option<HashSet<String>>,
452}
453
454impl Default for CodeTokenizerConfig {
455 fn default() -> Self {
456 Self {
457 language: None,
458 preserve_whitespace: false,
459 preserve_comments: true,
460 include_position_info: false,
461 normalize_identifiers: false,
462 max_token_length: Some(128),
463 custom_keywords: None,
464 }
465 }
466}
467
468pub struct CodeTokenizer {
470 config: CodeTokenizerConfig,
471 keywords: HashSet<String>,
472 token_to_id: HashMap<String, u32>,
473 id_to_token: HashMap<u32, String>,
474 special_tokens: HashMap<String, u32>,
475}
476
477impl CodeTokenizer {
478 pub fn new(config: CodeTokenizerConfig) -> Self {
480 let mut tokenizer = Self {
481 config,
482 keywords: HashSet::new(),
483 token_to_id: HashMap::new(),
484 id_to_token: HashMap::new(),
485 special_tokens: HashMap::new(),
486 };
487
488 tokenizer.initialize_vocabulary();
489 tokenizer
490 }
491
492 pub fn for_language(language: Language) -> Self {
494 let config = CodeTokenizerConfig {
495 language: Some(language),
496 ..Default::default()
497 };
498 Self::new(config)
499 }
500
501 fn initialize_vocabulary(&mut self) {
503 let mut next_id = 0u32;
504
505 for special in &[
507 "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]",
508 ] {
509 self.add_token(special, &mut next_id);
510 self.special_tokens.insert(special.to_string(), next_id - 1);
511 }
512
513 if let Some(language) = self.config.language {
515 for keyword in language.keywords() {
516 self.keywords.insert(keyword.to_string());
517 self.add_token(keyword, &mut next_id);
518 }
519 }
520
521 if let Some(custom_keywords) = &self.config.custom_keywords {
523 let keywords_to_add: Vec<String> = custom_keywords.iter().cloned().collect();
524 for keyword in keywords_to_add {
525 self.keywords.insert(keyword.clone());
526 self.add_token(&keyword, &mut next_id);
527 }
528 }
529
530 for op in &[
532 "+", "-", "*", "/", "%", "=", "==", "!=", "<", ">", "<=", ">=", "&&", "||", "!", "&",
533 "|", "^", "~", "<<", ">>", "++", "--", "+=", "-=", "*=", "/=", "%=", "(", ")", "[",
534 "]", "{", "}", ";", ",", ".", ":", "::", "->", "=>", "?",
535 ] {
536 self.add_token(op, &mut next_id);
537 }
538
539 for literal in &["true", "false", "null", "undefined", "nil", "None"] {
541 self.add_token(literal, &mut next_id);
542 }
543 }
544
545 fn add_token(&mut self, token: &str, next_id: &mut u32) {
547 if !self.token_to_id.contains_key(token) {
548 self.token_to_id.insert(token.to_string(), *next_id);
549 self.id_to_token.insert(*next_id, token.to_string());
550 *next_id += 1;
551 }
552 }
553
554 pub fn tokenize_code(&self, code: &str) -> Result<Vec<CodeToken>> {
556 let language = self.config.language.unwrap_or(Language::JavaScript);
557 let comment_patterns = language.comment_patterns();
558
559 let mut tokens = Vec::new();
560 let mut current_line = 1;
561 let mut current_column = 1;
562 let mut char_indices = code.char_indices().peekable();
563
564 while let Some((start_offset, ch)) = char_indices.next() {
565 let token_start_line = current_line;
566 let token_start_column = current_column;
567
568 if ch == '\n' {
570 current_line += 1;
571 current_column = 1;
572 } else {
573 current_column += 1;
574 }
575
576 if ch.is_whitespace() {
578 if self.config.preserve_whitespace {
579 let (text, end_offset) =
580 self.consume_whitespace(&mut char_indices, start_offset, ch);
581 tokens.push(CodeToken {
582 text,
583 token_type: CodeTokenType::Whitespace,
584 position: TokenPosition {
585 line: token_start_line,
586 column: token_start_column,
587 start_offset,
588 end_offset,
589 },
590 language,
591 });
592 }
593 continue;
594 }
595
596 if let Some(token) = self.try_parse_comment(
598 &mut char_indices,
599 start_offset,
600 ch,
601 &comment_patterns,
602 token_start_line,
603 token_start_column,
604 language,
605 )? {
606 if self.config.preserve_comments {
607 tokens.push(token);
608 }
609 continue;
610 }
611
612 if ch == '"'
614 || ch == '\''
615 || (ch == '`' && matches!(language, Language::JavaScript | Language::TypeScript))
616 {
617 let token = self.parse_string_literal(
618 &mut char_indices,
619 start_offset,
620 ch,
621 token_start_line,
622 token_start_column,
623 language,
624 )?;
625 tokens.push(token);
626 continue;
627 }
628
629 if ch.is_ascii_digit()
631 || (ch == '.'
632 && char_indices.peek().map(|(_, c)| c.is_ascii_digit()).unwrap_or(false))
633 {
634 let token = self.parse_numeric_literal(
635 &mut char_indices,
636 start_offset,
637 ch,
638 token_start_line,
639 token_start_column,
640 language,
641 )?;
642 tokens.push(token);
643 continue;
644 }
645
646 if ch.is_alphabetic() || ch == '_' || ch == '$' {
648 let token = self.parse_identifier(
649 &mut char_indices,
650 start_offset,
651 ch,
652 token_start_line,
653 token_start_column,
654 language,
655 )?;
656 tokens.push(token);
657 continue;
658 }
659
660 let token = self.parse_operator_or_punctuation(
662 &mut char_indices,
663 start_offset,
664 ch,
665 token_start_line,
666 token_start_column,
667 language,
668 )?;
669 tokens.push(token);
670 }
671
672 Ok(tokens)
673 }
674
675 fn consume_whitespace(
677 &self,
678 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
679 start_offset: usize,
680 first_char: char,
681 ) -> (String, usize) {
682 let mut text = String::new();
683 text.push(first_char);
684 let mut end_offset = start_offset;
685
686 while let Some((offset, ch)) = char_indices.peek() {
687 if ch.is_whitespace() {
688 text.push(*ch);
689 end_offset = *offset;
690 char_indices.next();
691 } else {
692 break;
693 }
694 }
695
696 (text, end_offset)
697 }
698
699 fn try_parse_comment(
701 &self,
702 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
703 start_offset: usize,
704 first_char: char,
705 patterns: &CommentPatterns,
706 token_start_line: usize,
707 token_start_column: usize,
708 language: Language,
709 ) -> Result<Option<CodeToken>> {
710 if let Some(line_comment) = patterns.line_comment {
712 if first_char
713 == line_comment.chars().next().expect("line_comment pattern must be non-empty")
714 {
715 if let Some(token) = self.try_parse_line_comment(
716 char_indices,
717 start_offset,
718 line_comment,
719 token_start_line,
720 token_start_column,
721 language,
722 )? {
723 return Ok(Some(token));
724 }
725 }
726 }
727
728 if let Some((start_delim, end_delim)) = patterns.block_comment {
730 if first_char
731 == start_delim
732 .chars()
733 .next()
734 .expect("block comment start delimiter must be non-empty")
735 {
736 if let Some(token) = self.try_parse_block_comment(
737 char_indices,
738 start_offset,
739 start_delim,
740 end_delim,
741 token_start_line,
742 token_start_column,
743 language,
744 )? {
745 return Ok(Some(token));
746 }
747 }
748 }
749
750 Ok(None)
751 }
752
753 fn try_parse_line_comment(
755 &self,
756 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
757 start_offset: usize,
758 comment_start: &str,
759 token_start_line: usize,
760 token_start_column: usize,
761 language: Language,
762 ) -> Result<Option<CodeToken>> {
763 let mut text = String::new();
764 text.push_str(comment_start);
765
766 for _ in 1..comment_start.len() {
768 if let Some((_, ch)) = char_indices.next() {
769 text.push(ch);
770 }
771 }
772
773 let mut end_offset = start_offset;
775 while let Some((offset, ch)) = char_indices.peek() {
776 if *ch == '\n' {
777 break;
778 }
779 text.push(*ch);
780 end_offset = *offset;
781 char_indices.next();
782 }
783
784 Ok(Some(CodeToken {
785 text,
786 token_type: CodeTokenType::Comment,
787 position: TokenPosition {
788 line: token_start_line,
789 column: token_start_column,
790 start_offset,
791 end_offset,
792 },
793 language,
794 }))
795 }
796
797 fn try_parse_block_comment(
799 &self,
800 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
801 start_offset: usize,
802 start_delim: &str,
803 end_delim: &str,
804 token_start_line: usize,
805 token_start_column: usize,
806 language: Language,
807 ) -> Result<Option<CodeToken>> {
808 let mut text = String::new();
809 text.push_str(start_delim);
810
811 for _ in 1..start_delim.len() {
813 if let Some((_, ch)) = char_indices.next() {
814 text.push(ch);
815 }
816 }
817
818 let mut end_offset = start_offset;
820 let end_chars: Vec<char> = end_delim.chars().collect();
821 let mut buffer = Vec::new();
822
823 for (offset, ch) in char_indices.by_ref() {
824 text.push(ch);
825 end_offset = offset;
826 buffer.push(ch);
827
828 if buffer.len() > end_chars.len() {
830 buffer.remove(0);
831 }
832
833 if buffer.len() == end_chars.len() && buffer == end_chars {
835 break;
836 }
837 }
838
839 Ok(Some(CodeToken {
840 text,
841 token_type: CodeTokenType::Comment,
842 position: TokenPosition {
843 line: token_start_line,
844 column: token_start_column,
845 start_offset,
846 end_offset,
847 },
848 language,
849 }))
850 }
851
852 fn parse_string_literal(
854 &self,
855 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
856 start_offset: usize,
857 quote_char: char,
858 token_start_line: usize,
859 token_start_column: usize,
860 language: Language,
861 ) -> Result<CodeToken> {
862 let mut text = String::new();
863 text.push(quote_char);
864 let mut end_offset = start_offset;
865 let mut escaped = false;
866
867 for (offset, ch) in char_indices.by_ref() {
868 text.push(ch);
869 end_offset = offset;
870
871 if escaped {
872 escaped = false;
873 continue;
874 }
875
876 if ch == '\\' {
877 escaped = true;
878 continue;
879 }
880
881 if ch == quote_char {
882 break;
883 }
884 }
885
886 Ok(CodeToken {
887 text,
888 token_type: CodeTokenType::String,
889 position: TokenPosition {
890 line: token_start_line,
891 column: token_start_column,
892 start_offset,
893 end_offset,
894 },
895 language,
896 })
897 }
898
899 fn parse_numeric_literal(
901 &self,
902 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
903 start_offset: usize,
904 first_char: char,
905 token_start_line: usize,
906 token_start_column: usize,
907 language: Language,
908 ) -> Result<CodeToken> {
909 let mut text = String::new();
910 text.push(first_char);
911 let mut end_offset = start_offset;
912 let mut has_dot = first_char == '.';
913
914 while let Some((offset, ch)) = char_indices.peek() {
915 if ch.is_ascii_digit()
916 || (*ch == '.' && !has_dot)
917 || (*ch == 'e' || *ch == 'E')
918 || (*ch == 'x' || *ch == 'X')
919 || (*ch == '_')
920 || ch.is_ascii_hexdigit()
921 {
922 if *ch == '.' {
923 has_dot = true;
924 }
925 text.push(*ch);
926 end_offset = *offset;
927 char_indices.next();
928 } else {
929 break;
930 }
931 }
932
933 Ok(CodeToken {
934 text,
935 token_type: CodeTokenType::Number,
936 position: TokenPosition {
937 line: token_start_line,
938 column: token_start_column,
939 start_offset,
940 end_offset,
941 },
942 language,
943 })
944 }
945
946 fn parse_identifier(
948 &self,
949 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
950 start_offset: usize,
951 first_char: char,
952 token_start_line: usize,
953 token_start_column: usize,
954 language: Language,
955 ) -> Result<CodeToken> {
956 let mut text = String::new();
957 text.push(first_char);
958 let mut end_offset = start_offset;
959
960 while let Some((offset, ch)) = char_indices.peek() {
961 if ch.is_alphanumeric() || *ch == '_' || *ch == '$' {
962 text.push(*ch);
963 end_offset = *offset;
964 char_indices.next();
965 } else {
966 break;
967 }
968 }
969
970 let token_type = if self.keywords.contains(&text) {
971 CodeTokenType::Keyword
972 } else {
973 CodeTokenType::Identifier
974 };
975
976 Ok(CodeToken {
977 text,
978 token_type,
979 position: TokenPosition {
980 line: token_start_line,
981 column: token_start_column,
982 start_offset,
983 end_offset,
984 },
985 language,
986 })
987 }
988
989 fn parse_operator_or_punctuation(
991 &self,
992 char_indices: &mut std::iter::Peekable<std::str::CharIndices>,
993 start_offset: usize,
994 first_char: char,
995 token_start_line: usize,
996 token_start_column: usize,
997 language: Language,
998 ) -> Result<CodeToken> {
999 let mut text = String::new();
1000 text.push(first_char);
1001 let mut end_offset = start_offset;
1002
1003 let operators = [
1005 "==", "!=", "<=", ">=", "&&", "||", "++", "--", "+=", "-=", "*=", "/=", "%=", "<<",
1006 ">>", "::", "->", "=>", "**", "//", "...", "..", ":=", "<=>",
1007 ];
1008
1009 for op in &operators {
1010 if op.starts_with(first_char) && op.len() > 1 {
1011 let chars = op.chars().skip(1);
1012 let mut matched = true;
1013 let mut lookahead = Vec::new();
1014
1015 for expected_char in chars {
1016 if let Some((offset, ch)) = char_indices.peek() {
1017 if *ch == expected_char {
1018 lookahead.push((*offset, *ch));
1019 char_indices.next();
1020 } else {
1021 matched = false;
1022 break;
1023 }
1024 } else {
1025 matched = false;
1026 break;
1027 }
1028 }
1029
1030 if matched {
1031 text = op.to_string();
1032 if let Some((offset, _)) = lookahead.last() {
1033 end_offset = *offset;
1034 }
1035 break;
1036 } else {
1037 for (_, _ch) in lookahead.into_iter().rev() {
1039 }
1042 }
1043 }
1044 }
1045
1046 let token_type = match first_char {
1047 '(' | ')' | '[' | ']' | '{' | '}' | ';' | ',' | '.' | ':' => CodeTokenType::Punctuation,
1048 _ => CodeTokenType::Operator,
1049 };
1050
1051 Ok(CodeToken {
1052 text,
1053 token_type,
1054 position: TokenPosition {
1055 line: token_start_line,
1056 column: token_start_column,
1057 start_offset,
1058 end_offset,
1059 },
1060 language,
1061 })
1062 }
1063
1064 #[allow(dead_code)]
1066 fn get_or_create_token_id(&mut self, token: &str) -> u32 {
1067 if let Some(&id) = self.token_to_id.get(token) {
1068 id
1069 } else {
1070 let id = self.token_to_id.len() as u32;
1071 self.token_to_id.insert(token.to_string(), id);
1072 self.id_to_token.insert(id, token.to_string());
1073 id
1074 }
1075 }
1076
1077 pub fn vocab_size(&self) -> usize {
1079 self.token_to_id.len()
1080 }
1081
1082 pub fn token_to_id(&self, token: &str) -> Option<u32> {
1084 self.token_to_id.get(token).copied()
1085 }
1086
1087 pub fn id_to_token(&self, id: u32) -> Option<String> {
1089 self.id_to_token.get(&id).cloned()
1090 }
1091}
1092
1093impl Tokenizer for CodeTokenizer {
1094 fn encode(&self, text: &str) -> Result<TokenizedInput> {
1095 let code_tokens = self.tokenize_code(text)?;
1096 let mut input_ids = Vec::new();
1097
1098 for token in code_tokens {
1099 let token_text = if self.config.normalize_identifiers
1100 && token.token_type == CodeTokenType::Identifier
1101 {
1102 "[IDENTIFIER]".to_string()
1103 } else {
1104 token.text
1105 };
1106
1107 if let Some(id) = self.token_to_id(&token_text) {
1108 input_ids.push(id);
1109 } else if let Some(&unk_id) = self.special_tokens.get("[UNK]") {
1110 input_ids.push(unk_id);
1111 }
1112 }
1113
1114 let attention_mask = vec![1u8; input_ids.len()];
1115
1116 Ok(TokenizedInput {
1117 input_ids,
1118 attention_mask,
1119 token_type_ids: None,
1120 special_tokens_mask: None,
1121 offset_mapping: None,
1122 overflowing_tokens: None,
1123 })
1124 }
1125
1126 fn decode(&self, ids: &[u32]) -> Result<String> {
1127 let tokens: Vec<String> = ids.iter().filter_map(|&id| self.id_to_token(id)).collect();
1128 Ok(tokens.join(" "))
1129 }
1130
1131 fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
1132 let combined = format!("{}\n{}", text_a, text_b);
1133 self.encode(&combined)
1134 }
1135
1136 fn vocab_size(&self) -> usize {
1137 self.token_to_id.len()
1138 }
1139
1140 fn get_vocab(&self) -> HashMap<String, u32> {
1141 self.token_to_id.clone()
1142 }
1143
1144 fn token_to_id(&self, token: &str) -> Option<u32> {
1145 self.token_to_id.get(token).copied()
1146 }
1147
1148 fn id_to_token(&self, id: u32) -> Option<String> {
1149 self.id_to_token.get(&id).cloned()
1150 }
1151}
1152
1153#[cfg(test)]
1154mod tests {
1155 use super::*;
1156
1157 #[test]
1158 fn test_language_detection() {
1159 assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
1160 assert_eq!(Language::from_extension("py"), Some(Language::Python));
1161 assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
1162 assert_eq!(Language::from_extension("unknown"), None);
1163 }
1164
1165 #[test]
1166 fn test_rust_tokenization() {
1167 let tokenizer = CodeTokenizer::for_language(Language::Rust);
1168 let code = "fn main() { let x = 42; }";
1169 let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1170
1171 assert!(!tokens.is_empty());
1172
1173 let fn_token = tokens.iter().find(|t| t.text == "fn").expect("Operation failed in test");
1175 assert_eq!(fn_token.token_type, CodeTokenType::Keyword);
1176
1177 let let_token = tokens.iter().find(|t| t.text == "let").expect("Operation failed in test");
1178 assert_eq!(let_token.token_type, CodeTokenType::Keyword);
1179 }
1180
1181 #[test]
1182 fn test_string_literal_parsing() {
1183 let tokenizer = CodeTokenizer::for_language(Language::JavaScript);
1184 let code = r#"let name = "Hello \"World\"";"#;
1185 let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1186
1187 let string_token = tokens
1188 .iter()
1189 .find(|t| t.token_type == CodeTokenType::String)
1190 .expect("Operation failed in test");
1191 assert!(string_token.text.starts_with('"'));
1192 assert!(string_token.text.ends_with('"'));
1193 }
1194
1195 #[test]
1196 fn test_comment_parsing() {
1197 let config = CodeTokenizerConfig {
1198 language: Some(Language::Rust),
1199 preserve_comments: true,
1200 ..Default::default()
1201 };
1202 let tokenizer = CodeTokenizer::new(config);
1203 let code = "// This is a comment\nfn main() {}";
1204 let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1205
1206 let comment_token = tokens
1207 .iter()
1208 .find(|t| t.token_type == CodeTokenType::Comment)
1209 .expect("Operation failed in test");
1210 assert!(comment_token.text.starts_with("//"));
1211 }
1212
1213 #[test]
1214 fn test_numeric_literals() {
1215 let tokenizer = CodeTokenizer::for_language(Language::Python);
1216 let code = "x = 42; y = 3.14; z = 0xFF;";
1217 let tokens = tokenizer.tokenize_code(code).expect("Operation failed in test");
1218
1219 let numeric_tokens: Vec<_> =
1220 tokens.iter().filter(|t| t.token_type == CodeTokenType::Number).collect();
1221
1222 assert!(numeric_tokens.len() >= 3);
1223 }
1224
1225 #[test]
1226 fn test_code_tokenizer_encode() {
1227 let tokenizer = CodeTokenizer::for_language(Language::Python);
1228 let code = "def hello(): return 42";
1229 let result = tokenizer.encode(code).expect("Encoding failed");
1230
1231 assert!(!result.input_ids.is_empty());
1232 assert_eq!(result.input_ids.len(), result.attention_mask.len());
1233 }
1234}