use crate::pipelines::generation::{MarianGenerator, GenerateConfig, LanguageGenerator};
use tch::Device;
use crate::common::resources::{Resource, RemoteResource};
use crate::marian::{MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
pub enum Language {
FrenchToEnglish,
CatalanToEnglish,
SpanishToEnglish,
PortugueseToEnglish,
ItalianToEnglish,
RomanianToEnglish,
GermanToEnglish,
RussianToEnglish,
EnglishToFrench,
EnglishToCatalan,
EnglishToSpanish,
EnglishToPortuguese,
EnglishToItalian,
EnglishToRomanian,
EnglishToGerman,
EnglishToRussian,
FrenchToGerman,
GermanToFrench,
}
struct RemoteTranslationResources;
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2FRENCH);
pub const ENGLISH2CATALAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2CATALAN);
pub const ENGLISH2SPANISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2SPANISH);
pub const ENGLISH2PORTUGUESE: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2PORTUGUESE);
pub const ENGLISH2ITALIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ITALIAN);
pub const ENGLISH2ROMANIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ROMANIAN);
pub const ENGLISH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2GERMAN, MarianConfigResources::ENGLISH2GERMAN, MarianVocabResources::ENGLISH2GERMAN, MarianSpmResources::ENGLISH2GERMAN, MarianPrefix::ENGLISH2GERMAN);
pub const ENGLISH2RUSSIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2RUSSIAN, MarianConfigResources::ENGLISH2RUSSIAN, MarianVocabResources::ENGLISH2RUSSIAN, MarianSpmResources::ENGLISH2RUSSIAN, MarianPrefix::ENGLISH2RUSSIAN);
pub const FRENCH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::FRENCH2ENGLISH);
pub const CATALAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::CATALAN2ENGLISH);
pub const SPANISH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::SPANISH2ENGLISH);
pub const PORTUGUESE2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::PORTUGUESE2ENGLISH);
pub const ITALIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ITALIAN2ENGLISH);
pub const ROMANIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ROMANIAN2ENGLISH);
pub const GERMAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2ENGLISH, MarianConfigResources::GERMAN2ENGLISH, MarianVocabResources::GERMAN2ENGLISH, MarianSpmResources::GERMAN2ENGLISH, MarianPrefix::GERMAN2ENGLISH);
pub const RUSSIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::RUSSIAN2ENGLISH, MarianConfigResources::RUSSIAN2ENGLISH, MarianVocabResources::RUSSIAN2ENGLISH, MarianSpmResources::RUSSIAN2ENGLISH, MarianPrefix::RUSSIAN2ENGLISH);
pub const FRENCH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::FRENCH2GERMAN, MarianConfigResources::FRENCH2GERMAN, MarianVocabResources::FRENCH2GERMAN, MarianSpmResources::FRENCH2GERMAN, MarianPrefix::FRENCH2GERMAN);
pub const GERMAN2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2FRENCH, MarianConfigResources::GERMAN2FRENCH, MarianVocabResources::GERMAN2FRENCH, MarianSpmResources::GERMAN2FRENCH, MarianPrefix::GERMAN2FRENCH);
}
pub struct TranslationConfig {
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Resource,
pub min_length: u64,
pub max_length: u64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: u64,
pub temperature: f64,
pub top_k: u64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: u64,
pub num_return_sequences: u64,
pub device: Device,
pub prefix: Option<String>,
}
impl TranslationConfig {
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) = match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
let prefix = match prefix {
Some(value) => Some(value.to_string()),
None => None
};
TranslationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
min_length: 0,
max_length: 512,
do_sample: false,
early_stopping: false,
num_beams: 6,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
device,
prefix,
}
}
pub fn new_from_resources(model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
sentence_piece_resource: Resource,
prefix: Option<String>,
device: Device) -> TranslationConfig {
TranslationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource: sentence_piece_resource,
min_length: 0,
max_length: 512,
do_sample: false,
early_stopping: false,
num_beams: 6,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
device,
prefix,
}
}
}
pub struct TranslationModel {
model: MarianGenerator,
prefix: Option<String>,
}
impl TranslationModel {
pub fn new(translation_config: TranslationConfig)
-> failure::Fallible<TranslationModel> {
let generate_config = GenerateConfig {
model_resource: translation_config.model_resource,
config_resource: translation_config.config_resource,
merges_resource: translation_config.merges_resource,
vocab_resource: translation_config.vocab_resource,
min_length: translation_config.min_length,
max_length: translation_config.max_length,
do_sample: translation_config.do_sample,
early_stopping: translation_config.early_stopping,
num_beams: translation_config.num_beams,
temperature: translation_config.temperature,
top_k: translation_config.top_k,
top_p: translation_config.top_p,
repetition_penalty: translation_config.repetition_penalty,
length_penalty: translation_config.length_penalty,
no_repeat_ngram_size: translation_config.no_repeat_ngram_size,
num_return_sequences: translation_config.num_return_sequences,
device: translation_config.device,
};
let model = MarianGenerator::new(generate_config)?;
Ok(TranslationModel { model, prefix: translation_config.prefix })
}
pub fn translate(&mut self, texts: &[&str]) -> Vec<String> {
match &self.prefix {
Some(value) => {
let texts: Vec<String> = texts
.into_iter()
.map(|&v| { format!("{} {}", value, v) })
.collect();
self.model.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
}
None => self.model.generate(Some(texts.to_vec()), None)
}
}
}