use crate::albert::AlbertConfig;
use crate::bart::BartConfig;
use crate::bert::BertConfig;
use crate::common::error::RustBertError;
use crate::distilbert::DistilBertConfig;
use crate::electra::ElectraConfig;
use crate::gpt2::Gpt2Config;
use crate::longformer::LongformerConfig;
use crate::mobilebert::MobileBertConfig;
use crate::prophetnet::ProphetNetConfig;
use crate::reformer::ReformerConfig;
use crate::t5::T5Config;
use crate::xlnet::XLNetConfig;
use crate::Config;
use rust_tokenizers::tokenizer::{
AlbertTokenizer, BertTokenizer, Gpt2Tokenizer, MarianTokenizer, MultiThreadedTokenizer,
OpenAiGptTokenizer, ProphetNetTokenizer, ReformerTokenizer, RobertaTokenizer, T5Tokenizer,
Tokenizer, TruncationStrategy, XLMRobertaTokenizer, XLNetTokenizer,
};
use rust_tokenizers::vocab::{
AlbertVocab, BertVocab, Gpt2Vocab, MarianVocab, OpenAiGptVocab, ProphetNetVocab, ReformerVocab,
RobertaVocab, T5Vocab, Vocab, XLMRobertaVocab, XLNetVocab,
};
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
pub enum ModelType {
Bart,
Bert,
DistilBert,
Roberta,
XLMRoberta,
Electra,
Marian,
MobileBert,
T5,
Albert,
XLNet,
GPT2,
OpenAiGpt,
Reformer,
ProphetNet,
Longformer,
}
pub enum ConfigOption {
Bart(BartConfig),
Bert(BertConfig),
DistilBert(DistilBertConfig),
Electra(ElectraConfig),
Marian(BartConfig),
MobileBert(MobileBertConfig),
T5(T5Config),
Albert(AlbertConfig),
XLNet(XLNetConfig),
GPT2(Gpt2Config),
Reformer(ReformerConfig),
ProphetNet(ProphetNetConfig),
Longformer(LongformerConfig),
}
pub enum TokenizerOption {
Bert(BertTokenizer),
Roberta(RobertaTokenizer),
XLMRoberta(XLMRobertaTokenizer),
Marian(MarianTokenizer),
T5(T5Tokenizer),
Albert(AlbertTokenizer),
XLNet(XLNetTokenizer),
GPT2(Gpt2Tokenizer),
OpenAiGpt(OpenAiGptTokenizer),
Reformer(ReformerTokenizer),
ProphetNet(ProphetNetTokenizer),
}
impl ConfigOption {
pub fn from_file<P: AsRef<Path>>(model_type: ModelType, path: P) -> Self {
match model_type {
ModelType::Bart => ConfigOption::Bart(BartConfig::from_file(path)),
ModelType::Bert | ModelType::Roberta | ModelType::XLMRoberta => {
ConfigOption::Bert(BertConfig::from_file(path))
}
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
ModelType::Marian => ConfigOption::Marian(BartConfig::from_file(path)),
ModelType::MobileBert => ConfigOption::MobileBert(MobileBertConfig::from_file(path)),
ModelType::T5 => ConfigOption::T5(T5Config::from_file(path)),
ModelType::Albert => ConfigOption::Albert(AlbertConfig::from_file(path)),
ModelType::XLNet => ConfigOption::XLNet(XLNetConfig::from_file(path)),
ModelType::GPT2 => ConfigOption::GPT2(Gpt2Config::from_file(path)),
ModelType::OpenAiGpt => ConfigOption::GPT2(Gpt2Config::from_file(path)),
ModelType::Reformer => ConfigOption::Reformer(ReformerConfig::from_file(path)),
ModelType::ProphetNet => ConfigOption::ProphetNet(ProphetNetConfig::from_file(path)),
ModelType::Longformer => ConfigOption::Longformer(LongformerConfig::from_file(path)),
}
}
pub fn get_label_mapping(self) -> HashMap<i64, String> {
match self {
Self::Bart(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Bert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::DistilBert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Electra(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Marian(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::MobileBert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Albert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::XLNet(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Reformer(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::ProphetNet(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Longformer(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
Self::GPT2(_) => panic!("GPT2 does not use a label mapping"),
}
}
}
impl TokenizerOption {
pub fn from_file(
model_type: ModelType,
vocab_path: &str,
merges_path: Option<&str>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> Result<Self, RustBertError> {
let strip_accents = strip_accents.into();
let add_prefix_space = add_prefix_space.into();
let tokenizer = match model_type {
ModelType::Bert
| ModelType::DistilBert
| ModelType::Electra
| ModelType::MobileBert => {
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::Bert(BertTokenizer::from_file(
vocab_path,
lower_case,
strip_accents.unwrap_or(lower_case),
)?)
}
ModelType::Roberta | ModelType::Bart | ModelType::Longformer => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
strip_accents.unwrap(),
model_type
)));
}
TokenizerOption::Roberta(RobertaTokenizer::from_file(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
add_prefix_space.unwrap_or(false),
)?)
}
ModelType::Marian => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
strip_accents.unwrap(),
model_type
)));
}
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::Marian(MarianTokenizer::from_files(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)?)
}
ModelType::T5 => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
strip_accents.unwrap(),
model_type
)));
}
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::T5(T5Tokenizer::from_file(vocab_path, lower_case)?)
}
ModelType::XLMRoberta => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
strip_accents.unwrap(),
model_type
)));
}
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::XLMRoberta(XLMRobertaTokenizer::from_file(vocab_path, lower_case)?)
}
ModelType::Albert => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
strip_accents.unwrap(),
model_type
)));
}
TokenizerOption::Albert(AlbertTokenizer::from_file(
vocab_path,
lower_case,
strip_accents.unwrap_or(lower_case),
)?)
}
ModelType::XLNet => {
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::XLNet(XLNetTokenizer::from_file(
vocab_path,
lower_case,
strip_accents.unwrap(),
)?)
}
ModelType::Reformer => {
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::Reformer(ReformerTokenizer::from_file(vocab_path, lower_case)?)
}
ModelType::GPT2 => TokenizerOption::GPT2(Gpt2Tokenizer::from_file(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)?),
ModelType::OpenAiGpt => TokenizerOption::OpenAiGpt(OpenAiGptTokenizer::from_file(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)?),
ModelType::ProphetNet => {
if add_prefix_space.is_some() {
return Err(RustBertError::InvalidConfigurationError(
format!("Optional input `add_prefix_space` set to value {} but cannot be used by {:?}",
add_prefix_space.unwrap(),
model_type)));
}
TokenizerOption::ProphetNet(ProphetNetTokenizer::from_file(
vocab_path,
lower_case,
strip_accents.unwrap_or(lower_case),
)?)
}
};
Ok(tokenizer)
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::GPT2(_) => ModelType::GPT2,
Self::OpenAiGpt(_) => ModelType::OpenAiGpt,
Self::Reformer(_) => ModelType::Reformer,
Self::ProphetNet(_) => ModelType::ProphetNet,
}
}
pub fn encode_list(
&self,
text_list: &[&str],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::Marian(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::T5(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::XLMRoberta(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::Albert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::OpenAiGpt(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
Self::ProphetNet(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
text_list,
max_len,
truncation_strategy,
stride,
),
}
}
pub fn encode_pair_list(
&self,
text_pair_list: &[(&str, &str)],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::Marian(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::T5(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::XLMRoberta(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::Albert(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::OpenAiGpt(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
Self::ProphetNet(ref tokenizer) => MultiThreadedTokenizer::encode_pair_list(
tokenizer,
text_pair_list,
max_len,
truncation_strategy,
stride,
),
}
}
pub fn encode_pair(
&self,
text_1: &str,
text_2: Option<&str>,
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> TokenizedInput {
match *self {
Self::Bert(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::Roberta(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::Marian(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::T5(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::XLMRoberta(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::Albert(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::XLNet(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::GPT2(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::OpenAiGpt(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::Reformer(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
Self::ProphetNet(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
}
}
pub fn tokenize(&self, text: &str) -> Vec<String> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.tokenize(text),
Self::Roberta(ref tokenizer) => tokenizer.tokenize(text),
Self::Marian(ref tokenizer) => tokenizer.tokenize(text),
Self::T5(ref tokenizer) => tokenizer.tokenize(text),
Self::XLMRoberta(ref tokenizer) => tokenizer.tokenize(text),
Self::Albert(ref tokenizer) => tokenizer.tokenize(text),
Self::XLNet(ref tokenizer) => tokenizer.tokenize(text),
Self::GPT2(ref tokenizer) => tokenizer.tokenize(text),
Self::OpenAiGpt(ref tokenizer) => tokenizer.tokenize(text),
Self::Reformer(ref tokenizer) => tokenizer.tokenize(text),
Self::ProphetNet(ref tokenizer) => tokenizer.tokenize(text),
}
}
pub fn tokenize_with_offsets(&self, text: &str) -> TokensWithOffsets {
match *self {
Self::Bert(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Roberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Marian(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::T5(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::XLMRoberta(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Albert(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::XLNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::GPT2(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::OpenAiGpt(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::Reformer(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::ProphetNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
}
}
pub fn tokenize_list(&self, text: &[&str]) -> Vec<Vec<String>> {
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::Marian(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::T5(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::XLMRoberta(ref tokenizer) => {
MultiThreadedTokenizer::tokenize_list(tokenizer, text)
}
Self::Albert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::XLNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::GPT2(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::OpenAiGpt(ref tokenizer) => {
MultiThreadedTokenizer::tokenize_list(tokenizer, text)
}
Self::Reformer(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::ProphetNet(ref tokenizer) => {
MultiThreadedTokenizer::tokenize_list(tokenizer, text)
}
}
}
pub fn decode(
&self,
token_ids: Vec<i64>,
skip_special_tokens: bool,
clean_up_tokenization_spaces: bool,
) -> String {
match *self {
Self::Bert(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::Roberta(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::Marian(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::T5(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::XLMRoberta(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::Albert(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::XLNet(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::GPT2(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::OpenAiGpt(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::Reformer(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
Self::ProphetNet(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
}
}
pub fn build_input_with_special_tokens(
&self,
token_ids_with_offsets_1: TokenIdsWithOffsets,
token_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
) -> TokenizedInput {
let token_ids_with_special_tokens = match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::XLNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::GPT2(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::OpenAiGpt(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Reformer(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::ProphetNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
};
TokenizedInput {
token_ids: token_ids_with_special_tokens.token_ids,
segment_ids: token_ids_with_special_tokens.segment_ids,
special_tokens_mask: token_ids_with_special_tokens.special_tokens_mask,
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets: token_ids_with_special_tokens.token_offsets,
reference_offsets: token_ids_with_special_tokens.reference_offsets,
mask: token_ids_with_special_tokens.mask,
}
}
pub fn convert_tokens_to_ids<S, ST>(&self, tokens: S) -> Vec<i64>
where
S: AsRef<[ST]>,
ST: AsRef<str>,
{
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::GPT2(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::OpenAiGpt(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Reformer(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::ProphetNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}
pub fn get_unk_id(&self) -> i64 {
match *self {
Self::Bert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Roberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::XLMRoberta(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Marian(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MarianVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::T5(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(T5Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Albert(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::XLNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::GPT2(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(Gpt2Vocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::OpenAiGpt(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(OpenAiGptVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::Reformer(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ReformerVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
Self::ProphetNet(ref tokenizer) => *MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::unknown_value())
.expect("UNK token not found in vocabulary"),
}
}
pub fn get_pad_id(&self) -> Option<i64> {
match *self {
Self::Bert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Marian(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MarianVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::T5(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(T5Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::ProphetNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Reformer(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
}
}
pub fn get_sep_id(&self) -> Option<i64> {
match *self {
Self::Bert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::ProphetNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
Self::Reformer(_) => None,
}
}
}