1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CustomTokenizerFormat {
10 pub format_name: String,
11 pub format_version: String,
12 pub vocabulary: CustomVocabulary,
13 pub special_tokens: Vec<CustomSpecialToken>,
14 pub normalization_rules: Vec<NormalizationRule>,
15 pub pre_tokenization_rules: Vec<PreTokenizationRule>,
16 pub post_processing_rules: Vec<PostProcessingRule>,
17 pub metadata: HashMap<String, serde_json::Value>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CustomVocabulary {
23 pub vocab_type: VocabularyType,
24 pub tokens: Vec<CustomToken>,
25 pub size: usize,
26 pub unk_token: Option<String>,
27 pub special_token_mapping: HashMap<String, u32>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum VocabularyType {
33 WordLevel,
34 SubwordBPE,
35 SubwordWordPiece,
36 CharacterLevel,
37 SentencePiece,
38 Custom(String),
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CustomToken {
44 pub text: String,
45 pub id: u32,
46 pub frequency: Option<f64>,
47 pub is_special: bool,
48 pub metadata: HashMap<String, serde_json::Value>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CustomSpecialToken {
54 pub token: String,
55 pub id: u32,
56 pub token_type: SpecialTokenType,
57 pub context: Option<String>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum SpecialTokenType {
63 Pad,
64 Unk,
65 Cls,
66 Sep,
67 Mask,
68 BOS,
69 EOS,
70 UserDefined(String),
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct NormalizationRule {
76 pub rule_type: NormalizationType,
77 pub pattern: Option<String>,
78 pub replacement: Option<String>,
79 pub enabled: bool,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum NormalizationType {
85 Lowercase,
86 RemoveAccents,
87 NormalizeWhitespace,
88 NormalizeUnicode,
89 RemovePunctuation,
90 Regex(String),
91 Custom(String),
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct PreTokenizationRule {
97 pub rule_type: PreTokenizationType,
98 pub pattern: Option<String>,
99 pub enabled: bool,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum PreTokenizationType {
105 WhitespaceSplit,
106 PunctuationSplit,
107 WordBoundary,
108 Regex(String),
109 Custom(String),
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct PostProcessingRule {
115 pub rule_type: PostProcessingType,
116 pub parameters: HashMap<String, serde_json::Value>,
117 pub enabled: bool,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum PostProcessingType {
123 AddSpecialTokens,
124 Truncation,
125 Padding,
126 AttentionMask,
127 TokenTypeIds,
128 Custom(String),
129}
130
131#[derive(Debug, Clone)]
133pub struct CustomFormatTokenizer {
134 format: CustomTokenizerFormat,
135 token_to_id: HashMap<String, u32>,
136 id_to_token: HashMap<u32, String>,
137 max_length: Option<usize>,
138}
139
140impl CustomFormatTokenizer {
141 pub fn from_format(format: CustomTokenizerFormat) -> Result<Self> {
143 let mut token_to_id = HashMap::new();
144 let mut id_to_token = HashMap::new();
145
146 for token in &format.vocabulary.tokens {
148 token_to_id.insert(token.text.clone(), token.id);
149 id_to_token.insert(token.id, token.text.clone());
150 }
151
152 for special_token in &format.special_tokens {
154 token_to_id.insert(special_token.token.clone(), special_token.id);
155 id_to_token.insert(special_token.id, special_token.token.clone());
156 }
157
158 Ok(Self {
159 format,
160 token_to_id,
161 id_to_token,
162 max_length: Some(512),
163 })
164 }
165
166 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
168 let content = std::fs::read_to_string(path).map_err(|e| {
169 TrustformersError::other(anyhow::anyhow!("Failed to read file: {}", e).to_string())
170 })?;
171 let format: CustomTokenizerFormat = serde_json::from_str(&content).map_err(|e| {
172 TrustformersError::other(anyhow::anyhow!("Failed to parse format: {}", e).to_string())
173 })?;
174 Self::from_format(format)
175 }
176
177 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
179 let content = serde_json::to_string_pretty(&self.format).map_err(|e| {
180 TrustformersError::other(
181 anyhow::anyhow!("Failed to serialize format: {}", e).to_string(),
182 )
183 })?;
184 std::fs::write(path, content).map_err(|e| {
185 TrustformersError::other(anyhow::anyhow!("Failed to write file: {}", e).to_string())
186 })?;
187 Ok(())
188 }
189
190 pub fn with_max_length(mut self, max_length: Option<usize>) -> Self {
192 self.max_length = max_length;
193 self
194 }
195
196 pub fn vocab_size(&self) -> usize {
198 self.format.vocabulary.size
199 }
200
201 pub fn token_to_id(&self, token: &str) -> Option<u32> {
203 self.token_to_id.get(token).copied()
204 }
205
206 pub fn id_to_token(&self, id: u32) -> Option<String> {
208 self.id_to_token.get(&id).cloned()
209 }
210
211 pub fn get_vocab(&self) -> &HashMap<String, u32> {
213 &self.token_to_id
214 }
215
216 fn normalize_text(&self, text: &str) -> String {
218 let mut normalized = text.to_string();
219
220 for rule in &self.format.normalization_rules {
221 if !rule.enabled {
222 continue;
223 }
224
225 normalized = match &rule.rule_type {
226 NormalizationType::Lowercase => normalized.to_lowercase(),
227 NormalizationType::RemoveAccents => self.remove_accents(&normalized),
228 NormalizationType::NormalizeWhitespace => {
229 normalized.split_whitespace().collect::<Vec<_>>().join(" ")
230 },
231 NormalizationType::NormalizeUnicode => {
232 unicode_normalization::UnicodeNormalization::nfc(normalized.as_str()).collect()
233 },
234 NormalizationType::RemovePunctuation => {
235 normalized.chars().filter(|c| !c.is_ascii_punctuation()).collect()
236 },
237 NormalizationType::Regex(_pattern) => {
238 if let (Some(pattern), Some(replacement)) = (&rule.pattern, &rule.replacement) {
239 if let Ok(re) = regex::Regex::new(pattern) {
240 re.replace_all(&normalized, replacement).to_string()
241 } else {
242 normalized
243 }
244 } else {
245 normalized
246 }
247 },
248 NormalizationType::Custom(_) => {
249 normalized
251 },
252 };
253 }
254
255 normalized
256 }
257
258 fn remove_accents(&self, text: &str) -> String {
260 use unicode_normalization::UnicodeNormalization;
261 text.nfd()
262 .filter(|c| !unicode_normalization::char::is_combining_mark(*c))
263 .collect()
264 }
265
266 fn pre_tokenize(&self, text: &str) -> Vec<String> {
268 let mut tokens = vec![text.to_string()];
269
270 for rule in &self.format.pre_tokenization_rules {
271 if !rule.enabled {
272 continue;
273 }
274
275 let mut new_tokens = Vec::new();
276 for token in tokens {
277 match &rule.rule_type {
278 PreTokenizationType::WhitespaceSplit => {
279 new_tokens.extend(token.split_whitespace().map(|s| s.to_string()));
280 },
281 PreTokenizationType::PunctuationSplit => {
282 let mut current = String::new();
283 for ch in token.chars() {
284 if ch.is_ascii_punctuation() {
285 if !current.is_empty() {
286 new_tokens.push(current.clone());
287 current.clear();
288 }
289 new_tokens.push(ch.to_string());
290 } else {
291 current.push(ch);
292 }
293 }
294 if !current.is_empty() {
295 new_tokens.push(current);
296 }
297 },
298 PreTokenizationType::WordBoundary => {
299 let words: Vec<String> = token
301 .split(|c: char| !c.is_alphanumeric())
302 .filter(|s| !s.is_empty())
303 .map(|s| s.to_string())
304 .collect();
305 new_tokens.extend(words);
306 },
307 PreTokenizationType::Regex(pattern) => {
308 if let Ok(re) = regex::Regex::new(pattern) {
309 let splits: Vec<String> = re
310 .split(&token)
311 .filter(|s| !s.is_empty())
312 .map(|s| s.to_string())
313 .collect();
314 new_tokens.extend(splits);
315 } else {
316 new_tokens.push(token);
317 }
318 },
319 PreTokenizationType::Custom(_) => {
320 new_tokens.push(token);
322 },
323 }
324 }
325 tokens = new_tokens;
326 }
327
328 tokens
329 }
330
331 fn tokenize_subwords(&self, tokens: Vec<String>) -> Vec<String> {
333 let mut subwords = Vec::new();
334
335 for token in tokens {
336 let mut remaining = token.as_str();
338 while !remaining.is_empty() {
339 let mut found = false;
340 for len in (1..=remaining.len()).rev() {
342 let candidate = &remaining[..len];
343 if self.token_to_id.contains_key(candidate) {
344 subwords.push(candidate.to_string());
345 remaining = &remaining[len..];
346 found = true;
347 break;
348 }
349 }
350 if !found {
351 if let Some(unk_token) = &self.format.vocabulary.unk_token {
353 subwords.push(unk_token.clone());
354 }
355 remaining = &remaining[1..];
356 }
357 }
358 }
359
360 subwords
361 }
362}
363
364impl Tokenizer for CustomFormatTokenizer {
365 fn encode(&self, text: &str) -> Result<TokenizedInput> {
366 let normalized = self.normalize_text(text);
367 let pre_tokens = self.pre_tokenize(&normalized);
368 let subwords = self.tokenize_subwords(pre_tokens);
369
370 let mut input_ids = Vec::new();
371 for token in &subwords {
372 if let Some(id) = self.token_to_id(token) {
373 input_ids.push(id);
374 } else if let Some(unk_token) = &self.format.vocabulary.unk_token {
375 if let Some(unk_id) = self.token_to_id(unk_token) {
376 input_ids.push(unk_id);
377 }
378 }
379 }
380
381 if let Some(max_len) = self.max_length {
383 input_ids.truncate(max_len);
384 }
385
386 let attention_mask = vec![1u8; input_ids.len()];
387
388 Ok(TokenizedInput {
389 input_ids,
390 attention_mask,
391 token_type_ids: None,
392 special_tokens_mask: None,
393 offset_mapping: None,
394 overflowing_tokens: None,
395 })
396 }
397
398 fn decode(&self, ids: &[u32]) -> Result<String> {
399 let tokens: Vec<String> = ids.iter().filter_map(|&id| self.id_to_token(id)).collect();
400 Ok(tokens.join(" "))
401 }
402
403 fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
404 let combined = format!("{} {} {}", text_a, "[SEP]", text_b);
406 self.encode(&combined)
407 }
408
409 fn vocab_size(&self) -> usize {
410 self.format.vocabulary.size
411 }
412
413 fn get_vocab(&self) -> HashMap<String, u32> {
414 self.format
415 .vocabulary
416 .tokens
417 .iter()
418 .map(|token| (token.text.clone(), token.id))
419 .collect()
420 }
421
422 fn token_to_id(&self, token: &str) -> Option<u32> {
423 self.format.vocabulary.tokens.iter().find(|t| t.text == token).map(|t| t.id)
424 }
425
426 fn id_to_token(&self, id: u32) -> Option<String> {
427 self.format
428 .vocabulary
429 .tokens
430 .iter()
431 .find(|t| t.id == id)
432 .map(|t| t.text.clone())
433 }
434}
435
436pub struct CustomFormatConverter;
438
439impl CustomFormatConverter {
440 pub fn from_huggingface_json(json_str: &str) -> Result<CustomTokenizerFormat> {
442 let hf_json: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
443 TrustformersError::other(anyhow::anyhow!("Failed to parse HF JSON: {}", e).to_string())
444 })?;
445
446 let mut tokens = Vec::new();
447 let mut special_tokens = Vec::new();
448
449 if let Some(vocab) = hf_json["model"]["vocab"].as_object() {
451 for (token_text, token_id) in vocab {
452 if let Some(id) = token_id.as_u64() {
453 tokens.push(CustomToken {
454 text: token_text.clone(),
455 id: id as u32,
456 frequency: None,
457 is_special: false,
458 metadata: HashMap::new(),
459 });
460 }
461 }
462 }
463
464 if let Some(added_tokens) = hf_json["added_tokens"].as_array() {
466 for token in added_tokens {
467 if let (Some(content), Some(id)) = (token["content"].as_str(), token["id"].as_u64())
468 {
469 special_tokens.push(CustomSpecialToken {
470 token: content.to_string(),
471 id: id as u32,
472 token_type: SpecialTokenType::UserDefined("unknown".to_string()),
473 context: None,
474 });
475 }
476 }
477 }
478
479 let tokens_len = tokens.len();
480 let vocabulary = CustomVocabulary {
481 vocab_type: VocabularyType::SubwordBPE, tokens,
483 size: tokens_len,
484 unk_token: Some("[UNK]".to_string()),
485 special_token_mapping: HashMap::new(),
486 };
487
488 Ok(CustomTokenizerFormat {
489 format_name: "TrustformersCustom".to_string(),
490 format_version: "1.0".to_string(),
491 vocabulary,
492 special_tokens,
493 normalization_rules: vec![NormalizationRule {
494 rule_type: NormalizationType::NormalizeUnicode,
495 pattern: None,
496 replacement: None,
497 enabled: true,
498 }],
499 pre_tokenization_rules: vec![PreTokenizationRule {
500 rule_type: PreTokenizationType::WhitespaceSplit,
501 pattern: None,
502 enabled: true,
503 }],
504 post_processing_rules: vec![PostProcessingRule {
505 rule_type: PostProcessingType::AddSpecialTokens,
506 parameters: HashMap::new(),
507 enabled: true,
508 }],
509 metadata: HashMap::new(),
510 })
511 }
512
513 pub fn to_huggingface_json(format: &CustomTokenizerFormat) -> Result<String> {
515 let mut hf_json = serde_json::json!({
516 "version": "1.0",
517 "truncation": null,
518 "padding": null,
519 "added_tokens": [],
520 "normalizer": {
521 "type": "Sequence",
522 "normalizers": []
523 },
524 "pre_tokenizer": {
525 "type": "Sequence",
526 "pre_tokenizers": []
527 },
528 "post_processor": null,
529 "decoder": {
530 "type": "BPEDecoder"
531 },
532 "model": {
533 "type": "BPE",
534 "dropout": null,
535 "unk_token": format.vocabulary.unk_token,
536 "continuing_subword_prefix": null,
537 "end_of_word_suffix": null,
538 "fuse_unk": false,
539 "vocab": {},
540 "merges": []
541 }
542 });
543
544 let mut vocab_map = serde_json::Map::new();
546 for token in &format.vocabulary.tokens {
547 vocab_map.insert(
548 token.text.clone(),
549 serde_json::Value::Number(token.id.into()),
550 );
551 }
552 hf_json["model"]["vocab"] = serde_json::Value::Object(vocab_map);
553
554 let mut added_tokens = Vec::new();
556 for special_token in &format.special_tokens {
557 added_tokens.push(serde_json::json!({
558 "id": special_token.id,
559 "content": special_token.token,
560 "single_word": false,
561 "lstrip": false,
562 "rstrip": false,
563 "normalized": false,
564 "special": true
565 }));
566 }
567 hf_json["added_tokens"] = serde_json::Value::Array(added_tokens);
568
569 serde_json::to_string_pretty(&hf_json).map_err(|e| {
570 TrustformersError::other(
571 anyhow::anyhow!("Failed to serialize HF JSON: {}", e).to_string(),
572 )
573 })
574 }
575
576 pub fn from_sentencepiece_model(model_path: &Path) -> Result<CustomTokenizerFormat> {
578 use crate::sentencepiece::SentencePieceTokenizer;
579
580 let sp_tokenizer = SentencePieceTokenizer::from_model_file(model_path)?;
582
583 let vocab_size = sp_tokenizer.vocab_size();
585 let mut tokens = Vec::new();
586 let mut special_tokens = Vec::new();
587 let mut special_token_mapping = HashMap::new();
588
589 for id in 0..vocab_size {
591 let id_u32 = id as u32;
592 if let Some(token_text) = sp_tokenizer.id_to_token(id_u32) {
593 let score = sp_tokenizer.get_score(id_u32).unwrap_or(0.0);
594 let is_special = sp_tokenizer.is_special_token_public(&token_text);
595
596 let custom_token = CustomToken {
597 text: token_text.clone(),
598 id: id_u32,
599 frequency: Some(score as f64),
600 is_special,
601 metadata: HashMap::new(),
602 };
603 tokens.push(custom_token);
604
605 if is_special {
607 let token_type = if token_text == "<pad>" {
608 SpecialTokenType::Pad
609 } else if token_text == "<unk>" {
610 SpecialTokenType::Unk
611 } else if token_text == "<s>" {
612 SpecialTokenType::BOS
613 } else if token_text == "</s>" {
614 SpecialTokenType::EOS
615 } else if token_text == "[CLS]" {
616 SpecialTokenType::Cls
617 } else if token_text == "[SEP]" {
618 SpecialTokenType::Sep
619 } else if token_text == "[MASK]" {
620 SpecialTokenType::Mask
621 } else {
622 SpecialTokenType::UserDefined(token_text.clone())
623 };
624
625 special_tokens.push(CustomSpecialToken {
626 token: token_text.clone(),
627 id: id_u32,
628 token_type,
629 context: None,
630 });
631 special_token_mapping.insert(token_text, id_u32);
632 }
633 }
634 }
635
636 let vocabulary = CustomVocabulary {
638 vocab_type: VocabularyType::SentencePiece,
639 tokens,
640 size: vocab_size,
641 unk_token: sp_tokenizer.unk_token().map(|s| s.to_string()),
642 special_token_mapping,
643 };
644
645 let mut normalization_rules = Vec::new();
647
648 if sp_tokenizer.uses_normalization() {
649 normalization_rules.push(NormalizationRule {
650 rule_type: NormalizationType::NormalizeUnicode,
651 pattern: None,
652 replacement: None,
653 enabled: true,
654 });
655 }
656
657 if sp_tokenizer.removes_extra_whitespaces() {
658 normalization_rules.push(NormalizationRule {
659 rule_type: NormalizationType::NormalizeWhitespace,
660 pattern: None,
661 replacement: None,
662 enabled: true,
663 });
664 }
665
666 let mut pre_tokenization_rules = Vec::new();
668 if sp_tokenizer.treats_whitespace_as_suffix() {
669 pre_tokenization_rules.push(PreTokenizationRule {
670 rule_type: PreTokenizationType::WhitespaceSplit,
671 pattern: None,
672 enabled: true,
673 });
674 }
675
676 let mut post_processing_rules = Vec::new();
678 if sp_tokenizer.bos_token_id().is_some() || sp_tokenizer.eos_token_id().is_some() {
679 let mut parameters = HashMap::new();
680 parameters.insert(
681 "template".to_string(),
682 serde_json::Value::String("$A".to_string()),
683 );
684 parameters.insert(
685 "tokens".to_string(),
686 serde_json::Value::Array(
687 special_tokens
688 .iter()
689 .map(|st| serde_json::Value::String(st.token.clone()))
690 .collect(),
691 ),
692 );
693
694 post_processing_rules.push(PostProcessingRule {
695 rule_type: PostProcessingType::AddSpecialTokens,
696 parameters,
697 enabled: true,
698 });
699 }
700
701 let mut metadata = HashMap::new();
703 metadata.insert(
704 "source".to_string(),
705 serde_json::Value::String("SentencePiece".to_string()),
706 );
707 metadata.insert(
708 "model_type".to_string(),
709 serde_json::Value::String(sp_tokenizer.model_type_string()),
710 );
711 metadata.insert(
712 "vocab_size".to_string(),
713 serde_json::Value::Number(serde_json::Number::from(vocab_size)),
714 );
715 metadata.insert(
716 "uses_byte_fallback".to_string(),
717 serde_json::Value::Bool(sp_tokenizer.uses_byte_fallback()),
718 );
719
720 Ok(CustomTokenizerFormat {
721 format_name: "SentencePiece".to_string(),
722 format_version: "1.0".to_string(),
723 vocabulary,
724 special_tokens,
725 normalization_rules,
726 pre_tokenization_rules,
727 post_processing_rules,
728 metadata,
729 })
730 }
731
732 pub fn validate_format(format: &CustomTokenizerFormat) -> Result<Vec<String>> {
734 let mut warnings = Vec::new();
735
736 if format.vocabulary.tokens.len() != format.vocabulary.size {
738 warnings.push(format!(
739 "Vocabulary size mismatch: declared {} but found {} tokens",
740 format.vocabulary.size,
741 format.vocabulary.tokens.len()
742 ));
743 }
744
745 let mut seen_ids = std::collections::HashSet::new();
747 for token in &format.vocabulary.tokens {
748 if !seen_ids.insert(token.id) {
749 warnings.push(format!("Duplicate token ID: {}", token.id));
750 }
751 }
752
753 for special_token in &format.special_tokens {
755 if !seen_ids.contains(&special_token.id) {
756 warnings.push(format!(
757 "Special token '{}' has ID {} not found in vocabulary",
758 special_token.token, special_token.id
759 ));
760 }
761 }
762
763 Ok(warnings)
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770
771 #[test]
772 fn test_custom_format_creation() {
773 let format = CustomTokenizerFormat {
774 format_name: "TestFormat".to_string(),
775 format_version: "1.0".to_string(),
776 vocabulary: CustomVocabulary {
777 vocab_type: VocabularyType::WordLevel,
778 tokens: vec![
779 CustomToken {
780 text: "hello".to_string(),
781 id: 0,
782 frequency: Some(0.1),
783 is_special: false,
784 metadata: HashMap::new(),
785 },
786 CustomToken {
787 text: "world".to_string(),
788 id: 1,
789 frequency: Some(0.05),
790 is_special: false,
791 metadata: HashMap::new(),
792 },
793 ],
794 size: 2,
795 unk_token: Some("[UNK]".to_string()),
796 special_token_mapping: HashMap::new(),
797 },
798 special_tokens: vec![CustomSpecialToken {
799 token: "[UNK]".to_string(),
800 id: 2,
801 token_type: SpecialTokenType::Unk,
802 context: None,
803 }],
804 normalization_rules: vec![],
805 pre_tokenization_rules: vec![],
806 post_processing_rules: vec![],
807 metadata: HashMap::new(),
808 };
809
810 let tokenizer =
811 CustomFormatTokenizer::from_format(format).expect("Operation failed in test");
812 assert_eq!(tokenizer.vocab_size(), 2);
813 assert_eq!(tokenizer.token_to_id("hello"), Some(0));
814 assert_eq!(tokenizer.id_to_token(1), Some("world".to_string()));
815 }
816
817 #[test]
818 fn test_custom_tokenizer_encode() {
819 let format = CustomTokenizerFormat {
820 format_name: "TestFormat".to_string(),
821 format_version: "1.0".to_string(),
822 vocabulary: CustomVocabulary {
823 vocab_type: VocabularyType::WordLevel,
824 tokens: vec![
825 CustomToken {
826 text: "hello".to_string(),
827 id: 0,
828 frequency: None,
829 is_special: false,
830 metadata: HashMap::new(),
831 },
832 CustomToken {
833 text: "world".to_string(),
834 id: 1,
835 frequency: None,
836 is_special: false,
837 metadata: HashMap::new(),
838 },
839 ],
840 size: 2,
841 unk_token: Some("[UNK]".to_string()),
842 special_token_mapping: HashMap::new(),
843 },
844 special_tokens: vec![],
845 normalization_rules: vec![],
846 pre_tokenization_rules: vec![PreTokenizationRule {
847 rule_type: PreTokenizationType::WhitespaceSplit,
848 pattern: None,
849 enabled: true,
850 }],
851 post_processing_rules: vec![],
852 metadata: HashMap::new(),
853 };
854
855 let tokenizer =
856 CustomFormatTokenizer::from_format(format).expect("Operation failed in test");
857 let result = tokenizer.encode("hello world").expect("Encoding failed");
858 assert_eq!(result.input_ids, vec![0, 1]);
859 assert_eq!(result.attention_mask, vec![1, 1]);
860 }
861
862 #[test]
863 fn test_format_validation() {
864 let format = CustomTokenizerFormat {
865 format_name: "TestFormat".to_string(),
866 format_version: "1.0".to_string(),
867 vocabulary: CustomVocabulary {
868 vocab_type: VocabularyType::WordLevel,
869 tokens: vec![CustomToken {
870 text: "hello".to_string(),
871 id: 0,
872 frequency: None,
873 is_special: false,
874 metadata: HashMap::new(),
875 }],
876 size: 2, unk_token: None,
878 special_token_mapping: HashMap::new(),
879 },
880 special_tokens: vec![],
881 normalization_rules: vec![],
882 pre_tokenization_rules: vec![],
883 post_processing_rules: vec![],
884 metadata: HashMap::new(),
885 };
886
887 let warnings =
888 CustomFormatConverter::validate_format(&format).expect("Operation failed in test");
889 assert!(!warnings.is_empty());
890 assert!(warnings[0].contains("Vocabulary size mismatch"));
891 }
892}