pub trait LanguageGenerator: PrivateLanguageGenerator {
// Provided methods
fn generate<S>(
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions<'_>>,
) -> Result<Vec<GeneratedTextOutput>, RustBertError>
where S: AsRef<str> + Send + Sync { ... }
fn generate_indices<S>(
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions<'_>>,
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
where S: AsRef<str> + Send + Sync { ... }
fn generate_from_ids_and_past(
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
generate_options: Option<GenerateOptions<'_>>,
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError> { ... }
fn get_tokenizer(&self) -> &TokenizerOption { ... }
fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption { ... }
fn half(&mut self) -> Result<(), RustBertError> { ... }
fn float(&mut self) -> Result<(), RustBertError> { ... }
fn set_device(&mut self, device: Device) -> Result<(), RustBertError> { ... }
}
Expand description
§Common trait for text generation models.
Main API for text generation
Provided Methods§
Sourcefn generate<S>(
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions<'_>>,
) -> Result<Vec<GeneratedTextOutput>, RustBertError>
fn generate<S>( &self, prompt_texts: Option<&[S]>, generate_options: Option<GenerateOptions<'_>>, ) -> Result<Vec<GeneratedTextOutput>, RustBertError>
Generate text based on a vector of promp texts.
§Arguments
prompt_texts
-Option<Vec<&str>>
Optional vector of text prompts. An empty prompt to the model may be passed if the model implement abos_id
.generate_options
-Option<GenerateOptions>
Optional set of generate options. If not (or partially) provided, will use the settings provided when creating the generator
§Returns
Vec<TextOutput>
Vector of length number_of_prompts x num_return_sequences containing TextOutput with the generated texts and the generation score ifoutput_scores
is true.
§Example
use rust_bert::gpt2::GPT2Generator;
use rust_bert::pipelines::generation_utils::{
GenerateConfig, GenerateOptions, LanguageGenerator,
};
use tch::Tensor;
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: Some(30),
do_sample: true,
num_beams: 5,
temperature: 1.1,
num_return_sequences: 3,
..Default::default()
};
let gpt2_generator = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";
//Example custom function for fine-grained generation control
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
let paragraph_tokens = [198, 628];
for paragraph_token in paragraph_tokens.iter() {
if previous_token_ids
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>()
.contains(paragraph_token)
{
return vec![50256];
}
}
(0..50255).collect()
}
let generate_options = GenerateOptions {
min_length: Some(32),
max_length: Some(128),
output_scores: true,
prefix_allowed_tokens_fn: Some(&force_one_paragraph),
..Default::default()
};
let output = gpt2_generator.generate(
Some(&[input_context, second_input_context]),
Some(generate_options),
);
Example output: \
[
"The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
"The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
"The dog has been able to stay in the home for more than three months now. \"It's a very good dog. She's",
"The cat was discovered earlier this month in the home of a relative of the deceased. The cat\'s owner, who wished to remain anonymous,",
"The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
"The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
]
Sourcefn generate_indices<S>(
&self,
prompt_texts: Option<&[S]>,
generate_options: Option<GenerateOptions<'_>>,
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
fn generate_indices<S>( &self, prompt_texts: Option<&[S]>, generate_options: Option<GenerateOptions<'_>>, ) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
§Arguments
prompt_texts
-Option<Vec<&str>>
Optional vector of text prompts. An empty prompt to the model may be passed if the model implement abos_id
.generate_options
-Option<GenerateOptions>
Optional set of generate options. If not (or partially) provided, will use the settings provided when creating the generator
§Returns
Vec<IndicesOutput>
Vector of length number_of_prompts x num_return_sequences containing IndicesOutput with the generated indices and the generation score ifoutput_scores
is true.
§Example
use rust_bert::gpt2::GPT2Generator;
use rust_bert::pipelines::generation_utils::{
GenerateConfig, GenerateOptions, LanguageGenerator,
};
use tch::Tensor;
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: Some(30),
do_sample: true,
num_beams: 5,
temperature: 1.1,
num_return_sequences: 3,
..Default::default()
};
let gpt2_generator = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";
//Example custom function for fine-grained generation control
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
let paragraph_tokens = [198, 628];
for paragraph_token in paragraph_tokens.iter() {
if previous_token_ids
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>()
.contains(paragraph_token)
{
return vec![50256];
}
}
(0..50255).collect()
}
let generate_options = GenerateOptions {
min_length: Some(32),
max_length: Some(128),
output_scores: true,
prefix_allowed_tokens_fn: Some(&force_one_paragraph),
..Default::default()
};
let output = gpt2_generator.generate_indices(
Some(&[input_context, second_input_context]),
Some(generate_options),
);
Sourcefn generate_from_ids_and_past(
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
generate_options: Option<GenerateOptions<'_>>,
) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
fn generate_from_ids_and_past( &self, input_ids: Tensor, attention_mask: Option<Tensor>, generate_options: Option<GenerateOptions<'_>>, ) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
Generate token indices given a list of indices (useful when the input has been pre-tokenized). Returns a list of output tokens that need to be decoded using a tokenizer.
§Arguments
input_ids
-Tensor
pre-tokenized and encoded input for generation.generate_options
-Option<GenerateOptions>
Optional set of generate options. If not (or partially) provided, will use the settings provided when creating the generator
§Returns
Vec<IndicesOutput>
Vector of length number_of_prompts x num_return_sequences containing IndicesOutput with the generated indices and the generation score ifoutput_scores
is true.
§Example
use rust_bert::gpt2::GPT2Generator;
use rust_bert::pipelines::generation_utils::{
GenerateConfig, GenerateOptions, LanguageGenerator,
};
use tch::{Kind, Tensor};
let device = Device::cuda_if_available();
let gpt2_generator = GPT2Generator::new(Default::default())?;
let input_tensor = Tensor::randn(&[32, 128], (Kind::Int64, Device::Cpu));
let input_mask = Tensor::ones(&[32, 128], (Kind::Int64, Device::Cpu));
let generate_options = GenerateOptions {
min_length: Some(32),
max_length: Some(128),
output_scores: true,
..Default::default()
};
let output = gpt2_generator.generate_from_ids_and_past(
input_tensor,
Some(input_mask),
Some(generate_options),
);
Sourcefn get_tokenizer(&self) -> &TokenizerOption
fn get_tokenizer(&self) -> &TokenizerOption
Returns a reference to the text generator’s tokenizer
§Returns
&TokenizerOption
Reference to the generator’s tokenizer.
§Example
use rust_bert::gpt2::GPT2Generator;
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use tch::Tensor;
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: Some(30),
do_sample: true,
num_beams: 5,
temperature: 1.1,
num_return_sequences: 3,
..Default::default()
};
let gpt2_generator = GPT2Generator::new(generate_config)?;
let tokenizer = gpt2_generator.get_tokenizer();
tokenizer.tokenize("Hello, world!");
fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption
fn half(&mut self) -> Result<(), RustBertError>
fn float(&mut self) -> Result<(), RustBertError>
fn set_device(&mut self, device: Device) -> Result<(), RustBertError>
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.