rust_bert/pipelines/
text_generation.rs

1// Copyright 2020 The Facebook AI Research Team Authors
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Text generation pipeline
15//! Text generation pipeline from a prompt text.
16//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
17//! By default, the dependencies for this model will be downloaded for a GPT2-medium model.
18//! Available architectures for text generation include:
19//! - OpenAI GPT
20//! - OpenAI GPT2
21//! - GPT-Neo
22//! - XLNet
23//! - Reformer
24//!
25//! Two APIs exist to build text generation models:
26//! - `TextGenerationModel` is a high-level module that exposes text generation capabilities with a set of reasonable defaults
27//! - the `LanguageGenerator` trait exposes lower-level text generation capabilities allowing the user to provide additional
28//!     generation options when building the model (via `GenerateConfig`) and at each query (via `GenerateOptions`). Please check the
29//!     [`generation_utils` module](../generation_utils/index.html) for more details
30//!
31//!
32//! Customized text generation models models can be loaded by overwriting the resources in the configuration.
33//! The dependencies will be downloaded to the user's home directory, e.g. under ~/.cache/.rustbert/gpt2
34use tch::{Device, Kind};
35
36use crate::common::error::RustBertError;
37use crate::gpt2::GPT2Generator;
38use crate::gpt_j::GptJGenerator;
39use crate::gpt_neo::GptNeoGenerator;
40use crate::openai_gpt::OpenAIGenerator;
41use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
42use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
43use crate::reformer::ReformerGenerator;
44use crate::resources::ResourceProvider;
45use crate::t5::T5Generator;
46use crate::xlnet::XLNetGenerator;
47
48#[cfg(feature = "onnx")]
49use crate::pipelines::onnx::ONNXCausalGenerator;
50#[cfg(feature = "remote")]
51use crate::{
52    gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
53    resources::RemoteResource,
54};
55
56/// # Configuration for text generation
57/// Contains information regarding the model to load, mirrors the GenerateConfig, with a
58/// different set of default parameters and sets the device to place the model on.
59pub struct TextGenerationConfig {
60    /// Model type
61    pub model_type: ModelType,
62    /// Model weights resource (default: pretrained BART model on CNN-DM)
63    pub model_resource: ModelResource,
64    /// Config resource (default: pretrained BART model on CNN-DM)
65    pub config_resource: Box<dyn ResourceProvider + Send>,
66    /// Vocab resource (default: pretrained BART model on CNN-DM)
67    pub vocab_resource: Box<dyn ResourceProvider + Send>,
68    /// Merges resource (default: pretrained BART model on CNN-DM)
69    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
70    /// Minimum sequence length (default: 0)
71    pub min_length: i64,
72    /// Maximum sequence length (default: 56)
73    pub max_length: Option<i64>,
74    /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
75    pub do_sample: bool,
76    /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
77    pub early_stopping: bool,
78    /// Number of beams for beam search (default: 5)
79    pub num_beams: i64,
80    /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
81    pub temperature: f64,
82    /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
83    pub top_k: i64,
84    /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
85    pub top_p: f64,
86    /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
87    pub repetition_penalty: f64,
88    /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
89    pub length_penalty: f64,
90    /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature and will prevent repeats of n-grams with a length equal or greater to this value (default: 0)
91    pub no_repeat_ngram_size: i64,
92    /// Number of sequences to return for each prompt text (default: 1)
93    pub num_return_sequences: i64,
94    /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
95    pub num_beam_groups: Option<i64>,
96    /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
97    pub diversity_penalty: Option<f64>,
98    /// Device to place the model on (default: CUDA/GPU when available)
99    pub device: Device,
100    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
101    pub kind: Option<Kind>,
102}
103
104impl TextGenerationConfig {
105    /// Instantiate a new text generation configuration of the supplied type.
106    ///
107    /// # Arguments
108    ///
109    /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
110    /// * model_resource - The `ModelResources` pointing to the model to load (e.g.  model.ot)
111    /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
112    /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g.  vocab.txt/vocab.json)
113    /// * merges_resource - The `ResourceProvider`  pointing to the tokenizer's merge file or SentencePiece model to load (e.g.  merges.txt).
114    pub fn new<RC, RV>(
115        model_type: ModelType,
116        model_resource: ModelResource,
117        config_resource: RC,
118        vocab_resource: RV,
119        merges_resource: Option<RV>,
120    ) -> TextGenerationConfig
121    where
122        RC: ResourceProvider + Send + 'static,
123        RV: ResourceProvider + Send + 'static,
124    {
125        TextGenerationConfig {
126            model_type,
127            model_resource,
128            config_resource: Box::new(config_resource),
129            vocab_resource: Box::new(vocab_resource),
130            merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
131            min_length: 0,
132            max_length: Some(56),
133            do_sample: true,
134            early_stopping: true,
135            num_beams: 5,
136            temperature: 1.0,
137            top_k: 0,
138            top_p: 0.9,
139            repetition_penalty: 1.0,
140            length_penalty: 1.0,
141            no_repeat_ngram_size: 0,
142            num_return_sequences: 1,
143            num_beam_groups: None,
144            diversity_penalty: None,
145            device: Device::cuda_if_available(),
146            kind: None,
147        }
148    }
149}
150
151#[cfg(feature = "remote")]
152impl Default for TextGenerationConfig {
153    fn default() -> TextGenerationConfig {
154        TextGenerationConfig::new(
155            ModelType::GPT2,
156            ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
157                Gpt2ModelResources::GPT2_MEDIUM,
158            ))),
159            RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
160            RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
161            Some(RemoteResource::from_pretrained(
162                Gpt2MergesResources::GPT2_MEDIUM,
163            )),
164        )
165    }
166}
167
168impl From<TextGenerationConfig> for GenerateConfig {
169    fn from(config: TextGenerationConfig) -> GenerateConfig {
170        GenerateConfig {
171            model_type: config.model_type,
172            model_resource: config.model_resource,
173            config_resource: config.config_resource,
174            merges_resource: config.merges_resource,
175            vocab_resource: config.vocab_resource,
176            min_length: config.min_length,
177            max_length: config.max_length,
178            do_sample: config.do_sample,
179            early_stopping: config.early_stopping,
180            num_beams: config.num_beams,
181            temperature: config.temperature,
182            top_k: config.top_k,
183            top_p: config.top_p,
184            repetition_penalty: config.repetition_penalty,
185            length_penalty: config.length_penalty,
186            no_repeat_ngram_size: config.no_repeat_ngram_size,
187            num_return_sequences: config.num_return_sequences,
188            num_beam_groups: config.num_beam_groups,
189            diversity_penalty: config.diversity_penalty,
190            device: config.device,
191            kind: config.kind,
192        }
193    }
194}
195
196/// # Abstraction that holds one particular text generation model, for any of the supported models
197pub enum TextGenerationOption {
198    /// Text Generator based on GPT2 model
199    GPT2(GPT2Generator),
200    /// Text Generator based on GPT model
201    GPT(OpenAIGenerator),
202    /// Text Generator based on GPT-Neo model
203    GPTNeo(GptNeoGenerator),
204    /// Text Generator based on GPT-J model
205    GPTJ(GptJGenerator),
206    /// Text Generator based on XLNet model
207    XLNet(XLNetGenerator),
208    /// Text Generator based on Reformer model
209    Reformer(ReformerGenerator),
210    /// Text Generator based on T5 model
211    T5(T5Generator),
212    /// ONNX model for text generation
213    #[cfg(feature = "onnx")]
214    ONNX(ONNXCausalGenerator),
215}
216
217impl TextGenerationOption {
218    pub fn new(config: TextGenerationConfig) -> Result<Self, RustBertError> {
219        match (config.model_type, &config.model_resource) {
220            #[cfg(feature = "onnx")]
221            (_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
222                ONNXCausalGenerator::new(config.into(), None, None)?,
223            )),
224            (ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
225                config.into(),
226            )?)),
227            (ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
228                config.into(),
229            )?)),
230            (ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
231                config.into(),
232            )?)),
233            (ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
234                config.into(),
235            )?)),
236            (ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new(
237                config.into(),
238            )?)),
239            (ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
240                config.into(),
241            )?)),
242            (ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
243            _ => Err(RustBertError::InvalidConfigurationError(format!(
244                "Text generation not implemented for {:?}!",
245                config.model_type
246            ))),
247        }
248    }
249
250    pub fn new_with_tokenizer(
251        config: TextGenerationConfig,
252        tokenizer: TokenizerOption,
253    ) -> Result<Self, RustBertError> {
254        match (config.model_type, &config.model_resource) {
255            #[cfg(feature = "onnx")]
256            (_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
257                ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
258            )),
259            (ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(
260                GPT2Generator::new_with_tokenizer(config.into(), tokenizer)?,
261            )),
262            (ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(
263                OpenAIGenerator::new_with_tokenizer(config.into(), tokenizer)?,
264            )),
265            (ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(
266                XLNetGenerator::new_with_tokenizer(config.into(), tokenizer)?,
267            )),
268            (ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(
269                ReformerGenerator::new_with_tokenizer(config.into(), tokenizer)?,
270            )),
271            (ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(
272                GptNeoGenerator::new_with_tokenizer(config.into(), tokenizer)?,
273            )),
274            (ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(
275                GptJGenerator::new_with_tokenizer(config.into(), tokenizer)?,
276            )),
277            (ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
278                config.into(),
279                tokenizer,
280            )?)),
281            _ => Err(RustBertError::InvalidConfigurationError(format!(
282                "Text generation not implemented for {:?}!",
283                config.model_type
284            ))),
285        }
286    }
287
288    /// Returns the `ModelType` for this TextGenerationOption
289    pub fn model_type(&self) -> ModelType {
290        match *self {
291            Self::GPT(_) => ModelType::OpenAiGpt,
292            Self::GPT2(_) => ModelType::GPT2,
293            Self::GPTNeo(_) => ModelType::GPTNeo,
294            Self::GPTJ(_) => ModelType::GPTJ,
295            Self::XLNet(_) => ModelType::XLNet,
296            Self::Reformer(_) => ModelType::Reformer,
297            Self::T5(_) => ModelType::T5,
298            #[cfg(feature = "onnx")]
299            Self::ONNX(_) => ModelType::ONNX,
300        }
301    }
302    /// Interface method to access tokenizer
303    pub fn get_tokenizer(&self) -> &TokenizerOption {
304        match self {
305            Self::GPT(model_ref) => model_ref.get_tokenizer(),
306            Self::GPT2(model_ref) => model_ref.get_tokenizer(),
307            Self::GPTNeo(model_ref) => model_ref.get_tokenizer(),
308            Self::GPTJ(model_ref) => model_ref.get_tokenizer(),
309            Self::XLNet(model_ref) => model_ref.get_tokenizer(),
310            Self::Reformer(model_ref) => model_ref.get_tokenizer(),
311            Self::T5(model_ref) => model_ref.get_tokenizer(),
312            #[cfg(feature = "onnx")]
313            Self::ONNX(model_ref) => model_ref.get_tokenizer(),
314        }
315    }
316
317    /// Interface method to access tokenizer
318    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
319        match self {
320            Self::GPT(model_ref) => model_ref.get_tokenizer_mut(),
321            Self::GPT2(model_ref) => model_ref.get_tokenizer_mut(),
322            Self::GPTNeo(model_ref) => model_ref.get_tokenizer_mut(),
323            Self::GPTJ(model_ref) => model_ref.get_tokenizer_mut(),
324            Self::XLNet(model_ref) => model_ref.get_tokenizer_mut(),
325            Self::Reformer(model_ref) => model_ref.get_tokenizer_mut(),
326            Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
327            #[cfg(feature = "onnx")]
328            Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
329        }
330    }
331
332    /// Interface method to generate() of the particular models.
333    pub fn generate_indices<S>(
334        &self,
335        prompt_texts: Option<&[S]>,
336        min_length: Option<i64>,
337        max_length: Option<i64>,
338    ) -> Result<Vec<Vec<i64>>, RustBertError>
339    where
340        S: AsRef<str> + Send + Sync,
341    {
342        let generate_options = Some(GenerateOptions {
343            min_length,
344            max_length,
345            ..Default::default()
346        });
347        Ok(match *self {
348            Self::GPT(ref model) => model
349                .generate_indices(prompt_texts, generate_options)?
350                .into_iter()
351                .map(|output| output.indices)
352                .collect(),
353            Self::GPT2(ref model) => model
354                .generate_indices(prompt_texts, generate_options)?
355                .into_iter()
356                .map(|output| output.indices)
357                .collect(),
358            Self::GPTNeo(ref model) => model
359                .generate_indices(prompt_texts, generate_options)?
360                .into_iter()
361                .map(|output| output.indices)
362                .collect(),
363            Self::GPTJ(ref model) => model
364                .generate_indices(prompt_texts, generate_options)?
365                .into_iter()
366                .map(|output| output.indices)
367                .collect(),
368            Self::XLNet(ref model) => model
369                .generate_indices(prompt_texts, generate_options)?
370                .into_iter()
371                .map(|output| output.indices)
372                .collect(),
373            Self::Reformer(ref model) => model
374                .generate_indices(prompt_texts, generate_options)?
375                .into_iter()
376                .map(|output| output.indices)
377                .collect(),
378            Self::T5(ref model) => model
379                .generate_indices(prompt_texts, generate_options)?
380                .into_iter()
381                .map(|output| output.indices)
382                .collect(),
383            #[cfg(feature = "onnx")]
384            Self::ONNX(ref model) => model
385                .generate_indices(prompt_texts, generate_options)?
386                .into_iter()
387                .map(|output| output.indices)
388                .collect(),
389        })
390    }
391
392    pub fn half(&mut self) -> Result<(), RustBertError> {
393        match self {
394            Self::GPT(model_ref) => model_ref.half(),
395            Self::GPT2(model_ref) => model_ref.half(),
396            Self::GPTNeo(model_ref) => model_ref.half(),
397            Self::GPTJ(model_ref) => model_ref.half(),
398            Self::XLNet(model_ref) => model_ref.half(),
399            Self::Reformer(model_ref) => model_ref.half(),
400            Self::T5(model_ref) => model_ref.half(),
401            #[cfg(feature = "onnx")]
402            Self::ONNX(_) => Err(RustBertError::OrtError(
403                "Type casting not supported for ONNX models.".to_string(),
404            )),
405        }
406    }
407
408    pub fn float(&mut self) -> Result<(), RustBertError> {
409        match self {
410            Self::GPT(model_ref) => model_ref.float(),
411            Self::GPT2(model_ref) => model_ref.float(),
412            Self::GPTNeo(model_ref) => model_ref.float(),
413            Self::GPTJ(model_ref) => model_ref.float(),
414            Self::XLNet(model_ref) => model_ref.float(),
415            Self::Reformer(model_ref) => model_ref.float(),
416            Self::T5(model_ref) => model_ref.float(),
417            #[cfg(feature = "onnx")]
418            Self::ONNX(_) => Err(RustBertError::OrtError(
419                "Type casting not supported for ONNX models.".to_string(),
420            )),
421        }
422    }
423
424    pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
425        match self {
426            Self::GPT(model_ref) => model_ref.set_device(device),
427            Self::GPT2(model_ref) => model_ref.set_device(device),
428            Self::GPTNeo(model_ref) => model_ref.set_device(device),
429            Self::GPTJ(model_ref) => model_ref.set_device(device),
430            Self::XLNet(model_ref) => model_ref.set_device(device),
431            Self::Reformer(model_ref) => model_ref.set_device(device),
432            Self::T5(model_ref) => model_ref.set_device(device),
433            #[cfg(feature = "onnx")]
434            Self::ONNX(_) => Err(RustBertError::OrtError(
435                "Device assignment not supported for ONNX models.".to_string(),
436            )),
437        }
438    }
439}
440
441/// # TextGenerationModel to generate texts from a prompt
442pub struct TextGenerationModel {
443    model: TextGenerationOption,
444    prefix: Option<String>,
445    prefix_length: Option<i64>,
446    min_length: i64,
447    max_length: Option<i64>,
448}
449
450impl TextGenerationModel {
451    /// Build a new `TextGenerationModel`
452    ///
453    /// # Arguments
454    ///
455    /// * `generation_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
456    ///
457    /// # Example
458    ///
459    /// ```no_run
460    /// # fn main() -> anyhow::Result<()> {
461    /// use rust_bert::pipelines::common::ModelType;
462    /// use rust_bert::pipelines::text_generation::TextGenerationModel;
463    ///
464    /// let generation_model = TextGenerationModel::new(Default::default())?;
465    /// # Ok(())
466    /// # }
467    /// ```
468    pub fn new(
469        generation_config: TextGenerationConfig,
470    ) -> Result<TextGenerationModel, RustBertError> {
471        let (prefix, min_length, max_length) =
472            TextGenerationModel::get_prefix_min_max_length(&generation_config);
473        let model = TextGenerationOption::new(generation_config)?;
474        let prefix_length = prefix
475            .as_ref()
476            .map(|prefix| model.get_tokenizer().tokenize(prefix).len() as i64);
477        Ok(TextGenerationModel {
478            model,
479            prefix,
480            prefix_length,
481            min_length,
482            max_length,
483        })
484    }
485
486    /// Build a new `TextGenerationModel` with a given tokenizer
487    ///
488    /// # Arguments
489    ///
490    /// * `generation_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
491    /// * `tokenizer` - `TokenizerOption` tokenizer to use for text generation
492    ///
493    /// # Example
494    ///
495    /// ```no_run
496    /// # fn main() -> anyhow::Result<()> {
497    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
498    /// use rust_bert::pipelines::text_generation::TextGenerationModel;
499    ///
500    /// let tokenizer = TokenizerOption::from_file(
501    ///     ModelType::GPT2,
502    ///     "path/to/vocab.json",
503    ///     Some("path/to/merges.txt"),
504    ///     false,
505    ///     None,
506    ///     None,
507    /// )?;
508    /// let generation_model = TextGenerationModel::new_with_tokenizer(Default::default(), tokenizer)?;
509    /// # Ok(())
510    /// # }
511    /// ```
512    pub fn new_with_tokenizer(
513        generation_config: TextGenerationConfig,
514        tokenizer: TokenizerOption,
515    ) -> Result<TextGenerationModel, RustBertError> {
516        let (prefix, min_length, max_length) =
517            TextGenerationModel::get_prefix_min_max_length(&generation_config);
518        let model = TextGenerationOption::new_with_tokenizer(generation_config, tokenizer)?;
519        let prefix_length = prefix
520            .as_ref()
521            .map(|prefix| model.get_tokenizer().tokenize(prefix).len() as i64);
522        Ok(TextGenerationModel {
523            model,
524            prefix,
525            prefix_length,
526            min_length,
527            max_length,
528        })
529    }
530
531    fn get_prefix_min_max_length(
532        generation_config: &TextGenerationConfig,
533    ) -> (Option<String>, i64, Option<i64>) {
534        let prefix = match generation_config.model_type {
535            ModelType::XLNet => Some(
536                "In 1991, the remains of Russian Tsar Nicholas II and his family \
537(except for Alexei and Maria) are discovered. \
538The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the \
539remainder of the story. 1883 Western Siberia, \
540a young Grigori Rasputin is asked by his father and a group of men to perform magic. \
541Rasputin has a vision and denounces one of the men as a horse thief. Although his \
542father initially slaps him for making such an accusation, Rasputin watches as the \
543man is chased outside and beaten. Twenty years later, Rasputin sees a vision of \
544the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, \
545with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
546                    .to_string(),
547            ),
548            _ => None,
549        };
550
551        let min_length = generation_config.min_length;
552        let max_length = generation_config.max_length;
553        (prefix, min_length, max_length)
554    }
555
556    pub fn get_tokenizer(&self) -> &TokenizerOption {
557        self.model.get_tokenizer()
558    }
559
560    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
561        self.model.get_tokenizer_mut()
562    }
563
564    pub fn half(&mut self) -> Result<(), RustBertError> {
565        self.model.half()
566    }
567
568    pub fn float(&mut self) -> Result<(), RustBertError> {
569        self.model.float()
570    }
571
572    pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
573        self.model.set_device(device)
574    }
575
576    /// Generate texts from provided prompts
577    ///
578    /// # Arguments
579    ///
580    /// * `input` - `&[&str]` Array of texts to summarize.
581    /// * `prefix` - `impl Into<Option<&'a str>>`: Optional string to pass as a prefix for generation. Will be excluded from generated sequences.
582    ///
583    /// # Returns
584    /// * `Vec<String>` Generated texts
585    ///
586    /// # Example
587    ///
588    /// ```no_run
589    /// # fn main() -> anyhow::Result<()> {
590    /// use rust_bert::pipelines::common::ModelType;
591    /// use rust_bert::pipelines::text_generation::TextGenerationModel;
592    ///
593    /// let model = TextGenerationModel::new(Default::default())?;
594    ///
595    /// let input = ["The dog", "The cat was"];
596    /// let prefix = None;
597    ///
598    /// let output = model.generate(&input, prefix);
599    /// # Ok(())
600    /// # }
601    /// ```
602    pub fn generate<'a, S>(
603        &self,
604        texts: &[S],
605        prefix: impl Into<Option<&'a str>>,
606    ) -> Result<Vec<String>, RustBertError>
607    where
608        S: AsRef<str> + Send + Sync,
609    {
610        let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
611            (Some(query_prefix), _) => (
612                Some(query_prefix),
613                Some(self.model.get_tokenizer().tokenize(query_prefix).len() as i64),
614            ),
615            (None, Some(pipeline_prefix)) => (Some(pipeline_prefix.as_str()), self.prefix_length),
616            (None, None) => (None, None),
617        };
618        let generated_indices = match (prefix, prefix_length) {
619            (None, _) => self.model.generate_indices(Some(texts), None, None),
620            (Some(prefix), Some(prefix_length)) => {
621                let texts = texts
622                    .as_ref()
623                    .iter()
624                    .map(|text| format!("{} {}", prefix, text.as_ref()))
625                    .collect::<Vec<String>>();
626                self.model.generate_indices(
627                    Some(&texts),
628                    Some(self.min_length + prefix_length),
629                    self.max_length.map(|max_length| max_length + prefix_length),
630                )
631            }
632            _ => Err(RustBertError::ValueError(
633                "Prefix length not defined but prefix provided!".to_string(),
634            )),
635        }?;
636
637        let mut output = Vec::with_capacity(generated_indices.len());
638        for generated_sequence in generated_indices {
639            output.push(self.model.get_tokenizer().decode(
640                &generated_sequence[prefix_length.unwrap_or(0) as usize..],
641                true,
642                true,
643            ));
644        }
645        Ok(output)
646    }
647}
648
649#[cfg(test)]
650mod test {
651    use super::*;
652
653    #[test]
654    #[ignore] // no need to run, compilation is enough to verify it is Send
655    fn test() {
656        let config = TextGenerationConfig::default();
657        let _: Box<dyn Send> = Box::new(TextGenerationModel::new(config));
658    }
659}