Skip to main content

trustformers_tokenizers/
tokenizer.rs

1use crate::alignment::{AlignedSpan, AlignmentConfig, AlignmentEngine, TokenAlignment};
2use std::collections::HashMap;
3use std::path::Path;
4use std::str::FromStr;
5use std::sync::Arc;
6// SciRS2 Integration Policy: Use re-exported tokenizers types from trustformers_core
7use trustformers_core::errors::{Result, TrustformersError};
8use trustformers_core::traits::{TokenizedInput, Tokenizer};
9use trustformers_core::{Encoding, Tokenizer as HFTokenizer, TokenizerError};
10
11#[derive(Debug, Clone)]
12pub struct TokenizedInputWithOffsets {
13    pub input_ids: Vec<u32>,
14    pub attention_mask: Vec<u8>,
15    pub token_type_ids: Option<Vec<u32>>,
16    pub offset_mapping: Option<Vec<(usize, usize)>>,
17    pub special_tokens_mask: Option<Vec<u8>>,
18}
19
20#[derive(Debug, Clone)]
21pub struct TokenizedInputWithAlignment {
22    pub input_ids: Vec<u32>,
23    pub attention_mask: Vec<u8>,
24    pub token_type_ids: Option<Vec<u32>>,
25    pub offset_mapping: Option<Vec<(usize, usize)>>,
26    pub special_tokens_mask: Option<Vec<u8>>,
27    pub word_alignments: Vec<TokenAlignment>,
28    pub words: Vec<crate::alignment::Word>,
29}
30
31impl From<TokenizedInputWithOffsets> for TokenizedInput {
32    fn from(input: TokenizedInputWithOffsets) -> Self {
33        TokenizedInput {
34            input_ids: input.input_ids,
35            attention_mask: input.attention_mask,
36            token_type_ids: input.token_type_ids,
37            special_tokens_mask: input.special_tokens_mask,
38            offset_mapping: input.offset_mapping,
39            overflowing_tokens: None,
40        }
41    }
42}
43
44impl From<TokenizedInputWithAlignment> for TokenizedInput {
45    fn from(input: TokenizedInputWithAlignment) -> Self {
46        TokenizedInput {
47            input_ids: input.input_ids,
48            attention_mask: input.attention_mask,
49            token_type_ids: input.token_type_ids,
50            special_tokens_mask: input.special_tokens_mask,
51            offset_mapping: input.offset_mapping,
52            overflowing_tokens: None,
53        }
54    }
55}
56
57impl From<TokenizedInputWithAlignment> for TokenizedInputWithOffsets {
58    fn from(input: TokenizedInputWithAlignment) -> Self {
59        TokenizedInputWithOffsets {
60            input_ids: input.input_ids,
61            attention_mask: input.attention_mask,
62            token_type_ids: input.token_type_ids,
63            offset_mapping: input.offset_mapping,
64            special_tokens_mask: input.special_tokens_mask,
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct TokenizerImpl {
71    tokenizer: Arc<HFTokenizer>,
72    do_lower_case: bool,
73    max_length: Option<usize>,
74    alignment_engine: Option<AlignmentEngine>,
75}
76
77impl TokenizerImpl {
78    pub fn from_file(path: &Path) -> Result<Self> {
79        let tokenizer = HFTokenizer::from_file(path)
80            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
81        Ok(Self {
82            tokenizer: Arc::new(tokenizer),
83            do_lower_case: false,
84            max_length: Some(512),
85            alignment_engine: None,
86        })
87    }
88
89    pub fn from_pretrained(name: &str) -> Result<Self> {
90        Self::from_pretrained_with_revision(name, None)
91    }
92
93    pub fn from_pretrained_with_revision(name: &str, revision: Option<&str>) -> Result<Self> {
94        // Simplified version - in practice, this would download from HuggingFace Hub
95        // For now, try to load from a local cache path
96        let cache_dir = std::env::var("HF_HOME")
97            .or_else(|_| std::env::var("TRANSFORMERS_CACHE"))
98            .unwrap_or_else(|_| {
99                format!(
100                    "{}/.cache/huggingface/transformers",
101                    std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string())
102                )
103            });
104
105        // Include revision in path if specified
106        let tokenizer_path = match revision {
107            Some(rev) => format!("{}/{}/refs/{}/tokenizer.json", cache_dir, name, rev),
108            None => format!("{}/{}/tokenizer.json", cache_dir, name),
109        };
110        let path = Path::new(&tokenizer_path);
111
112        if path.exists() {
113            Self::from_file(path)
114        } else {
115            Err(TrustformersError::other(anyhow::anyhow!(
116                "Model '{}' not found locally. Please download it first or implement model downloading.",
117                name
118            ).to_string()))
119        }
120    }
121
122    pub fn from_tokenizer_json(json_str: &str) -> Result<Self> {
123        let tokenizer = HFTokenizer::from_str(json_str).map_err(|e: TokenizerError| {
124            TrustformersError::other(anyhow::anyhow!(e).to_string())
125        })?;
126        Ok(Self {
127            tokenizer: Arc::new(tokenizer),
128            do_lower_case: false,
129            max_length: Some(512),
130            alignment_engine: None,
131        })
132    }
133
134    pub fn save_to_file(&self, path: &Path) -> Result<()> {
135        let json = self
136            .tokenizer
137            .to_string(false)
138            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
139        std::fs::write(path, json)
140            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
141        Ok(())
142    }
143
144    pub fn to_json(&self) -> Result<String> {
145        self.tokenizer
146            .to_string(false)
147            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))
148    }
149
150    pub fn with_config(mut self, do_lower_case: bool, max_length: Option<usize>) -> Self {
151        self.do_lower_case = do_lower_case;
152        self.max_length = max_length;
153        self
154    }
155
156    pub fn encode_with_offsets(
157        &self,
158        text: &str,
159        add_special_tokens: bool,
160    ) -> Result<TokenizedInputWithOffsets> {
161        let encoding = self
162            .tokenizer
163            .encode(text, add_special_tokens)
164            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
165        Ok(self.encoding_to_tokenized_input_with_offsets(encoding))
166    }
167
168    pub fn encode_pair_with_offsets(
169        &self,
170        text: &str,
171        text2: &str,
172        add_special_tokens: bool,
173    ) -> Result<TokenizedInputWithOffsets> {
174        let encoding = self
175            .tokenizer
176            .encode((text, text2), add_special_tokens)
177            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
178        Ok(self.encoding_to_tokenized_input_with_offsets(encoding))
179    }
180
181    pub fn decode_with_special_tokens(
182        &self,
183        ids: &[u32],
184        skip_special_tokens: bool,
185    ) -> Result<String> {
186        self.tokenizer
187            .decode(ids, skip_special_tokens)
188            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))
189    }
190
191    pub fn get_vocab(&self) -> HashMap<String, u32> {
192        self.tokenizer.get_vocab(false)
193    }
194
195    pub fn token_to_id(&self, token: &str) -> Option<u32> {
196        self.tokenizer.token_to_id(token)
197    }
198
199    pub fn id_to_token(&self, id: u32) -> Option<String> {
200        self.tokenizer.id_to_token(id)
201    }
202
203    /// Configure word alignment engine
204    pub fn with_alignment_config(mut self, config: AlignmentConfig) -> Self {
205        self.alignment_engine = Some(AlignmentEngine::new(config));
206        self
207    }
208
209    /// Enable word alignment with default configuration
210    pub fn with_word_alignment(mut self) -> Self {
211        self.alignment_engine = Some(AlignmentEngine::new(AlignmentConfig::default()));
212        self
213    }
214
215    /// Get mutable reference to alignment engine
216    pub fn alignment_engine_mut(&mut self) -> Option<&mut AlignmentEngine> {
217        self.alignment_engine.as_mut()
218    }
219
220    /// Encode text with word alignment
221    pub fn encode_with_alignment(
222        &mut self,
223        text: &str,
224        add_special_tokens: bool,
225    ) -> Result<TokenizedInputWithAlignment> {
226        let encoding = self
227            .tokenizer
228            .encode(text, add_special_tokens)
229            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
230
231        self.encoding_to_tokenized_input_with_alignment(text, encoding)
232    }
233
234    /// Encode text pair with word alignment
235    pub fn encode_pair_with_alignment(
236        &mut self,
237        text: &str,
238        text2: &str,
239        add_special_tokens: bool,
240    ) -> Result<TokenizedInputWithAlignment> {
241        let encoding = self
242            .tokenizer
243            .encode((text, text2), add_special_tokens)
244            .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
245
246        // Combine texts for alignment
247        let combined_text = format!("{} {}", text, text2);
248        self.encoding_to_tokenized_input_with_alignment(&combined_text, encoding)
249    }
250
251    /// Extract spans with word alignment
252    pub fn extract_aligned_spans(
253        &mut self,
254        text: &str,
255        spans: &[(usize, usize)],
256        add_special_tokens: bool,
257    ) -> Result<Vec<AlignedSpan>> {
258        let input_with_alignment = self.encode_with_alignment(text, add_special_tokens)?;
259
260        if let Some(engine) = &mut self.alignment_engine {
261            engine.extract_spans(text, &input_with_alignment.word_alignments, spans)
262        } else {
263            Err(TrustformersError::other(
264                "Word alignment engine not configured".to_string(),
265            ))
266        }
267    }
268
269    /// Preserve entity boundaries in tokenization
270    pub fn preserve_entities(
271        &mut self,
272        text: &str,
273        entities: &[(usize, usize, String)],
274        add_special_tokens: bool,
275    ) -> Result<Vec<AlignedSpan>> {
276        let input_with_alignment = self.encode_with_alignment(text, add_special_tokens)?;
277
278        if let Some(engine) = &mut self.alignment_engine {
279            engine.preserve_entities(text, &input_with_alignment.word_alignments, entities)
280        } else {
281            Err(TrustformersError::other(
282                "Word alignment engine not configured".to_string(),
283            ))
284        }
285    }
286
287    /// Get word boundaries for a specific token
288    pub fn get_word_boundaries_for_token(
289        &self,
290        alignments: &[TokenAlignment],
291        token_index: usize,
292    ) -> Option<(usize, usize)> {
293        if let Some(engine) = &self.alignment_engine {
294            engine.get_word_boundaries_for_token(alignments, token_index)
295        } else {
296            None
297        }
298    }
299
300    /// Check if tokens form a complete word
301    pub fn tokens_form_complete_word(
302        &self,
303        alignments: &[TokenAlignment],
304        token_indices: &[usize],
305    ) -> bool {
306        if let Some(engine) = &self.alignment_engine {
307            engine.tokens_form_complete_word(alignments, token_indices)
308        } else {
309            false
310        }
311    }
312
313    fn encoding_to_tokenized_input(&self, encoding: Encoding) -> TokenizedInput {
314        TokenizedInput {
315            input_ids: encoding.get_ids().to_vec(),
316            attention_mask: encoding.get_attention_mask().iter().map(|&x| x as u8).collect(),
317            token_type_ids: if encoding.get_type_ids().is_empty() {
318                None
319            } else {
320                Some(encoding.get_type_ids().to_vec())
321            },
322            special_tokens_mask: None,
323            offset_mapping: None,
324            overflowing_tokens: None,
325        }
326    }
327
328    fn encoding_to_tokenized_input_with_offsets(
329        &self,
330        encoding: Encoding,
331    ) -> TokenizedInputWithOffsets {
332        let offset_mapping = if !encoding.get_offsets().is_empty() {
333            Some(encoding.get_offsets().to_vec())
334        } else {
335            None
336        };
337
338        let special_tokens_mask = if !encoding.get_special_tokens_mask().is_empty() {
339            Some(encoding.get_special_tokens_mask().iter().map(|&x| x as u8).collect())
340        } else {
341            None
342        };
343
344        TokenizedInputWithOffsets {
345            input_ids: encoding.get_ids().to_vec(),
346            attention_mask: encoding.get_attention_mask().iter().map(|&x| x as u8).collect(),
347            token_type_ids: if encoding.get_type_ids().is_empty() {
348                None
349            } else {
350                Some(encoding.get_type_ids().to_vec())
351            },
352            offset_mapping,
353            special_tokens_mask,
354        }
355    }
356
357    fn encoding_to_tokenized_input_with_alignment(
358        &mut self,
359        text: &str,
360        encoding: Encoding,
361    ) -> Result<TokenizedInputWithAlignment> {
362        let offset_mapping = if !encoding.get_offsets().is_empty() {
363            Some(encoding.get_offsets().to_vec())
364        } else {
365            None
366        };
367
368        let special_tokens_mask = if !encoding.get_special_tokens_mask().is_empty() {
369            Some(encoding.get_special_tokens_mask().iter().map(|&x| x as u8).collect())
370        } else {
371            None
372        };
373
374        // Perform word alignment if engine is available
375        let (word_alignments, words) = if let Some(engine) = &mut self.alignment_engine {
376            if let Some(ref offsets) = offset_mapping {
377                let alignments =
378                    engine.align_tokens_to_words(text, offsets, special_tokens_mask.as_deref())?;
379                let words = engine.extract_words(text);
380                (alignments, words)
381            } else {
382                // If no offsets available, create empty alignments
383                (Vec::new(), Vec::new())
384            }
385        } else {
386            return Err(TrustformersError::other(
387                "Word alignment engine not configured".to_string(),
388            ));
389        };
390
391        Ok(TokenizedInputWithAlignment {
392            input_ids: encoding.get_ids().to_vec(),
393            attention_mask: encoding.get_attention_mask().iter().map(|&x| x as u8).collect(),
394            token_type_ids: if encoding.get_type_ids().is_empty() {
395                None
396            } else {
397                Some(encoding.get_type_ids().to_vec())
398            },
399            offset_mapping,
400            special_tokens_mask,
401            word_alignments,
402            words,
403        })
404    }
405}
406
407impl Tokenizer for TokenizerImpl {
408    fn encode(&self, text: &str) -> Result<TokenizedInput> {
409        let encoding = self.tokenizer.encode(text, false).map_err(|e| {
410            trustformers_core::errors::TrustformersError::other(anyhow::anyhow!(e).to_string())
411        })?;
412        Ok(self.encoding_to_tokenized_input(encoding))
413    }
414
415    fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
416        let encoding = self.tokenizer.encode((text, text2), false).map_err(|e| {
417            trustformers_core::errors::TrustformersError::other(anyhow::anyhow!(e).to_string())
418        })?;
419        Ok(self.encoding_to_tokenized_input(encoding))
420    }
421
422    fn decode(&self, ids: &[u32]) -> Result<String> {
423        self.tokenizer.decode(ids, false).map_err(|e| {
424            trustformers_core::errors::TrustformersError::other(anyhow::anyhow!(e).to_string())
425        })
426    }
427
428    fn vocab_size(&self) -> usize {
429        self.tokenizer.get_vocab_size(false)
430    }
431
432    fn get_vocab(&self) -> HashMap<String, u32> {
433        self.tokenizer.get_vocab(false)
434    }
435
436    fn token_to_id(&self, token: &str) -> Option<u32> {
437        self.tokenizer.token_to_id(token)
438    }
439
440    fn id_to_token(&self, id: u32) -> Option<String> {
441        self.tokenizer.id_to_token(id)
442    }
443}
444
445#[derive(Debug, Clone)]
446pub enum TokenizerWrapper {
447    WordPiece(crate::wordpiece::WordPieceTokenizer),
448    BPE(crate::bpe::BPETokenizer),
449    Unigram(crate::unigram::UnigramTokenizer),
450    Char(crate::char::CharTokenizer),
451    HuggingFace(TokenizerImpl),
452}
453
454impl Tokenizer for TokenizerWrapper {
455    fn encode(&self, text: &str) -> Result<TokenizedInput> {
456        match self {
457            TokenizerWrapper::WordPiece(t) => t.encode(text),
458            TokenizerWrapper::BPE(t) => t.encode(text),
459            TokenizerWrapper::Unigram(t) => t.encode(text),
460            TokenizerWrapper::Char(t) => t.encode(text),
461            TokenizerWrapper::HuggingFace(t) => t.encode(text),
462        }
463    }
464
465    fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
466        match self {
467            TokenizerWrapper::WordPiece(t) => t.encode_pair(text, text2),
468            TokenizerWrapper::BPE(t) => t.encode_pair(text, text2),
469            TokenizerWrapper::Unigram(t) => t.encode_pair(text, text2),
470            TokenizerWrapper::Char(t) => t.encode_pair(text, text2),
471            TokenizerWrapper::HuggingFace(t) => t.encode_pair(text, text2),
472        }
473    }
474
475    fn decode(&self, ids: &[u32]) -> Result<String> {
476        match self {
477            TokenizerWrapper::WordPiece(t) => t.decode(ids),
478            TokenizerWrapper::BPE(t) => t.decode(ids),
479            TokenizerWrapper::Unigram(t) => t.decode(ids),
480            TokenizerWrapper::Char(t) => t.decode(ids),
481            TokenizerWrapper::HuggingFace(t) => t.decode(ids),
482        }
483    }
484
485    fn vocab_size(&self) -> usize {
486        match self {
487            TokenizerWrapper::WordPiece(t) => t.vocab_size(),
488            TokenizerWrapper::BPE(t) => t.vocab_size(),
489            TokenizerWrapper::Unigram(t) => t.vocab_size(),
490            TokenizerWrapper::Char(t) => t.vocab_size(),
491            TokenizerWrapper::HuggingFace(t) => t.vocab_size(),
492        }
493    }
494
495    fn get_vocab(&self) -> HashMap<String, u32> {
496        match self {
497            TokenizerWrapper::WordPiece(t) => t.get_vocab(),
498            TokenizerWrapper::BPE(t) => t.get_vocab(),
499            TokenizerWrapper::Unigram(t) => t.get_vocab(),
500            TokenizerWrapper::Char(t) => t.get_vocab(),
501            TokenizerWrapper::HuggingFace(t) => t.get_vocab(),
502        }
503    }
504
505    fn token_to_id(&self, token: &str) -> Option<u32> {
506        match self {
507            TokenizerWrapper::WordPiece(t) => t.token_to_id(token),
508            TokenizerWrapper::BPE(t) => t.token_to_id(token),
509            TokenizerWrapper::Unigram(t) => t.token_to_id(token),
510            TokenizerWrapper::Char(t) => t.token_to_id(token),
511            TokenizerWrapper::HuggingFace(t) => t.token_to_id(token),
512        }
513    }
514
515    fn id_to_token(&self, id: u32) -> Option<String> {
516        match self {
517            TokenizerWrapper::WordPiece(t) => t.id_to_token(id),
518            TokenizerWrapper::BPE(t) => t.id_to_token(id),
519            TokenizerWrapper::Unigram(t) => t.id_to_token(id),
520            TokenizerWrapper::Char(t) => t.id_to_token(id),
521            TokenizerWrapper::HuggingFace(t) => t.id_to_token(id),
522        }
523    }
524}
525
526impl TokenizerWrapper {
527    /// Load a tokenizer from a pretrained model or path
528    pub fn from_pretrained<P: AsRef<Path>>(model_name_or_path: P) -> Result<Self> {
529        let path = model_name_or_path.as_ref();
530
531        // First try to load as a HuggingFace tokenizer from tokenizer.json
532        let tokenizer_json_path = path.join("tokenizer.json");
533        if tokenizer_json_path.exists() {
534            let tokenizer = TokenizerImpl::from_file(&tokenizer_json_path)?;
535            return Ok(TokenizerWrapper::HuggingFace(tokenizer));
536        }
537
538        // Try to load from tokenizer config
539        let config_path = path.join("tokenizer_config.json");
540        if config_path.exists() {
541            let config_str = std::fs::read_to_string(&config_path)
542                .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
543            let config: serde_json::Value = serde_json::from_str(&config_str)
544                .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
545
546            if let Some(tokenizer_type) = config.get("tokenizer_type").and_then(|v| v.as_str()) {
547                match tokenizer_type {
548                    "WordPiece" => {
549                        // Create a basic WordPiece tokenizer with minimal config
550                        let vocab = std::collections::HashMap::new();
551                        let tokenizer = crate::wordpiece::WordPieceTokenizer::new(vocab, false);
552                        return Ok(TokenizerWrapper::WordPiece(tokenizer));
553                    },
554                    "BPE" => {
555                        // Create a basic BPE tokenizer
556                        let vocab = std::collections::HashMap::new();
557                        let merges = Vec::new();
558                        let tokenizer = crate::bpe::BPETokenizer::new(vocab, merges);
559                        return Ok(TokenizerWrapper::BPE(tokenizer));
560                    },
561                    "Unigram" => {
562                        // Create a basic Unigram tokenizer
563                        let vocab = std::collections::HashMap::new();
564                        let scores = std::collections::HashMap::new();
565                        let tokenizer = crate::unigram::UnigramTokenizer::new(vocab, scores)?;
566                        return Ok(TokenizerWrapper::Unigram(tokenizer));
567                    },
568                    "Character" => {
569                        // Create a basic Character tokenizer
570                        let vocab = std::collections::HashMap::new();
571                        let tokenizer = crate::char::CharTokenizer::new(vocab);
572                        return Ok(TokenizerWrapper::Char(tokenizer));
573                    },
574                    _ => {
575                        return Err(TrustformersError::invalid_input(format!(
576                            "Unsupported tokenizer type: {}",
577                            tokenizer_type
578                        )));
579                    },
580                }
581            }
582        }
583
584        // If no config found, try to load as a HuggingFace tokenizer directly
585        // (in case the path is a model name for hub download)
586        match TokenizerImpl::from_pretrained(path.to_string_lossy().as_ref()) {
587            Ok(tokenizer) => Ok(TokenizerWrapper::HuggingFace(tokenizer)),
588            Err(_) => {
589                // As a last resort, create a basic BPE tokenizer
590                let vocab = std::collections::HashMap::new();
591                let merges = Vec::new();
592                Ok(TokenizerWrapper::BPE(crate::bpe::BPETokenizer::new(
593                    vocab, merges,
594                )))
595            },
596        }
597    }
598
599    /// Save the tokenizer to a directory path
600    pub fn save_pretrained<P: AsRef<Path>>(&self, path: P) -> Result<()> {
601        let path = path.as_ref();
602
603        // Create directory if it doesn't exist
604        std::fs::create_dir_all(path)
605            .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
606
607        match self {
608            TokenizerWrapper::HuggingFace(tokenizer) => {
609                // For HuggingFace tokenizers, use the existing save_to_file method
610                let tokenizer_path = path.join("tokenizer.json");
611                tokenizer.save_to_file(&tokenizer_path)
612            },
613            TokenizerWrapper::WordPiece(_) => {
614                // For WordPiece, create a simple config file indicating the type
615                let config_path = path.join("tokenizer_config.json");
616                let config = serde_json::json!({
617                    "tokenizer_type": "WordPiece",
618                    "model_type": "WordPiece",
619                    "version": "1.0"
620                });
621                std::fs::write(
622                    config_path,
623                    serde_json::to_string_pretty(&config)
624                        .expect("hardcoded JSON config must serialize"),
625                )
626                .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
627
628                // Note: Full WordPiece serialization would require implementing
629                // vocabulary and config serialization for WordPieceTokenizer
630                Ok(())
631            },
632            TokenizerWrapper::BPE(_) => {
633                // For BPE, create a simple config file indicating the type
634                let config_path = path.join("tokenizer_config.json");
635                let config = serde_json::json!({
636                    "tokenizer_type": "BPE",
637                    "model_type": "BPE",
638                    "version": "1.0"
639                });
640                std::fs::write(
641                    config_path,
642                    serde_json::to_string_pretty(&config)
643                        .expect("hardcoded JSON config must serialize"),
644                )
645                .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
646                Ok(())
647            },
648            TokenizerWrapper::Unigram(_) => {
649                // For Unigram, create a simple config file indicating the type
650                let config_path = path.join("tokenizer_config.json");
651                let config = serde_json::json!({
652                    "tokenizer_type": "Unigram",
653                    "model_type": "Unigram",
654                    "version": "1.0"
655                });
656                std::fs::write(
657                    config_path,
658                    serde_json::to_string_pretty(&config)
659                        .expect("hardcoded JSON config must serialize"),
660                )
661                .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
662                Ok(())
663            },
664            TokenizerWrapper::Char(_) => {
665                // For Char, create a simple config file indicating the type
666                let config_path = path.join("tokenizer_config.json");
667                let config = serde_json::json!({
668                    "tokenizer_type": "Character",
669                    "model_type": "Character",
670                    "version": "1.0"
671                });
672                std::fs::write(
673                    config_path,
674                    serde_json::to_string_pretty(&config)
675                        .expect("hardcoded JSON config must serialize"),
676                )
677                .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
678                Ok(())
679            },
680        }
681    }
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687
688    #[test]
689    fn test_tokenized_input_with_offsets_conversion() {
690        let input_with_offsets = TokenizedInputWithOffsets {
691            input_ids: vec![101, 2023, 2003, 102],
692            attention_mask: vec![1, 1, 1, 1],
693            token_type_ids: Some(vec![0, 0, 0, 0]),
694            offset_mapping: Some(vec![(0, 0), (0, 4), (5, 7), (0, 0)]),
695            special_tokens_mask: Some(vec![1, 0, 0, 1]),
696        };
697
698        let regular_input: TokenizedInput = input_with_offsets.into();
699
700        assert_eq!(regular_input.input_ids, vec![101, 2023, 2003, 102]);
701        assert_eq!(regular_input.attention_mask, vec![1, 1, 1, 1]);
702        assert_eq!(regular_input.token_type_ids, Some(vec![0, 0, 0, 0]));
703    }
704
705    #[test]
706    fn test_tokenizer_wrapper_char() {
707        let text = "Hello World!";
708        let tokenizer = crate::char::CharTokenizer::from_text(text, 1000);
709        let wrapper = TokenizerWrapper::Char(tokenizer);
710
711        let encoded = wrapper.encode(text).expect("Encoding failed");
712        let decoded = wrapper.decode(&encoded.input_ids).expect("Decoding failed");
713
714        assert!(!encoded.input_ids.is_empty());
715        assert!(decoded.contains("Hello"));
716        assert!(wrapper.vocab_size() > 0);
717    }
718
719    #[test]
720    fn test_tokenizer_from_json_string() {
721        // Simple minimal tokenizer JSON for testing
722        let json_str = r#"{
723            "version": "1.0",
724            "truncation": null,
725            "padding": null,
726            "added_tokens": [
727                {
728                    "id": 0,
729                    "content": "[PAD]",
730                    "single_word": false,
731                    "lstrip": false,
732                    "rstrip": false,
733                    "normalized": false,
734                    "special": true
735                },
736                {
737                    "id": 1,
738                    "content": "[UNK]",
739                    "single_word": false,
740                    "lstrip": false,
741                    "rstrip": false,
742                    "normalized": false,
743                    "special": true
744                }
745            ],
746            "normalizer": null,
747            "pre_tokenizer": {
748                "type": "Whitespace"
749            },
750            "post_processor": null,
751            "decoder": null,
752            "model": {
753                "type": "WordLevel",
754                "vocab": {
755                    "[PAD]": 0,
756                    "[UNK]": 1,
757                    "hello": 2,
758                    "world": 3
759                },
760                "unk_token": "[UNK]"
761            }
762        }"#;
763
764        let result = TokenizerImpl::from_tokenizer_json(json_str);
765        assert!(result.is_ok());
766
767        if let Ok(tokenizer) = result {
768            assert_eq!(tokenizer.vocab_size(), 4);
769            assert_eq!(tokenizer.token_to_id("hello"), Some(2));
770            assert_eq!(tokenizer.id_to_token(3), Some("world".to_string()));
771        }
772    }
773}