Skip to main content

trustformers_tokenizers/
custom_format.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7/// Represents a custom tokenizer format specification
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CustomTokenizerFormat {
10    pub format_name: String,
11    pub format_version: String,
12    pub vocabulary: CustomVocabulary,
13    pub special_tokens: Vec<CustomSpecialToken>,
14    pub normalization_rules: Vec<NormalizationRule>,
15    pub pre_tokenization_rules: Vec<PreTokenizationRule>,
16    pub post_processing_rules: Vec<PostProcessingRule>,
17    pub metadata: HashMap<String, serde_json::Value>,
18}
19
20/// Custom vocabulary definition
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CustomVocabulary {
23    pub vocab_type: VocabularyType,
24    pub tokens: Vec<CustomToken>,
25    pub size: usize,
26    pub unk_token: Option<String>,
27    pub special_token_mapping: HashMap<String, u32>,
28}
29
30/// Types of vocabularies supported
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum VocabularyType {
33    WordLevel,
34    SubwordBPE,
35    SubwordWordPiece,
36    CharacterLevel,
37    SentencePiece,
38    Custom(String),
39}
40
41/// Custom token with metadata
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CustomToken {
44    pub text: String,
45    pub id: u32,
46    pub frequency: Option<f64>,
47    pub is_special: bool,
48    pub metadata: HashMap<String, serde_json::Value>,
49}
50
51/// Custom special token definition
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CustomSpecialToken {
54    pub token: String,
55    pub id: u32,
56    pub token_type: SpecialTokenType,
57    pub context: Option<String>,
58}
59
60/// Types of special tokens
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum SpecialTokenType {
63    Pad,
64    Unk,
65    Cls,
66    Sep,
67    Mask,
68    BOS,
69    EOS,
70    UserDefined(String),
71}
72
73/// Normalization rule for text preprocessing
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct NormalizationRule {
76    pub rule_type: NormalizationType,
77    pub pattern: Option<String>,
78    pub replacement: Option<String>,
79    pub enabled: bool,
80}
81
82/// Types of normalization
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum NormalizationType {
85    Lowercase,
86    RemoveAccents,
87    NormalizeWhitespace,
88    NormalizeUnicode,
89    RemovePunctuation,
90    Regex(String),
91    Custom(String),
92}
93
94/// Pre-tokenization rule
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct PreTokenizationRule {
97    pub rule_type: PreTokenizationType,
98    pub pattern: Option<String>,
99    pub enabled: bool,
100}
101
102/// Types of pre-tokenization
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum PreTokenizationType {
105    WhitespaceSplit,
106    PunctuationSplit,
107    WordBoundary,
108    Regex(String),
109    Custom(String),
110}
111
112/// Post-processing rule
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct PostProcessingRule {
115    pub rule_type: PostProcessingType,
116    pub parameters: HashMap<String, serde_json::Value>,
117    pub enabled: bool,
118}
119
120/// Types of post-processing
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum PostProcessingType {
123    AddSpecialTokens,
124    Truncation,
125    Padding,
126    AttentionMask,
127    TokenTypeIds,
128    Custom(String),
129}
130
131/// Custom format tokenizer implementation
132#[derive(Debug, Clone)]
133pub struct CustomFormatTokenizer {
134    format: CustomTokenizerFormat,
135    token_to_id: HashMap<String, u32>,
136    id_to_token: HashMap<u32, String>,
137    max_length: Option<usize>,
138}
139
140impl CustomFormatTokenizer {
141    /// Create a new tokenizer from a custom format
142    pub fn from_format(format: CustomTokenizerFormat) -> Result<Self> {
143        let mut token_to_id = HashMap::new();
144        let mut id_to_token = HashMap::new();
145
146        // Build vocabulary maps
147        for token in &format.vocabulary.tokens {
148            token_to_id.insert(token.text.clone(), token.id);
149            id_to_token.insert(token.id, token.text.clone());
150        }
151
152        // Add special tokens
153        for special_token in &format.special_tokens {
154            token_to_id.insert(special_token.token.clone(), special_token.id);
155            id_to_token.insert(special_token.id, special_token.token.clone());
156        }
157
158        Ok(Self {
159            format,
160            token_to_id,
161            id_to_token,
162            max_length: Some(512),
163        })
164    }
165
166    /// Load tokenizer from custom format file
167    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
168        let content = std::fs::read_to_string(path).map_err(|e| {
169            TrustformersError::other(anyhow::anyhow!("Failed to read file: {}", e).to_string())
170        })?;
171        let format: CustomTokenizerFormat = serde_json::from_str(&content).map_err(|e| {
172            TrustformersError::other(anyhow::anyhow!("Failed to parse format: {}", e).to_string())
173        })?;
174        Self::from_format(format)
175    }
176
177    /// Save tokenizer to custom format file
178    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
179        let content = serde_json::to_string_pretty(&self.format).map_err(|e| {
180            TrustformersError::other(
181                anyhow::anyhow!("Failed to serialize format: {}", e).to_string(),
182            )
183        })?;
184        std::fs::write(path, content).map_err(|e| {
185            TrustformersError::other(anyhow::anyhow!("Failed to write file: {}", e).to_string())
186        })?;
187        Ok(())
188    }
189
190    /// Set maximum sequence length
191    pub fn with_max_length(mut self, max_length: Option<usize>) -> Self {
192        self.max_length = max_length;
193        self
194    }
195
196    /// Get vocabulary size
197    pub fn vocab_size(&self) -> usize {
198        self.format.vocabulary.size
199    }
200
201    /// Get token ID
202    pub fn token_to_id(&self, token: &str) -> Option<u32> {
203        self.token_to_id.get(token).copied()
204    }
205
206    /// Get token from ID
207    pub fn id_to_token(&self, id: u32) -> Option<String> {
208        self.id_to_token.get(&id).cloned()
209    }
210
211    /// Get vocabulary
212    pub fn get_vocab(&self) -> &HashMap<String, u32> {
213        &self.token_to_id
214    }
215
216    /// Apply normalization rules
217    fn normalize_text(&self, text: &str) -> String {
218        let mut normalized = text.to_string();
219
220        for rule in &self.format.normalization_rules {
221            if !rule.enabled {
222                continue;
223            }
224
225            normalized = match &rule.rule_type {
226                NormalizationType::Lowercase => normalized.to_lowercase(),
227                NormalizationType::RemoveAccents => self.remove_accents(&normalized),
228                NormalizationType::NormalizeWhitespace => {
229                    normalized.split_whitespace().collect::<Vec<_>>().join(" ")
230                },
231                NormalizationType::NormalizeUnicode => {
232                    unicode_normalization::UnicodeNormalization::nfc(normalized.as_str()).collect()
233                },
234                NormalizationType::RemovePunctuation => {
235                    normalized.chars().filter(|c| !c.is_ascii_punctuation()).collect()
236                },
237                NormalizationType::Regex(_pattern) => {
238                    if let (Some(pattern), Some(replacement)) = (&rule.pattern, &rule.replacement) {
239                        if let Ok(re) = regex::Regex::new(pattern) {
240                            re.replace_all(&normalized, replacement).to_string()
241                        } else {
242                            normalized
243                        }
244                    } else {
245                        normalized
246                    }
247                },
248                NormalizationType::Custom(_) => {
249                    // Custom normalization would be implemented based on specific needs
250                    normalized
251                },
252            };
253        }
254
255        normalized
256    }
257
258    /// Remove accents from text
259    fn remove_accents(&self, text: &str) -> String {
260        use unicode_normalization::UnicodeNormalization;
261        text.nfd()
262            .filter(|c| !unicode_normalization::char::is_combining_mark(*c))
263            .collect()
264    }
265
266    /// Apply pre-tokenization rules
267    fn pre_tokenize(&self, text: &str) -> Vec<String> {
268        let mut tokens = vec![text.to_string()];
269
270        for rule in &self.format.pre_tokenization_rules {
271            if !rule.enabled {
272                continue;
273            }
274
275            let mut new_tokens = Vec::new();
276            for token in tokens {
277                match &rule.rule_type {
278                    PreTokenizationType::WhitespaceSplit => {
279                        new_tokens.extend(token.split_whitespace().map(|s| s.to_string()));
280                    },
281                    PreTokenizationType::PunctuationSplit => {
282                        let mut current = String::new();
283                        for ch in token.chars() {
284                            if ch.is_ascii_punctuation() {
285                                if !current.is_empty() {
286                                    new_tokens.push(current.clone());
287                                    current.clear();
288                                }
289                                new_tokens.push(ch.to_string());
290                            } else {
291                                current.push(ch);
292                            }
293                        }
294                        if !current.is_empty() {
295                            new_tokens.push(current);
296                        }
297                    },
298                    PreTokenizationType::WordBoundary => {
299                        // Simple word boundary implementation
300                        let words: Vec<String> = token
301                            .split(|c: char| !c.is_alphanumeric())
302                            .filter(|s| !s.is_empty())
303                            .map(|s| s.to_string())
304                            .collect();
305                        new_tokens.extend(words);
306                    },
307                    PreTokenizationType::Regex(pattern) => {
308                        if let Ok(re) = regex::Regex::new(pattern) {
309                            let splits: Vec<String> = re
310                                .split(&token)
311                                .filter(|s| !s.is_empty())
312                                .map(|s| s.to_string())
313                                .collect();
314                            new_tokens.extend(splits);
315                        } else {
316                            new_tokens.push(token);
317                        }
318                    },
319                    PreTokenizationType::Custom(_) => {
320                        // Custom pre-tokenization would be implemented based on specific needs
321                        new_tokens.push(token);
322                    },
323                }
324            }
325            tokens = new_tokens;
326        }
327
328        tokens
329    }
330
331    /// Tokenize text into subwords
332    fn tokenize_subwords(&self, tokens: Vec<String>) -> Vec<String> {
333        let mut subwords = Vec::new();
334
335        for token in tokens {
336            // Simple greedy tokenization - can be improved with more sophisticated algorithms
337            let mut remaining = token.as_str();
338            while !remaining.is_empty() {
339                let mut found = false;
340                // Try to find the longest matching token
341                for len in (1..=remaining.len()).rev() {
342                    let candidate = &remaining[..len];
343                    if self.token_to_id.contains_key(candidate) {
344                        subwords.push(candidate.to_string());
345                        remaining = &remaining[len..];
346                        found = true;
347                        break;
348                    }
349                }
350                if !found {
351                    // Use UNK token or skip character
352                    if let Some(unk_token) = &self.format.vocabulary.unk_token {
353                        subwords.push(unk_token.clone());
354                    }
355                    remaining = &remaining[1..];
356                }
357            }
358        }
359
360        subwords
361    }
362}
363
364impl Tokenizer for CustomFormatTokenizer {
365    fn encode(&self, text: &str) -> Result<TokenizedInput> {
366        let normalized = self.normalize_text(text);
367        let pre_tokens = self.pre_tokenize(&normalized);
368        let subwords = self.tokenize_subwords(pre_tokens);
369
370        let mut input_ids = Vec::new();
371        for token in &subwords {
372            if let Some(id) = self.token_to_id(token) {
373                input_ids.push(id);
374            } else if let Some(unk_token) = &self.format.vocabulary.unk_token {
375                if let Some(unk_id) = self.token_to_id(unk_token) {
376                    input_ids.push(unk_id);
377                }
378            }
379        }
380
381        // Apply max length constraint
382        if let Some(max_len) = self.max_length {
383            input_ids.truncate(max_len);
384        }
385
386        let attention_mask = vec![1u8; input_ids.len()];
387
388        Ok(TokenizedInput {
389            input_ids,
390            attention_mask,
391            token_type_ids: None,
392            special_tokens_mask: None,
393            offset_mapping: None,
394            overflowing_tokens: None,
395        })
396    }
397
398    fn decode(&self, ids: &[u32]) -> Result<String> {
399        let tokens: Vec<String> = ids.iter().filter_map(|&id| self.id_to_token(id)).collect();
400        Ok(tokens.join(" "))
401    }
402
403    fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
404        // Simple concatenation with separator
405        let combined = format!("{} {} {}", text_a, "[SEP]", text_b);
406        self.encode(&combined)
407    }
408
409    fn vocab_size(&self) -> usize {
410        self.format.vocabulary.size
411    }
412
413    fn get_vocab(&self) -> HashMap<String, u32> {
414        self.format
415            .vocabulary
416            .tokens
417            .iter()
418            .map(|token| (token.text.clone(), token.id))
419            .collect()
420    }
421
422    fn token_to_id(&self, token: &str) -> Option<u32> {
423        self.format.vocabulary.tokens.iter().find(|t| t.text == token).map(|t| t.id)
424    }
425
426    fn id_to_token(&self, id: u32) -> Option<String> {
427        self.format
428            .vocabulary
429            .tokens
430            .iter()
431            .find(|t| t.id == id)
432            .map(|t| t.text.clone())
433    }
434}
435
436/// Custom format converter for converting between different tokenizer formats
437pub struct CustomFormatConverter;
438
439impl CustomFormatConverter {
440    /// Convert HuggingFace tokenizer.json to custom format
441    pub fn from_huggingface_json(json_str: &str) -> Result<CustomTokenizerFormat> {
442        let hf_json: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
443            TrustformersError::other(anyhow::anyhow!("Failed to parse HF JSON: {}", e).to_string())
444        })?;
445
446        let mut tokens = Vec::new();
447        let mut special_tokens = Vec::new();
448
449        // Extract vocabulary
450        if let Some(vocab) = hf_json["model"]["vocab"].as_object() {
451            for (token_text, token_id) in vocab {
452                if let Some(id) = token_id.as_u64() {
453                    tokens.push(CustomToken {
454                        text: token_text.clone(),
455                        id: id as u32,
456                        frequency: None,
457                        is_special: false,
458                        metadata: HashMap::new(),
459                    });
460                }
461            }
462        }
463
464        // Extract special tokens
465        if let Some(added_tokens) = hf_json["added_tokens"].as_array() {
466            for token in added_tokens {
467                if let (Some(content), Some(id)) = (token["content"].as_str(), token["id"].as_u64())
468                {
469                    special_tokens.push(CustomSpecialToken {
470                        token: content.to_string(),
471                        id: id as u32,
472                        token_type: SpecialTokenType::UserDefined("unknown".to_string()),
473                        context: None,
474                    });
475                }
476            }
477        }
478
479        let tokens_len = tokens.len();
480        let vocabulary = CustomVocabulary {
481            vocab_type: VocabularyType::SubwordBPE, // Default assumption
482            tokens,
483            size: tokens_len,
484            unk_token: Some("[UNK]".to_string()),
485            special_token_mapping: HashMap::new(),
486        };
487
488        Ok(CustomTokenizerFormat {
489            format_name: "TrustformersCustom".to_string(),
490            format_version: "1.0".to_string(),
491            vocabulary,
492            special_tokens,
493            normalization_rules: vec![NormalizationRule {
494                rule_type: NormalizationType::NormalizeUnicode,
495                pattern: None,
496                replacement: None,
497                enabled: true,
498            }],
499            pre_tokenization_rules: vec![PreTokenizationRule {
500                rule_type: PreTokenizationType::WhitespaceSplit,
501                pattern: None,
502                enabled: true,
503            }],
504            post_processing_rules: vec![PostProcessingRule {
505                rule_type: PostProcessingType::AddSpecialTokens,
506                parameters: HashMap::new(),
507                enabled: true,
508            }],
509            metadata: HashMap::new(),
510        })
511    }
512
513    /// Convert custom format to HuggingFace tokenizer.json
514    pub fn to_huggingface_json(format: &CustomTokenizerFormat) -> Result<String> {
515        let mut hf_json = serde_json::json!({
516            "version": "1.0",
517            "truncation": null,
518            "padding": null,
519            "added_tokens": [],
520            "normalizer": {
521                "type": "Sequence",
522                "normalizers": []
523            },
524            "pre_tokenizer": {
525                "type": "Sequence",
526                "pre_tokenizers": []
527            },
528            "post_processor": null,
529            "decoder": {
530                "type": "BPEDecoder"
531            },
532            "model": {
533                "type": "BPE",
534                "dropout": null,
535                "unk_token": format.vocabulary.unk_token,
536                "continuing_subword_prefix": null,
537                "end_of_word_suffix": null,
538                "fuse_unk": false,
539                "vocab": {},
540                "merges": []
541            }
542        });
543
544        // Add vocabulary
545        let mut vocab_map = serde_json::Map::new();
546        for token in &format.vocabulary.tokens {
547            vocab_map.insert(
548                token.text.clone(),
549                serde_json::Value::Number(token.id.into()),
550            );
551        }
552        hf_json["model"]["vocab"] = serde_json::Value::Object(vocab_map);
553
554        // Add special tokens
555        let mut added_tokens = Vec::new();
556        for special_token in &format.special_tokens {
557            added_tokens.push(serde_json::json!({
558                "id": special_token.id,
559                "content": special_token.token,
560                "single_word": false,
561                "lstrip": false,
562                "rstrip": false,
563                "normalized": false,
564                "special": true
565            }));
566        }
567        hf_json["added_tokens"] = serde_json::Value::Array(added_tokens);
568
569        serde_json::to_string_pretty(&hf_json).map_err(|e| {
570            TrustformersError::other(
571                anyhow::anyhow!("Failed to serialize HF JSON: {}", e).to_string(),
572            )
573        })
574    }
575
576    /// Convert SentencePiece model to custom format
577    pub fn from_sentencepiece_model(model_path: &Path) -> Result<CustomTokenizerFormat> {
578        use crate::sentencepiece::SentencePieceTokenizer;
579
580        // Load the SentencePiece model
581        let sp_tokenizer = SentencePieceTokenizer::from_model_file(model_path)?;
582
583        // Get vocabulary from the tokenizer
584        let vocab_size = sp_tokenizer.vocab_size();
585        let mut tokens = Vec::new();
586        let mut special_tokens = Vec::new();
587        let mut special_token_mapping = HashMap::new();
588
589        // Extract tokens and their metadata
590        for id in 0..vocab_size {
591            let id_u32 = id as u32;
592            if let Some(token_text) = sp_tokenizer.id_to_token(id_u32) {
593                let score = sp_tokenizer.get_score(id_u32).unwrap_or(0.0);
594                let is_special = sp_tokenizer.is_special_token_public(&token_text);
595
596                let custom_token = CustomToken {
597                    text: token_text.clone(),
598                    id: id_u32,
599                    frequency: Some(score as f64),
600                    is_special,
601                    metadata: HashMap::new(),
602                };
603                tokens.push(custom_token);
604
605                // Handle special tokens
606                if is_special {
607                    let token_type = if token_text == "<pad>" {
608                        SpecialTokenType::Pad
609                    } else if token_text == "<unk>" {
610                        SpecialTokenType::Unk
611                    } else if token_text == "<s>" {
612                        SpecialTokenType::BOS
613                    } else if token_text == "</s>" {
614                        SpecialTokenType::EOS
615                    } else if token_text == "[CLS]" {
616                        SpecialTokenType::Cls
617                    } else if token_text == "[SEP]" {
618                        SpecialTokenType::Sep
619                    } else if token_text == "[MASK]" {
620                        SpecialTokenType::Mask
621                    } else {
622                        SpecialTokenType::UserDefined(token_text.clone())
623                    };
624
625                    special_tokens.push(CustomSpecialToken {
626                        token: token_text.clone(),
627                        id: id_u32,
628                        token_type,
629                        context: None,
630                    });
631                    special_token_mapping.insert(token_text, id_u32);
632                }
633            }
634        }
635
636        // Create custom vocabulary
637        let vocabulary = CustomVocabulary {
638            vocab_type: VocabularyType::SentencePiece,
639            tokens,
640            size: vocab_size,
641            unk_token: sp_tokenizer.unk_token().map(|s| s.to_string()),
642            special_token_mapping,
643        };
644
645        // Create normalization rules based on SentencePiece configuration
646        let mut normalization_rules = Vec::new();
647
648        if sp_tokenizer.uses_normalization() {
649            normalization_rules.push(NormalizationRule {
650                rule_type: NormalizationType::NormalizeUnicode,
651                pattern: None,
652                replacement: None,
653                enabled: true,
654            });
655        }
656
657        if sp_tokenizer.removes_extra_whitespaces() {
658            normalization_rules.push(NormalizationRule {
659                rule_type: NormalizationType::NormalizeWhitespace,
660                pattern: None,
661                replacement: None,
662                enabled: true,
663            });
664        }
665
666        // Create pre-tokenization rules
667        let mut pre_tokenization_rules = Vec::new();
668        if sp_tokenizer.treats_whitespace_as_suffix() {
669            pre_tokenization_rules.push(PreTokenizationRule {
670                rule_type: PreTokenizationType::WhitespaceSplit,
671                pattern: None,
672                enabled: true,
673            });
674        }
675
676        // Create post-processing rules for special tokens
677        let mut post_processing_rules = Vec::new();
678        if sp_tokenizer.bos_token_id().is_some() || sp_tokenizer.eos_token_id().is_some() {
679            let mut parameters = HashMap::new();
680            parameters.insert(
681                "template".to_string(),
682                serde_json::Value::String("$A".to_string()),
683            );
684            parameters.insert(
685                "tokens".to_string(),
686                serde_json::Value::Array(
687                    special_tokens
688                        .iter()
689                        .map(|st| serde_json::Value::String(st.token.clone()))
690                        .collect(),
691                ),
692            );
693
694            post_processing_rules.push(PostProcessingRule {
695                rule_type: PostProcessingType::AddSpecialTokens,
696                parameters,
697                enabled: true,
698            });
699        }
700
701        // Create metadata
702        let mut metadata = HashMap::new();
703        metadata.insert(
704            "source".to_string(),
705            serde_json::Value::String("SentencePiece".to_string()),
706        );
707        metadata.insert(
708            "model_type".to_string(),
709            serde_json::Value::String(sp_tokenizer.model_type_string()),
710        );
711        metadata.insert(
712            "vocab_size".to_string(),
713            serde_json::Value::Number(serde_json::Number::from(vocab_size)),
714        );
715        metadata.insert(
716            "uses_byte_fallback".to_string(),
717            serde_json::Value::Bool(sp_tokenizer.uses_byte_fallback()),
718        );
719
720        Ok(CustomTokenizerFormat {
721            format_name: "SentencePiece".to_string(),
722            format_version: "1.0".to_string(),
723            vocabulary,
724            special_tokens,
725            normalization_rules,
726            pre_tokenization_rules,
727            post_processing_rules,
728            metadata,
729        })
730    }
731
732    /// Validate custom format
733    pub fn validate_format(format: &CustomTokenizerFormat) -> Result<Vec<String>> {
734        let mut warnings = Vec::new();
735
736        // Check vocabulary consistency
737        if format.vocabulary.tokens.len() != format.vocabulary.size {
738            warnings.push(format!(
739                "Vocabulary size mismatch: declared {} but found {} tokens",
740                format.vocabulary.size,
741                format.vocabulary.tokens.len()
742            ));
743        }
744
745        // Check for duplicate token IDs
746        let mut seen_ids = std::collections::HashSet::new();
747        for token in &format.vocabulary.tokens {
748            if !seen_ids.insert(token.id) {
749                warnings.push(format!("Duplicate token ID: {}", token.id));
750            }
751        }
752
753        // Check special tokens
754        for special_token in &format.special_tokens {
755            if !seen_ids.contains(&special_token.id) {
756                warnings.push(format!(
757                    "Special token '{}' has ID {} not found in vocabulary",
758                    special_token.token, special_token.id
759                ));
760            }
761        }
762
763        Ok(warnings)
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770
771    #[test]
772    fn test_custom_format_creation() {
773        let format = CustomTokenizerFormat {
774            format_name: "TestFormat".to_string(),
775            format_version: "1.0".to_string(),
776            vocabulary: CustomVocabulary {
777                vocab_type: VocabularyType::WordLevel,
778                tokens: vec![
779                    CustomToken {
780                        text: "hello".to_string(),
781                        id: 0,
782                        frequency: Some(0.1),
783                        is_special: false,
784                        metadata: HashMap::new(),
785                    },
786                    CustomToken {
787                        text: "world".to_string(),
788                        id: 1,
789                        frequency: Some(0.05),
790                        is_special: false,
791                        metadata: HashMap::new(),
792                    },
793                ],
794                size: 2,
795                unk_token: Some("[UNK]".to_string()),
796                special_token_mapping: HashMap::new(),
797            },
798            special_tokens: vec![CustomSpecialToken {
799                token: "[UNK]".to_string(),
800                id: 2,
801                token_type: SpecialTokenType::Unk,
802                context: None,
803            }],
804            normalization_rules: vec![],
805            pre_tokenization_rules: vec![],
806            post_processing_rules: vec![],
807            metadata: HashMap::new(),
808        };
809
810        let tokenizer =
811            CustomFormatTokenizer::from_format(format).expect("Operation failed in test");
812        assert_eq!(tokenizer.vocab_size(), 2);
813        assert_eq!(tokenizer.token_to_id("hello"), Some(0));
814        assert_eq!(tokenizer.id_to_token(1), Some("world".to_string()));
815    }
816
817    #[test]
818    fn test_custom_tokenizer_encode() {
819        let format = CustomTokenizerFormat {
820            format_name: "TestFormat".to_string(),
821            format_version: "1.0".to_string(),
822            vocabulary: CustomVocabulary {
823                vocab_type: VocabularyType::WordLevel,
824                tokens: vec![
825                    CustomToken {
826                        text: "hello".to_string(),
827                        id: 0,
828                        frequency: None,
829                        is_special: false,
830                        metadata: HashMap::new(),
831                    },
832                    CustomToken {
833                        text: "world".to_string(),
834                        id: 1,
835                        frequency: None,
836                        is_special: false,
837                        metadata: HashMap::new(),
838                    },
839                ],
840                size: 2,
841                unk_token: Some("[UNK]".to_string()),
842                special_token_mapping: HashMap::new(),
843            },
844            special_tokens: vec![],
845            normalization_rules: vec![],
846            pre_tokenization_rules: vec![PreTokenizationRule {
847                rule_type: PreTokenizationType::WhitespaceSplit,
848                pattern: None,
849                enabled: true,
850            }],
851            post_processing_rules: vec![],
852            metadata: HashMap::new(),
853        };
854
855        let tokenizer =
856            CustomFormatTokenizer::from_format(format).expect("Operation failed in test");
857        let result = tokenizer.encode("hello world").expect("Encoding failed");
858        assert_eq!(result.input_ids, vec![0, 1]);
859        assert_eq!(result.attention_mask, vec![1, 1]);
860    }
861
862    #[test]
863    fn test_format_validation() {
864        let format = CustomTokenizerFormat {
865            format_name: "TestFormat".to_string(),
866            format_version: "1.0".to_string(),
867            vocabulary: CustomVocabulary {
868                vocab_type: VocabularyType::WordLevel,
869                tokens: vec![CustomToken {
870                    text: "hello".to_string(),
871                    id: 0,
872                    frequency: None,
873                    is_special: false,
874                    metadata: HashMap::new(),
875                }],
876                size: 2, // Mismatch: claims 2 but only has 1 token
877                unk_token: None,
878                special_token_mapping: HashMap::new(),
879            },
880            special_tokens: vec![],
881            normalization_rules: vec![],
882            pre_tokenization_rules: vec![],
883            post_processing_rules: vec![],
884            metadata: HashMap::new(),
885        };
886
887        let warnings =
888            CustomFormatConverter::validate_format(&format).expect("Operation failed in test");
889        assert!(!warnings.is_empty());
890        assert!(warnings[0].contains("Vocabulary size mismatch"));
891    }
892}