rust_bert/pipelines/
common.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019-2020 Guillaume Becquin
3// Copyright 2020 Maarten van Gompel
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Common blocks for generic pipelines (e.g. token classification or sequence classification)
15//! Provides Enums holding configuration or tokenization resources that can be used to create
16//! generic pipelines. The model component is defined in the generic pipeline itself as the
17//! pre-processing, forward pass and postprocessing differs between pipelines while basic config and
18//! tokenization objects don't.
19use crate::albert::AlbertConfig;
20use crate::bart::BartConfig;
21use crate::bert::BertConfig;
22use crate::common::error::RustBertError;
23use crate::deberta::DebertaConfig;
24use crate::deberta_v2::DebertaV2Config;
25use crate::distilbert::DistilBertConfig;
26use crate::electra::ElectraConfig;
27use crate::fnet::FNetConfig;
28use crate::gpt2::Gpt2Config;
29use crate::gpt_j::GptJConfig;
30use crate::gpt_neo::GptNeoConfig;
31use crate::longformer::LongformerConfig;
32use crate::longt5::LongT5Config;
33use crate::m2m_100::M2M100Config;
34use crate::marian::MarianConfig;
35use crate::mbart::MBartConfig;
36use crate::mobilebert::MobileBertConfig;
37use crate::openai_gpt::OpenAiGptConfig;
38use crate::pegasus::PegasusConfig;
39use crate::pipelines::translation::Language;
40use crate::prophetnet::ProphetNetConfig;
41use crate::reformer::ReformerConfig;
42use crate::resources::{Resource, ResourceProvider};
43use crate::roberta::RobertaConfig;
44use crate::t5::T5Config;
45use crate::xlnet::XLNetConfig;
46use crate::Config;
47use rust_tokenizers::tokenizer::{
48    AlbertTokenizer, BertTokenizer, DeBERTaTokenizer, DeBERTaV2Tokenizer, FNetTokenizer,
49    Gpt2Tokenizer, M2M100Tokenizer, MBart50Tokenizer, MarianTokenizer, MultiThreadedTokenizer,
50    NLLBTokenizer, OpenAiGptTokenizer, PegasusTokenizer, ProphetNetTokenizer, ReformerTokenizer,
51    RobertaTokenizer, T5Tokenizer, Tokenizer, TruncationStrategy, XLMRobertaTokenizer,
52    XLNetTokenizer,
53};
54use rust_tokenizers::vocab::Vocab;
55use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
56use serde::{Deserialize, Serialize};
57use std::collections::{HashMap, HashSet};
58use std::convert::TryFrom;
59
60use std::fmt::Debug;
61
62use std::path::{Path, PathBuf};
63use tch::nn::VarStore;
64use tch::{Device, Kind, Tensor};
65
66#[cfg(feature = "onnx")]
67use crate::pipelines::onnx::ONNXModelConfig;
68
69#[cfg(feature = "hf-tokenizers")]
70use crate::pipelines::hf_tokenizers::HFTokenizer;
71
72#[derive(Debug, Default)]
73/// Container for ONNX model resources, containing 3 optional resources (Encoder, Decoder and Decoder with past)
74pub struct ONNXModelResources {
75    /// Model encoder resource
76    pub encoder_resource: Option<Box<dyn ResourceProvider + Send>>,
77    /// Model encoder resource
78    pub decoder_resource: Option<Box<dyn ResourceProvider + Send>>,
79    /// Model encoder resource
80    pub decoder_with_past_resource: Option<Box<dyn ResourceProvider + Send>>,
81}
82
83#[derive(Debug)]
84/// Variants to store either a Torch model resource or ONNX resources
85pub enum ModelResource {
86    Torch(Box<dyn ResourceProvider + Send>),
87    #[cfg(feature = "onnx")]
88    ONNX(ONNXModelResources),
89}
90
91impl ResourceProvider for ModelResource {
92    fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
93        match self {
94            ModelResource::Torch(ref resource) => resource.get_local_path(),
95            #[cfg(feature = "onnx")]
96            ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
97        }
98    }
99    fn get_resource(&self) -> Result<Resource, RustBertError> {
100        match self {
101            ModelResource::Torch(ref resource) => resource.get_resource(),
102            #[cfg(feature = "onnx")]
103            ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
104        }
105    }
106}
107
108pub struct ONNXLocalPaths {
109    pub encoder_path: Option<PathBuf>,
110    pub decoder_path: Option<PathBuf>,
111    pub decoder_with_past_path: Option<PathBuf>,
112}
113
114impl ModelResource {
115    /// Provides the torch resource local path.
116    /// Returns an error if the variant is not a `ModelResources::TORCH`
117    pub fn get_torch_local_path(&self) -> Result<PathBuf, RustBertError> {
118        match self {
119            ModelResource::Torch(torch_resource) => torch_resource.get_local_path(),
120            #[cfg(feature = "onnx")]
121            _ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the Torch local path but other weights variants were given: {:?}", self)))
122        }
123    }
124
125    #[cfg(feature = "onnx")]
126    pub fn get_onnx_local_paths(&self) -> Result<ONNXLocalPaths, RustBertError> {
127        let (encoder_path, decoder_path, decoder_with_past_path) = match self {
128            ModelResource::ONNX(onnx_model_resources) => Ok((
129                onnx_model_resources
130                    .encoder_resource.as_ref()
131                    .map(|r| r.get_local_path()),
132                onnx_model_resources
133                    .decoder_resource.as_ref()
134                    .map(|r| r.get_local_path()),
135                onnx_model_resources
136                    .decoder_with_past_resource.as_ref()
137                    .map(|r| r.get_local_path()),
138            )),
139            _ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the ONNX local paths but other weights variants were given: {:?}", self)))
140        }?;
141        Ok(ONNXLocalPaths {
142            encoder_path: encoder_path.transpose()?,
143            decoder_path: decoder_path.transpose()?,
144            decoder_with_past_path: decoder_with_past_path.transpose()?,
145        })
146    }
147}
148
149pub(crate) fn get_device(_model_resource: ModelResource, device: Device) -> Device {
150    #[cfg(feature = "onnx")]
151    let device = if let ModelResource::ONNX(_) = _model_resource {
152        Device::Cpu
153    } else {
154        device
155    };
156
157    #[cfg(not(feature = "onnx"))]
158    let device = device;
159    device
160}
161
162#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq, Eq)]
163/// # Identifies the type of model
164pub enum ModelType {
165    Bart,
166    #[serde(alias = "bert")]
167    Bert,
168    #[serde(alias = "distilbert")]
169    DistilBert,
170    Deberta,
171    DebertaV2,
172    #[serde(alias = "roberta")]
173    Roberta,
174    XLMRoberta,
175    Electra,
176    Marian,
177    MobileBert,
178    #[serde(alias = "t5")]
179    T5,
180    #[serde(alias = "longt5")]
181    LongT5,
182    #[serde(alias = "albert")]
183    Albert,
184    XLNet,
185    GPT2,
186    GPTJ,
187    OpenAiGpt,
188    Reformer,
189    ProphetNet,
190    Longformer,
191    Pegasus,
192    GPTNeo,
193    MBart,
194    M2M100,
195    #[serde(alias = "m2m100")]
196    NLLB,
197    FNet,
198    #[cfg(feature = "onnx")]
199    ONNX,
200}
201
202/// # Abstraction that holds a model configuration, can be of any of the supported models
203pub enum ConfigOption {
204    /// Bart configuration
205    Bart(BartConfig),
206    /// Bert configuration
207    Bert(BertConfig),
208    /// DistilBert configuration
209    DistilBert(DistilBertConfig),
210    /// DeBERTa configuration
211    Deberta(DebertaConfig),
212    /// DeBERTa V2 configuration
213    DebertaV2(DebertaV2Config),
214    /// Electra configuration
215    Electra(ElectraConfig),
216    /// Marian configuration
217    Marian(MarianConfig),
218    /// MobileBert configuration
219    MobileBert(MobileBertConfig),
220    /// OpenAI GPT configuration
221    OpenAiGpt(OpenAiGptConfig),
222    /// T5 configuration
223    T5(T5Config),
224    /// LongT5 configuration
225    LongT5(LongT5Config),
226    /// Albert configuration
227    Albert(AlbertConfig),
228    /// XLNet configuration
229    XLNet(XLNetConfig),
230    /// GPT2 configuration
231    GPT2(Gpt2Config),
232    /// GPT-J configuration
233    GPTJ(GptJConfig),
234    /// Reformer configuration
235    Reformer(ReformerConfig),
236    /// RoBERTa configuration
237    Roberta(RobertaConfig),
238    /// ProphetNet configuration
239    ProphetNet(ProphetNetConfig),
240    /// Longformer configuration
241    Longformer(LongformerConfig),
242    /// Pegasus configuration
243    Pegasus(PegasusConfig),
244    /// GPT-Neo configuration
245    GPTNeo(GptNeoConfig),
246    /// MBart configuration
247    MBart(MBartConfig),
248    /// M2M100 configuration
249    M2M100(M2M100Config),
250    /// FNet configuration
251    FNet(FNetConfig),
252    /// ONNX Model configuration
253    #[cfg(feature = "onnx")]
254    ONNX(ONNXModelConfig),
255}
256
257/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
258pub enum TokenizerOption {
259    /// Bert Tokenizer
260    Bert(BertTokenizer),
261    /// DeBERTa Tokenizer
262    Deberta(DeBERTaTokenizer),
263    /// DeBERTa V2 Tokenizer
264    DebertaV2(DeBERTaV2Tokenizer),
265    /// Roberta Tokenizer
266    Roberta(RobertaTokenizer),
267    /// XLMRoberta Tokenizer
268    XLMRoberta(XLMRobertaTokenizer),
269    /// Marian Tokenizer
270    Marian(MarianTokenizer),
271    /// T5 Tokenizer
272    T5(T5Tokenizer),
273    /// Albert Tokenizer
274    Albert(AlbertTokenizer),
275    /// XLNet Tokenizer
276    XLNet(XLNetTokenizer),
277    /// GPT2 Tokenizer
278    GPT2(Gpt2Tokenizer),
279    /// GPT Tokenizer
280    OpenAiGpt(OpenAiGptTokenizer),
281    /// Reformer Tokenizer
282    Reformer(ReformerTokenizer),
283    /// ProphetNet Tokenizer
284    ProphetNet(ProphetNetTokenizer),
285    /// Pegasus Tokenizer
286    Pegasus(PegasusTokenizer),
287    /// MBart50 Tokenizer
288    MBart50(MBart50Tokenizer),
289    /// M2M100 Tokenizer
290    M2M100(M2M100Tokenizer),
291    /// NLLB tokenizer.
292    NLLB(NLLBTokenizer),
293    /// FNet Tokenizer
294    FNet(FNetTokenizer),
295    /// Bart Tokenizer
296    Bart(RobertaTokenizer),
297    /// HF Tokenizer
298    #[cfg(feature = "hf-tokenizers")]
299    HFTokenizer(HFTokenizer),
300}
301
302impl ConfigOption {
303    /// Interface method to load a configuration from file
304    pub fn from_file<P: AsRef<Path>>(model_type: ModelType, path: P) -> Self {
305        match model_type {
306            ModelType::Bart => ConfigOption::Bart(BartConfig::from_file(path)),
307            ModelType::Bert => ConfigOption::Bert(BertConfig::from_file(path)),
308            ModelType::Deberta => ConfigOption::Deberta(DebertaConfig::from_file(path)),
309            ModelType::DebertaV2 => ConfigOption::DebertaV2(DebertaV2Config::from_file(path)),
310            ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
311            ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
312            ModelType::Marian => ConfigOption::Marian(MarianConfig::from_file(path)),
313            ModelType::MobileBert => ConfigOption::MobileBert(MobileBertConfig::from_file(path)),
314            ModelType::T5 => ConfigOption::T5(T5Config::from_file(path)),
315            ModelType::LongT5 => ConfigOption::LongT5(LongT5Config::from_file(path)),
316            ModelType::Albert => ConfigOption::Albert(AlbertConfig::from_file(path)),
317            ModelType::XLNet => ConfigOption::XLNet(XLNetConfig::from_file(path)),
318            ModelType::GPT2 => ConfigOption::GPT2(Gpt2Config::from_file(path)),
319            ModelType::GPTJ => ConfigOption::GPTJ(GptJConfig::from_file(path)),
320            ModelType::GPTNeo => ConfigOption::GPTNeo(GptNeoConfig::from_file(path)),
321            ModelType::OpenAiGpt => ConfigOption::OpenAiGpt(OpenAiGptConfig::from_file(path)),
322            ModelType::Reformer => ConfigOption::Reformer(ReformerConfig::from_file(path)),
323            ModelType::ProphetNet => ConfigOption::ProphetNet(ProphetNetConfig::from_file(path)),
324            ModelType::Longformer => ConfigOption::Longformer(LongformerConfig::from_file(path)),
325            ModelType::Pegasus => ConfigOption::Pegasus(PegasusConfig::from_file(path)),
326            ModelType::Roberta | ModelType::XLMRoberta => {
327                ConfigOption::Roberta(RobertaConfig::from_file(path))
328            }
329            ModelType::MBart => ConfigOption::MBart(MBartConfig::from_file(path)),
330            ModelType::M2M100 | ModelType::NLLB => {
331                ConfigOption::M2M100(M2M100Config::from_file(path))
332            }
333            ModelType::FNet => ConfigOption::FNet(FNetConfig::from_file(path)),
334            #[cfg(feature = "onnx")]
335            ModelType::ONNX => ConfigOption::ONNX(ONNXModelConfig::from_file(path)),
336        }
337    }
338
339    pub fn get_label_mapping(&self) -> &HashMap<i64, String> {
340        match self {
341            Self::Bart(config) => config
342                .id2label
343                .as_ref()
344                .expect("No label dictionary (id2label) provided in configuration file"),
345            Self::Bert(config) => config
346                .id2label
347                .as_ref()
348                .expect("No label dictionary (id2label) provided in configuration file"),
349            Self::Deberta(config) => config
350                .id2label
351                .as_ref()
352                .expect("No label dictionary (id2label) provided in configuration file"),
353            Self::DebertaV2(config) => config
354                .id2label
355                .as_ref()
356                .expect("No label dictionary (id2label) provided in configuration file"),
357            Self::DistilBert(config) => config
358                .id2label
359                .as_ref()
360                .expect("No label dictionary (id2label) provided in configuration file"),
361            Self::Electra(config) => config
362                .id2label
363                .as_ref()
364                .expect("No label dictionary (id2label) provided in configuration file"),
365            Self::Marian(config) => config
366                .id2label
367                .as_ref()
368                .expect("No label dictionary (id2label) provided in configuration file"),
369            Self::MobileBert(config) => config
370                .id2label
371                .as_ref()
372                .expect("No label dictionary (id2label) provided in configuration file"),
373            Self::Albert(config) => config
374                .id2label
375                .as_ref()
376                .expect("No label dictionary (id2label) provided in configuration file"),
377            Self::XLNet(config) => config
378                .id2label
379                .as_ref()
380                .expect("No label dictionary (id2label) provided in configuration file"),
381            Self::Reformer(config) => config
382                .id2label
383                .as_ref()
384                .expect("No label dictionary (id2label) provided in configuration file"),
385            Self::ProphetNet(config) => config
386                .id2label
387                .as_ref()
388                .expect("No label dictionary (id2label) provided in configuration file"),
389            Self::Longformer(config) => config
390                .id2label
391                .as_ref()
392                .expect("No label dictionary (id2label) provided in configuration file"),
393            Self::MBart(config) => config
394                .id2label
395                .as_ref()
396                .expect("No label dictionary (id2label) provided in configuration file"),
397            Self::M2M100(config) => config
398                .id2label
399                .as_ref()
400                .expect("No label dictionary (id2label) provided in configuration file"),
401            Self::FNet(config) => config
402                .id2label
403                .as_ref()
404                .expect("No label dictionary (id2label) provided in configuration file"),
405            Self::Roberta(config) => config
406                .id2label
407                .as_ref()
408                .expect("No label dictionary (id2label) provided in configuration file"),
409            #[cfg(feature = "onnx")]
410            Self::ONNX(config) => config
411                .id2label
412                .as_ref()
413                .expect("No label dictionary (id2label) provided in configuration file"),
414            Self::T5(_) => panic!("T5 does not use a label mapping"),
415            Self::LongT5(_) => panic!("LongT5 does not use a label mapping"),
416            Self::OpenAiGpt(_) => panic!("OpenAI GPT does not use a label mapping"),
417            Self::GPT2(_) => panic!("GPT2 does not use a label mapping"),
418            Self::GPTJ(_) => panic!("GPT-J does not use a label mapping"),
419            Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"),
420            Self::Pegasus(_) => panic!("Pegasus does not use a label mapping"),
421        }
422    }
423
424    pub fn get_max_len(&self) -> Option<i64> {
425        match self {
426            Self::Bart(config) => Some(config.max_position_embeddings),
427            Self::Bert(config) => Some(config.max_position_embeddings),
428            Self::Deberta(config) => Some(config.max_position_embeddings),
429            Self::DebertaV2(config) => Some(config.max_position_embeddings),
430            Self::DistilBert(config) => Some(config.max_position_embeddings),
431            Self::Electra(config) => Some(config.max_position_embeddings),
432            Self::Marian(config) => Some(config.max_position_embeddings),
433            Self::MobileBert(config) => Some(config.max_position_embeddings),
434            Self::T5(_) => None,
435            Self::LongT5(_) => None,
436            Self::Albert(config) => Some(config.max_position_embeddings),
437            Self::XLNet(_) => None,
438            Self::GPT2(config) => Some(config.n_positions),
439            Self::GPTJ(config) => Some(config.n_positions),
440            Self::Reformer(config) => Some(config.max_position_embeddings),
441            Self::ProphetNet(config) => Some(config.max_position_embeddings),
442            Self::Longformer(config) => Some(config.max_position_embeddings),
443            Self::Pegasus(config) => Some(config.max_position_embeddings),
444            Self::OpenAiGpt(config) => Some(config.n_positions),
445            Self::GPTNeo(config) => Some(config.max_position_embeddings),
446            Self::MBart(config) => Some(config.max_position_embeddings),
447            Self::M2M100(config) => Some(config.max_position_embeddings),
448            Self::FNet(config) => Some(config.max_position_embeddings),
449            Self::Roberta(config) => Some(config.max_position_embeddings),
450            #[cfg(feature = "onnx")]
451            Self::ONNX(config) => config.max_position_embeddings,
452        }
453    }
454
455    pub fn get_vocab_size(&self) -> i64 {
456        match self {
457            Self::Bart(config) => config.vocab_size,
458            Self::Bert(config) => config.vocab_size,
459            Self::Deberta(config) => config.vocab_size,
460            Self::DebertaV2(config) => config.vocab_size,
461            Self::DistilBert(config) => config.vocab_size,
462            Self::Electra(config) => config.vocab_size,
463            Self::Marian(config) => config.vocab_size,
464            Self::MobileBert(config) => config.vocab_size,
465            Self::T5(config) => config.vocab_size,
466            Self::LongT5(config) => config.vocab_size,
467            Self::Albert(config) => config.vocab_size,
468            Self::XLNet(config) => config.vocab_size,
469            Self::GPT2(config) => config.vocab_size,
470            Self::GPTJ(config) => config.vocab_size,
471            Self::Reformer(config) => config.vocab_size,
472            Self::ProphetNet(config) => config.vocab_size,
473            Self::Longformer(config) => config.vocab_size,
474            Self::Pegasus(config) => config.vocab_size,
475            Self::OpenAiGpt(config) => config.vocab_size,
476            Self::GPTNeo(config) => config.vocab_size,
477            Self::MBart(config) => config.vocab_size,
478            Self::M2M100(config) => config.vocab_size,
479            Self::FNet(config) => config.vocab_size,
480            Self::Roberta(config) => config.vocab_size,
481            #[cfg(feature = "onnx")]
482            Self::ONNX(config) => config.vocab_size,
483        }
484    }
485
486    pub fn get_decoder_start_token_id(&self) -> Option<i64> {
487        match self {
488            Self::Bart(config) => config.decoder_start_token_id,
489            Self::Bert(_) => None,
490            Self::Deberta(_) => None,
491            Self::DebertaV2(_) => None,
492            Self::DistilBert(_) => None,
493            Self::Electra(_) => None,
494            Self::Marian(config) => config.decoder_start_token_id,
495            Self::MobileBert(_) => None,
496            Self::T5(config) => config.decoder_start_token_id,
497            Self::LongT5(config) => config.decoder_start_token_id,
498            Self::Albert(_) => None,
499            Self::XLNet(_) => None,
500            Self::GPT2(config) => config.decoder_start_token_id,
501            Self::GPTJ(config) => config.decoder_start_token_id,
502            Self::Reformer(config) => config.decoder_start_token_id,
503            Self::ProphetNet(config) => config.decoder_start_token_id,
504            Self::Longformer(_) => None,
505            Self::Pegasus(config) => config.decoder_start_token_id,
506            Self::OpenAiGpt(config) => config.decoder_start_token_id,
507            Self::GPTNeo(config) => config.decoder_start_token_id,
508            Self::MBart(config) => config.decoder_start_token_id,
509            Self::M2M100(config) => config.decoder_start_token_id,
510            Self::FNet(config) => config.decoder_start_token_id,
511            Self::Roberta(_) => None,
512            #[cfg(feature = "onnx")]
513            Self::ONNX(config) => config.decoder_start_token_id,
514        }
515    }
516
517    pub fn get_forced_bos_token_id(&self) -> Option<i64> {
518        match self {
519            Self::Bart(config) => config.forced_bos_token_id,
520            Self::Bert(_) => None,
521            Self::Deberta(_) => None,
522            Self::DebertaV2(_) => None,
523            Self::DistilBert(_) => None,
524            Self::Electra(_) => None,
525            Self::Marian(config) => config.forced_bos_token_id,
526            Self::MobileBert(_) => None,
527            Self::T5(config) => config.forced_bos_token_id,
528            Self::LongT5(config) => config.forced_bos_token_id,
529            Self::Albert(_) => None,
530            Self::XLNet(_) => None,
531            Self::GPT2(config) => config.forced_bos_token_id,
532            Self::GPTJ(config) => config.forced_bos_token_id,
533            Self::Reformer(config) => config.forced_bos_token_id,
534            Self::ProphetNet(config) => config.forced_bos_token_id,
535            Self::Longformer(_) => None,
536            Self::Pegasus(config) => config.forced_bos_token_id,
537            Self::OpenAiGpt(config) => config.forced_bos_token_id,
538            Self::GPTNeo(config) => config.forced_bos_token_id,
539            Self::MBart(config) => config.forced_bos_token_id,
540            Self::M2M100(config) => config.forced_bos_token_id,
541            Self::FNet(_) => None,
542            Self::Roberta(_) => None,
543            #[cfg(feature = "onnx")]
544            Self::ONNX(config) => config.forced_bos_token_id,
545        }
546    }
547
548    pub fn get_forced_eos_token_id(&self) -> Option<i64> {
549        match self {
550            Self::Bart(config) => config.forced_eos_token_id,
551            Self::Bert(_) => None,
552            Self::Deberta(_) => None,
553            Self::DebertaV2(_) => None,
554            Self::DistilBert(_) => None,
555            Self::Electra(_) => None,
556            Self::Marian(config) => config.forced_eos_token_id,
557            Self::MobileBert(_) => None,
558            Self::T5(config) => config.forced_eos_token_id,
559            Self::LongT5(config) => config.forced_eos_token_id,
560            Self::Albert(_) => None,
561            Self::XLNet(_) => None,
562            Self::GPT2(config) => config.forced_eos_token_id,
563            Self::GPTJ(config) => config.forced_eos_token_id,
564            Self::Reformer(config) => config.forced_eos_token_id,
565            Self::ProphetNet(config) => config.forced_eos_token_id,
566            Self::Longformer(_) => None,
567            Self::Pegasus(config) => config.forced_eos_token_id,
568            Self::OpenAiGpt(config) => config.forced_eos_token_id,
569            Self::GPTNeo(config) => config.forced_eos_token_id,
570            Self::MBart(config) => config.forced_eos_token_id,
571            Self::M2M100(config) => config.forced_eos_token_id,
572            Self::FNet(_) => None,
573            Self::Roberta(_) => None,
574            #[cfg(feature = "onnx")]
575            Self::ONNX(config) => config.forced_eos_token_id,
576        }
577    }
578}
579
580impl TryFrom<&ConfigOption> for BertConfig {
581    type Error = RustBertError;
582
583    fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
584        match config {
585            ConfigOption::Bert(config) | ConfigOption::Roberta(config) => Ok(config.clone()),
586            _ => Err(RustBertError::InvalidConfigurationError(
587                "You can only supply a BertConfig for Bert or a RobertaConfig for Roberta!"
588                    .to_string(),
589            )),
590        }
591    }
592}
593
594impl TryFrom<&ConfigOption> for DistilBertConfig {
595    type Error = RustBertError;
596
597    fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
598        if let ConfigOption::DistilBert(config) = config {
599            Ok(config.clone())
600        } else {
601            Err(RustBertError::InvalidConfigurationError(
602                "You can only supply a DistilBertConfig for DistilBert!".to_string(),
603            ))
604        }
605    }
606}
607
608impl TryFrom<&ConfigOption> for AlbertConfig {
609    type Error = RustBertError;
610
611    fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
612        if let ConfigOption::Albert(config) = config {
613            Ok(config.clone())
614        } else {
615            Err(RustBertError::InvalidConfigurationError(
616                "You can only supply an AlbertConfig for Albert!".to_string(),
617            ))
618        }
619    }
620}
621
622impl TryFrom<&ConfigOption> for T5Config {
623    type Error = RustBertError;
624
625    fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
626        if let ConfigOption::T5(config) = config {
627            Ok(config.clone())
628        } else {
629            Err(RustBertError::InvalidConfigurationError(
630                "You can only supply a T5Config for T5!".to_string(),
631            ))
632        }
633    }
634}
635
636impl TokenizerOption {
637    /// Interface method to load a tokenizer from file
638    pub fn from_file(
639        model_type: ModelType,
640        vocab_path: &str,
641        merges_path: Option<&str>,
642        lower_case: bool,
643        strip_accents: impl Into<Option<bool>>,
644        add_prefix_space: impl Into<Option<bool>>,
645    ) -> Result<Self, RustBertError> {
646        let strip_accents = strip_accents.into();
647        let add_prefix_space = add_prefix_space.into();
648
649        let tokenizer = match model_type {
650            ModelType::Bert
651            | ModelType::DistilBert
652            | ModelType::Electra
653            | ModelType::MobileBert => {
654                if add_prefix_space.is_some() {
655                    return Err(RustBertError::InvalidConfigurationError(
656                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
657                                add_prefix_space.unwrap(),
658                                model_type)));
659                }
660                TokenizerOption::Bert(BertTokenizer::from_file(
661                    vocab_path,
662                    lower_case,
663                    strip_accents.unwrap_or(lower_case),
664                )?)
665            }
666            ModelType::Deberta => {
667                if strip_accents.is_some() {
668                    return Err(RustBertError::InvalidConfigurationError(format!(
669                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
670                        strip_accents.unwrap(),
671                        model_type
672                    )));
673                }
674                if add_prefix_space.is_some() {
675                    return Err(RustBertError::InvalidConfigurationError(
676                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
677                                add_prefix_space.unwrap(),
678                                model_type)));
679                }
680                TokenizerOption::Deberta(DeBERTaTokenizer::from_file(
681                    vocab_path,
682                    merges_path.expect("No merges specified!"),
683                    lower_case,
684                )?)
685            }
686            ModelType::DebertaV2 => TokenizerOption::DebertaV2(DeBERTaV2Tokenizer::from_file(
687                vocab_path,
688                lower_case,
689                strip_accents.unwrap_or(false),
690                add_prefix_space.unwrap_or(false),
691            )?),
692            ModelType::Roberta | ModelType::Longformer => {
693                if strip_accents.is_some() {
694                    return Err(RustBertError::InvalidConfigurationError(format!(
695                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
696                        strip_accents.unwrap(),
697                        model_type
698                    )));
699                }
700                TokenizerOption::Roberta(RobertaTokenizer::from_file(
701                    vocab_path,
702                    merges_path.expect("No merges specified!"),
703                    lower_case,
704                    add_prefix_space.unwrap_or(false),
705                )?)
706            }
707            ModelType::Bart => {
708                if strip_accents.is_some() {
709                    return Err(RustBertError::InvalidConfigurationError(format!(
710                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
711                        strip_accents.unwrap(),
712                        model_type
713                    )));
714                }
715                TokenizerOption::Bart(RobertaTokenizer::from_file(
716                    vocab_path,
717                    merges_path.expect("No merges specified!"),
718                    lower_case,
719                    add_prefix_space.unwrap_or(false),
720                )?)
721            }
722            ModelType::Marian => {
723                if strip_accents.is_some() {
724                    return Err(RustBertError::InvalidConfigurationError(format!(
725                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
726                        strip_accents.unwrap(),
727                        model_type
728                    )));
729                }
730                if add_prefix_space.is_some() {
731                    return Err(RustBertError::InvalidConfigurationError(
732                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
733                                add_prefix_space.unwrap(),
734                                model_type)));
735                }
736                TokenizerOption::Marian(MarianTokenizer::from_files(
737                    vocab_path,
738                    merges_path.expect("No merges specified!"),
739                    lower_case,
740                )?)
741            }
742            ModelType::T5 | ModelType::LongT5 => {
743                if strip_accents.is_some() {
744                    return Err(RustBertError::InvalidConfigurationError(format!(
745                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
746                        strip_accents.unwrap(),
747                        model_type
748                    )));
749                }
750                if add_prefix_space.is_some() {
751                    return Err(RustBertError::InvalidConfigurationError(
752                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
753                                add_prefix_space.unwrap(),
754                                model_type)));
755                }
756                TokenizerOption::T5(T5Tokenizer::from_file(vocab_path, lower_case)?)
757            }
758            ModelType::XLMRoberta => {
759                if strip_accents.is_some() {
760                    return Err(RustBertError::InvalidConfigurationError(format!(
761                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
762                        strip_accents.unwrap(),
763                        model_type
764                    )));
765                }
766                if add_prefix_space.is_some() {
767                    return Err(RustBertError::InvalidConfigurationError(
768                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
769                                add_prefix_space.unwrap(),
770                                model_type)));
771                }
772                TokenizerOption::XLMRoberta(XLMRobertaTokenizer::from_file(vocab_path, lower_case)?)
773            }
774            ModelType::Albert => {
775                if strip_accents.is_some() {
776                    return Err(RustBertError::InvalidConfigurationError(format!(
777                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
778                        strip_accents.unwrap(),
779                        model_type
780                    )));
781                }
782                TokenizerOption::Albert(AlbertTokenizer::from_file(
783                    vocab_path,
784                    lower_case,
785                    strip_accents.unwrap_or(lower_case),
786                )?)
787            }
788            ModelType::XLNet => {
789                if add_prefix_space.is_some() {
790                    return Err(RustBertError::InvalidConfigurationError(
791                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
792                                add_prefix_space.unwrap(),
793                                model_type)));
794                }
795                TokenizerOption::XLNet(XLNetTokenizer::from_file(
796                    vocab_path,
797                    lower_case,
798                    strip_accents.unwrap_or(false),
799                )?)
800            }
801            ModelType::Reformer => {
802                if add_prefix_space.is_some() {
803                    return Err(RustBertError::InvalidConfigurationError(
804                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
805                                add_prefix_space.unwrap(),
806                                model_type)));
807                }
808                if strip_accents.is_some() {
809                    return Err(RustBertError::InvalidConfigurationError(
810                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
811                                add_prefix_space.unwrap(),
812                                model_type)));
813                }
814                TokenizerOption::Reformer(ReformerTokenizer::from_file(vocab_path, lower_case)?)
815            }
816            ModelType::GPT2 | ModelType::GPTNeo | ModelType::GPTJ => {
817                TokenizerOption::GPT2(Gpt2Tokenizer::from_file(
818                    vocab_path,
819                    merges_path.expect("No merges specified!"),
820                    lower_case,
821                )?)
822            }
823            ModelType::OpenAiGpt => TokenizerOption::OpenAiGpt(OpenAiGptTokenizer::from_file(
824                vocab_path,
825                merges_path.expect("No merges specified!"),
826                lower_case,
827            )?),
828            ModelType::ProphetNet => {
829                if add_prefix_space.is_some() {
830                    return Err(RustBertError::InvalidConfigurationError(
831                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
832                                add_prefix_space.unwrap(),
833                                model_type)));
834                }
835                TokenizerOption::ProphetNet(ProphetNetTokenizer::from_file(
836                    vocab_path,
837                    lower_case,
838                    strip_accents.unwrap_or(lower_case),
839                )?)
840            }
841            ModelType::Pegasus => {
842                if add_prefix_space.is_some() {
843                    return Err(RustBertError::InvalidConfigurationError(
844                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
845                                add_prefix_space.unwrap(),
846                                model_type)));
847                }
848                if strip_accents.is_some() {
849                    return Err(RustBertError::InvalidConfigurationError(format!(
850                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
851                        strip_accents.unwrap(),
852                        model_type
853                    )));
854                }
855                TokenizerOption::Pegasus(PegasusTokenizer::from_file(vocab_path, lower_case)?)
856            }
857            ModelType::MBart => {
858                if add_prefix_space.is_some() {
859                    return Err(RustBertError::InvalidConfigurationError(
860                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
861                                add_prefix_space.unwrap(),
862                                model_type)));
863                }
864                if strip_accents.is_some() {
865                    return Err(RustBertError::InvalidConfigurationError(format!(
866                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
867                        strip_accents.unwrap(),
868                        model_type
869                    )));
870                }
871                TokenizerOption::MBart50(MBart50Tokenizer::from_file(vocab_path, lower_case)?)
872            }
873            ModelType::M2M100 => {
874                if add_prefix_space.is_some() {
875                    return Err(RustBertError::InvalidConfigurationError(
876                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
877                                add_prefix_space.unwrap(),
878                                model_type)));
879                }
880                if strip_accents.is_some() {
881                    return Err(RustBertError::InvalidConfigurationError(format!(
882                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
883                        strip_accents.unwrap(),
884                        model_type
885                    )));
886                }
887                TokenizerOption::M2M100(M2M100Tokenizer::from_files(
888                    vocab_path,
889                    merges_path.expect("No merges specified!"),
890                    lower_case,
891                )?)
892            }
893            ModelType::NLLB => {
894                if add_prefix_space.is_some() {
895                    return Err(RustBertError::InvalidConfigurationError(
896                        format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
897                                add_prefix_space.unwrap(),
898                                model_type)));
899                }
900                if strip_accents.is_some() {
901                    return Err(RustBertError::InvalidConfigurationError(format!(
902                        "Optional input `strip_accents` set to value {} but cannot be used by {:?}",
903                        strip_accents.unwrap(),
904                        model_type
905                    )));
906                }
907                TokenizerOption::NLLB(NLLBTokenizer::from_files(
908                    vocab_path,
909                    merges_path.expect("No merges specified."),
910                )?)
911            }
912            ModelType::FNet => TokenizerOption::FNet(FNetTokenizer::from_file(
913                vocab_path,
914                lower_case,
915                strip_accents.unwrap_or(false),
916            )?),
917            #[cfg(feature = "onnx")]
918            ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
919                "Default Tokenizer not defined for generic ONNX models.".to_string(),
920            ))?,
921        };
922        Ok(tokenizer)
923    }
924
925    #[cfg(feature = "hf-tokenizers")]
926    pub fn from_hf_tokenizer_file<P: AsRef<Path>, S: AsRef<Path>>(
927        tokenizer_file: P,
928        special_token_map: S,
929    ) -> Result<Self, RustBertError> {
930        let hf_tokenizer = HFTokenizer::from_file(tokenizer_file, special_token_map)?;
931        Ok(TokenizerOption::HFTokenizer(hf_tokenizer))
932    }
933
934    /// Interface method
935    pub fn encode_list<S>(
936        &self,
937        text_list: &[S],
938        max_len: usize,
939        truncation_strategy: &TruncationStrategy,
940        stride: usize,
941    ) -> Vec<TokenizedInput>
942    where
943        S: AsRef<str> + Send + Sync,
944    {
945        match *self {
946            Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
947                tokenizer,
948                text_list,
949                max_len,
950                truncation_strategy,
951                stride,
952            ),
953            Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
954                tokenizer,
955                text_list,
956                max_len,
957                truncation_strategy,
958                stride,
959            ),
960            Self::DebertaV2(ref tokenizer) => MultiThreadedTokenizer::encode_list(
961                tokenizer,
962                text_list,
963                max_len,
964                truncation_strategy,
965                stride,
966            ),
967            Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
968                tokenizer,
969                text_list,
970                max_len,
971                truncation_strategy,
972                stride,
973            ),
974            Self::Bart(ref tokenizer) => MultiThreadedTokenizer::encode_list(
975                tokenizer,
976                text_list,
977                max_len,
978                truncation_strategy,
979                stride,
980            ),
981            Self::Marian(ref tokenizer) => MultiThreadedTokenizer::encode_list(
982                tokenizer,
983                text_list,
984                max_len,
985                truncation_strategy,
986                stride,
987            ),
988            Self::T5(ref tokenizer) => MultiThreadedTokenizer::encode_list(
989                tokenizer,
990                text_list,
991                max_len,
992                truncation_strategy,
993                stride,
994            ),
995            Self::XLMRoberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
996                tokenizer,
997                text_list,
998                max_len,
999                truncation_strategy,
1000                stride,
1001            ),
1002            Self::Albert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1003                tokenizer,
1004                text_list,
1005                max_len,
1006                truncation_strategy,
1007                stride,
1008            ),
1009            Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1010                tokenizer,
1011                text_list,
1012                max_len,
1013                truncation_strategy,
1014                stride,
1015            ),
1016            Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1017                tokenizer,
1018                text_list,
1019                max_len,
1020                truncation_strategy,
1021                stride,
1022            ),
1023            Self::OpenAiGpt(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1024                tokenizer,
1025                text_list,
1026                max_len,
1027                truncation_strategy,
1028                stride,
1029            ),
1030            Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1031                tokenizer,
1032                text_list,
1033                max_len,
1034                truncation_strategy,
1035                stride,
1036            ),
1037            Self::ProphetNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1038                tokenizer,
1039                text_list,
1040                max_len,
1041                truncation_strategy,
1042                stride,
1043            ),
1044            Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1045                tokenizer,
1046                text_list,
1047                max_len,
1048                truncation_strategy,
1049                stride,
1050            ),
1051            Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1052                tokenizer,
1053                text_list,
1054                max_len,
1055                truncation_strategy,
1056                stride,
1057            ),
1058            Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1059                tokenizer,
1060                text_list,
1061                max_len,
1062                truncation_strategy,
1063                stride,
1064            ),
1065            Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1066                tokenizer,
1067                text_list,
1068                max_len,
1069                truncation_strategy,
1070                stride,
1071            ),
1072            Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::encode_list(
1073                tokenizer,
1074                text_list,
1075                max_len,
1076                truncation_strategy,
1077                stride,
1078            ),
1079            #[cfg(feature = "hf-tokenizers")]
1080            Self::HFTokenizer(ref tokenizer) => tokenizer.encode_list(text_list).unwrap(),
1081        }
1082    }
1083
1084    /// Interface method for pair encoding
1085    pub fn encode_pair_list(
1086        &self,
1087        text_pair_list: &[(&str, &str)],
1088        max_len: usize,
1089        truncation_strategy: &TruncationStrategy,
1090        stride: usize,
1091    ) -> Vec<TokenizedInput> {
1092        match *self {
1093            Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1094                tokenizer,
1095                text_pair_list,
1096                max_len,
1097                truncation_strategy,
1098                stride,
1099            ),
1100            Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1101                tokenizer,
1102                text_pair_list,
1103                max_len,
1104                truncation_strategy,
1105                stride,
1106            ),
1107            Self::DebertaV2(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1108                tokenizer,
1109                text_pair_list,
1110                max_len,
1111                truncation_strategy,
1112                stride,
1113            ),
1114            Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1115                tokenizer,
1116                text_pair_list,
1117                max_len,
1118                truncation_strategy,
1119                stride,
1120            ),
1121            Self::Bart(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1122                tokenizer,
1123                text_pair_list,
1124                max_len,
1125                truncation_strategy,
1126                stride,
1127            ),
1128            Self::Marian(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1129                tokenizer,
1130                text_pair_list,
1131                max_len,
1132                truncation_strategy,
1133                stride,
1134            ),
1135            Self::T5(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1136                tokenizer,
1137                text_pair_list,
1138                max_len,
1139                truncation_strategy,
1140                stride,
1141            ),
1142            Self::XLMRoberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1143                tokenizer,
1144                text_pair_list,
1145                max_len,
1146                truncation_strategy,
1147                stride,
1148            ),
1149            Self::Albert(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1150                tokenizer,
1151                text_pair_list,
1152                max_len,
1153                truncation_strategy,
1154                stride,
1155            ),
1156            Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1157                tokenizer,
1158                text_pair_list,
1159                max_len,
1160                truncation_strategy,
1161                stride,
1162            ),
1163            Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1164                tokenizer,
1165                text_pair_list,
1166                max_len,
1167                truncation_strategy,
1168                stride,
1169            ),
1170            Self::OpenAiGpt(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1171                tokenizer,
1172                text_pair_list,
1173                max_len,
1174                truncation_strategy,
1175                stride,
1176            ),
1177            Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1178                tokenizer,
1179                text_pair_list,
1180                max_len,
1181                truncation_strategy,
1182                stride,
1183            ),
1184            Self::ProphetNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1185                tokenizer,
1186                text_pair_list,
1187                max_len,
1188                truncation_strategy,
1189                stride,
1190            ),
1191            Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1192                tokenizer,
1193                text_pair_list,
1194                max_len,
1195                truncation_strategy,
1196                stride,
1197            ),
1198            Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1199                tokenizer,
1200                text_pair_list,
1201                max_len,
1202                truncation_strategy,
1203                stride,
1204            ),
1205            Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1206                tokenizer,
1207                text_pair_list,
1208                max_len,
1209                truncation_strategy,
1210                stride,
1211            ),
1212            Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1213                tokenizer,
1214                text_pair_list,
1215                max_len,
1216                truncation_strategy,
1217                stride,
1218            ),
1219            Self::FNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
1220                tokenizer,
1221                text_pair_list,
1222                max_len,
1223                truncation_strategy,
1224                stride,
1225            ),
1226            #[cfg(feature = "hf-tokenizers")]
1227            Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair_list(text_pair_list).unwrap(),
1228        }
1229    }
1230
1231    /// Interface method for pair encoding (single input)
1232    pub fn encode_pair(
1233        &self,
1234        text_1: &str,
1235        text_2: Option<&str>,
1236        max_len: usize,
1237        truncation_strategy: &TruncationStrategy,
1238        stride: usize,
1239    ) -> TokenizedInput {
1240        match *self {
1241            Self::Bert(ref tokenizer) => {
1242                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1243            }
1244            Self::Deberta(ref tokenizer) => {
1245                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1246            }
1247            Self::DebertaV2(ref tokenizer) => {
1248                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1249            }
1250            Self::Roberta(ref tokenizer) => {
1251                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1252            }
1253            Self::Bart(ref tokenizer) => {
1254                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1255            }
1256            Self::Marian(ref tokenizer) => {
1257                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1258            }
1259            Self::T5(ref tokenizer) => {
1260                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1261            }
1262            Self::XLMRoberta(ref tokenizer) => {
1263                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1264            }
1265            Self::Albert(ref tokenizer) => {
1266                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1267            }
1268            Self::XLNet(ref tokenizer) => {
1269                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1270            }
1271            Self::GPT2(ref tokenizer) => {
1272                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1273            }
1274            Self::OpenAiGpt(ref tokenizer) => {
1275                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1276            }
1277            Self::Reformer(ref tokenizer) => {
1278                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1279            }
1280            Self::ProphetNet(ref tokenizer) => {
1281                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1282            }
1283            Self::Pegasus(ref tokenizer) => {
1284                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1285            }
1286            Self::MBart50(ref tokenizer) => {
1287                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1288            }
1289            Self::M2M100(ref tokenizer) => {
1290                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1291            }
1292            Self::NLLB(ref tokenizer) => {
1293                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1294            }
1295            Self::FNet(ref tokenizer) => {
1296                tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
1297            }
1298            #[cfg(feature = "hf-tokenizers")]
1299            Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair(text_1, text_2).unwrap(),
1300        }
1301    }
1302
1303    /// Interface method to tokenization
1304    pub fn tokenize(&self, text: &str) -> Vec<String> {
1305        match *self {
1306            Self::Bert(ref tokenizer) => tokenizer.tokenize(text),
1307            Self::Deberta(ref tokenizer) => tokenizer.tokenize(text),
1308            Self::DebertaV2(ref tokenizer) => tokenizer.tokenize(text),
1309            Self::Roberta(ref tokenizer) => tokenizer.tokenize(text),
1310            Self::Bart(ref tokenizer) => tokenizer.tokenize(text),
1311            Self::Marian(ref tokenizer) => tokenizer.tokenize(text),
1312            Self::T5(ref tokenizer) => tokenizer.tokenize(text),
1313            Self::XLMRoberta(ref tokenizer) => tokenizer.tokenize(text),
1314            Self::Albert(ref tokenizer) => tokenizer.tokenize(text),
1315            Self::XLNet(ref tokenizer) => tokenizer.tokenize(text),
1316            Self::GPT2(ref tokenizer) => tokenizer.tokenize(text),
1317            Self::OpenAiGpt(ref tokenizer) => tokenizer.tokenize(text),
1318            Self::Reformer(ref tokenizer) => tokenizer.tokenize(text),
1319            Self::ProphetNet(ref tokenizer) => tokenizer.tokenize(text),
1320            Self::Pegasus(ref tokenizer) => tokenizer.tokenize(text),
1321            Self::MBart50(ref tokenizer) => tokenizer.tokenize(text),
1322            Self::M2M100(ref tokenizer) => tokenizer.tokenize(text),
1323            Self::NLLB(ref tokenizer) => tokenizer.tokenize(text),
1324            Self::FNet(ref tokenizer) => tokenizer.tokenize(text),
1325            #[cfg(feature = "hf-tokenizers")]
1326            Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize(text),
1327        }
1328    }
1329
1330    /// Interface method to tokenization
1331    pub fn tokenize_with_offsets(&self, text: &str) -> TokensWithOffsets {
1332        match *self {
1333            Self::Bert(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1334            Self::Deberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1335            Self::DebertaV2(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1336            Self::Roberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1337            Self::Bart(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1338            Self::Marian(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1339            Self::T5(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1340            Self::XLMRoberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1341            Self::Albert(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1342            Self::XLNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1343            Self::GPT2(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1344            Self::OpenAiGpt(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1345            Self::Reformer(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1346            Self::ProphetNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1347            Self::Pegasus(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1348            Self::MBart50(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1349            Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1350            Self::NLLB(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1351            Self::FNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1352            #[cfg(feature = "hf-tokenizers")]
1353            Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
1354        }
1355    }
1356
1357    /// Interface method to tokenization
1358    pub fn tokenize_list<S>(&self, text: &[S]) -> Vec<Vec<String>>
1359    where
1360        S: AsRef<str> + Send + Sync,
1361    {
1362        match *self {
1363            Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1364            Self::Deberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1365            Self::DebertaV2(ref tokenizer) => {
1366                MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1367            }
1368            Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1369            Self::Bart(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1370            Self::Marian(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1371            Self::T5(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1372            Self::XLMRoberta(ref tokenizer) => {
1373                MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1374            }
1375            Self::Albert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1376            Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1377            Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1378            Self::OpenAiGpt(ref tokenizer) => {
1379                MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1380            }
1381            Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1382            Self::ProphetNet(ref tokenizer) => {
1383                MultiThreadedTokenizer::tokenize_list(tokenizer, text)
1384            }
1385            Self::Pegasus(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1386            Self::MBart50(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1387            Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1388            Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1389            Self::FNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
1390            #[cfg(feature = "hf-tokenizers")]
1391            Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_list(text),
1392        }
1393    }
1394
1395    /// Interface method to decoding
1396    pub fn decode(
1397        &self,
1398        token_ids: &[i64],
1399        skip_special_tokens: bool,
1400        clean_up_tokenization_spaces: bool,
1401    ) -> String {
1402        match *self {
1403            Self::Bert(ref tokenizer) => {
1404                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1405            }
1406            Self::Deberta(ref tokenizer) => {
1407                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1408            }
1409            Self::DebertaV2(ref tokenizer) => {
1410                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1411            }
1412            Self::Roberta(ref tokenizer) => {
1413                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1414            }
1415            Self::Bart(ref tokenizer) => {
1416                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1417            }
1418            Self::Marian(ref tokenizer) => {
1419                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1420            }
1421            Self::T5(ref tokenizer) => {
1422                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1423            }
1424            Self::XLMRoberta(ref tokenizer) => {
1425                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1426            }
1427            Self::Albert(ref tokenizer) => {
1428                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1429            }
1430            Self::XLNet(ref tokenizer) => {
1431                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1432            }
1433            Self::GPT2(ref tokenizer) => {
1434                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1435            }
1436            Self::OpenAiGpt(ref tokenizer) => {
1437                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1438            }
1439            Self::Reformer(ref tokenizer) => {
1440                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1441            }
1442            Self::ProphetNet(ref tokenizer) => {
1443                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1444            }
1445            Self::Pegasus(ref tokenizer) => {
1446                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1447            }
1448            Self::MBart50(ref tokenizer) => {
1449                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1450            }
1451            Self::M2M100(ref tokenizer) => {
1452                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1453            }
1454            Self::NLLB(ref tokenizer) => {
1455                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1456            }
1457            Self::FNet(ref tokenizer) => {
1458                tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
1459            }
1460            #[cfg(feature = "hf-tokenizers")]
1461            Self::HFTokenizer(ref tokenizer) => tokenizer.decode(token_ids, skip_special_tokens),
1462        }
1463    }
1464
1465    /// Interface method to build input with special tokens
1466    pub fn build_input_with_special_tokens(
1467        &self,
1468        token_ids_with_offsets_1: TokenIdsWithOffsets,
1469        token_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
1470    ) -> TokenizedInput {
1471        let token_ids_with_special_tokens = match *self {
1472            Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1473                token_ids_with_offsets_1,
1474                token_ids_with_offsets_2,
1475            ),
1476            Self::Deberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1477                token_ids_with_offsets_1,
1478                token_ids_with_offsets_2,
1479            ),
1480            Self::DebertaV2(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1481                token_ids_with_offsets_1,
1482                token_ids_with_offsets_2,
1483            ),
1484            Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1485                token_ids_with_offsets_1,
1486                token_ids_with_offsets_2,
1487            ),
1488            Self::Bart(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1489                token_ids_with_offsets_1,
1490                token_ids_with_offsets_2,
1491            ),
1492            Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1493                token_ids_with_offsets_1,
1494                token_ids_with_offsets_2,
1495            ),
1496            Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1497                token_ids_with_offsets_1,
1498                token_ids_with_offsets_2,
1499            ),
1500            Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1501                token_ids_with_offsets_1,
1502                token_ids_with_offsets_2,
1503            ),
1504            Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1505                token_ids_with_offsets_1,
1506                token_ids_with_offsets_2,
1507            ),
1508            Self::XLNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1509                token_ids_with_offsets_1,
1510                token_ids_with_offsets_2,
1511            ),
1512            Self::GPT2(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1513                token_ids_with_offsets_1,
1514                token_ids_with_offsets_2,
1515            ),
1516            Self::OpenAiGpt(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1517                token_ids_with_offsets_1,
1518                token_ids_with_offsets_2,
1519            ),
1520            Self::Reformer(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1521                token_ids_with_offsets_1,
1522                token_ids_with_offsets_2,
1523            ),
1524            Self::ProphetNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1525                token_ids_with_offsets_1,
1526                token_ids_with_offsets_2,
1527            ),
1528            Self::Pegasus(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1529                token_ids_with_offsets_1,
1530                token_ids_with_offsets_2,
1531            ),
1532            Self::MBart50(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1533                token_ids_with_offsets_1,
1534                token_ids_with_offsets_2,
1535            ),
1536            Self::M2M100(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1537                token_ids_with_offsets_1,
1538                token_ids_with_offsets_2,
1539            ),
1540            Self::NLLB(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1541                token_ids_with_offsets_1,
1542                token_ids_with_offsets_2,
1543            ),
1544            Self::FNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
1545                token_ids_with_offsets_1,
1546                token_ids_with_offsets_2,
1547            ),
1548            #[cfg(feature = "hf-tokenizers")]
1549            Self::HFTokenizer(ref tokenizer) => {
1550                return tokenizer.build_input_with_special_tokens(
1551                    token_ids_with_offsets_1,
1552                    token_ids_with_offsets_2,
1553                )
1554            }
1555        };
1556        TokenizedInput {
1557            token_ids: token_ids_with_special_tokens.token_ids,
1558            segment_ids: token_ids_with_special_tokens.segment_ids,
1559            special_tokens_mask: token_ids_with_special_tokens.special_tokens_mask,
1560            overflowing_tokens: vec![],
1561            num_truncated_tokens: 0,
1562            token_offsets: token_ids_with_special_tokens.token_offsets,
1563            reference_offsets: token_ids_with_special_tokens.reference_offsets,
1564            mask: token_ids_with_special_tokens.mask,
1565        }
1566    }
1567
1568    /// Helper function to prepare the input for translation models
1569    pub fn get_prefix_and_forced_bos_id(
1570        &self,
1571        source_language: Option<&Language>,
1572        target_language: Option<&Language>,
1573        supported_source_languages: &HashSet<Language>,
1574        supported_target_languages: &HashSet<Language>,
1575    ) -> Result<(Option<String>, Option<i64>), RustBertError> {
1576        if let Some(source_language) = source_language {
1577            if !supported_source_languages.contains(source_language) {
1578                return Err(RustBertError::ValueError(format!(
1579                        "{source_language} not in list of supported languages: {supported_source_languages:?}",
1580                    )));
1581            }
1582        }
1583
1584        if let Some(target_language) = target_language {
1585            if !supported_target_languages.contains(target_language) {
1586                return Err(RustBertError::ValueError(format!(
1587                        "{target_language} not in list of supported languages: {supported_target_languages:?}"
1588                    )));
1589            }
1590        }
1591
1592        Ok(match *self {
1593                Self::Marian(_) => {
1594                    if supported_target_languages.len() > 1 {
1595                        (
1596                            Some(format!(
1597                                ">>{}<< ",
1598                                target_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
1599                                    "Missing target language for Marian \
1600                                        (multiple languages supported by model: {supported_target_languages:?}, \
1601                                        need to specify target language)",
1602                                )))?
1603                            )),
1604                            None,
1605                        )
1606                    } else {
1607                        (None, None)
1608                    }
1609                }
1610                Self::T5(_) => (
1611                    Some(format!(
1612                        "translate {} to {}:",
1613                        source_language.ok_or_else(|| RustBertError::ValueError(
1614                            "Missing source language for T5".to_string(),
1615                        ))?,
1616                        target_language.ok_or_else(|| RustBertError::ValueError(
1617                            "Missing target language for T5".to_string(),
1618                        ))?,
1619                    )),
1620                    None,
1621                ),
1622                Self::MBart50(_) => {
1623                    (
1624                        Some(format!(
1625                            ">>{}<< ",
1626                            source_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
1627                                "Missing source language for MBart\
1628                                (multiple languages supported by model: {supported_source_languages:?}, \
1629                                need to specify target language)"
1630                            )))?
1631                        )),
1632                        if let Some(target_language) = target_language {
1633                            Some(
1634                                self.convert_tokens_to_ids(&[format!(
1635                                    ">>{}<<",
1636                                    target_language.get_iso_639_1_code().ok_or_else(|| {
1637                                        RustBertError::ValueError(format!(
1638                                            "This language has no ISO639-I code. Languages supported by model: {supported_source_languages:?}."
1639                                        ))
1640                                    })?
1641                                )])[0],
1642                            )
1643                        } else {
1644                            return Err(RustBertError::ValueError(format!(
1645                                "Missing target language for MBart\
1646                        (multiple languages supported by model: {supported_target_languages:?}, \
1647                        need to specify target language)"
1648                            )));
1649                        },
1650                    )
1651                }
1652                Self::M2M100(_) => (
1653                    Some(match source_language {
1654                        Some(value) => {
1655                            let language_code = value.get_iso_639_1_code().ok_or_else(|| {
1656                                RustBertError::ValueError(format!(
1657                                    "This language has no ISO639-I language code representation. \
1658                                languages supported by the model: {supported_target_languages:?}"
1659                                ))
1660                            })?;
1661                            match language_code.len() {
1662                                2 => format!(">>{language_code}.<< "),
1663                                3 => format!(">>{language_code}<< "),
1664                                _ => {
1665                                    return Err(RustBertError::ValueError(
1666                                        "Invalid ISO 639-I code".to_string(),
1667                                    ));
1668                                }
1669                            }
1670                        }
1671                        None => {
1672                            return Err(RustBertError::ValueError(format!(
1673                                "Missing source language for M2M100 \
1674                            (multiple languages supported by model: {supported_source_languages:?}, \
1675                            need to specify target language)"
1676                            )));
1677                        }
1678                    }),
1679                    if let Some(target_language) = target_language {
1680                        let language_code = target_language.get_iso_639_1_code().ok_or_else(|| {
1681                            RustBertError::ValueError(format!(
1682                                "This language has no ISO639-I language code representation. \
1683                            languages supported by the model: {supported_target_languages:?}"
1684                            ))
1685                        })?;
1686                        Some(
1687                            self.convert_tokens_to_ids(&[
1688                                match language_code.len() {
1689                                    2 => format!(">>{language_code}.<<"),
1690                                    3 => format!(">>{language_code}<<"),
1691                                    _ => {
1692                                        return Err(RustBertError::ValueError(
1693                                            "Invalid ISO 639-3 code".to_string(),
1694                                        ));
1695                                    }
1696                                },
1697                            ])[0],
1698                        )
1699                    } else {
1700                        return Err(RustBertError::ValueError(format!(
1701                            "Missing target language for M2M100 \
1702                        (multiple languages supported by model: {supported_target_languages:?}, \
1703                        need to specify target language)",
1704                        )));
1705                    },
1706                ),
1707                Self::NLLB(_) => {
1708                    let source_language = source_language
1709                        .and_then(Language::get_nllb_code)
1710                        .map(str::to_string)
1711                        .ok_or_else(|| RustBertError::ValueError(
1712                            format!("Missing source language for NLLB. Need to specify one from: {supported_source_languages:?}")
1713                        ))?;
1714
1715                    let target_language = target_language
1716                        .and_then(Language::get_nllb_code)
1717                        .map(str::to_string)
1718                        .map(|code| self.convert_tokens_to_ids(&[code])[0])
1719                        .ok_or_else(|| RustBertError::ValueError(
1720                            format!("Missing target language for NLLB. Need to specify one from: {supported_target_languages:?}")
1721                        ))?;
1722
1723                    (Some(source_language), Some(target_language))
1724                }
1725                _ => (None, None),
1726            })
1727    }
1728
1729    /// Interface method to convert tokens to ids
1730    pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
1731    where
1732        S: AsRef<str>,
1733    {
1734        match *self {
1735            Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1736            Self::Deberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1737            Self::DebertaV2(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1738            Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1739            Self::Bart(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1740            Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1741            Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1742            Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1743            Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1744            Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1745            Self::GPT2(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1746            Self::OpenAiGpt(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1747            Self::Reformer(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1748            Self::ProphetNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1749            Self::Pegasus(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1750            Self::MBart50(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1751            Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1752            Self::NLLB(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1753            Self::FNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1754            #[cfg(feature = "hf-tokenizers")]
1755            Self::HFTokenizer(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
1756        }
1757    }
1758
1759    /// Interface method
1760    pub fn get_unk_id(&self) -> i64 {
1761        match *self {
1762            Self::Bert(ref tokenizer) => {
1763                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1764                vocab.token_to_id(vocab.get_unknown_value())
1765            }
1766            Self::Deberta(ref tokenizer) => {
1767                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1768                vocab.token_to_id(vocab.get_unknown_value())
1769            }
1770            Self::DebertaV2(ref tokenizer) => {
1771                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1772                vocab.token_to_id(vocab.get_unknown_value())
1773            }
1774            Self::Roberta(ref tokenizer) => {
1775                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1776                vocab.token_to_id(vocab.get_unknown_value())
1777            }
1778            Self::Bart(ref tokenizer) => {
1779                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1780                vocab.token_to_id(vocab.get_unknown_value())
1781            }
1782            Self::XLMRoberta(ref tokenizer) => {
1783                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1784                vocab.token_to_id(vocab.get_unknown_value())
1785            }
1786            Self::Marian(ref tokenizer) => {
1787                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1788                vocab.token_to_id(vocab.get_unknown_value())
1789            }
1790            Self::T5(ref tokenizer) => {
1791                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1792                vocab.token_to_id(vocab.get_unknown_value())
1793            }
1794            Self::Albert(ref tokenizer) => {
1795                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1796                vocab.token_to_id(vocab.get_unknown_value())
1797            }
1798            Self::XLNet(ref tokenizer) => {
1799                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1800                vocab.token_to_id(vocab.get_unknown_value())
1801            }
1802            Self::GPT2(ref tokenizer) => {
1803                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1804                vocab.token_to_id(vocab.get_unknown_value())
1805            }
1806            Self::OpenAiGpt(ref tokenizer) => {
1807                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1808                vocab.token_to_id(vocab.get_unknown_value())
1809            }
1810            Self::Reformer(ref tokenizer) => {
1811                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1812                vocab.token_to_id(vocab.get_unknown_value())
1813            }
1814            Self::ProphetNet(ref tokenizer) => {
1815                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1816                vocab.token_to_id(vocab.get_unknown_value())
1817            }
1818            Self::Pegasus(ref tokenizer) => {
1819                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1820                vocab.token_to_id(vocab.get_unknown_value())
1821            }
1822            Self::MBart50(ref tokenizer) => {
1823                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1824                vocab.token_to_id(vocab.get_unknown_value())
1825            }
1826            Self::M2M100(ref tokenizer) => {
1827                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1828                vocab.token_to_id(vocab.get_unknown_value())
1829            }
1830            Self::NLLB(ref tokenizer) => {
1831                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1832                vocab.token_to_id(vocab.get_unknown_value())
1833            }
1834            Self::FNet(ref tokenizer) => {
1835                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1836                vocab.token_to_id(vocab.get_unknown_value())
1837            }
1838            #[cfg(feature = "hf-tokenizers")]
1839            Self::HFTokenizer(ref tokenizer) => {
1840                tokenizer.token_to_id(&tokenizer.special_token_map.unk_token)
1841            }
1842        }
1843    }
1844
1845    /// Interface method
1846    pub fn get_pad_id(&self) -> Option<i64> {
1847        match *self {
1848            Self::Bert(ref tokenizer) => {
1849                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1850                Some(vocab.token_to_id(vocab.get_pad_value()))
1851            }
1852            Self::Deberta(ref tokenizer) => {
1853                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1854                Some(vocab.token_to_id(vocab.get_pad_value()))
1855            }
1856            Self::DebertaV2(ref tokenizer) => {
1857                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1858                Some(vocab.token_to_id(vocab.get_pad_value()))
1859            }
1860            Self::Roberta(ref tokenizer) => {
1861                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1862                Some(vocab.token_to_id(vocab.get_pad_value()))
1863            }
1864            Self::Bart(ref tokenizer) => {
1865                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1866                Some(vocab.token_to_id(vocab.get_pad_value()))
1867            }
1868            Self::XLMRoberta(ref tokenizer) => {
1869                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1870                Some(vocab.token_to_id(vocab.get_pad_value()))
1871            }
1872            Self::Marian(ref tokenizer) => {
1873                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1874                Some(vocab.token_to_id(vocab.get_pad_value()))
1875            }
1876            Self::T5(ref tokenizer) => {
1877                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1878                Some(vocab.token_to_id(vocab.get_pad_value()))
1879            }
1880            Self::Albert(ref tokenizer) => {
1881                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1882                Some(vocab.token_to_id(vocab.get_pad_value()))
1883            }
1884            Self::XLNet(ref tokenizer) => {
1885                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1886                Some(vocab.token_to_id(vocab.get_pad_value()))
1887            }
1888            Self::ProphetNet(ref tokenizer) => {
1889                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1890                Some(vocab.token_to_id(vocab.get_pad_value()))
1891            }
1892            Self::Pegasus(ref tokenizer) => {
1893                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1894                Some(vocab.token_to_id(vocab.get_pad_value()))
1895            }
1896            Self::MBart50(ref tokenizer) => {
1897                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1898                Some(vocab.token_to_id(vocab.get_pad_value()))
1899            }
1900            Self::M2M100(ref tokenizer) => {
1901                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1902                Some(vocab.token_to_id(vocab.get_pad_value()))
1903            }
1904            Self::NLLB(ref tokenizer) => {
1905                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1906                Some(vocab.token_to_id(vocab.get_pad_value()))
1907            }
1908            Self::FNet(ref tokenizer) => {
1909                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1910                Some(vocab.token_to_id(vocab.get_pad_value()))
1911            }
1912            #[cfg(feature = "hf-tokenizers")]
1913            Self::HFTokenizer(ref tokenizer) => tokenizer
1914                .special_token_map
1915                .pad_token
1916                .as_ref()
1917                .map(|token| tokenizer.token_to_id(token)),
1918            Self::Reformer(_) => None,
1919            Self::GPT2(_) => None,
1920            Self::OpenAiGpt(_) => None,
1921        }
1922    }
1923
1924    /// Interface method
1925    pub fn get_sep_id(&self) -> Option<i64> {
1926        match *self {
1927            Self::Bert(ref tokenizer) => {
1928                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1929                Some(vocab.token_to_id(vocab.get_sep_value()))
1930            }
1931            Self::Deberta(ref tokenizer) => {
1932                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1933                Some(vocab.token_to_id(vocab.get_sep_value()))
1934            }
1935            Self::DebertaV2(ref tokenizer) => {
1936                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1937                Some(vocab.token_to_id(vocab.get_sep_value()))
1938            }
1939            Self::Roberta(ref tokenizer) => {
1940                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1941                Some(vocab.token_to_id(vocab.get_sep_value()))
1942            }
1943            Self::Bart(ref tokenizer) => {
1944                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1945                Some(vocab.token_to_id(vocab.get_sep_value()))
1946            }
1947            Self::XLMRoberta(ref tokenizer) => {
1948                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1949                Some(vocab.token_to_id(vocab.get_sep_value()))
1950            }
1951            Self::Albert(ref tokenizer) => {
1952                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1953                Some(vocab.token_to_id(vocab.get_sep_value()))
1954            }
1955            Self::XLNet(ref tokenizer) => {
1956                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1957                Some(vocab.token_to_id(vocab.get_sep_value()))
1958            }
1959            Self::ProphetNet(ref tokenizer) => {
1960                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1961                Some(vocab.token_to_id(vocab.get_sep_value()))
1962            }
1963            Self::MBart50(ref tokenizer) => {
1964                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1965                Some(vocab.token_to_id(vocab.get_sep_value()))
1966            }
1967            Self::M2M100(ref tokenizer) => {
1968                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1969                Some(vocab.token_to_id(vocab.get_sep_value()))
1970            }
1971            Self::NLLB(ref tokenizer) => {
1972                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1973                Some(vocab.token_to_id(vocab.get_sep_value()))
1974            }
1975            Self::FNet(ref tokenizer) => {
1976                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1977                Some(vocab.token_to_id(vocab.get_sep_value()))
1978            }
1979            #[cfg(feature = "hf-tokenizers")]
1980            Self::HFTokenizer(ref tokenizer) => tokenizer
1981                .special_token_map
1982                .sep_token
1983                .as_ref()
1984                .map(|token| tokenizer.token_to_id(token)),
1985            Self::Marian(_) => None,
1986            Self::T5(_) => None,
1987            Self::GPT2(_) => None,
1988            Self::OpenAiGpt(_) => None,
1989            Self::Reformer(_) => None,
1990            Self::Pegasus(_) => None,
1991        }
1992    }
1993
1994    /// Interface method
1995    pub fn get_mask_id(&self) -> Option<i64> {
1996        match *self {
1997            Self::Bert(ref tokenizer) => {
1998                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
1999                Some(vocab.token_to_id(vocab.get_mask_value()))
2000            }
2001            Self::Deberta(ref tokenizer) => {
2002                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2003                Some(vocab.token_to_id(vocab.get_mask_value()))
2004            }
2005            Self::DebertaV2(ref tokenizer) => {
2006                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2007                Some(vocab.token_to_id(vocab.get_mask_value()))
2008            }
2009            Self::Roberta(ref tokenizer) => {
2010                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2011                Some(vocab.token_to_id(vocab.get_mask_value()))
2012            }
2013            Self::Bart(ref tokenizer) => {
2014                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2015                Some(vocab.token_to_id(vocab.get_mask_value()))
2016            }
2017            Self::XLMRoberta(ref tokenizer) => {
2018                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2019                Some(vocab.token_to_id(vocab.get_mask_value()))
2020            }
2021            Self::Albert(ref tokenizer) => {
2022                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2023                Some(vocab.token_to_id(vocab.get_mask_value()))
2024            }
2025            Self::XLNet(ref tokenizer) => {
2026                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2027                Some(vocab.token_to_id(vocab.get_mask_value()))
2028            }
2029            Self::ProphetNet(ref tokenizer) => {
2030                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2031                Some(vocab.token_to_id(vocab.get_mask_value()))
2032            }
2033            Self::MBart50(ref tokenizer) => {
2034                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2035                Some(vocab.token_to_id(vocab.get_mask_value()))
2036            }
2037            Self::FNet(ref tokenizer) => {
2038                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2039                Some(vocab.token_to_id(vocab.get_mask_value()))
2040            }
2041            Self::Pegasus(ref tokenizer) => {
2042                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2043                Some(vocab.token_to_id(vocab.get_mask_value()))
2044            }
2045            #[cfg(feature = "hf-tokenizers")]
2046            Self::HFTokenizer(ref tokenizer) => tokenizer
2047                .special_token_map
2048                .mask_token
2049                .as_ref()
2050                .map(|token| tokenizer.token_to_id(token)),
2051            Self::Marian(_) => None,
2052            Self::M2M100(_) => None,
2053            Self::NLLB(_) => None,
2054            Self::T5(_) => None,
2055            Self::GPT2(_) => None,
2056            Self::OpenAiGpt(_) => None,
2057            Self::Reformer(_) => None,
2058        }
2059    }
2060
2061    /// Interface method
2062    pub fn get_mask_value(&self) -> Option<&str> {
2063        match self {
2064            Self::Bert(ref tokenizer) => {
2065                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2066            }
2067            Self::Deberta(ref tokenizer) => {
2068                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2069            }
2070            Self::DebertaV2(ref tokenizer) => {
2071                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2072            }
2073            Self::Roberta(ref tokenizer) => {
2074                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2075            }
2076            Self::Bart(ref tokenizer) => {
2077                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2078            }
2079            Self::XLMRoberta(ref tokenizer) => {
2080                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2081            }
2082            Self::Albert(ref tokenizer) => {
2083                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2084            }
2085            Self::XLNet(ref tokenizer) => {
2086                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2087            }
2088            Self::ProphetNet(ref tokenizer) => {
2089                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2090            }
2091            Self::MBart50(ref tokenizer) => {
2092                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2093            }
2094            Self::FNet(ref tokenizer) => {
2095                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2096            }
2097            Self::Pegasus(ref tokenizer) => {
2098                Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
2099            }
2100            #[cfg(feature = "hf-tokenizers")]
2101            Self::HFTokenizer(ref tokenizer) => tokenizer.special_token_map.mask_token.as_deref(),
2102            Self::M2M100(_) => None,
2103            Self::NLLB(_) => None,
2104            Self::Marian(_) => None,
2105            Self::T5(_) => None,
2106            Self::GPT2(_) => None,
2107            Self::OpenAiGpt(_) => None,
2108            Self::Reformer(_) => None,
2109        }
2110    }
2111
2112    /// Interface method
2113    pub fn get_bos_id(&self) -> Option<i64> {
2114        match *self {
2115            Self::Roberta(ref tokenizer) => {
2116                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2117                Some(vocab.token_to_id(vocab.get_bos_value()))
2118            }
2119            Self::Bart(ref tokenizer) => {
2120                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2121                Some(vocab.token_to_id(vocab.get_bos_value()))
2122            }
2123            Self::DebertaV2(ref tokenizer) => {
2124                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2125                Some(vocab.token_to_id(vocab.get_bos_value()))
2126            }
2127            Self::XLMRoberta(ref tokenizer) => {
2128                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2129                Some(vocab.token_to_id(vocab.get_bos_value()))
2130            }
2131            Self::Albert(ref tokenizer) => {
2132                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2133                Some(vocab.token_to_id(vocab.get_bos_value()))
2134            }
2135            Self::XLNet(ref tokenizer) => {
2136                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2137                Some(vocab.token_to_id(vocab.get_bos_value()))
2138            }
2139            Self::M2M100(ref tokenizer) => {
2140                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2141                Some(vocab.token_to_id(vocab.get_bos_value()))
2142            }
2143            Self::NLLB(ref tokenizer) => {
2144                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2145                Some(vocab.token_to_id(vocab.get_bos_value()))
2146            }
2147            Self::GPT2(ref tokenizer) => {
2148                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2149                Some(vocab.token_to_id(vocab.get_bos_value()))
2150            }
2151            Self::Deberta(ref tokenizer) => {
2152                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2153                Some(vocab.token_to_id(vocab.get_bos_value()))
2154            }
2155            #[cfg(feature = "hf-tokenizers")]
2156            Self::HFTokenizer(ref tokenizer) => tokenizer
2157                .special_token_map
2158                .bos_token
2159                .as_ref()
2160                .map(|token| tokenizer.token_to_id(token)),
2161            Self::MBart50(_) => Some(0),
2162            Self::FNet(_) => None,
2163            Self::Bert(_) => None,
2164            Self::Marian(_) => Some(0),
2165            Self::T5(_) => None,
2166            Self::ProphetNet(_) => None,
2167            Self::OpenAiGpt(_) => None,
2168            Self::Reformer(_) => None,
2169            Self::Pegasus(_) => Some(0),
2170        }
2171    }
2172
2173    /// Interface method
2174    pub fn get_eos_id(&self) -> Option<i64> {
2175        match *self {
2176            Self::Roberta(ref tokenizer) => {
2177                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2178                Some(vocab.token_to_id(vocab.get_eos_value()))
2179            }
2180            Self::Bart(ref tokenizer) => {
2181                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2182                Some(vocab.token_to_id(vocab.get_eos_value()))
2183            }
2184            Self::DebertaV2(ref tokenizer) => {
2185                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2186                Some(vocab.token_to_id(vocab.get_eos_value()))
2187            }
2188            Self::XLMRoberta(ref tokenizer) => {
2189                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2190                Some(vocab.token_to_id(vocab.get_eos_value()))
2191            }
2192            Self::Albert(ref tokenizer) => {
2193                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2194                Some(vocab.token_to_id(vocab.get_eos_value()))
2195            }
2196            Self::XLNet(ref tokenizer) => {
2197                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2198                Some(vocab.token_to_id(vocab.get_eos_value()))
2199            }
2200            Self::MBart50(ref tokenizer) => {
2201                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2202                Some(vocab.token_to_id(vocab.get_eos_value()))
2203            }
2204            Self::M2M100(ref tokenizer) => {
2205                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2206                Some(vocab.token_to_id(vocab.get_eos_value()))
2207            }
2208            Self::NLLB(ref tokenizer) => {
2209                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2210                Some(vocab.token_to_id(vocab.get_eos_value()))
2211            }
2212            Self::GPT2(ref tokenizer) => {
2213                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2214                Some(vocab.token_to_id(vocab.get_eos_value()))
2215            }
2216            Self::Deberta(ref tokenizer) => {
2217                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2218                Some(vocab.token_to_id(vocab.get_eos_value()))
2219            }
2220            Self::Marian(ref tokenizer) => {
2221                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2222                Some(vocab.token_to_id(vocab.get_eos_value()))
2223            }
2224            Self::T5(ref tokenizer) => {
2225                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2226                Some(vocab.token_to_id(vocab.get_eos_value()))
2227            }
2228            Self::Reformer(ref tokenizer) => {
2229                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2230                Some(vocab.token_to_id(vocab.get_eos_value()))
2231            }
2232            Self::Pegasus(ref tokenizer) => {
2233                let vocab = MultiThreadedTokenizer::vocab(tokenizer);
2234                Some(vocab.token_to_id(vocab.get_eos_value()))
2235            }
2236            #[cfg(feature = "hf-tokenizers")]
2237            Self::HFTokenizer(ref tokenizer) => tokenizer
2238                .special_token_map
2239                .eos_token
2240                .as_ref()
2241                .map(|token| tokenizer.token_to_id(token)),
2242            Self::FNet(_) => None,
2243            Self::Bert(_) => None,
2244            Self::ProphetNet(_) => None,
2245            Self::OpenAiGpt(_) => None,
2246        }
2247    }
2248
2249    pub fn tokenize_and_pad<'a, S>(
2250        &self,
2251        input: S,
2252        max_length: usize,
2253        device: Device,
2254    ) -> (Tensor, Tensor)
2255    where
2256        S: AsRef<[&'a str]>,
2257    {
2258        let mut tokenized_input: Vec<TokenizedInput> = self.encode_list(
2259            input.as_ref(),
2260            max_length,
2261            &TruncationStrategy::LongestFirst,
2262            0,
2263        );
2264        let max_len = tokenized_input
2265            .iter()
2266            .map(|input| input.token_ids.len())
2267            .max()
2268            .unwrap();
2269        let pad_id = self
2270            .get_pad_id()
2271            .expect("The Tokenizer used for sequence classification should contain a PAD id");
2272        let tokenized_input_tensors: Vec<Tensor> = tokenized_input
2273            .iter_mut()
2274            .map(|input| {
2275                input.token_ids.resize(max_len, pad_id);
2276                Tensor::from_slice(&(input.token_ids))
2277            })
2278            .collect::<Vec<_>>();
2279
2280        let token_type_ids: Vec<Tensor> = tokenized_input
2281            .iter_mut()
2282            .map(|input| {
2283                input
2284                    .segment_ids
2285                    .resize(max_len, *input.segment_ids.last().unwrap_or(&0));
2286                Tensor::from_slice(&(input.segment_ids))
2287            })
2288            .collect::<Vec<_>>();
2289
2290        (
2291            Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(device),
2292            Tensor::stack(token_type_ids.as_slice(), 0)
2293                .to(device)
2294                .to_kind(Kind::Int64),
2295        )
2296    }
2297
2298    /// Interface method
2299    pub fn add_extra_ids(&mut self, num_extra_ids: i64) {
2300        match *self {
2301            Self::Bert(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2302            Self::Deberta(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2303            Self::DebertaV2(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2304            Self::Roberta(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2305            Self::Bart(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2306            Self::Marian(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2307            Self::T5(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2308            Self::XLMRoberta(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2309            Self::Albert(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2310            Self::XLNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2311            Self::GPT2(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2312            Self::OpenAiGpt(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2313            Self::Reformer(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2314            Self::ProphetNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2315            Self::Pegasus(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2316            Self::MBart50(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2317            Self::M2M100(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2318            Self::NLLB(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2319            Self::FNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2320            #[cfg(feature = "hf-tokenizers")]
2321            Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
2322        }
2323    }
2324
2325    /// Interface method
2326    pub fn add_tokens(&mut self, tokens: &[&str]) {
2327        match *self {
2328            Self::Bert(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2329            Self::Deberta(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2330            Self::DebertaV2(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2331            Self::Roberta(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2332            Self::Bart(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2333            Self::Marian(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2334            Self::T5(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2335            Self::XLMRoberta(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2336            Self::Albert(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2337            Self::XLNet(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2338            Self::GPT2(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2339            Self::OpenAiGpt(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2340            Self::Reformer(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2341            Self::ProphetNet(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2342            Self::Pegasus(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2343            Self::MBart50(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2344            Self::M2M100(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2345            Self::NLLB(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2346            Self::FNet(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2347            #[cfg(feature = "hf-tokenizers")]
2348            Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_tokens(tokens),
2349        }
2350    }
2351}
2352
2353pub fn cast_var_store(varstore: &mut VarStore, kind: Option<Kind>, device: Device) {
2354    match (kind, device) {
2355        (Some(kind), _) => varstore.set_kind(kind),
2356        (None, Device::Cpu) => varstore.set_kind(Kind::Float),
2357        (None, _) => {}
2358    }
2359}