use itertools::Itertools;
use tch::{Device, Tensor};
use crate::common::error::RustBertError;
use crate::common::resources::RemoteResource;
use crate::gpt2::{
GPT2Generator, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::openai_gpt::OpenAIGenerator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::reformer::ReformerGenerator;
use crate::resources::Resource;
use crate::xlnet::XLNetGenerator;
pub struct TextGenerationConfig {
pub model_type: ModelType,
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Resource,
pub min_length: i64,
pub max_length: i64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub device: Device,
}
impl TextGenerationConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
) -> TextGenerationConfig {
TextGenerationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
device: Device::cuda_if_available(),
..Default::default()
}
}
}
impl Default for TextGenerationConfig {
fn default() -> TextGenerationConfig {
TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2_MEDIUM,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::GPT2_MEDIUM,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::GPT2_MEDIUM,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2_MEDIUM,
)),
min_length: 0,
max_length: 20,
do_sample: true,
early_stopping: true,
num_beams: 5,
temperature: 1.0,
top_k: 0,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
}
}
}
impl From<TextGenerationConfig> for GenerateConfig {
fn from(config: TextGenerationConfig) -> GenerateConfig {
GenerateConfig {
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
}
}
}
pub enum TextGenerationOption {
GPT2(GPT2Generator),
GPT(OpenAIGenerator),
XLNet(XLNetGenerator),
Reformer(ReformerGenerator),
}
impl TextGenerationOption {
pub fn new(config: TextGenerationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::GPT2 => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
config.into(),
)?)),
ModelType::OpenAiGpt => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
config.into(),
)?)),
ModelType::XLNet => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
config.into(),
)?)),
ModelType::Reformer => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
config.into(),
)?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Text generation not implemented for {:?}!",
config.model_type
))),
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::GPT2(_) => ModelType::GPT2,
Self::GPT(_) => ModelType::OpenAiGpt,
Self::XLNet(_) => ModelType::XLNet,
Self::Reformer(_) => ModelType::Reformer,
}
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
Self::GPT(model_ref) => model_ref.get_tokenizer(),
Self::XLNet(model_ref) => model_ref.get_tokenizer(),
Self::Reformer(model_ref) => model_ref.get_tokenizer(),
}
}
pub fn generate_indices<'a, S>(
&self,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
min_length: Option<i64>,
max_length: Option<i64>,
) -> Vec<Vec<i64>>
where
S: AsRef<[&'a str]>,
{
match *self {
Self::GPT2(ref model) => {
model.generate_indices(prompt_texts, attention_mask, min_length, max_length, None)
}
Self::GPT(ref model) => {
model.generate_indices(prompt_texts, attention_mask, min_length, max_length, None)
}
Self::XLNet(ref model) => {
model.generate_indices(prompt_texts, attention_mask, min_length, max_length, None)
}
Self::Reformer(ref model) => {
model.generate_indices(prompt_texts, attention_mask, min_length, max_length, None)
}
}
}
}
pub struct TextGenerationModel {
model: TextGenerationOption,
prefix: Option<String>,
prefix_length: Option<i64>,
min_length: i64,
max_length: i64,
}
impl TextGenerationModel {
pub fn new(
generation_config: TextGenerationConfig,
) -> Result<TextGenerationModel, RustBertError> {
let prefix = match generation_config.model_type {
ModelType::XLNet => Some(
"In 1991, the remains of Russian Tsar Nicholas II and his family \
(except for Alexei and Maria) are discovered. \
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the \
remainder of the story. 1883 Western Siberia, \
a young Grigori Rasputin is asked by his father and a group of men to perform magic. \
Rasputin has a vision and denounces one of the men as a horse thief. Although his \
father initially slaps him for making such an accusation, Rasputin watches as the \
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of \
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, \
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
.to_string(),
),
_ => None,
};
let min_length = generation_config.min_length;
let max_length = generation_config.max_length;
let model = TextGenerationOption::new(generation_config)?;
let prefix_length = if let Some(prefix) = &prefix {
Some(model.get_tokenizer().tokenize(prefix).len() as i64)
} else {
None
};
Ok(TextGenerationModel {
model,
prefix,
prefix_length,
min_length,
max_length,
})
}
pub fn generate<'a, S>(&self, texts: S, prefix: impl Into<Option<&'a str>>) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
(Some(query_prefix), _) => (
Some(query_prefix),
Some(self.model.get_tokenizer().tokenize(query_prefix).len() as i64),
),
(None, Some(pipeline_prefix)) => (Some(pipeline_prefix.as_str()), self.prefix_length),
(None, None) => (None, None),
};
let generated_indices = match (prefix, prefix_length) {
(None, _) => self.model.generate_indices(Some(texts), None, None, None),
(Some(prefix), Some(prefix_length)) => {
let texts = texts
.as_ref()
.iter()
.map(|text| format!("{} {}", prefix, text))
.collect_vec();
self.model.generate_indices(
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
None,
Some(self.min_length + prefix_length),
Some(self.max_length + prefix_length),
)
}
_ => panic!("Prefix length not defined but prefix provided!"),
};
let mut output = Vec::with_capacity(generated_indices.len());
for generated_sequence in generated_indices {
output.push(self.model.get_tokenizer().decode(
if prefix_length.is_some() {
generated_sequence
.into_iter()
.skip(prefix_length.unwrap_or(0) as usize)
.collect_vec()
} else {
generated_sequence
},
true,
true,
));
}
output
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore]
fn test() {
let config = TextGenerationConfig::default();
let _: Box<dyn Send> = Box::new(TextGenerationModel::new(config));
}
}