1use crate::alignment::{AlignedSpan, AlignmentConfig, AlignmentEngine, TokenAlignment};
2use std::collections::HashMap;
3use std::path::Path;
4use std::str::FromStr;
5use std::sync::Arc;
6use trustformers_core::errors::{Result, TrustformersError};
8use trustformers_core::traits::{TokenizedInput, Tokenizer};
9use trustformers_core::{Encoding, Tokenizer as HFTokenizer, TokenizerError};
10
11#[derive(Debug, Clone)]
12pub struct TokenizedInputWithOffsets {
13 pub input_ids: Vec<u32>,
14 pub attention_mask: Vec<u8>,
15 pub token_type_ids: Option<Vec<u32>>,
16 pub offset_mapping: Option<Vec<(usize, usize)>>,
17 pub special_tokens_mask: Option<Vec<u8>>,
18}
19
20#[derive(Debug, Clone)]
21pub struct TokenizedInputWithAlignment {
22 pub input_ids: Vec<u32>,
23 pub attention_mask: Vec<u8>,
24 pub token_type_ids: Option<Vec<u32>>,
25 pub offset_mapping: Option<Vec<(usize, usize)>>,
26 pub special_tokens_mask: Option<Vec<u8>>,
27 pub word_alignments: Vec<TokenAlignment>,
28 pub words: Vec<crate::alignment::Word>,
29}
30
31impl From<TokenizedInputWithOffsets> for TokenizedInput {
32 fn from(input: TokenizedInputWithOffsets) -> Self {
33 TokenizedInput {
34 input_ids: input.input_ids,
35 attention_mask: input.attention_mask,
36 token_type_ids: input.token_type_ids,
37 special_tokens_mask: input.special_tokens_mask,
38 offset_mapping: input.offset_mapping,
39 overflowing_tokens: None,
40 }
41 }
42}
43
44impl From<TokenizedInputWithAlignment> for TokenizedInput {
45 fn from(input: TokenizedInputWithAlignment) -> Self {
46 TokenizedInput {
47 input_ids: input.input_ids,
48 attention_mask: input.attention_mask,
49 token_type_ids: input.token_type_ids,
50 special_tokens_mask: input.special_tokens_mask,
51 offset_mapping: input.offset_mapping,
52 overflowing_tokens: None,
53 }
54 }
55}
56
57impl From<TokenizedInputWithAlignment> for TokenizedInputWithOffsets {
58 fn from(input: TokenizedInputWithAlignment) -> Self {
59 TokenizedInputWithOffsets {
60 input_ids: input.input_ids,
61 attention_mask: input.attention_mask,
62 token_type_ids: input.token_type_ids,
63 offset_mapping: input.offset_mapping,
64 special_tokens_mask: input.special_tokens_mask,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct TokenizerImpl {
71 tokenizer: Arc<HFTokenizer>,
72 do_lower_case: bool,
73 max_length: Option<usize>,
74 alignment_engine: Option<AlignmentEngine>,
75}
76
77impl TokenizerImpl {
78 pub fn from_file(path: &Path) -> Result<Self> {
79 let tokenizer = HFTokenizer::from_file(path)
80 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
81 Ok(Self {
82 tokenizer: Arc::new(tokenizer),
83 do_lower_case: false,
84 max_length: Some(512),
85 alignment_engine: None,
86 })
87 }
88
89 pub fn from_pretrained(name: &str) -> Result<Self> {
90 Self::from_pretrained_with_revision(name, None)
91 }
92
93 pub fn from_pretrained_with_revision(name: &str, revision: Option<&str>) -> Result<Self> {
94 let cache_dir = std::env::var("HF_HOME")
97 .or_else(|_| std::env::var("TRANSFORMERS_CACHE"))
98 .unwrap_or_else(|_| {
99 format!(
100 "{}/.cache/huggingface/transformers",
101 std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string())
102 )
103 });
104
105 let tokenizer_path = match revision {
107 Some(rev) => format!("{}/{}/refs/{}/tokenizer.json", cache_dir, name, rev),
108 None => format!("{}/{}/tokenizer.json", cache_dir, name),
109 };
110 let path = Path::new(&tokenizer_path);
111
112 if path.exists() {
113 Self::from_file(path)
114 } else {
115 Err(TrustformersError::other(anyhow::anyhow!(
116 "Model '{}' not found locally. Please download it first or implement model downloading.",
117 name
118 ).to_string()))
119 }
120 }
121
122 pub fn from_tokenizer_json(json_str: &str) -> Result<Self> {
123 let tokenizer = HFTokenizer::from_str(json_str).map_err(|e: TokenizerError| {
124 TrustformersError::other(anyhow::anyhow!(e).to_string())
125 })?;
126 Ok(Self {
127 tokenizer: Arc::new(tokenizer),
128 do_lower_case: false,
129 max_length: Some(512),
130 alignment_engine: None,
131 })
132 }
133
134 pub fn save_to_file(&self, path: &Path) -> Result<()> {
135 let json = self
136 .tokenizer
137 .to_string(false)
138 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
139 std::fs::write(path, json)
140 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
141 Ok(())
142 }
143
144 pub fn to_json(&self) -> Result<String> {
145 self.tokenizer
146 .to_string(false)
147 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))
148 }
149
150 pub fn with_config(mut self, do_lower_case: bool, max_length: Option<usize>) -> Self {
151 self.do_lower_case = do_lower_case;
152 self.max_length = max_length;
153 self
154 }
155
156 pub fn encode_with_offsets(
157 &self,
158 text: &str,
159 add_special_tokens: bool,
160 ) -> Result<TokenizedInputWithOffsets> {
161 let encoding = self
162 .tokenizer
163 .encode(text, add_special_tokens)
164 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
165 Ok(self.encoding_to_tokenized_input_with_offsets(encoding))
166 }
167
168 pub fn encode_pair_with_offsets(
169 &self,
170 text: &str,
171 text2: &str,
172 add_special_tokens: bool,
173 ) -> Result<TokenizedInputWithOffsets> {
174 let encoding = self
175 .tokenizer
176 .encode((text, text2), add_special_tokens)
177 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
178 Ok(self.encoding_to_tokenized_input_with_offsets(encoding))
179 }
180
181 pub fn decode_with_special_tokens(
182 &self,
183 ids: &[u32],
184 skip_special_tokens: bool,
185 ) -> Result<String> {
186 self.tokenizer
187 .decode(ids, skip_special_tokens)
188 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))
189 }
190
191 pub fn get_vocab(&self) -> HashMap<String, u32> {
192 self.tokenizer.get_vocab(false)
193 }
194
195 pub fn token_to_id(&self, token: &str) -> Option<u32> {
196 self.tokenizer.token_to_id(token)
197 }
198
199 pub fn id_to_token(&self, id: u32) -> Option<String> {
200 self.tokenizer.id_to_token(id)
201 }
202
203 pub fn with_alignment_config(mut self, config: AlignmentConfig) -> Self {
205 self.alignment_engine = Some(AlignmentEngine::new(config));
206 self
207 }
208
209 pub fn with_word_alignment(mut self) -> Self {
211 self.alignment_engine = Some(AlignmentEngine::new(AlignmentConfig::default()));
212 self
213 }
214
215 pub fn alignment_engine_mut(&mut self) -> Option<&mut AlignmentEngine> {
217 self.alignment_engine.as_mut()
218 }
219
220 pub fn encode_with_alignment(
222 &mut self,
223 text: &str,
224 add_special_tokens: bool,
225 ) -> Result<TokenizedInputWithAlignment> {
226 let encoding = self
227 .tokenizer
228 .encode(text, add_special_tokens)
229 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
230
231 self.encoding_to_tokenized_input_with_alignment(text, encoding)
232 }
233
234 pub fn encode_pair_with_alignment(
236 &mut self,
237 text: &str,
238 text2: &str,
239 add_special_tokens: bool,
240 ) -> Result<TokenizedInputWithAlignment> {
241 let encoding = self
242 .tokenizer
243 .encode((text, text2), add_special_tokens)
244 .map_err(|e| TrustformersError::other(anyhow::anyhow!(e).to_string()))?;
245
246 let combined_text = format!("{} {}", text, text2);
248 self.encoding_to_tokenized_input_with_alignment(&combined_text, encoding)
249 }
250
251 pub fn extract_aligned_spans(
253 &mut self,
254 text: &str,
255 spans: &[(usize, usize)],
256 add_special_tokens: bool,
257 ) -> Result<Vec<AlignedSpan>> {
258 let input_with_alignment = self.encode_with_alignment(text, add_special_tokens)?;
259
260 if let Some(engine) = &mut self.alignment_engine {
261 engine.extract_spans(text, &input_with_alignment.word_alignments, spans)
262 } else {
263 Err(TrustformersError::other(
264 "Word alignment engine not configured".to_string(),
265 ))
266 }
267 }
268
269 pub fn preserve_entities(
271 &mut self,
272 text: &str,
273 entities: &[(usize, usize, String)],
274 add_special_tokens: bool,
275 ) -> Result<Vec<AlignedSpan>> {
276 let input_with_alignment = self.encode_with_alignment(text, add_special_tokens)?;
277
278 if let Some(engine) = &mut self.alignment_engine {
279 engine.preserve_entities(text, &input_with_alignment.word_alignments, entities)
280 } else {
281 Err(TrustformersError::other(
282 "Word alignment engine not configured".to_string(),
283 ))
284 }
285 }
286
287 pub fn get_word_boundaries_for_token(
289 &self,
290 alignments: &[TokenAlignment],
291 token_index: usize,
292 ) -> Option<(usize, usize)> {
293 if let Some(engine) = &self.alignment_engine {
294 engine.get_word_boundaries_for_token(alignments, token_index)
295 } else {
296 None
297 }
298 }
299
300 pub fn tokens_form_complete_word(
302 &self,
303 alignments: &[TokenAlignment],
304 token_indices: &[usize],
305 ) -> bool {
306 if let Some(engine) = &self.alignment_engine {
307 engine.tokens_form_complete_word(alignments, token_indices)
308 } else {
309 false
310 }
311 }
312
313 fn encoding_to_tokenized_input(&self, encoding: Encoding) -> TokenizedInput {
314 TokenizedInput {
315 input_ids: encoding.get_ids().to_vec(),
316 attention_mask: encoding.get_attention_mask().iter().map(|&x| x as u8).collect(),
317 token_type_ids: if encoding.get_type_ids().is_empty() {
318 None
319 } else {
320 Some(encoding.get_type_ids().to_vec())
321 },
322 special_tokens_mask: None,
323 offset_mapping: None,
324 overflowing_tokens: None,
325 }
326 }
327
328 fn encoding_to_tokenized_input_with_offsets(
329 &self,
330 encoding: Encoding,
331 ) -> TokenizedInputWithOffsets {
332 let offset_mapping = if !encoding.get_offsets().is_empty() {
333 Some(encoding.get_offsets().to_vec())
334 } else {
335 None
336 };
337
338 let special_tokens_mask = if !encoding.get_special_tokens_mask().is_empty() {
339 Some(encoding.get_special_tokens_mask().iter().map(|&x| x as u8).collect())
340 } else {
341 None
342 };
343
344 TokenizedInputWithOffsets {
345 input_ids: encoding.get_ids().to_vec(),
346 attention_mask: encoding.get_attention_mask().iter().map(|&x| x as u8).collect(),
347 token_type_ids: if encoding.get_type_ids().is_empty() {
348 None
349 } else {
350 Some(encoding.get_type_ids().to_vec())
351 },
352 offset_mapping,
353 special_tokens_mask,
354 }
355 }
356
357 fn encoding_to_tokenized_input_with_alignment(
358 &mut self,
359 text: &str,
360 encoding: Encoding,
361 ) -> Result<TokenizedInputWithAlignment> {
362 let offset_mapping = if !encoding.get_offsets().is_empty() {
363 Some(encoding.get_offsets().to_vec())
364 } else {
365 None
366 };
367
368 let special_tokens_mask = if !encoding.get_special_tokens_mask().is_empty() {
369 Some(encoding.get_special_tokens_mask().iter().map(|&x| x as u8).collect())
370 } else {
371 None
372 };
373
374 let (word_alignments, words) = if let Some(engine) = &mut self.alignment_engine {
376 if let Some(ref offsets) = offset_mapping {
377 let alignments =
378 engine.align_tokens_to_words(text, offsets, special_tokens_mask.as_deref())?;
379 let words = engine.extract_words(text);
380 (alignments, words)
381 } else {
382 (Vec::new(), Vec::new())
384 }
385 } else {
386 return Err(TrustformersError::other(
387 "Word alignment engine not configured".to_string(),
388 ));
389 };
390
391 Ok(TokenizedInputWithAlignment {
392 input_ids: encoding.get_ids().to_vec(),
393 attention_mask: encoding.get_attention_mask().iter().map(|&x| x as u8).collect(),
394 token_type_ids: if encoding.get_type_ids().is_empty() {
395 None
396 } else {
397 Some(encoding.get_type_ids().to_vec())
398 },
399 offset_mapping,
400 special_tokens_mask,
401 word_alignments,
402 words,
403 })
404 }
405}
406
407impl Tokenizer for TokenizerImpl {
408 fn encode(&self, text: &str) -> Result<TokenizedInput> {
409 let encoding = self.tokenizer.encode(text, false).map_err(|e| {
410 trustformers_core::errors::TrustformersError::other(anyhow::anyhow!(e).to_string())
411 })?;
412 Ok(self.encoding_to_tokenized_input(encoding))
413 }
414
415 fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
416 let encoding = self.tokenizer.encode((text, text2), false).map_err(|e| {
417 trustformers_core::errors::TrustformersError::other(anyhow::anyhow!(e).to_string())
418 })?;
419 Ok(self.encoding_to_tokenized_input(encoding))
420 }
421
422 fn decode(&self, ids: &[u32]) -> Result<String> {
423 self.tokenizer.decode(ids, false).map_err(|e| {
424 trustformers_core::errors::TrustformersError::other(anyhow::anyhow!(e).to_string())
425 })
426 }
427
428 fn vocab_size(&self) -> usize {
429 self.tokenizer.get_vocab_size(false)
430 }
431
432 fn get_vocab(&self) -> HashMap<String, u32> {
433 self.tokenizer.get_vocab(false)
434 }
435
436 fn token_to_id(&self, token: &str) -> Option<u32> {
437 self.tokenizer.token_to_id(token)
438 }
439
440 fn id_to_token(&self, id: u32) -> Option<String> {
441 self.tokenizer.id_to_token(id)
442 }
443}
444
445#[derive(Debug, Clone)]
446pub enum TokenizerWrapper {
447 WordPiece(crate::wordpiece::WordPieceTokenizer),
448 BPE(crate::bpe::BPETokenizer),
449 Unigram(crate::unigram::UnigramTokenizer),
450 Char(crate::char::CharTokenizer),
451 HuggingFace(TokenizerImpl),
452}
453
454impl Tokenizer for TokenizerWrapper {
455 fn encode(&self, text: &str) -> Result<TokenizedInput> {
456 match self {
457 TokenizerWrapper::WordPiece(t) => t.encode(text),
458 TokenizerWrapper::BPE(t) => t.encode(text),
459 TokenizerWrapper::Unigram(t) => t.encode(text),
460 TokenizerWrapper::Char(t) => t.encode(text),
461 TokenizerWrapper::HuggingFace(t) => t.encode(text),
462 }
463 }
464
465 fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
466 match self {
467 TokenizerWrapper::WordPiece(t) => t.encode_pair(text, text2),
468 TokenizerWrapper::BPE(t) => t.encode_pair(text, text2),
469 TokenizerWrapper::Unigram(t) => t.encode_pair(text, text2),
470 TokenizerWrapper::Char(t) => t.encode_pair(text, text2),
471 TokenizerWrapper::HuggingFace(t) => t.encode_pair(text, text2),
472 }
473 }
474
475 fn decode(&self, ids: &[u32]) -> Result<String> {
476 match self {
477 TokenizerWrapper::WordPiece(t) => t.decode(ids),
478 TokenizerWrapper::BPE(t) => t.decode(ids),
479 TokenizerWrapper::Unigram(t) => t.decode(ids),
480 TokenizerWrapper::Char(t) => t.decode(ids),
481 TokenizerWrapper::HuggingFace(t) => t.decode(ids),
482 }
483 }
484
485 fn vocab_size(&self) -> usize {
486 match self {
487 TokenizerWrapper::WordPiece(t) => t.vocab_size(),
488 TokenizerWrapper::BPE(t) => t.vocab_size(),
489 TokenizerWrapper::Unigram(t) => t.vocab_size(),
490 TokenizerWrapper::Char(t) => t.vocab_size(),
491 TokenizerWrapper::HuggingFace(t) => t.vocab_size(),
492 }
493 }
494
495 fn get_vocab(&self) -> HashMap<String, u32> {
496 match self {
497 TokenizerWrapper::WordPiece(t) => t.get_vocab(),
498 TokenizerWrapper::BPE(t) => t.get_vocab(),
499 TokenizerWrapper::Unigram(t) => t.get_vocab(),
500 TokenizerWrapper::Char(t) => t.get_vocab(),
501 TokenizerWrapper::HuggingFace(t) => t.get_vocab(),
502 }
503 }
504
505 fn token_to_id(&self, token: &str) -> Option<u32> {
506 match self {
507 TokenizerWrapper::WordPiece(t) => t.token_to_id(token),
508 TokenizerWrapper::BPE(t) => t.token_to_id(token),
509 TokenizerWrapper::Unigram(t) => t.token_to_id(token),
510 TokenizerWrapper::Char(t) => t.token_to_id(token),
511 TokenizerWrapper::HuggingFace(t) => t.token_to_id(token),
512 }
513 }
514
515 fn id_to_token(&self, id: u32) -> Option<String> {
516 match self {
517 TokenizerWrapper::WordPiece(t) => t.id_to_token(id),
518 TokenizerWrapper::BPE(t) => t.id_to_token(id),
519 TokenizerWrapper::Unigram(t) => t.id_to_token(id),
520 TokenizerWrapper::Char(t) => t.id_to_token(id),
521 TokenizerWrapper::HuggingFace(t) => t.id_to_token(id),
522 }
523 }
524}
525
526impl TokenizerWrapper {
527 pub fn from_pretrained<P: AsRef<Path>>(model_name_or_path: P) -> Result<Self> {
529 let path = model_name_or_path.as_ref();
530
531 let tokenizer_json_path = path.join("tokenizer.json");
533 if tokenizer_json_path.exists() {
534 let tokenizer = TokenizerImpl::from_file(&tokenizer_json_path)?;
535 return Ok(TokenizerWrapper::HuggingFace(tokenizer));
536 }
537
538 let config_path = path.join("tokenizer_config.json");
540 if config_path.exists() {
541 let config_str = std::fs::read_to_string(&config_path)
542 .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
543 let config: serde_json::Value = serde_json::from_str(&config_str)
544 .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
545
546 if let Some(tokenizer_type) = config.get("tokenizer_type").and_then(|v| v.as_str()) {
547 match tokenizer_type {
548 "WordPiece" => {
549 let vocab = std::collections::HashMap::new();
551 let tokenizer = crate::wordpiece::WordPieceTokenizer::new(vocab, false);
552 return Ok(TokenizerWrapper::WordPiece(tokenizer));
553 },
554 "BPE" => {
555 let vocab = std::collections::HashMap::new();
557 let merges = Vec::new();
558 let tokenizer = crate::bpe::BPETokenizer::new(vocab, merges);
559 return Ok(TokenizerWrapper::BPE(tokenizer));
560 },
561 "Unigram" => {
562 let vocab = std::collections::HashMap::new();
564 let scores = std::collections::HashMap::new();
565 let tokenizer = crate::unigram::UnigramTokenizer::new(vocab, scores)?;
566 return Ok(TokenizerWrapper::Unigram(tokenizer));
567 },
568 "Character" => {
569 let vocab = std::collections::HashMap::new();
571 let tokenizer = crate::char::CharTokenizer::new(vocab);
572 return Ok(TokenizerWrapper::Char(tokenizer));
573 },
574 _ => {
575 return Err(TrustformersError::invalid_input(format!(
576 "Unsupported tokenizer type: {}",
577 tokenizer_type
578 )));
579 },
580 }
581 }
582 }
583
584 match TokenizerImpl::from_pretrained(path.to_string_lossy().as_ref()) {
587 Ok(tokenizer) => Ok(TokenizerWrapper::HuggingFace(tokenizer)),
588 Err(_) => {
589 let vocab = std::collections::HashMap::new();
591 let merges = Vec::new();
592 Ok(TokenizerWrapper::BPE(crate::bpe::BPETokenizer::new(
593 vocab, merges,
594 )))
595 },
596 }
597 }
598
599 pub fn save_pretrained<P: AsRef<Path>>(&self, path: P) -> Result<()> {
601 let path = path.as_ref();
602
603 std::fs::create_dir_all(path)
605 .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
606
607 match self {
608 TokenizerWrapper::HuggingFace(tokenizer) => {
609 let tokenizer_path = path.join("tokenizer.json");
611 tokenizer.save_to_file(&tokenizer_path)
612 },
613 TokenizerWrapper::WordPiece(_) => {
614 let config_path = path.join("tokenizer_config.json");
616 let config = serde_json::json!({
617 "tokenizer_type": "WordPiece",
618 "model_type": "WordPiece",
619 "version": "1.0"
620 });
621 std::fs::write(
622 config_path,
623 serde_json::to_string_pretty(&config)
624 .expect("hardcoded JSON config must serialize"),
625 )
626 .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
627
628 Ok(())
631 },
632 TokenizerWrapper::BPE(_) => {
633 let config_path = path.join("tokenizer_config.json");
635 let config = serde_json::json!({
636 "tokenizer_type": "BPE",
637 "model_type": "BPE",
638 "version": "1.0"
639 });
640 std::fs::write(
641 config_path,
642 serde_json::to_string_pretty(&config)
643 .expect("hardcoded JSON config must serialize"),
644 )
645 .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
646 Ok(())
647 },
648 TokenizerWrapper::Unigram(_) => {
649 let config_path = path.join("tokenizer_config.json");
651 let config = serde_json::json!({
652 "tokenizer_type": "Unigram",
653 "model_type": "Unigram",
654 "version": "1.0"
655 });
656 std::fs::write(
657 config_path,
658 serde_json::to_string_pretty(&config)
659 .expect("hardcoded JSON config must serialize"),
660 )
661 .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
662 Ok(())
663 },
664 TokenizerWrapper::Char(_) => {
665 let config_path = path.join("tokenizer_config.json");
667 let config = serde_json::json!({
668 "tokenizer_type": "Character",
669 "model_type": "Character",
670 "version": "1.0"
671 });
672 std::fs::write(
673 config_path,
674 serde_json::to_string_pretty(&config)
675 .expect("hardcoded JSON config must serialize"),
676 )
677 .map_err(|e| TrustformersError::other(format!("I/O error: {}", e)))?;
678 Ok(())
679 },
680 }
681 }
682}
683
684#[cfg(test)]
685mod tests {
686 use super::*;
687
688 #[test]
689 fn test_tokenized_input_with_offsets_conversion() {
690 let input_with_offsets = TokenizedInputWithOffsets {
691 input_ids: vec![101, 2023, 2003, 102],
692 attention_mask: vec![1, 1, 1, 1],
693 token_type_ids: Some(vec![0, 0, 0, 0]),
694 offset_mapping: Some(vec![(0, 0), (0, 4), (5, 7), (0, 0)]),
695 special_tokens_mask: Some(vec![1, 0, 0, 1]),
696 };
697
698 let regular_input: TokenizedInput = input_with_offsets.into();
699
700 assert_eq!(regular_input.input_ids, vec![101, 2023, 2003, 102]);
701 assert_eq!(regular_input.attention_mask, vec![1, 1, 1, 1]);
702 assert_eq!(regular_input.token_type_ids, Some(vec![0, 0, 0, 0]));
703 }
704
705 #[test]
706 fn test_tokenizer_wrapper_char() {
707 let text = "Hello World!";
708 let tokenizer = crate::char::CharTokenizer::from_text(text, 1000);
709 let wrapper = TokenizerWrapper::Char(tokenizer);
710
711 let encoded = wrapper.encode(text).expect("Encoding failed");
712 let decoded = wrapper.decode(&encoded.input_ids).expect("Decoding failed");
713
714 assert!(!encoded.input_ids.is_empty());
715 assert!(decoded.contains("Hello"));
716 assert!(wrapper.vocab_size() > 0);
717 }
718
719 #[test]
720 fn test_tokenizer_from_json_string() {
721 let json_str = r#"{
723 "version": "1.0",
724 "truncation": null,
725 "padding": null,
726 "added_tokens": [
727 {
728 "id": 0,
729 "content": "[PAD]",
730 "single_word": false,
731 "lstrip": false,
732 "rstrip": false,
733 "normalized": false,
734 "special": true
735 },
736 {
737 "id": 1,
738 "content": "[UNK]",
739 "single_word": false,
740 "lstrip": false,
741 "rstrip": false,
742 "normalized": false,
743 "special": true
744 }
745 ],
746 "normalizer": null,
747 "pre_tokenizer": {
748 "type": "Whitespace"
749 },
750 "post_processor": null,
751 "decoder": null,
752 "model": {
753 "type": "WordLevel",
754 "vocab": {
755 "[PAD]": 0,
756 "[UNK]": 1,
757 "hello": 2,
758 "world": 3
759 },
760 "unk_token": "[UNK]"
761 }
762 }"#;
763
764 let result = TokenizerImpl::from_tokenizer_json(json_str);
765 assert!(result.is_ok());
766
767 if let Ok(tokenizer) = result {
768 assert_eq!(tokenizer.vocab_size(), 4);
769 assert_eq!(tokenizer.token_to_id("hello"), Some(2));
770 assert_eq!(tokenizer.id_to_token(3), Some("world".to_string()));
771 }
772 }
773}