1use crate::albert::AlbertConfig;
20use crate::bart::BartConfig;
21use crate::bert::BertConfig;
22use crate::common::error::RustBertError;
23use crate::deberta::DebertaConfig;
24use crate::deberta_v2::DebertaV2Config;
25use crate::distilbert::DistilBertConfig;
26use crate::electra::ElectraConfig;
27use crate::fnet::FNetConfig;
28use crate::gpt2::Gpt2Config;
29use crate::gpt_j::GptJConfig;
30use crate::gpt_neo::GptNeoConfig;
31use crate::longformer::LongformerConfig;
32use crate::longt5::LongT5Config;
33use crate::m2m_100::M2M100Config;
34use crate::marian::MarianConfig;
35use crate::mbart::MBartConfig;
36use crate::mobilebert::MobileBertConfig;
37use crate::openai_gpt::OpenAiGptConfig;
38use crate::pegasus::PegasusConfig;
39use crate::pipelines::translation::Language;
40use crate::prophetnet::ProphetNetConfig;
41use crate::reformer::ReformerConfig;
42use crate::resources::{Resource, ResourceProvider};
43use crate::roberta::RobertaConfig;
44use crate::t5::T5Config;
45use crate::xlnet::XLNetConfig;
46use crate::Config;
47use rust_tokenizers::tokenizer::{
48 AlbertTokenizer, BertTokenizer, DeBERTaTokenizer, DeBERTaV2Tokenizer, FNetTokenizer,
49 Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer, MarianTokenizer, MultiThreadedTokenizer,
50 NLLBTokenizer, OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer,
51 RobertaTokenizer, T5Tokenizer, Tokenizer, TruncationStrategy, XLMRobertaTokenizer,
52 XLNetTokenizer,
53};
54use rust_tokenizers::vocab::Vocab;
55use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
56use serde::{Deserialize, Serialize};
57use std::collections::{HashMap, HashSet};
58use std::convert::TryFrom;
59
60use std::fmt::Debug;
61
62use std::path::{Path, PathBuf};
63use tch::nn::VarStore;
64use tch::{Device, Kind, Tensor};
65
66#[cfg(feature = "onnx")]
67use crate::pipelines::onnx::ONNXModelConfig;
68
69#[cfg(feature = "hf-tokenizers")]
70use crate::pipelines::hf_tokenizers::HFTokenizer;
71
72#[derive(Debug, Default)]
73pub struct ONNXModelResources {
75 pub encoder_resource: Option<Box<dyn ResourceProvider + Send>>,
77 pub decoder_resource: Option<Box<dyn ResourceProvider + Send>>,
79 pub decoder_with_past_resource: Option<Box<dyn ResourceProvider + Send>>,
81}
82
83#[derive(Debug)]
84pub enum ModelResource {
86 Torch(Box<dyn ResourceProvider + Send>),
87 #[cfg(feature = "onnx")]
88 ONNX(ONNXModelResources),
89}
90
91impl ResourceProvider for ModelResource {
92 fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
93 match self {
94 ModelResource::Torch(ref resource) => resource.get_local_path(),
95 #[cfg(feature = "onnx")]
96 ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
97 }
98 }
99 fn get_resource(&self) -> Result<Resource, RustBertError> {
100 match self {
101 ModelResource::Torch(ref resource) => resource.get_resource(),
102 #[cfg(feature = "onnx")]
103 ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
104 }
105 }
106}
107
108pub struct ONNXLocalPaths {
109 pub encoder_path: Option<PathBuf>,
110 pub decoder_path: Option<PathBuf>,
111 pub decoder_with_past_path: Option<PathBuf>,
112}
113
114impl ModelResource {
115 pub fn get_torch_local_path(&self) -> Result<PathBuf, RustBertError> {
118 match self {
119 ModelResource::Torch(torch_resource) => torch_resource.get_local_path(),
120 #[cfg(feature = "onnx")]
121 _ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the Torch local path but other weights variants were given: {:?}", self)))
122 }
123 }
124
125 #[cfg(feature = "onnx")]
126 pub fn get_onnx_local_paths(&self) -> Result<ONNXLocalPaths, RustBertError> {
127 let (encoder_path, decoder_path, decoder_with_past_path) = match self {
128 ModelResource::ONNX(onnx_model_resources) => Ok((
129 onnx_model_resources
130 .encoder_resource.as_ref()
131 .map(|r| r.get_local_path()),
132 onnx_model_resources
133 .decoder_resource.as_ref()
134 .map(|r| r.get_local_path()),
135 onnx_model_resources
136 .decoder_with_past_resource.as_ref()
137 .map(|r| r.get_local_path()),
138 )),
139 _ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the ONNX local paths but other weights variants were given: {:?}", self)))
140 }?;
141 Ok(ONNXLocalPaths {
142 encoder_path: encoder_path.transpose()?,
143 decoder_path: decoder_path.transpose()?,
144 decoder_with_past_path: decoder_with_past_path.transpose()?,
145 })
146 }
147}
148
149pub(crate) fn get_device(_model_resource: ModelResource, device: Device) -> Device {
150 #[cfg(feature = "onnx")]
151 let device = if let ModelResource::ONNX(_) = _model_resource {
152 Device::Cpu
153 } else {
154 device
155 };
156
157 #[cfg(not(feature = "onnx"))]
158 let device = device;
159 device
160}
161
162#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq, Eq)]
163pub enum ModelType {
165 Bart,
166 #[serde(alias = "bert")]
167 Bert,
168 #[serde(alias = "distilbert")]
169 DistilBert,
170 Deberta,
171 DebertaV2,
172 #[serde(alias = "roberta")]
173 Roberta,
174 XLMRoberta,
175 Electra,
176 Marian,
177 MobileBert,
178 #[serde(alias = "t5")]
179 T5,
180 #[serde(alias = "longt5")]
181 LongT5,
182 #[serde(alias = "albert")]
183 Albert,
184 XLNet,
185 GPT2,
186 GPTJ,
187 OpenAiGpt,
188 Reformer,
189 ProphetNet,
190 Longformer,
191 Pegasus,
192 GPTNeo,
193 MBart,
194 M2M100,
195 #[serde(alias = "m2m100")]
196 NLLB,
197 FNet,
198 #[cfg(feature = "onnx")]
199 ONNX,
200}
201
202pub enum ConfigOption {
204 Bart(BartConfig),
206 Bert(BertConfig),
208 DistilBert(DistilBertConfig),
210 Deberta(DebertaConfig),
212 DebertaV2(DebertaV2Config),
214 Electra(ElectraConfig),
216 Marian(MarianConfig),
218 MobileBert(MobileBertConfig),
220 OpenAiGpt(OpenAiGptConfig),
222 T5(T5Config),
224 LongT5(LongT5Config),
226 Albert(AlbertConfig),
228 XLNet(XLNetConfig),
230 GPT2(Gpt2Config),
232 GPTJ(GptJConfig),
234 Reformer(ReformerConfig),
236 Roberta(RobertaConfig),
238 ProphetNet(ProphetNetConfig),
240 Longformer(LongformerConfig),
242 Pegasus(PegasusConfig),
244 GPTNeo(GptNeoConfig),
246 MBart(MBartConfig),
248 M2M100(M2M100Config),
250 FNet(FNetConfig),
252 #[cfg(feature = "onnx")]
254 ONNX(ONNXModelConfig),
255}
256
257pub enum TokenizerOption {
259 Bert(BertTokenizer),
261 Deberta(DeBERTaTokenizer),
263 DebertaV2(DeBERTaV2Tokenizer),
265 Roberta(RobertaTokenizer),
267 XLMRoberta(XLMRobertaTokenizer),
269 Marian(MarianTokenizer),
271 T5(T5Tokenizer),
273 Albert(AlbertTokenizer),
275 XLNet(XLNetTokenizer),
277 GPT2(Gpt2Tokenizer),
279 OpenAiGpt(OpenAiGptTokenizer),
281 Reformer(ReformerTokenizer),
283 ProphetNet(ProphetNetTokenizer),
285 Pegasus(PegasusTokenizer),
287 MBart50(MBart50Tokenizer),
289 M2M100(M2M100Tokenizer),
291 NLLB(NLLBTokenizer),
293 FNet(FNetTokenizer),
295 Bart(RobertaTokenizer),
297 #[cfg(feature = "hf-tokenizers")]
299 HFTokenizer(HFTokenizer),
300}
301
302impl ConfigOption {
303 pub fn from_file<P: AsRef<Path>>(model_type: ModelType, path: P) -> Self {
305 match model_type {
306 ModelType::Bart => ConfigOption::Bart(BartConfig::from_file(path)),
307 ModelType::Bert => ConfigOption::Bert(BertConfig::from_file(path)),
308 ModelType::Deberta => ConfigOption::Deberta(DebertaConfig::from_file(path)),
309 ModelType::DebertaV2 => ConfigOption::DebertaV2(DebertaV2Config::from_file(path)),
310 ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
311 ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
312 ModelType::Marian => ConfigOption::Marian(MarianConfig::from_file(path)),
313 ModelType::MobileBert => ConfigOption::MobileBert(MobileBertConfig::from_file(path)),
314 ModelType::T5 => ConfigOption::T5(T5Config::from_file(path)),
315 ModelType::LongT5 => ConfigOption::LongT5(LongT5Config::from_file(path)),
316 ModelType::Albert => ConfigOption::Albert(AlbertConfig::from_file(path)),
317 ModelType::XLNet => ConfigOption::XLNet(XLNetConfig::from_file(path)),
318 ModelType::GPT2 => ConfigOption::GPT2(Gpt2Config::from_file(path)),
319 ModelType::GPTJ => ConfigOption::GPTJ(GptJConfig::from_file(path)),
320 ModelType::GPTNeo => ConfigOption::GPTNeo(GptNeoConfig::from_file(path)),
321 ModelType::OpenAiGpt => ConfigOption::OpenAiGpt(OpenAiGptConfig::from_file(path)),
322 ModelType::Reformer => ConfigOption::Reformer(ReformerConfig::from_file(path)),
323 ModelType::ProphetNet => ConfigOption::ProphetNet(ProphetNetConfig::from_file(path)),
324 ModelType::Longformer => ConfigOption::Longformer(LongformerConfig::from_file(path)),
325 ModelType::Pegasus => ConfigOption::Pegasus(PegasusConfig::from_file(path)),
326 ModelType::Roberta | ModelType::XLMRoberta => {
327 ConfigOption::Roberta(RobertaConfig::from_file(path))
328 }
329 ModelType::MBart => ConfigOption::MBart(MBartConfig::from_file(path)),
330 ModelType::M2M100 | ModelType::NLLB => {
331 ConfigOption::M2M100(M2M100Config::from_file(path))
332 }
333 ModelType::FNet => ConfigOption::FNet(FNetConfig::from_file(path)),
334 #[cfg(feature = "onnx")]
335 ModelType::ONNX => ConfigOption::ONNX(ONNXModelConfig::from_file(path)),
336 }
337 }
338
339 pub fn get_label_mapping(&self) -> &HashMap<i64, String> {
340 match self {
341 Self::Bart(config) => config
342 .id2label
343 .as_ref()
344 .expect("No label dictionary (id2label) provided in configuration file"),
345 Self::Bert(config) => config
346 .id2label
347 .as_ref()
348 .expect("No label dictionary (id2label) provided in configuration file"),
349 Self::Deberta(config) => config
350 .id2label
351 .as_ref()
352 .expect("No label dictionary (id2label) provided in configuration file"),
353 Self::DebertaV2(config) => config
354 .id2label
355 .as_ref()
356 .expect("No label dictionary (id2label) provided in configuration file"),
357 Self::DistilBert(config) => config
358 .id2label
359 .as_ref()
360 .expect("No label dictionary (id2label) provided in configuration file"),
361 Self::Electra(config) => config
362 .id2label
363 .as_ref()
364 .expect("No label dictionary (id2label) provided in configuration file"),
365 Self::Marian(config) => config
366 .id2label
367 .as_ref()
368 .expect("No label dictionary (id2label) provided in configuration file"),
369 Self::MobileBert(config) => config
370 .id2label
371 .as_ref()
372 .expect("No label dictionary (id2label) provided in configuration file"),
373 Self::Albert(config) => config
374 .id2label
375 .as_ref()
376 .expect("No label dictionary (id2label) provided in configuration file"),
377 Self::XLNet(config) => config
378 .id2label
379 .as_ref()
380 .expect("No label dictionary (id2label) provided in configuration file"),
381 Self::Reformer(config) => config
382 .id2label
383 .as_ref()
384 .expect("No label dictionary (id2label) provided in configuration file"),
385 Self::ProphetNet(config) => config
386 .id2label
387 .as_ref()
388 .expect("No label dictionary (id2label) provided in configuration file"),
389 Self::Longformer(config) => config
390 .id2label
391 .as_ref()
392 .expect("No label dictionary (id2label) provided in configuration file"),
393 Self::MBart(config) => config
394 .id2label
395 .as_ref()
396 .expect("No label dictionary (id2label) provided in configuration file"),
397 Self::M2M100(config) => config
398 .id2label
399 .as_ref()
400 .expect("No label dictionary (id2label) provided in configuration file"),
401 Self::FNet(config) => config
402 .id2label
403 .as_ref()
404 .expect("No label dictionary (id2label) provided in configuration file"),
405 Self::Roberta(config) => config
406 .id2label
407 .as_ref()
408 .expect("No label dictionary (id2label) provided in configuration file"),
409 #[cfg(feature = "onnx")]
410 Self::ONNX(config) => config
411 .id2label
412 .as_ref()
413 .expect("No label dictionary (id2label) provided in configuration file"),
414 Self::T5(_) => panic!("T5 does not use a label mapping"),
415 Self::LongT5(_) => panic!("LongT5 does not use a label mapping"),
416 Self::OpenAiGpt(_) => panic!("OpenAI GPT does not use a label mapping"),
417 Self::GPT2(_) => panic!("GPT2 does not use a label mapping"),
418 Self::GPTJ(_) => panic!("GPT-J does not use a label mapping"),
419 Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"),
420 Self::Pegasus(_) => panic!("Pegasus does not use a label mapping"),
421 }
422 }
423
424 pub fn get_max_len(&self) -> Option<i64> {
425 match self {
426 Self::Bart(config) => Some(config.max_position_embeddings),
427 Self::Bert(config) => Some(config.max_position_embeddings),
428 Self::Deberta(config) => Some(config.max_position_embeddings),
429 Self::DebertaV2(config) => Some(config.max_position_embeddings),
430 Self::DistilBert(config) => Some(config.max_position_embeddings),
431 Self::Electra(config) => Some(config.max_position_embeddings),
432 Self::Marian(config) => Some(config.max_position_embeddings),
433 Self::MobileBert(config) => Some(config.max_position_embeddings),
434 Self::T5(_) => None,
435 Self::LongT5(_) => None,
436 Self::Albert(config) => Some(config.max_position_embeddings),
437 Self::XLNet(_) => None,
438 Self::GPT2(config) => Some(config.n_positions),
439 Self::GPTJ(config) => Some(config.n_positions),
440 Self::Reformer(config) => Some(config.max_position_embeddings),
441 Self::ProphetNet(config) => Some(config.max_position_embeddings),
442 Self::Longformer(config) => Some(config.max_position_embeddings),
443 Self::Pegasus(config) => Some(config.max_position_embeddings),
444 Self::OpenAiGpt(config) => Some(config.n_positions),
445 Self::GPTNeo(config) => Some(config.max_position_embeddings),
446 Self::MBart(config) => Some(config.max_position_embeddings),
447 Self::M2M100(config) => Some(config.max_position_embeddings),
448 Self::FNet(config) => Some(config.max_position_embeddings),
449 Self::Roberta(config) => Some(config.max_position_embeddings),
450 #[cfg(feature = "onnx")]
451 Self::ONNX(config) => config.max_position_embeddings,
452 }
453 }
454
455 pub fn get_vocab_size(&self) -> i64 {
456 match self {
457 Self::Bart(config) => config.vocab_size,
458 Self::Bert(config) => config.vocab_size,
459 Self::Deberta(config) => config.vocab_size,
460 Self::DebertaV2(config) => config.vocab_size,
461 Self::DistilBert(config) => config.vocab_size,
462 Self::Electra(config) => config.vocab_size,
463 Self::Marian(config) => config.vocab_size,
464 Self::MobileBert(config) => config.vocab_size,
465 Self::T5(config) => config.vocab_size,
466 Self::LongT5(config) => config.vocab_size,
467 Self::Albert(config) => config.vocab_size,
468 Self::XLNet(config) => config.vocab_size,
469 Self::GPT2(config) => config.vocab_size,
470 Self::GPTJ(config) => config.vocab_size,
471 Self::Reformer(config) => config.vocab_size,
472 Self::ProphetNet(config) => config.vocab_size,
473 Self::Longformer(config) => config.vocab_size,
474 Self::Pegasus(config) => config.vocab_size,
475 Self::OpenAiGpt(config) => config.vocab_size,
476 Self::GPTNeo(config) => config.vocab_size,
477 Self::MBart(config) => config.vocab_size,
478 Self::M2M100(config) => config.vocab_size,
479 Self::FNet(config) => config.vocab_size,
480 Self::Roberta(config) => config.vocab_size,
481 #[cfg(feature = "onnx")]
482 Self::ONNX(config) => config.vocab_size,
483 }
484 }
485
486 pub fn get_decoder_start_token_id(&self) -> Option<i64> {
487 match self {
488 Self::Bart(config) => config.decoder_start_token_id,
489 Self::Bert(_) => None,
490 Self::Deberta(_) => None,
491 Self::DebertaV2(_) => None,
492 Self::DistilBert(_) => None,
493 Self::Electra(_) => None,
494 Self::Marian(config) => config.decoder_start_token_id,
495 Self::MobileBert(_) => None,
496 Self::T5(config) => config.decoder_start_token_id,
497 Self::LongT5(config) => config.decoder_start_token_id,
498 Self::Albert(_) => None,
499 Self::XLNet(_) => None,
500 Self::GPT2(config) => config.decoder_start_token_id,
501 Self::GPTJ(config) => config.decoder_start_token_id,
502 Self::Reformer(config) => config.decoder_start_token_id,
503 Self::ProphetNet(config) => config.decoder_start_token_id,
504 Self::Longformer(_) => None,
505 Self::Pegasus(config) => config.decoder_start_token_id,
506 Self::OpenAiGpt(config) => config.decoder_start_token_id,
507 Self::GPTNeo(config) => config.decoder_start_token_id,
508 Self::MBart(config) => config.decoder_start_token_id,
509 Self::M2M100(config) => config.decoder_start_token_id,
510 Self::FNet(config) => config.decoder_start_token_id,
511 Self::Roberta(_) => None,
512 #[cfg(feature = "onnx")]
513 Self::ONNX(config) => config.decoder_start_token_id,
514 }
515 }
516
517 pub fn get_forced_bos_token_id(&self) -> Option<i64> {
518 match self {
519 Self::Bart(config) => config.forced_bos_token_id,
520 Self::Bert(_) => None,
521 Self::Deberta(_) => None,
522 Self::DebertaV2(_) => None,
523 Self::DistilBert(_) => None,
524 Self::Electra(_) => None,
525 Self::Marian(config) => config.forced_bos_token_id,
526 Self::MobileBert(_) => None,
527 Self::T5(config) => config.forced_bos_token_id,
528 Self::LongT5(config) => config.forced_bos_token_id,
529 Self::Albert(_) => None,
530 Self::XLNet(_) => None,
531 Self::GPT2(config) => config.forced_bos_token_id,
532 Self::GPTJ(config) => config.forced_bos_token_id,
533 Self::Reformer(config) => config.forced_bos_token_id,
534 Self::ProphetNet(config) => config.forced_bos_token_id,
535 Self::Longformer(_) => None,
536 Self::Pegasus(config) => config.forced_bos_token_id,
537 Self::OpenAiGpt(config) => config.forced_bos_token_id,
538 Self::GPTNeo(config) => config.forced_bos_token_id,
539 Self::MBart(config) => config.forced_bos_token_id,
540 Self::M2M100(config) => config.forced_bos_token_id,
541 Self::FNet(_) => None,
542 Self::Roberta(_) => None,
543 #[cfg(feature = "onnx")]
544 Self::ONNX(config) => config.forced_bos_token_id,
545 }
546 }
547
548 pub fn get_forced_eos_token_id(&self) -> Option<i64> {
549 match self {
550 Self::Bart(config) => config.forced_eos_token_id,
551 Self::Bert(_) => None,
552 Self::Deberta(_) => None,
553 Self::DebertaV2(_) => None,
554 Self::DistilBert(_) => None,
555 Self::Electra(_) => None,
556 Self::Marian(config) => config.forced_eos_token_id,
557 Self::MobileBert(_) => None,
558 Self::T5(config) => config.forced_eos_token_id,
559 Self::LongT5(config) => config.forced_eos_token_id,
560 Self::Albert(_) => None,
561 Self::XLNet(_) => None,
562 Self::GPT2(config) => config.forced_eos_token_id,
563 Self::GPTJ(config) => config.forced_eos_token_id,
564 Self::Reformer(config) => config.forced_eos_token_id,
565 Self::ProphetNet(config) => config.forced_eos_token_id,
566 Self::Longformer(_) => None,
567 Self::Pegasus(config) => config.forced_eos_token_id,
568 Self::OpenAiGpt(config) => config.forced_eos_token_id,
569 Self::GPTNeo(config) => config.forced_eos_token_id,
570 Self::MBart(config) => config.forced_eos_token_id,
571 Self::M2M100(config) => config.forced_eos_token_id,
572 Self::FNet(_) => None,
573 Self::Roberta(_) => None,
574 #[cfg(feature = "onnx")]
575 Self::ONNX(config) => config.forced_eos_token_id,
576 }
577 }
578}
579
580impl TryFrom<&ConfigOption> for BertConfig {
581 type Error = RustBertError;
582
583 fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
584 match config {
585 ConfigOption::Bert(config) | ConfigOption::Roberta(config) => Ok(config.clone()),
586 _ => Err(RustBertError::InvalidConfigurationError(
587 "You can only supply a BertConfig for Bert or a RobertaConfig for Roberta!"
588 .to_string(),
589 )),
590 }
591 }
592}
593
594impl TryFrom<&ConfigOption> for DistilBertConfig {
595 type Error = RustBertError;
596
597 fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
598 if let ConfigOption::DistilBert(config) = config {
599 Ok(config.clone())
600 } else {
601 Err(RustBertError::InvalidConfigurationError(
602 "You can only supply a DistilBertConfig for DistilBert!".to_string(),
603 ))
604 }
605 }
606}
607
608impl TryFrom<&ConfigOption> for AlbertConfig {
609 type Error = RustBertError;
610
611 fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
612 if let ConfigOption::Albert(config) = config {
613 Ok(config.clone())
614 } else {
615 Err(RustBertError::InvalidConfigurationError(
616 "You can only supply an AlbertConfig for Albert!".to_string(),
617 ))
618 }
619 }
620}
621
622impl TryFrom<&ConfigOption> for T5Config {
623 type Error = RustBertError;
624
625 fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
626 if let ConfigOption::T5(config) = config {
627 Ok(config.clone())
628 } else {
629 Err(RustBertError::InvalidConfigurationError(
630 "You can only supply a T5Config for T5!".to_string(),
631 ))
632 }
633 }
634}
635
636impl TokenizerOption {
637 pub fn from_file(
639 model_type: ModelType,
640 vocab_path: &str,
641 merges_path: Option<&str>,
642 lower_case: bool,
643 strip_accents: impl Into<Option<bool>>,
644 add_prefix_space: impl Into<Option<bool>>,
645 ) -> Result<Self, RustBertError> {
646 let strip_accents = strip_accents.into();
647 let add_prefix_space = add_prefix_space.into();
648
649 let tokenizer = match model_type {
650 ModelType::Bert
651 | ModelType::DistilBert
652 | ModelType::Electra
653 | ModelType::MobileBert => {
654 if add_prefix_space.is_some() {
655 return Err(RustBertError::InvalidConfigurationError(
656 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
657 add_prefix_space.unwrap(),
658 model_type)));
659 }
660 TokenizerOption::Bert(BertTokenizer::from_file(
661 vocab_path,
662 lower_case,
663 strip_accents.unwrap_or(lower_case),
664 )?)
665 }
666 ModelType::Deberta => {
667 if strip_accents.is_some() {
668 return Err(RustBertError::InvalidConfigurationError(format!(
669 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
670 strip_accents.unwrap(),
671 model_type
672 )));
673 }
674 if add_prefix_space.is_some() {
675 return Err(RustBertError::InvalidConfigurationError(
676 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
677 add_prefix_space.unwrap(),
678 model_type)));
679 }
680 TokenizerOption::Deberta(DeBERTaTokenizer::from_file(
681 vocab_path,
682 merges_path.expect("No merges specified!"),
683 lower_case,
684 )?)
685 }
686 ModelType::DebertaV2 => TokenizerOption::DebertaV2(DeBERTaV2Tokenizer::from_file(
687 vocab_path,
688 lower_case,
689 strip_accents.unwrap_or(false),
690 add_prefix_space.unwrap_or(false),
691 )?),
692 ModelType::Roberta | ModelType::Longformer => {
693 if strip_accents.is_some() {
694 return Err(RustBertError::InvalidConfigurationError(format!(
695 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
696 strip_accents.unwrap(),
697 model_type
698 )));
699 }
700 TokenizerOption::Roberta(RobertaTokenizer::from_file(
701 vocab_path,
702 merges_path.expect("No merges specified!"),
703 lower_case,
704 add_prefix_space.unwrap_or(false),
705 )?)
706 }
707 ModelType::Bart => {
708 if strip_accents.is_some() {
709 return Err(RustBertError::InvalidConfigurationError(format!(
710 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
711 strip_accents.unwrap(),
712 model_type
713 )));
714 }
715 TokenizerOption::Bart(RobertaTokenizer::from_file(
716 vocab_path,
717 merges_path.expect("No merges specified!"),
718 lower_case,
719 add_prefix_space.unwrap_or(false),
720 )?)
721 }
722 ModelType::Marian => {
723 if strip_accents.is_some() {
724 return Err(RustBertError::InvalidConfigurationError(format!(
725 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
726 strip_accents.unwrap(),
727 model_type
728 )));
729 }
730 if add_prefix_space.is_some() {
731 return Err(RustBertError::InvalidConfigurationError(
732 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
733 add_prefix_space.unwrap(),
734 model_type)));
735 }
736 TokenizerOption::Marian(MarianTokenizer::from_files(
737 vocab_path,
738 merges_path.expect("No merges specified!"),
739 lower_case,
740 )?)
741 }
742 ModelType::T5 | ModelType::LongT5 => {
743 if strip_accents.is_some() {
744 return Err(RustBertError::InvalidConfigurationError(format!(
745 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
746 strip_accents.unwrap(),
747 model_type
748 )));
749 }
750 if add_prefix_space.is_some() {
751 return Err(RustBertError::InvalidConfigurationError(
752 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
753 add_prefix_space.unwrap(),
754 model_type)));
755 }
756 TokenizerOption::T5(T5Tokenizer::from_file(vocab_path, lower_case)?)
757 }
758 ModelType::XLMRoberta => {
759 if strip_accents.is_some() {
760 return Err(RustBertError::InvalidConfigurationError(format!(
761 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
762 strip_accents.unwrap(),
763 model_type
764 )));
765 }
766 if add_prefix_space.is_some() {
767 return Err(RustBertError::InvalidConfigurationError(
768 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
769 add_prefix_space.unwrap(),
770 model_type)));
771 }
772 TokenizerOption::XLMRoberta(XLMRobertaTokenizer::from_file(vocab_path, lower_case)?)
773 }
774 ModelType::Albert => {
775 if strip_accents.is_some() {
776 return Err(RustBertError::InvalidConfigurationError(format!(
777 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
778 strip_accents.unwrap(),
779 model_type
780 )));
781 }
782 TokenizerOption::Albert(AlbertTokenizer::from_file(
783 vocab_path,
784 lower_case,
785 strip_accents.unwrap_or(lower_case),
786 )?)
787 }
788 ModelType::XLNet => {
789 if add_prefix_space.is_some() {
790 return Err(RustBertError::InvalidConfigurationError(
791 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
792 add_prefix_space.unwrap(),
793 model_type)));
794 }
795 TokenizerOption::XLNet(XLNetTokenizer::from_file(
796 vocab_path,
797 lower_case,
798 strip_accents.unwrap_or(false),
799 )?)
800 }
801 ModelType::Reformer => {
802 if add_prefix_space.is_some() {
803 return Err(RustBertError::InvalidConfigurationError(
804 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
805 add_prefix_space.unwrap(),
806 model_type)));
807 }
808 if strip_accents.is_some() {
809 return Err(RustBertError::InvalidConfigurationError(
810 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
811 add_prefix_space.unwrap(),
812 model_type)));
813 }
814 TokenizerOption::Reformer(ReformerTokenizer::from_file(vocab_path, lower_case)?)
815 }
816 ModelType::GPT2 | ModelType::GPTNeo | ModelType::GPTJ => {
817 TokenizerOption::GPT2(Gpt2Tokenizer::from_file(
818 vocab_path,
819 merges_path.expect("No merges specified!"),
820 lower_case,
821 )?)
822 }
823 ModelType::OpenAiGpt => TokenizerOption::OpenAiGpt(OpenAiGptTokenizer::from_file(
824 vocab_path,
825 merges_path.expect("No merges specified!"),
826 lower_case,
827 )?),
828 ModelType::ProphetNet => {
829 if add_prefix_space.is_some() {
830 return Err(RustBertError::InvalidConfigurationError(
831 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
832 add_prefix_space.unwrap(),
833 model_type)));
834 }
835 TokenizerOption::ProphetNet(ProphetNetTokenizer::from_file(
836 vocab_path,
837 lower_case,
838 strip_accents.unwrap_or(lower_case),
839 )?)
840 }
841 ModelType::Pegasus => {
842 if add_prefix_space.is_some() {
843 return Err(RustBertError::InvalidConfigurationError(
844 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
845 add_prefix_space.unwrap(),
846 model_type)));
847 }
848 if strip_accents.is_some() {
849 return Err(RustBertError::InvalidConfigurationError(format!(
850 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
851 strip_accents.unwrap(),
852 model_type
853 )));
854 }
855 TokenizerOption::Pegasus(PegasusTokenizer::from_file(vocab_path, lower_case)?)
856 }
857 ModelType::MBart => {
858 if add_prefix_space.is_some() {
859 return Err(RustBertError::InvalidConfigurationError(
860 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
861 add_prefix_space.unwrap(),
862 model_type)));
863 }
864 if strip_accents.is_some() {
865 return Err(RustBertError::InvalidConfigurationError(format!(
866 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
867 strip_accents.unwrap(),
868 model_type
869 )));
870 }
871 TokenizerOption::MBart50(MBart50Tokenizer::from_file(vocab_path, lower_case)?)
872 }
873 ModelType::M2M100 => {
874 if add_prefix_space.is_some() {
875 return Err(RustBertError::InvalidConfigurationError(
876 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
877 add_prefix_space.unwrap(),
878 model_type)));
879 }
880 if strip_accents.is_some() {
881 return Err(RustBertError::InvalidConfigurationError(format!(
882 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
883 strip_accents.unwrap(),
884 model_type
885 )));
886 }
887 TokenizerOption::M2M100(M2M100Tokenizer::from_files(
888 vocab_path,
889 merges_path.expect("No merges specified!"),
890 lower_case,
891 )?)
892 }
893 ModelType::NLLB => {
894 if add_prefix_space.is_some() {
895 return Err(RustBertError::InvalidConfigurationError(
896 format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
897 add_prefix_space.unwrap(),
898 model_type)));
899 }
900 if strip_accents.is_some() {
901 return Err(RustBertError::InvalidConfigurationError(format!(
902 "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
903 strip_accents.unwrap(),
904 model_type
905 )));
906 }
907 TokenizerOption::NLLB(NLLBTokenizer::from_files(
908 vocab_path,
909 merges_path.expect("No merges specified."),
910 )?)
911 }
912 ModelType::FNet => TokenizerOption::FNet(FNetTokenizer::from_file(
913 vocab_path,
914 lower_case,
915 strip_accents.unwrap_or(false),
916 )?),
917 #[cfg(feature = "onnx")]
918 ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
919 "Default Tokenizer not defined for generic ONNX models.".to_string(),
920 ))?,
921 };
922 Ok(tokenizer)
923 }
924
925 #[cfg(feature = "hf-tokenizers")]
926 pub fn from_hf_tokenizer_file<P: AsRef<Path>, S: AsRef<Path>>(
927 tokenizer_file: P,
928 special_token_map: S,
929 ) -> Result<Self, RustBertError> {
930 let hf_tokenizer = HFTokenizer::from_file(tokenizer_file, special_token_map)?;
931 Ok(TokenizerOption::HFTokenizer(hf_tokenizer))
932 }
933
934 pub fn encode_list<S>(
936 &self,
937 text_list: &[S],
938 max_len: usize,
939 truncation_strategy: &TruncationStrategy,
940 stride: usize,
941 ) -> Vec<TokenizedInput>
942 where
943 S: AsRef<str> + Send + Sync,
944 {
945 match *self {
946 Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
947 tokenizer,
948 text_list,
949 max_len,
950 truncation_strategy,
951 stride,
952 ),
953 Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
954 tokenizer,
955 text_list,
956 max_len,
957 truncation_strategy,
958 stride,
959 ),
960 Self::DebertaV2(ref tokenizer) => MultiThreadedTokenizer::encode_list(
961 tokenizer,
962 text_list,
963 max_len,
964 truncation_strategy,
965 stride,
966 ),
967 Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
968 tokenizer,
969 text_list,
970 max_len,
971 truncation_strategy,
972 stride,
973 ),
974 Self::Bart(ref tokenizer) => MultiThreadedTokenizer::encode_list(
975 tokenizer,
976 text_list,
977 max_len,
978 truncation_strategy,
979 stride,
980 ),
981 Self::Marian(ref tokenizer) => MultiThreadedTokenizer::encode_list(
982 tokenizer,
983 text_list,
984 max_len,
985 truncation_strategy,
986 stride,
987 ),
988 Self::T5(ref tokenizer) => MultiThreadedTokenizer::encode_list(
989 tokenizer,
990 text_list,
991 max_len,
992 truncation_strategy,
993 stride,
994 ),
995 Self::XLMRoberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
996 tokenizer,
997 text_list,
998 max_len,
999 truncation_strategy,
1000 stride,
1001 ),
1002 Self::Albert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1003 tokenizer,
1004 text_list,
1005 max_len,
1006 truncation_strategy,
1007 stride,
1008 ),
1009 Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1010 tokenizer,
1011 text_list,
1012 max_len,
1013 truncation_strategy,
1014 stride,
1015 ),
1016 Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1017 tokenizer,
1018 text_list,
1019 max_len,
1020 truncation_strategy,
1021 stride,
1022 ),
1023 Self::OpenAiGpt(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1024 tokenizer,
1025 text_list,
1026 max_len,
1027 truncation_strategy,
1028 stride,
1029 ),
1030 Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1031 tokenizer,
1032 text_list,
1033 max_len,
1034 truncation_strategy,
1035 stride,
1036 ),
1037 Self::ProphetNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1038 tokenizer,
1039 text_list,
1040 max_len,
1041 truncation_strategy,
1042 stride,
1043 ),
1044 Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1045 tokenizer,
1046 text_list,
1047 max_len,
1048 truncation_strategy,
1049 stride,
1050 ),
1051 Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1052 tokenizer,
1053 text_list,
1054 max_len,
1055 truncation_strategy,
1056 stride,
1057 ),
1058 Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1059 tokenizer,
1060 text_list,
1061 max_len,
1062 truncation_strategy,
1063 stride,
1064 ),
1065 Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1066 tokenizer,
1067 text_list,
1068 max_len,
1069 truncation_strategy,
1070 stride,
1071 ),
1072 Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1073 tokenizer,
1074 text_list,
1075 max_len,
1076 truncation_strategy,
1077 stride,
1078 ),
1079 #[cfg(feature = "hf-tokenizers")]
1080 Self::HFTokenizer(ref tokenizer) => tokenizer.encode_list(text_list).unwrap(),
1081 }
1082 }
1083
1084 pub fn encode_pair_list(
1086 &self,
1087 text_pair_list: &[(&str, &str)],
1088 max_len: usize,
1089 truncation_strategy: &TruncationStrategy,
1090 stride: usize,
1091 ) -> Vec<TokenizedInput> {
1092 match *self {
1093 Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1094 tokenizer,
1095 text_pair_list,
1096 max_len,
1097 truncation_strategy,
1098 stride,
1099 ),
1100 Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1101 tokenizer,
1102 text_pair_list,
1103 max_len,
1104 truncation_strategy,
1105 stride,
1106 ),
1107 Self::DebertaV2(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1108 tokenizer,
1109 text_pair_list,
1110 max_len,
1111 truncation_strategy,
1112 stride,
1113 ),
1114 Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1115 tokenizer,
1116 text_pair_list,
1117 max_len,
1118 truncation_strategy,
1119 stride,
1120 ),
1121 Self::Bart(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1122 tokenizer,
1123 text_pair_list,
1124 max_len,
1125 truncation_strategy,
1126 stride,
1127 ),
1128 Self::Marian(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1129 tokenizer,
1130 text_pair_list,
1131 max_len,
1132 truncation_strategy,
1133 stride,
1134 ),
1135 Self::T5(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1136 tokenizer,
1137 text_pair_list,
1138 max_len,
1139 truncation_strategy,
1140 stride,
1141 ),
1142 Self::XLMRoberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1143 tokenizer,
1144 text_pair_list,
1145 max_len,
1146 truncation_strategy,
1147 stride,
1148 ),
1149 Self::Albert(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1150 tokenizer,
1151 text_pair_list,
1152 max_len,
1153 truncation_strategy,
1154 stride,
1155 ),
1156 Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1157 tokenizer,
1158 text_pair_list,
1159 max_len,
1160 truncation_strategy,
1161 stride,
1162 ),
1163 Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1164 tokenizer,
1165 text_pair_list,
1166 max_len,
1167 truncation_strategy,
1168 stride,
1169 ),
1170 Self::OpenAiGpt(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1171 tokenizer,
1172 text_pair_list,
1173 max_len,
1174 truncation_strategy,
1175 stride,
1176 ),
1177 Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1178 tokenizer,
1179 text_pair_list,
1180 max_len,
1181 truncation_strategy,
1182 stride,
1183 ),
1184 Self::ProphetNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1185 tokenizer,
1186 text_pair_list,
1187 max_len,
1188 truncation_strategy,
1189 stride,
1190 ),
1191 Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1192 tokenizer,
1193 text_pair_list,
1194 max_len,
1195 truncation_strategy,
1196 stride,
1197 ),
1198 Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1199 tokenizer,
1200 text_pair_list,
1201 max_len,
1202 truncation_strategy,
1203 stride,
1204 ),
1205 Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1206 tokenizer,
1207 text_pair_list,
1208 max_len,
1209 truncation_strategy,
1210 stride,
1211 ),
1212 Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1213 tokenizer,
1214 text_pair_list,
1215 max_len,
1216 truncation_strategy,
1217 stride,
1218 ),
1219 Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1220 tokenizer,
1221 text_pair_list,
1222 max_len,
1223 truncation_strategy,
1224 stride,
1225 ),
1226 #[cfg(feature = "hf-tokenizers")]
1227 Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair_list(text_pair_list).unwrap(),
1228 }
1229 }
1230
1231 pub fn encode_pair(
1233 &self,
1234 text_1: &str,
1235 text_2: Option<&str>,
1236 max_len: usize,
1237 truncation_strategy: &TruncationStrategy,
1238 stride: usize,
1239 ) -> TokenizedInput {
1240 match *self {
1241 Self::Bert(ref tokenizer) => {
1242 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1243 }
1244 Self::Deberta(ref tokenizer) => {
1245 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1246 }
1247 Self::DebertaV2(ref tokenizer) => {
1248 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1249 }
1250 Self::Roberta(ref tokenizer) => {
1251 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1252 }
1253 Self::Bart(ref tokenizer) => {
1254 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1255 }
1256 Self::Marian(ref tokenizer) => {
1257 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1258 }
1259 Self::T5(ref tokenizer) => {
1260 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1261 }
1262 Self::XLMRoberta(ref tokenizer) => {
1263 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1264 }
1265 Self::Albert(ref tokenizer) => {
1266 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1267 }
1268 Self::XLNet(ref tokenizer) => {
1269 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1270 }
1271 Self::GPT2(ref tokenizer) => {
1272 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1273 }
1274 Self::OpenAiGpt(ref tokenizer) => {
1275 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1276 }
1277 Self::Reformer(ref tokenizer) => {
1278 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1279 }
1280 Self::ProphetNet(ref tokenizer) => {
1281 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1282 }
1283 Self::Pegasus(ref tokenizer) => {
1284 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1285 }
1286 Self::MBart50(ref tokenizer) => {
1287 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1288 }
1289 Self::M2M100(ref tokenizer) => {
1290 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1291 }
1292 Self::NLLB(ref tokenizer) => {
1293 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1294 }
1295 Self::FNet(ref tokenizer) => {
1296 tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1297 }
1298 #[cfg(feature = "hf-tokenizers")]
1299 Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair(text_1, text_2).unwrap(),
1300 }
1301 }
1302
1303 pub fn tokenize(&self, text: &str) -> Vec<String> {
1305 match *self {
1306 Self::Bert(ref tokenizer) => tokenizer.tokenize(text),
1307 Self::Deberta(ref tokenizer) => tokenizer.tokenize(text),
1308 Self::DebertaV2(ref tokenizer) => tokenizer.tokenize(text),
1309 Self::Roberta(ref tokenizer) => tokenizer.tokenize(text),
1310 Self::Bart(ref tokenizer) => tokenizer.tokenize(text),
1311 Self::Marian(ref tokenizer) => tokenizer.tokenize(text),
1312 Self::T5(ref tokenizer) => tokenizer.tokenize(text),
1313 Self::XLMRoberta(ref tokenizer) => tokenizer.tokenize(text),
1314 Self::Albert(ref tokenizer) => tokenizer.tokenize(text),
1315 Self::XLNet(ref tokenizer) => tokenizer.tokenize(text),
1316 Self::GPT2(ref tokenizer) => tokenizer.tokenize(text),
1317 Self::OpenAiGpt(ref tokenizer) => tokenizer.tokenize(text),
1318 Self::Reformer(ref tokenizer) => tokenizer.tokenize(text),
1319 Self::ProphetNet(ref tokenizer) => tokenizer.tokenize(text),
1320 Self::Pegasus(ref tokenizer) => tokenizer.tokenize(text),
1321 Self::MBart50(ref tokenizer) => tokenizer.tokenize(text),
1322 Self::M2M100(ref tokenizer) => tokenizer.tokenize(text),
1323 Self::NLLB(ref tokenizer) => tokenizer.tokenize(text),
1324 Self::FNet(ref tokenizer) => tokenizer.tokenize(text),
1325 #[cfg(feature = "hf-tokenizers")]
1326 Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize(text),
1327 }
1328 }
1329
1330 pub fn tokenize_with_offsets(&self, text: &str) -> TokensWithOffsets {
1332 match *self {
1333 Self::Bert(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1334 Self::Deberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1335 Self::DebertaV2(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1336 Self::Roberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1337 Self::Bart(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1338 Self::Marian(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1339 Self::T5(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1340 Self::XLMRoberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1341 Self::Albert(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1342 Self::XLNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1343 Self::GPT2(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1344 Self::OpenAiGpt(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1345 Self::Reformer(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1346 Self::ProphetNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1347 Self::Pegasus(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1348 Self::MBart50(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1349 Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1350 Self::NLLB(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1351 Self::FNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1352 #[cfg(feature = "hf-tokenizers")]
1353 Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1354 }
1355 }
1356
1357 pub fn tokenize_list<S>(&self, text: &[S]) -> Vec<Vec<String>>
1359 where
1360 S: AsRef<str> + Send + Sync,
1361 {
1362 match *self {
1363 Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1364 Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1365 Self::DebertaV2(ref tokenizer) => {
1366 MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1367 }
1368 Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1369 Self::Bart(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1370 Self::Marian(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1371 Self::T5(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1372 Self::XLMRoberta(ref tokenizer) => {
1373 MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1374 }
1375 Self::Albert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1376 Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1377 Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1378 Self::OpenAiGpt(ref tokenizer) => {
1379 MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1380 }
1381 Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1382 Self::ProphetNet(ref tokenizer) => {
1383 MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1384 }
1385 Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1386 Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1387 Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1388 Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1389 Self::FNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1390 #[cfg(feature = "hf-tokenizers")]
1391 Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_list(text),
1392 }
1393 }
1394
1395 pub fn decode(
1397 &self,
1398 token_ids: &[i64],
1399 skip_special_tokens: bool,
1400 clean_up_tokenization_spaces: bool,
1401 ) -> String {
1402 match *self {
1403 Self::Bert(ref tokenizer) => {
1404 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1405 }
1406 Self::Deberta(ref tokenizer) => {
1407 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1408 }
1409 Self::DebertaV2(ref tokenizer) => {
1410 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1411 }
1412 Self::Roberta(ref tokenizer) => {
1413 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1414 }
1415 Self::Bart(ref tokenizer) => {
1416 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1417 }
1418 Self::Marian(ref tokenizer) => {
1419 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1420 }
1421 Self::T5(ref tokenizer) => {
1422 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1423 }
1424 Self::XLMRoberta(ref tokenizer) => {
1425 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1426 }
1427 Self::Albert(ref tokenizer) => {
1428 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1429 }
1430 Self::XLNet(ref tokenizer) => {
1431 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1432 }
1433 Self::GPT2(ref tokenizer) => {
1434 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1435 }
1436 Self::OpenAiGpt(ref tokenizer) => {
1437 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1438 }
1439 Self::Reformer(ref tokenizer) => {
1440 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1441 }
1442 Self::ProphetNet(ref tokenizer) => {
1443 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1444 }
1445 Self::Pegasus(ref tokenizer) => {
1446 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1447 }
1448 Self::MBart50(ref tokenizer) => {
1449 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1450 }
1451 Self::M2M100(ref tokenizer) => {
1452 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1453 }
1454 Self::NLLB(ref tokenizer) => {
1455 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1456 }
1457 Self::FNet(ref tokenizer) => {
1458 tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1459 }
1460 #[cfg(feature = "hf-tokenizers")]
1461 Self::HFTokenizer(ref tokenizer) => tokenizer.decode(token_ids, skip_special_tokens),
1462 }
1463 }
1464
1465 pub fn build_input_with_special_tokens(
1467 &self,
1468 token_ids_with_offsets_1: TokenIdsWithOffsets,
1469 token_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
1470 ) -> TokenizedInput {
1471 let token_ids_with_special_tokens = match *self {
1472 Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1473 token_ids_with_offsets_1,
1474 token_ids_with_offsets_2,
1475 ),
1476 Self::Deberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1477 token_ids_with_offsets_1,
1478 token_ids_with_offsets_2,
1479 ),
1480 Self::DebertaV2(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1481 token_ids_with_offsets_1,
1482 token_ids_with_offsets_2,
1483 ),
1484 Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1485 token_ids_with_offsets_1,
1486 token_ids_with_offsets_2,
1487 ),
1488 Self::Bart(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1489 token_ids_with_offsets_1,
1490 token_ids_with_offsets_2,
1491 ),
1492 Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1493 token_ids_with_offsets_1,
1494 token_ids_with_offsets_2,
1495 ),
1496 Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1497 token_ids_with_offsets_1,
1498 token_ids_with_offsets_2,
1499 ),
1500 Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1501 token_ids_with_offsets_1,
1502 token_ids_with_offsets_2,
1503 ),
1504 Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1505 token_ids_with_offsets_1,
1506 token_ids_with_offsets_2,
1507 ),
1508 Self::XLNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1509 token_ids_with_offsets_1,
1510 token_ids_with_offsets_2,
1511 ),
1512 Self::GPT2(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1513 token_ids_with_offsets_1,
1514 token_ids_with_offsets_2,
1515 ),
1516 Self::OpenAiGpt(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1517 token_ids_with_offsets_1,
1518 token_ids_with_offsets_2,
1519 ),
1520 Self::Reformer(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1521 token_ids_with_offsets_1,
1522 token_ids_with_offsets_2,
1523 ),
1524 Self::ProphetNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1525 token_ids_with_offsets_1,
1526 token_ids_with_offsets_2,
1527 ),
1528 Self::Pegasus(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1529 token_ids_with_offsets_1,
1530 token_ids_with_offsets_2,
1531 ),
1532 Self::MBart50(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1533 token_ids_with_offsets_1,
1534 token_ids_with_offsets_2,
1535 ),
1536 Self::M2M100(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1537 token_ids_with_offsets_1,
1538 token_ids_with_offsets_2,
1539 ),
1540 Self::NLLB(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1541 token_ids_with_offsets_1,
1542 token_ids_with_offsets_2,
1543 ),
1544 Self::FNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1545 token_ids_with_offsets_1,
1546 token_ids_with_offsets_2,
1547 ),
1548 #[cfg(feature = "hf-tokenizers")]
1549 Self::HFTokenizer(ref tokenizer) => {
1550 return tokenizer.build_input_with_special_tokens(
1551 token_ids_with_offsets_1,
1552 token_ids_with_offsets_2,
1553 )
1554 }
1555 };
1556 TokenizedInput {
1557 token_ids: token_ids_with_special_tokens.token_ids,
1558 segment_ids: token_ids_with_special_tokens.segment_ids,
1559 special_tokens_mask: token_ids_with_special_tokens.special_tokens_mask,
1560 overflowing_tokens: vec![],
1561 num_truncated_tokens: 0,
1562 token_offsets: token_ids_with_special_tokens.token_offsets,
1563 reference_offsets: token_ids_with_special_tokens.reference_offsets,
1564 mask: token_ids_with_special_tokens.mask,
1565 }
1566 }
1567
1568 pub fn get_prefix_and_forced_bos_id(
1570 &self,
1571 source_language: Option<&Language>,
1572 target_language: Option<&Language>,
1573 supported_source_languages: &HashSet<Language>,
1574 supported_target_languages: &HashSet<Language>,
1575 ) -> Result<(Option<String>, Option<i64>), RustBertError> {
1576 if let Some(source_language) = source_language {
1577 if !supported_source_languages.contains(source_language) {
1578 return Err(RustBertError::ValueError(format!(
1579 "{source_language} not in list of supported languages: {supported_source_languages:?}",
1580 )));
1581 }
1582 }
1583
1584 if let Some(target_language) = target_language {
1585 if !supported_target_languages.contains(target_language) {
1586 return Err(RustBertError::ValueError(format!(
1587 "{target_language} not in list of supported languages: {supported_target_languages:?}"
1588 )));
1589 }
1590 }
1591
1592 Ok(match *self {
1593 Self::Marian(_) => {
1594 if supported_target_languages.len() > 1 {
1595 (
1596 Some(format!(
1597 ">>{}<< ",
1598 target_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
1599 "Missing target language for Marian \
1600 (multiple languages supported by model: {supported_target_languages:?}, \
1601 need to specify target language)",
1602 )))?
1603 )),
1604 None,
1605 )
1606 } else {
1607 (None, None)
1608 }
1609 }
1610 Self::T5(_) => (
1611 Some(format!(
1612 "translate {} to {}:",
1613 source_language.ok_or_else(|| RustBertError::ValueError(
1614 "Missing source language for T5".to_string(),
1615 ))?,
1616 target_language.ok_or_else(|| RustBertError::ValueError(
1617 "Missing target language for T5".to_string(),
1618 ))?,
1619 )),
1620 None,
1621 ),
1622 Self::MBart50(_) => {
1623 (
1624 Some(format!(
1625 ">>{}<< ",
1626 source_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
1627 "Missing source language for MBart\
1628 (multiple languages supported by model: {supported_source_languages:?}, \
1629 need to specify target language)"
1630 )))?
1631 )),
1632 if let Some(target_language) = target_language {
1633 Some(
1634 self.convert_tokens_to_ids(&[format!(
1635 ">>{}<<",
1636 target_language.get_iso_639_1_code().ok_or_else(|| {
1637 RustBertError::ValueError(format!(
1638 "This language has no ISO639-I code. Languages supported by model: {supported_source_languages:?}."
1639 ))
1640 })?
1641 )])[0],
1642 )
1643 } else {
1644 return Err(RustBertError::ValueError(format!(
1645 "Missing target language for MBart\
1646 (multiple languages supported by model: {supported_target_languages:?}, \
1647 need to specify target language)"
1648 )));
1649 },
1650 )
1651 }
1652 Self::M2M100(_) => (
1653 Some(match source_language {
1654 Some(value) => {
1655 let language_code = value.get_iso_639_1_code().ok_or_else(|| {
1656 RustBertError::ValueError(format!(
1657 "This language has no ISO639-I language code representation. \
1658 languages supported by the model: {supported_target_languages:?}"
1659 ))
1660 })?;
1661 match language_code.len() {
1662 2 => format!(">>{language_code}.<< "),
1663 3 => format!(">>{language_code}<< "),
1664 _ => {
1665 return Err(RustBertError::ValueError(
1666 "Invalid ISO 639-I code".to_string(),
1667 ));
1668 }
1669 }
1670 }
1671 None => {
1672 return Err(RustBertError::ValueError(format!(
1673 "Missing source language for M2M100 \
1674 (multiple languages supported by model: {supported_source_languages:?}, \
1675 need to specify target language)"
1676 )));
1677 }
1678 }),
1679 if let Some(target_language) = target_language {
1680 let language_code = target_language.get_iso_639_1_code().ok_or_else(|| {
1681 RustBertError::ValueError(format!(
1682 "This language has no ISO639-I language code representation. \
1683 languages supported by the model: {supported_target_languages:?}"
1684 ))
1685 })?;
1686 Some(
1687 self.convert_tokens_to_ids(&[
1688 match language_code.len() {
1689 2 => format!(">>{language_code}.<<"),
1690 3 => format!(">>{language_code}<<"),
1691 _ => {
1692 return Err(RustBertError::ValueError(
1693 "Invalid ISO 639-3 code".to_string(),
1694 ));
1695 }
1696 },
1697 ])[0],
1698 )
1699 } else {
1700 return Err(RustBertError::ValueError(format!(
1701 "Missing target language for M2M100 \
1702 (multiple languages supported by model: {supported_target_languages:?}, \
1703 need to specify target language)",
1704 )));
1705 },
1706 ),
1707 Self::NLLB(_) => {
1708 let source_language = source_language
1709 .and_then(Language::get_nllb_code)
1710 .map(str::to_string)
1711 .ok_or_else(|| RustBertError::ValueError(
1712 format!("Missing source language for NLLB. Need to specify one from: {supported_source_languages:?}")
1713 ))?;
1714
1715 let target_language = target_language
1716 .and_then(Language::get_nllb_code)
1717 .map(str::to_string)
1718 .map(|code| self.convert_tokens_to_ids(&[code])[0])
1719 .ok_or_else(|| RustBertError::ValueError(
1720 format!("Missing target language for NLLB. Need to specify one from: {supported_target_languages:?}")
1721 ))?;
1722
1723 (Some(source_language), Some(target_language))
1724 }
1725 _ => (None, None),
1726 })
1727 }
1728
1729 pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
1731 where
1732 S: AsRef<str>,
1733 {
1734 match *self {
1735 Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1736 Self::Deberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1737 Self::DebertaV2(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1738 Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1739 Self::Bart(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1740 Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1741 Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1742 Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1743 Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1744 Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1745 Self::GPT2(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1746 Self::OpenAiGpt(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1747 Self::Reformer(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1748 Self::ProphetNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1749 Self::Pegasus(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1750 Self::MBart50(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1751 Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1752 Self::NLLB(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1753 Self::FNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1754 #[cfg(feature = "hf-tokenizers")]
1755 Self::HFTokenizer(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1756 }
1757 }
1758
1759 pub fn get_unk_id(&self) -> i64 {
1761 match *self {
1762 Self::Bert(ref tokenizer) => {
1763 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1764 vocab.token_to_id(vocab.get_unknown_value())
1765 }
1766 Self::Deberta(ref tokenizer) => {
1767 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1768 vocab.token_to_id(vocab.get_unknown_value())
1769 }
1770 Self::DebertaV2(ref tokenizer) => {
1771 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1772 vocab.token_to_id(vocab.get_unknown_value())
1773 }
1774 Self::Roberta(ref tokenizer) => {
1775 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1776 vocab.token_to_id(vocab.get_unknown_value())
1777 }
1778 Self::Bart(ref tokenizer) => {
1779 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1780 vocab.token_to_id(vocab.get_unknown_value())
1781 }
1782 Self::XLMRoberta(ref tokenizer) => {
1783 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1784 vocab.token_to_id(vocab.get_unknown_value())
1785 }
1786 Self::Marian(ref tokenizer) => {
1787 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1788 vocab.token_to_id(vocab.get_unknown_value())
1789 }
1790 Self::T5(ref tokenizer) => {
1791 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1792 vocab.token_to_id(vocab.get_unknown_value())
1793 }
1794 Self::Albert(ref tokenizer) => {
1795 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1796 vocab.token_to_id(vocab.get_unknown_value())
1797 }
1798 Self::XLNet(ref tokenizer) => {
1799 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1800 vocab.token_to_id(vocab.get_unknown_value())
1801 }
1802 Self::GPT2(ref tokenizer) => {
1803 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1804 vocab.token_to_id(vocab.get_unknown_value())
1805 }
1806 Self::OpenAiGpt(ref tokenizer) => {
1807 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1808 vocab.token_to_id(vocab.get_unknown_value())
1809 }
1810 Self::Reformer(ref tokenizer) => {
1811 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1812 vocab.token_to_id(vocab.get_unknown_value())
1813 }
1814 Self::ProphetNet(ref tokenizer) => {
1815 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1816 vocab.token_to_id(vocab.get_unknown_value())
1817 }
1818 Self::Pegasus(ref tokenizer) => {
1819 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1820 vocab.token_to_id(vocab.get_unknown_value())
1821 }
1822 Self::MBart50(ref tokenizer) => {
1823 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1824 vocab.token_to_id(vocab.get_unknown_value())
1825 }
1826 Self::M2M100(ref tokenizer) => {
1827 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1828 vocab.token_to_id(vocab.get_unknown_value())
1829 }
1830 Self::NLLB(ref tokenizer) => {
1831 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1832 vocab.token_to_id(vocab.get_unknown_value())
1833 }
1834 Self::FNet(ref tokenizer) => {
1835 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1836 vocab.token_to_id(vocab.get_unknown_value())
1837 }
1838 #[cfg(feature = "hf-tokenizers")]
1839 Self::HFTokenizer(ref tokenizer) => {
1840 tokenizer.token_to_id(&tokenizer.special_token_map.unk_token)
1841 }
1842 }
1843 }
1844
1845 pub fn get_pad_id(&self) -> Option<i64> {
1847 match *self {
1848 Self::Bert(ref tokenizer) => {
1849 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1850 Some(vocab.token_to_id(vocab.get_pad_value()))
1851 }
1852 Self::Deberta(ref tokenizer) => {
1853 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1854 Some(vocab.token_to_id(vocab.get_pad_value()))
1855 }
1856 Self::DebertaV2(ref tokenizer) => {
1857 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1858 Some(vocab.token_to_id(vocab.get_pad_value()))
1859 }
1860 Self::Roberta(ref tokenizer) => {
1861 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1862 Some(vocab.token_to_id(vocab.get_pad_value()))
1863 }
1864 Self::Bart(ref tokenizer) => {
1865 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1866 Some(vocab.token_to_id(vocab.get_pad_value()))
1867 }
1868 Self::XLMRoberta(ref tokenizer) => {
1869 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1870 Some(vocab.token_to_id(vocab.get_pad_value()))
1871 }
1872 Self::Marian(ref tokenizer) => {
1873 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1874 Some(vocab.token_to_id(vocab.get_pad_value()))
1875 }
1876 Self::T5(ref tokenizer) => {
1877 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1878 Some(vocab.token_to_id(vocab.get_pad_value()))
1879 }
1880 Self::Albert(ref tokenizer) => {
1881 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1882 Some(vocab.token_to_id(vocab.get_pad_value()))
1883 }
1884 Self::XLNet(ref tokenizer) => {
1885 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1886 Some(vocab.token_to_id(vocab.get_pad_value()))
1887 }
1888 Self::ProphetNet(ref tokenizer) => {
1889 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1890 Some(vocab.token_to_id(vocab.get_pad_value()))
1891 }
1892 Self::Pegasus(ref tokenizer) => {
1893 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1894 Some(vocab.token_to_id(vocab.get_pad_value()))
1895 }
1896 Self::MBart50(ref tokenizer) => {
1897 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1898 Some(vocab.token_to_id(vocab.get_pad_value()))
1899 }
1900 Self::M2M100(ref tokenizer) => {
1901 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1902 Some(vocab.token_to_id(vocab.get_pad_value()))
1903 }
1904 Self::NLLB(ref tokenizer) => {
1905 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1906 Some(vocab.token_to_id(vocab.get_pad_value()))
1907 }
1908 Self::FNet(ref tokenizer) => {
1909 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1910 Some(vocab.token_to_id(vocab.get_pad_value()))
1911 }
1912 #[cfg(feature = "hf-tokenizers")]
1913 Self::HFTokenizer(ref tokenizer) => tokenizer
1914 .special_token_map
1915 .pad_token
1916 .as_ref()
1917 .map(|token| tokenizer.token_to_id(token)),
1918 Self::Reformer(_) => None,
1919 Self::GPT2(_) => None,
1920 Self::OpenAiGpt(_) => None,
1921 }
1922 }
1923
1924 pub fn get_sep_id(&self) -> Option<i64> {
1926 match *self {
1927 Self::Bert(ref tokenizer) => {
1928 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1929 Some(vocab.token_to_id(vocab.get_sep_value()))
1930 }
1931 Self::Deberta(ref tokenizer) => {
1932 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1933 Some(vocab.token_to_id(vocab.get_sep_value()))
1934 }
1935 Self::DebertaV2(ref tokenizer) => {
1936 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1937 Some(vocab.token_to_id(vocab.get_sep_value()))
1938 }
1939 Self::Roberta(ref tokenizer) => {
1940 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1941 Some(vocab.token_to_id(vocab.get_sep_value()))
1942 }
1943 Self::Bart(ref tokenizer) => {
1944 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1945 Some(vocab.token_to_id(vocab.get_sep_value()))
1946 }
1947 Self::XLMRoberta(ref tokenizer) => {
1948 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1949 Some(vocab.token_to_id(vocab.get_sep_value()))
1950 }
1951 Self::Albert(ref tokenizer) => {
1952 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1953 Some(vocab.token_to_id(vocab.get_sep_value()))
1954 }
1955 Self::XLNet(ref tokenizer) => {
1956 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1957 Some(vocab.token_to_id(vocab.get_sep_value()))
1958 }
1959 Self::ProphetNet(ref tokenizer) => {
1960 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1961 Some(vocab.token_to_id(vocab.get_sep_value()))
1962 }
1963 Self::MBart50(ref tokenizer) => {
1964 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1965 Some(vocab.token_to_id(vocab.get_sep_value()))
1966 }
1967 Self::M2M100(ref tokenizer) => {
1968 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1969 Some(vocab.token_to_id(vocab.get_sep_value()))
1970 }
1971 Self::NLLB(ref tokenizer) => {
1972 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1973 Some(vocab.token_to_id(vocab.get_sep_value()))
1974 }
1975 Self::FNet(ref tokenizer) => {
1976 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1977 Some(vocab.token_to_id(vocab.get_sep_value()))
1978 }
1979 #[cfg(feature = "hf-tokenizers")]
1980 Self::HFTokenizer(ref tokenizer) => tokenizer
1981 .special_token_map
1982 .sep_token
1983 .as_ref()
1984 .map(|token| tokenizer.token_to_id(token)),
1985 Self::Marian(_) => None,
1986 Self::T5(_) => None,
1987 Self::GPT2(_) => None,
1988 Self::OpenAiGpt(_) => None,
1989 Self::Reformer(_) => None,
1990 Self::Pegasus(_) => None,
1991 }
1992 }
1993
1994 pub fn get_mask_id(&self) -> Option<i64> {
1996 match *self {
1997 Self::Bert(ref tokenizer) => {
1998 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1999 Some(vocab.token_to_id(vocab.get_mask_value()))
2000 }
2001 Self::Deberta(ref tokenizer) => {
2002 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2003 Some(vocab.token_to_id(vocab.get_mask_value()))
2004 }
2005 Self::DebertaV2(ref tokenizer) => {
2006 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2007 Some(vocab.token_to_id(vocab.get_mask_value()))
2008 }
2009 Self::Roberta(ref tokenizer) => {
2010 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2011 Some(vocab.token_to_id(vocab.get_mask_value()))
2012 }
2013 Self::Bart(ref tokenizer) => {
2014 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2015 Some(vocab.token_to_id(vocab.get_mask_value()))
2016 }
2017 Self::XLMRoberta(ref tokenizer) => {
2018 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2019 Some(vocab.token_to_id(vocab.get_mask_value()))
2020 }
2021 Self::Albert(ref tokenizer) => {
2022 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2023 Some(vocab.token_to_id(vocab.get_mask_value()))
2024 }
2025 Self::XLNet(ref tokenizer) => {
2026 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2027 Some(vocab.token_to_id(vocab.get_mask_value()))
2028 }
2029 Self::ProphetNet(ref tokenizer) => {
2030 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2031 Some(vocab.token_to_id(vocab.get_mask_value()))
2032 }
2033 Self::MBart50(ref tokenizer) => {
2034 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2035 Some(vocab.token_to_id(vocab.get_mask_value()))
2036 }
2037 Self::FNet(ref tokenizer) => {
2038 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2039 Some(vocab.token_to_id(vocab.get_mask_value()))
2040 }
2041 Self::Pegasus(ref tokenizer) => {
2042 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2043 Some(vocab.token_to_id(vocab.get_mask_value()))
2044 }
2045 #[cfg(feature = "hf-tokenizers")]
2046 Self::HFTokenizer(ref tokenizer) => tokenizer
2047 .special_token_map
2048 .mask_token
2049 .as_ref()
2050 .map(|token| tokenizer.token_to_id(token)),
2051 Self::Marian(_) => None,
2052 Self::M2M100(_) => None,
2053 Self::NLLB(_) => None,
2054 Self::T5(_) => None,
2055 Self::GPT2(_) => None,
2056 Self::OpenAiGpt(_) => None,
2057 Self::Reformer(_) => None,
2058 }
2059 }
2060
2061 pub fn get_mask_value(&self) -> Option<&str> {
2063 match self {
2064 Self::Bert(ref tokenizer) => {
2065 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2066 }
2067 Self::Deberta(ref tokenizer) => {
2068 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2069 }
2070 Self::DebertaV2(ref tokenizer) => {
2071 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2072 }
2073 Self::Roberta(ref tokenizer) => {
2074 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2075 }
2076 Self::Bart(ref tokenizer) => {
2077 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2078 }
2079 Self::XLMRoberta(ref tokenizer) => {
2080 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2081 }
2082 Self::Albert(ref tokenizer) => {
2083 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2084 }
2085 Self::XLNet(ref tokenizer) => {
2086 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2087 }
2088 Self::ProphetNet(ref tokenizer) => {
2089 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2090 }
2091 Self::MBart50(ref tokenizer) => {
2092 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2093 }
2094 Self::FNet(ref tokenizer) => {
2095 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2096 }
2097 Self::Pegasus(ref tokenizer) => {
2098 Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2099 }
2100 #[cfg(feature = "hf-tokenizers")]
2101 Self::HFTokenizer(ref tokenizer) => tokenizer.special_token_map.mask_token.as_deref(),
2102 Self::M2M100(_) => None,
2103 Self::NLLB(_) => None,
2104 Self::Marian(_) => None,
2105 Self::T5(_) => None,
2106 Self::GPT2(_) => None,
2107 Self::OpenAiGpt(_) => None,
2108 Self::Reformer(_) => None,
2109 }
2110 }
2111
2112 pub fn get_bos_id(&self) -> Option<i64> {
2114 match *self {
2115 Self::Roberta(ref tokenizer) => {
2116 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2117 Some(vocab.token_to_id(vocab.get_bos_value()))
2118 }
2119 Self::Bart(ref tokenizer) => {
2120 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2121 Some(vocab.token_to_id(vocab.get_bos_value()))
2122 }
2123 Self::DebertaV2(ref tokenizer) => {
2124 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2125 Some(vocab.token_to_id(vocab.get_bos_value()))
2126 }
2127 Self::XLMRoberta(ref tokenizer) => {
2128 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2129 Some(vocab.token_to_id(vocab.get_bos_value()))
2130 }
2131 Self::Albert(ref tokenizer) => {
2132 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2133 Some(vocab.token_to_id(vocab.get_bos_value()))
2134 }
2135 Self::XLNet(ref tokenizer) => {
2136 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2137 Some(vocab.token_to_id(vocab.get_bos_value()))
2138 }
2139 Self::M2M100(ref tokenizer) => {
2140 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2141 Some(vocab.token_to_id(vocab.get_bos_value()))
2142 }
2143 Self::NLLB(ref tokenizer) => {
2144 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2145 Some(vocab.token_to_id(vocab.get_bos_value()))
2146 }
2147 Self::GPT2(ref tokenizer) => {
2148 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2149 Some(vocab.token_to_id(vocab.get_bos_value()))
2150 }
2151 Self::Deberta(ref tokenizer) => {
2152 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2153 Some(vocab.token_to_id(vocab.get_bos_value()))
2154 }
2155 #[cfg(feature = "hf-tokenizers")]
2156 Self::HFTokenizer(ref tokenizer) => tokenizer
2157 .special_token_map
2158 .bos_token
2159 .as_ref()
2160 .map(|token| tokenizer.token_to_id(token)),
2161 Self::MBart50(_) => Some(0),
2162 Self::FNet(_) => None,
2163 Self::Bert(_) => None,
2164 Self::Marian(_) => Some(0),
2165 Self::T5(_) => None,
2166 Self::ProphetNet(_) => None,
2167 Self::OpenAiGpt(_) => None,
2168 Self::Reformer(_) => None,
2169 Self::Pegasus(_) => Some(0),
2170 }
2171 }
2172
2173 pub fn get_eos_id(&self) -> Option<i64> {
2175 match *self {
2176 Self::Roberta(ref tokenizer) => {
2177 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2178 Some(vocab.token_to_id(vocab.get_eos_value()))
2179 }
2180 Self::Bart(ref tokenizer) => {
2181 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2182 Some(vocab.token_to_id(vocab.get_eos_value()))
2183 }
2184 Self::DebertaV2(ref tokenizer) => {
2185 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2186 Some(vocab.token_to_id(vocab.get_eos_value()))
2187 }
2188 Self::XLMRoberta(ref tokenizer) => {
2189 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2190 Some(vocab.token_to_id(vocab.get_eos_value()))
2191 }
2192 Self::Albert(ref tokenizer) => {
2193 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2194 Some(vocab.token_to_id(vocab.get_eos_value()))
2195 }
2196 Self::XLNet(ref tokenizer) => {
2197 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2198 Some(vocab.token_to_id(vocab.get_eos_value()))
2199 }
2200 Self::MBart50(ref tokenizer) => {
2201 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2202 Some(vocab.token_to_id(vocab.get_eos_value()))
2203 }
2204 Self::M2M100(ref tokenizer) => {
2205 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2206 Some(vocab.token_to_id(vocab.get_eos_value()))
2207 }
2208 Self::NLLB(ref tokenizer) => {
2209 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2210 Some(vocab.token_to_id(vocab.get_eos_value()))
2211 }
2212 Self::GPT2(ref tokenizer) => {
2213 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2214 Some(vocab.token_to_id(vocab.get_eos_value()))
2215 }
2216 Self::Deberta(ref tokenizer) => {
2217 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2218 Some(vocab.token_to_id(vocab.get_eos_value()))
2219 }
2220 Self::Marian(ref tokenizer) => {
2221 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2222 Some(vocab.token_to_id(vocab.get_eos_value()))
2223 }
2224 Self::T5(ref tokenizer) => {
2225 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2226 Some(vocab.token_to_id(vocab.get_eos_value()))
2227 }
2228 Self::Reformer(ref tokenizer) => {
2229 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2230 Some(vocab.token_to_id(vocab.get_eos_value()))
2231 }
2232 Self::Pegasus(ref tokenizer) => {
2233 let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2234 Some(vocab.token_to_id(vocab.get_eos_value()))
2235 }
2236 #[cfg(feature = "hf-tokenizers")]
2237 Self::HFTokenizer(ref tokenizer) => tokenizer
2238 .special_token_map
2239 .eos_token
2240 .as_ref()
2241 .map(|token| tokenizer.token_to_id(token)),
2242 Self::FNet(_) => None,
2243 Self::Bert(_) => None,
2244 Self::ProphetNet(_) => None,
2245 Self::OpenAiGpt(_) => None,
2246 }
2247 }
2248
2249 pub fn tokenize_and_pad<'a, S>(
2250 &self,
2251 input: S,
2252 max_length: usize,
2253 device: Device,
2254 ) -> (Tensor, Tensor)
2255 where
2256 S: AsRef<[&'a str]>,
2257 {
2258 let mut tokenized_input: Vec<TokenizedInput> = self.encode_list(
2259 input.as_ref(),
2260 max_length,
2261 &TruncationStrategy::LongestFirst,
2262 0,
2263 );
2264 let max_len = tokenized_input
2265 .iter()
2266 .map(|input| input.token_ids.len())
2267 .max()
2268 .unwrap();
2269 let pad_id = self
2270 .get_pad_id()
2271 .expect("The Tokenizer used for sequence classification should contain a PAD id");
2272 let tokenized_input_tensors: Vec<Tensor> = tokenized_input
2273 .iter_mut()
2274 .map(|input| {
2275 input.token_ids.resize(max_len, pad_id);
2276 Tensor::from_slice(&(input.token_ids))
2277 })
2278 .collect::<Vec<_>>();
2279
2280 let token_type_ids: Vec<Tensor> = tokenized_input
2281 .iter_mut()
2282 .map(|input| {
2283 input
2284 .segment_ids
2285 .resize(max_len, *input.segment_ids.last().unwrap_or(&0));
2286 Tensor::from_slice(&(input.segment_ids))
2287 })
2288 .collect::<Vec<_>>();
2289
2290 (
2291 Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(device),
2292 Tensor::stack(token_type_ids.as_slice(), 0)
2293 .to(device)
2294 .to_kind(Kind::Int64),
2295 )
2296 }
2297
2298 pub fn add_extra_ids(&mut self, num_extra_ids: i64) {
2300 match *self {
2301 Self::Bert(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2302 Self::Deberta(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2303 Self::DebertaV2(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2304 Self::Roberta(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2305 Self::Bart(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2306 Self::Marian(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2307 Self::T5(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2308 Self::XLMRoberta(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2309 Self::Albert(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2310 Self::XLNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2311 Self::GPT2(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2312 Self::OpenAiGpt(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2313 Self::Reformer(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2314 Self::ProphetNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2315 Self::Pegasus(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2316 Self::MBart50(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2317 Self::M2M100(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2318 Self::NLLB(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2319 Self::FNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2320 #[cfg(feature = "hf-tokenizers")]
2321 Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2322 }
2323 }
2324
2325 pub fn add_tokens(&mut self, tokens: &[&str]) {
2327 match *self {
2328 Self::Bert(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2329 Self::Deberta(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2330 Self::DebertaV2(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2331 Self::Roberta(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2332 Self::Bart(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2333 Self::Marian(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2334 Self::T5(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2335 Self::XLMRoberta(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2336 Self::Albert(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2337 Self::XLNet(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2338 Self::GPT2(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2339 Self::OpenAiGpt(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2340 Self::Reformer(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2341 Self::ProphetNet(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2342 Self::Pegasus(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2343 Self::MBart50(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2344 Self::M2M100(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2345 Self::NLLB(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2346 Self::FNet(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2347 #[cfg(feature = "hf-tokenizers")]
2348 Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2349 }
2350 }
2351}
2352
2353pub fn cast_var_store(varstore: &mut VarStore, kind: Option<Kind>, device: Device) {
2354 match (kind, device) {
2355 (Some(kind), _) => varstore.set_kind(kind),
2356 (None, Device::Cpu) => varstore.set_kind(Kind::Float),
2357 (None, _) => {}
2358 }
2359}