use crate::common::resources::{RemoteResource, Resource};
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianPrefix, MarianSpmResources,
MarianVocabResources,
};
use crate::pipelines::common::ModelType;
use crate::pipelines::generation::{
GenerateConfig, LanguageGenerator, MarianGenerator, T5Generator,
};
use crate::t5::{T5ConfigResources, T5ModelResources, T5Prefix, T5VocabResources};
use tch::{Device, Tensor};
pub enum Language {
FrenchToEnglish,
CatalanToEnglish,
SpanishToEnglish,
PortugueseToEnglish,
ItalianToEnglish,
RomanianToEnglish,
GermanToEnglish,
RussianToEnglish,
EnglishToFrench,
EnglishToCatalan,
EnglishToSpanish,
EnglishToPortuguese,
EnglishToItalian,
EnglishToRomanian,
EnglishToGerman,
EnglishToRussian,
EnglishToFrenchV2,
EnglishToGermanV2,
FrenchToGerman,
GermanToFrench,
}
struct RemoteTranslationResources;
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2FRENCH,
ModelType::Marian,
);
pub const ENGLISH2FRENCH_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2FRENCH,
ModelType::T5,
);
pub const ENGLISH2GERMAN_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2GERMAN,
ModelType::T5,
);
pub const ENGLISH2CATALAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2CATALAN,
ModelType::Marian,
);
pub const ENGLISH2SPANISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2SPANISH,
ModelType::Marian,
);
pub const ENGLISH2PORTUGUESE: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2PORTUGUESE,
ModelType::Marian,
);
pub const ENGLISH2ITALIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ITALIAN,
ModelType::Marian,
);
pub const ENGLISH2ROMANIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ROMANIAN,
ModelType::Marian,
);
pub const ENGLISH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
MarianPrefix::ENGLISH2GERMAN,
ModelType::Marian,
);
pub const ENGLISH2RUSSIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
MarianPrefix::ENGLISH2RUSSIAN,
ModelType::Marian,
);
pub const FRENCH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::FRENCH2ENGLISH,
ModelType::Marian,
);
pub const CATALAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::CATALAN2ENGLISH,
ModelType::Marian,
);
pub const SPANISH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::SPANISH2ENGLISH,
ModelType::Marian,
);
pub const PORTUGUESE2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::PORTUGUESE2ENGLISH,
ModelType::Marian,
);
pub const ITALIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ITALIAN2ENGLISH,
ModelType::Marian,
);
pub const ROMANIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ROMANIAN2ENGLISH,
ModelType::Marian,
);
pub const GERMAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
MarianPrefix::GERMAN2ENGLISH,
ModelType::Marian,
);
pub const RUSSIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
MarianPrefix::RUSSIAN2ENGLISH,
ModelType::Marian,
);
pub const FRENCH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::FRENCH2GERMAN,
MarianConfigResources::FRENCH2GERMAN,
MarianVocabResources::FRENCH2GERMAN,
MarianSpmResources::FRENCH2GERMAN,
MarianPrefix::FRENCH2GERMAN,
ModelType::Marian,
);
pub const GERMAN2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2FRENCH,
MarianConfigResources::GERMAN2FRENCH,
MarianVocabResources::GERMAN2FRENCH,
MarianSpmResources::GERMAN2FRENCH,
MarianPrefix::GERMAN2FRENCH,
ModelType::Marian,
);
}
pub struct TranslationConfig {
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Resource,
pub min_length: u64,
pub max_length: u64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: u64,
pub temperature: f64,
pub top_k: u64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: u64,
pub num_return_sequences: u64,
pub device: Device,
pub prefix: Option<String>,
pub model_type: ModelType,
}
impl TranslationConfig {
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix, model_type) =
match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
let prefix = match prefix {
Some(value) => Some(value.to_string()),
None => None,
};
TranslationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
min_length: 0,
max_length: 512,
do_sample: false,
early_stopping: false,
num_beams: 6,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
device,
prefix,
model_type,
}
}
pub fn new_from_resources(
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
sentence_piece_resource: Resource,
prefix: Option<String>,
device: Device,
model_type: ModelType,
) -> TranslationConfig {
TranslationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource: sentence_piece_resource,
min_length: 0,
max_length: 512,
do_sample: false,
early_stopping: false,
num_beams: 6,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
device,
prefix,
model_type,
}
}
}
pub enum TranslationOption {
Marian(MarianGenerator),
T5(T5Generator),
}
impl TranslationOption {
pub fn new(config: TranslationConfig) -> Self {
let generate_config = GenerateConfig {
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
device: config.device,
};
match config.model_type {
ModelType::Marian => {
TranslationOption::Marian(MarianGenerator::new(generate_config).unwrap())
}
ModelType::T5 => TranslationOption::T5(T5Generator::new(generate_config).unwrap()),
ModelType::Bert => {
panic!("Translation not implemented for Electra!");
}
ModelType::DistilBert => {
panic!("Translation not implemented for DistilBert!");
}
ModelType::Roberta => {
panic!("Translation not implemented for Roberta!");
}
ModelType::XLMRoberta => {
panic!("Translation not implemented for XLMRoberta!");
}
ModelType::Electra => {
panic!("Translation not implemented for Electra!");
}
ModelType::Albert => {
panic!("Translation not implemented for Albert!");
}
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
}
}
pub fn generate(
&self,
prompt_texts: Option<Vec<&str>>,
attention_mask: Option<Tensor>,
) -> Vec<String> {
match *self {
Self::Marian(ref model) => model.generate(prompt_texts, attention_mask),
Self::T5(ref model) => model.generate(prompt_texts, attention_mask),
}
}
}
pub struct TranslationModel {
model: TranslationOption,
prefix: Option<String>,
}
impl TranslationModel {
pub fn new(translation_config: TranslationConfig) -> failure::Fallible<TranslationModel> {
let prefix = translation_config.prefix.clone();
let model = TranslationOption::new(translation_config);
Ok(TranslationModel { model, prefix })
}
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
match &self.prefix {
Some(value) => {
let texts: Vec<String> = texts
.into_iter()
.map(|&v| format!("{} {}", value, v))
.collect();
self.model
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
}
None => self.model.generate(Some(texts.to_vec()), None),
}
}
}