use tch::{Device, Tensor};
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::marian::{
MarianConfigResources, MarianGenerator, MarianModelResources, MarianPrefix, MarianSpmResources,
MarianVocabResources,
};
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::t5::{T5ConfigResources, T5Generator, T5ModelResources, T5Prefix, T5VocabResources};
pub enum Language {
FrenchToEnglish,
CatalanToEnglish,
SpanishToEnglish,
PortugueseToEnglish,
ItalianToEnglish,
RomanianToEnglish,
GermanToEnglish,
RussianToEnglish,
EnglishToFrench,
EnglishToCatalan,
EnglishToSpanish,
EnglishToPortuguese,
EnglishToItalian,
EnglishToRomanian,
EnglishToGerman,
EnglishToRussian,
EnglishToFrenchV2,
EnglishToGermanV2,
FrenchToGerman,
GermanToFrench,
}
struct RemoteTranslationResources {
model_resource: (&'static str, &'static str),
config_resource: (&'static str, &'static str),
vocab_resource: (&'static str, &'static str),
merges_resource: (&'static str, &'static str),
prefix: Option<&'static str>,
model_type: ModelType,
}
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2FRENCH,
model_type: ModelType::Marian,
};
pub const ENGLISH2FRENCH_V2: RemoteTranslationResources = Self {
model_resource: T5ModelResources::T5_BASE,
config_resource: T5ConfigResources::T5_BASE,
vocab_resource: T5VocabResources::T5_BASE,
merges_resource: T5VocabResources::T5_BASE,
prefix: T5Prefix::ENGLISH2FRENCH,
model_type: ModelType::T5,
};
pub const ENGLISH2GERMAN_V2: RemoteTranslationResources = Self {
model_resource: T5ModelResources::T5_BASE,
config_resource: T5ConfigResources::T5_BASE,
vocab_resource: T5VocabResources::T5_BASE,
merges_resource: T5VocabResources::T5_BASE,
prefix: T5Prefix::ENGLISH2GERMAN,
model_type: ModelType::T5,
};
pub const ENGLISH2CATALAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2CATALAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2SPANISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2SPANISH,
model_type: ModelType::Marian,
};
pub const ENGLISH2PORTUGUESE: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2PORTUGUESE,
model_type: ModelType::Marian,
};
pub const ENGLISH2ITALIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2ITALIAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2ROMANIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2ROMANIAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2GERMAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2GERMAN,
config_resource: MarianConfigResources::ENGLISH2GERMAN,
vocab_resource: MarianVocabResources::ENGLISH2GERMAN,
merges_resource: MarianSpmResources::ENGLISH2GERMAN,
prefix: MarianPrefix::ENGLISH2GERMAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2RUSSIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2RUSSIAN,
config_resource: MarianConfigResources::ENGLISH2RUSSIAN,
vocab_resource: MarianVocabResources::ENGLISH2RUSSIAN,
merges_resource: MarianSpmResources::ENGLISH2RUSSIAN,
prefix: MarianPrefix::ENGLISH2RUSSIAN,
model_type: ModelType::Marian,
};
pub const FRENCH2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::FRENCH2ENGLISH,
model_type: ModelType::Marian,
};
pub const CATALAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::CATALAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const SPANISH2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::SPANISH2ENGLISH,
model_type: ModelType::Marian,
};
pub const PORTUGUESE2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::PORTUGUESE2ENGLISH,
model_type: ModelType::Marian,
};
pub const ITALIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::ITALIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const ROMANIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::ROMANIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const GERMAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::GERMAN2ENGLISH,
config_resource: MarianConfigResources::GERMAN2ENGLISH,
vocab_resource: MarianVocabResources::GERMAN2ENGLISH,
merges_resource: MarianSpmResources::GERMAN2ENGLISH,
prefix: MarianPrefix::GERMAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const RUSSIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::RUSSIAN2ENGLISH,
config_resource: MarianConfigResources::RUSSIAN2ENGLISH,
vocab_resource: MarianVocabResources::RUSSIAN2ENGLISH,
merges_resource: MarianSpmResources::RUSSIAN2ENGLISH,
prefix: MarianPrefix::RUSSIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const FRENCH2GERMAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::FRENCH2GERMAN,
config_resource: MarianConfigResources::FRENCH2GERMAN,
vocab_resource: MarianVocabResources::FRENCH2GERMAN,
merges_resource: MarianSpmResources::FRENCH2GERMAN,
prefix: MarianPrefix::FRENCH2GERMAN,
model_type: ModelType::Marian,
};
pub const GERMAN2FRENCH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::GERMAN2FRENCH,
config_resource: MarianConfigResources::GERMAN2FRENCH,
vocab_resource: MarianVocabResources::GERMAN2FRENCH,
merges_resource: MarianSpmResources::GERMAN2FRENCH,
prefix: MarianPrefix::GERMAN2FRENCH,
model_type: ModelType::Marian,
};
}
pub struct TranslationConfig {
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Resource,
pub min_length: i64,
pub max_length: i64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub device: Device,
pub prefix: Option<String>,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub model_type: ModelType,
}
impl TranslationConfig {
pub fn new(language: Language, device: Device) -> TranslationConfig {
let translation_resource = match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.model_resource,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.config_resource,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.vocab_resource,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.merges_resource,
));
let prefix = match translation_resource.prefix {
Some(value) => Some(value.to_string()),
None => None,
};
TranslationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
min_length: 0,
max_length: 512,
do_sample: false,
early_stopping: true,
num_beams: 6,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
device,
prefix,
num_beam_groups: None,
diversity_penalty: None,
model_type: translation_resource.model_type,
}
}
pub fn new_from_resources(
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
sentence_piece_resource: Resource,
prefix: Option<String>,
device: Device,
model_type: ModelType,
) -> TranslationConfig {
TranslationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource: sentence_piece_resource,
min_length: 0,
max_length: 512,
do_sample: false,
early_stopping: true,
num_beams: 6,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
device,
prefix,
num_beam_groups: None,
diversity_penalty: None,
model_type,
}
}
}
impl From<TranslationConfig> for GenerateConfig {
fn from(config: TranslationConfig) -> GenerateConfig {
GenerateConfig {
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
}
}
}
pub enum TranslationOption {
Marian(MarianGenerator),
T5(T5Generator),
}
impl TranslationOption {
pub fn new(config: TranslationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::Marian => Ok(TranslationOption::Marian(MarianGenerator::new(
config.into(),
)?)),
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Translation not implemented for {:?}!",
config.model_type
))),
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
}
}
pub fn generate<'a, S>(
&self,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
match *self {
Self::Marian(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None)
}
Self::T5(ref model) => model.generate(prompt_texts, attention_mask, None, None, None),
}
}
}
pub struct TranslationModel {
model: TranslationOption,
prefix: Option<String>,
}
impl TranslationModel {
pub fn new(translation_config: TranslationConfig) -> Result<TranslationModel, RustBertError> {
let prefix = translation_config.prefix.clone();
let model = TranslationOption::new(translation_config)?;
Ok(TranslationModel { model, prefix })
}
pub fn translate<'a, S>(&self, texts: S) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
match &self.prefix {
Some(value) => {
let texts = texts
.as_ref()
.iter()
.map(|&v| format!("{}{}", value, v))
.collect::<Vec<String>>();
self.model.generate(
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
None,
)
}
None => self.model.generate(Some(texts), None),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore]
fn test() {
let config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
let _: Box<dyn Send> = Box::new(TranslationModel::new(config));
}
}