rust_bert/pipelines/
generation_utils.rs

1// Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors.
2// Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4// Copyright 2019 Guillaume Becquin
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//     http://www.apache.org/licenses/LICENSE-2.0
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # Natural Language Generation utilities
16//! Set of text generation utilities, serving as a basis for TextGenerationModel, SummarizationModels and TranslationModels.
17//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
18//! Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
19//! This may impact the results and it is recommended to submit prompts of similar length for best results.
20//!
21//! ```no_run
22//! # fn main() -> anyhow::Result<()> {
23//! use rust_bert::gpt2::GPT2Generator;
24//! use rust_bert::pipelines::generation_utils::{
25//!     GenerateConfig, GenerateOptions, LanguageGenerator,
26//! };
27//!
28//! let generate_config = GenerateConfig {
29//!     do_sample: true,
30//!     num_beams: 5,
31//!     temperature: 1.1,
32//!     num_return_sequences: 3,
33//!     ..Default::default()
34//! };
35//! let mut gpt2_generator = GPT2Generator::new(generate_config)?;
36//!
37//! let input_context = "The dog";
38//! let second_input_context = "The cat was";
39//!
40//! let generate_options = GenerateOptions {
41//!     min_length: Some(32),
42//!     max_length: Some(128),
43//!     output_scores: true,
44//!     ..Default::default()
45//! };
46//!
47//! let output = gpt2_generator.generate(
48//!     Some(&[input_context, second_input_context]),
49//!     Some(generate_options),
50//! );
51//! # Ok(())
52//! # }
53//! ```
54//!
55//! Example output: \
56//! ```no_run
57//! # let output =
58//! [
59//!     "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
60//!     "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
61//!     "The dog has been able to stay in the home for more than three months now. \"It's a very good dog. She's",
62//!     "The cat was discovered earlier this month in the home of a relative of the deceased. The cat\'s owner, who wished to remain anonymous,",
63//!     "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
64//!     "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
65//! ]
66//! # ;
67//! ```
68
69use tch::kind::Kind::Int64;
70use tch::{no_grad, Device, Kind, Tensor};
71
72use crate::bart::LayerState as BartLayerState;
73use crate::common::resources::ResourceProvider;
74use crate::gpt_j::LayerState as GPTJLayerState;
75use crate::gpt_neo::LayerState as GPTNeoLayerState;
76use crate::pipelines::generation_utils::private_generation_utils::{
77    InternalGenerateOptions, PrivateLanguageGenerator,
78};
79use crate::prophetnet::LayerState as ProphetNetLayerState;
80use crate::reformer::LayerState as ReformerLayerState;
81use crate::t5::LayerState as T5LayerState;
82use crate::xlnet::LayerState as XLNetLayerState;
83
84use self::ordered_float::OrderedFloat;
85use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
86
87extern crate ordered_float;
88#[cfg(feature = "onnx")]
89use crate::pipelines::onnx::ONNXLayerCache;
90use crate::RustBertError;
91#[cfg(feature = "remote")]
92use crate::{
93    gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
94    resources::RemoteResource,
95};
96
97/// # Configuration for text generation
98pub struct GenerateConfig {
99    /// Model type used for generation
100    pub model_type: ModelType,
101    /// Model weights resource (default: pretrained GPT2 model)
102    pub model_resource: ModelResource,
103    /// Config resource (default: pretrained GPT2 model)
104    pub config_resource: Box<dyn ResourceProvider + Send>,
105    /// Vocab resource (default: pretrained GPT2 model)
106    pub vocab_resource: Box<dyn ResourceProvider + Send>,
107    /// Merges resource (default: pretrained GPT2 model)
108    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
109    /// Minimum sequence length (default: 0)
110    pub min_length: i64,
111    /// Maximum sequence length (default: 20)
112    pub max_length: Option<i64>,
113    /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
114    pub do_sample: bool,
115    /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
116    pub early_stopping: bool,
117    /// Number of beams for beam search (default: 5)
118    pub num_beams: i64,
119    /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
120    pub temperature: f64,
121    /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
122    pub top_k: i64,
123    /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
124    pub top_p: f64,
125    /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
126    pub repetition_penalty: f64,
127    /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
128    pub length_penalty: f64,
129    /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
130    pub no_repeat_ngram_size: i64,
131    /// Number of sequences to return for each prompt text (default: 1)
132    pub num_return_sequences: i64,
133    /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
134    pub num_beam_groups: Option<i64>,
135    /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
136    pub diversity_penalty: Option<f64>,
137    /// Device to place the model on (default: CUDA/GPU when available)
138    pub device: Device,
139    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
140    pub kind: Option<Kind>,
141}
142
143#[cfg(feature = "remote")]
144impl Default for GenerateConfig {
145    fn default() -> GenerateConfig {
146        GenerateConfig {
147            model_type: ModelType::GPT2,
148            model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
149                Gpt2ModelResources::GPT2,
150            ))),
151            config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
152            vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
153            merges_resource: Some(Box::new(RemoteResource::from_pretrained(
154                Gpt2MergesResources::GPT2,
155            ))),
156            min_length: 0,
157            max_length: Some(56),
158            do_sample: true,
159            early_stopping: true,
160            num_beams: 5,
161            temperature: 1.0,
162            top_k: 0,
163            top_p: 0.9,
164            repetition_penalty: 1.0,
165            length_penalty: 1.0,
166            no_repeat_ngram_size: 3,
167            num_return_sequences: 1,
168            num_beam_groups: None,
169            diversity_penalty: None,
170            device: Device::cuda_if_available(),
171            kind: None,
172        }
173    }
174}
175
176impl GenerateConfig {
177    pub(crate) fn validate(&self) {
178        assert!(self.temperature > 0f64, "temperature must positive");
179        assert!(
180            (self.top_p >= 0f64) & (self.top_p <= 1f64),
181            "top_p must be 0 and 1"
182        );
183        assert!(
184            self.repetition_penalty >= 1f64,
185            "repetition_penalty must be greater than 1"
186        );
187        assert!(
188            self.length_penalty > 0f64,
189            "length_penalty must be strictly greater than 0"
190        );
191        assert!(
192            self.num_return_sequences > 0i64,
193            "num_return_sequences must be strictly greater than 0"
194        );
195        assert!(
196            self.num_beams > 0i64,
197            "num_beams must be strictly greater than 0"
198        );
199
200        if !self.do_sample {
201            if self.num_beams == 1 {
202                assert_eq!(
203                    self.num_return_sequences, 1,
204                    "num_return_sequences must be set to 1 for greedy decoding"
205                )
206            } else {
207                assert!(
208                    self.num_beams >= self.num_return_sequences,
209                    "num_return_sequences must be lower than the number of beams"
210                )
211            }
212        }
213        if let Some(num_beam_groups_value) = self.num_beam_groups {
214            if num_beam_groups_value > 1 {
215                assert_eq!(
216                    self.num_beams % num_beam_groups_value,
217                    0,
218                    "num_beam_groups must be a multiple of num_beam_groups"
219                )
220            }
221        }
222    }
223}
224
225#[derive(Debug)]
226pub enum Cache {
227    GPT2Cache(Option<Vec<Tensor>>),
228    BARTCache(Option<Vec<(Option<BartLayerState>, Option<BartLayerState>)>>),
229    T5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
230    LongT5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
231    XLNetCache(Option<Vec<Option<XLNetLayerState>>>),
232    ReformerCache(Option<Vec<Option<ReformerLayerState>>>),
233    ProphetNetCache(Option<Vec<(Option<ProphetNetLayerState>, Option<ProphetNetLayerState>)>>),
234    GPTNeoCache(Option<Vec<Option<GPTNeoLayerState>>>),
235    GPTJCache(Option<Vec<Option<GPTJLayerState>>>),
236    #[cfg(feature = "onnx")]
237    ONNXCache(ONNXLayerCache),
238    None,
239}
240
241pub(crate) mod private_generation_utils {
242    use rust_tokenizers::TokenIdsWithOffsets;
243    use std::cmp::{max, min};
244    use std::collections::HashMap;
245    use std::convert::TryFrom;
246    use std::mem;
247
248    use rust_tokenizers::tokenizer::{truncate_sequences, TruncationStrategy};
249    use tch::{nn, Device, Kind, Tensor};
250
251    use crate::pipelines::common::TokenizerOption;
252    use crate::pipelines::generation_utils::{
253        BeamHypotheses, Cache, GenerateConfig, LMModelOutput, PrefixAllowedFunction,
254    };
255
256    use super::ordered_float::OrderedFloat;
257    use crate::common::kind::{get_negative_infinity, get_positive_infinity};
258    use crate::RustBertError;
259
260    pub struct InternalGenerateOptions<'a> {
261        pub min_length: i64,
262        pub max_length: Option<i64>,
263        pub do_sample: bool,
264        pub temperature: f64,
265        pub top_k: i64,
266        pub top_p: f64,
267        pub repetition_penalty: f64,
268        pub no_repeat_ngram_size: i64,
269        pub pad_token_id: Option<i64>,
270        pub eos_token_ids: Option<Vec<i64>>,
271        pub num_return_sequences: i64,
272        pub early_stopping: bool,
273        pub num_beams: i64,
274        pub length_penalty: f64,
275        pub num_beam_groups: Option<i64>,
276        pub diversity_penalty: Option<f64>,
277        pub forced_bos_token_id: Option<i64>,
278        pub bad_word_ids: Option<&'a Vec<Vec<i64>>>,
279    }
280
281    pub struct PreparedInput<'a> {
282        pub prepared_input: Option<Tensor>,
283        pub prepared_attention_mask: Option<Tensor>,
284        pub prepared_encoder_output: Option<&'a Tensor>,
285        pub prepared_decoder_input: Option<Tensor>,
286        pub prepared_position_ids: Option<Tensor>,
287        pub prepared_past: Cache,
288    }
289
290    pub struct GeneratedOutputWithScores {
291        pub indices: Tensor,
292        pub scores: Option<Vec<f64>>,
293        pub token_scores: Option<Vec<Vec<f64>>>,
294    }
295
296    pub trait PrivateLanguageGenerator {
297        fn _get_tokenizer(&self) -> &TokenizerOption;
298        fn get_device(&self) -> Device;
299        fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError>;
300        fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption;
301        fn get_config(&self) -> &GenerateConfig;
302        fn get_bos_id(&self) -> Option<i64>;
303        fn get_eos_ids(&self) -> Option<&Vec<i64>>;
304        fn get_forced_bos_token_id(&self) -> Option<i64> {
305            None
306        }
307        fn get_forced_eos_token_id(&self) -> Option<i64> {
308            None
309        }
310        fn get_pad_id(&self) -> Option<i64>;
311        fn is_encoder_decoder(&self) -> bool;
312        fn get_vocab_size(&self) -> i64;
313        fn get_decoder_start_id(&self) -> Option<i64>;
314        fn get_max_positions_embeddings(&self) -> Option<i64>;
315
316        fn forward_t(
317            &self,
318            input_ids: Option<&Tensor>,
319            layer_past: Cache,
320            attention_mask: Option<&Tensor>,
321            token_type_ids: Option<&Tensor>,
322            position_ids: Option<&Tensor>,
323            input_embeds: Option<&Tensor>,
324            encoder_outputs: Option<&Tensor>,
325            decoder_input_ids: Option<&Tensor>,
326            train: bool,
327        ) -> Result<LMModelOutput, RustBertError>;
328
329        fn prepare_scores_for_generation(
330            &self,
331            scores: &mut Tensor,
332            current_length: i64,
333            max_length: Option<i64>,
334            forced_bos_token_id: Option<i64>,
335        ) {
336            if current_length == 1 {
337                if let Some(forced_bos_token_id) =
338                    forced_bos_token_id.or(self.get_forced_bos_token_id())
339                {
340                    force_token_id_generation(
341                        scores,
342                        &[forced_bos_token_id],
343                        self.get_vocab_size(),
344                    );
345                }
346            } else if let Some(max_length) = max_length {
347                if let Some(forced_eos_token_id) = self.get_forced_eos_token_id() {
348                    if current_length == max_length - 1 {
349                        force_token_id_generation(
350                            scores,
351                            &[forced_eos_token_id],
352                            self.get_vocab_size(),
353                        );
354                    }
355                }
356            }
357        }
358
359        fn encode(&self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> {
360            None
361        }
362
363        fn prepare_inputs_for_generation<'a>(
364            &self,
365            input_ids: Tensor,
366            _encoder_outputs: Option<&'a Tensor>,
367            past: Cache,
368            attention_mask: Tensor,
369        ) -> PreparedInput<'a> {
370            PreparedInput {
371                prepared_input: Some(input_ids),
372                prepared_attention_mask: Some(attention_mask),
373                prepared_encoder_output: None,
374                prepared_decoder_input: None,
375                prepared_position_ids: None,
376                prepared_past: past,
377            }
378        }
379
380        fn encode_prompt_text<S>(
381            &self,
382            prompt_text: &[S],
383            max_len: Option<i64>,
384            pad_token_id: Option<i64>,
385        ) -> Tensor
386        where
387            S: AsRef<str> + Send + Sync,
388        {
389            let token_ids = if self.is_encoder_decoder() {
390                let tokens = self._get_tokenizer().encode_list(
391                    prompt_text,
392                    max_len
393                        .map(|max_len| max_len as usize)
394                        .unwrap_or(usize::MAX),
395                    &TruncationStrategy::LongestFirst,
396                    0,
397                );
398                tokens
399                    .into_iter()
400                    .map(|tokenized_input| tokenized_input.token_ids)
401                    .collect::<Vec<Vec<i64>>>()
402            } else {
403                // Special tokens (e.g. BOS) are not added at the end of the prompt for causal generation
404                let tokens = self._get_tokenizer().tokenize_list(prompt_text);
405                let token_ids = tokens
406                    .into_iter()
407                    .map(|prompt_tokens| {
408                        self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens)
409                    })
410                    .collect::<Vec<Vec<i64>>>();
411
412                let num_truncated_tokens = token_ids
413                    .iter()
414                    .map(|token_ids| {
415                        max_len
416                            .map(|max_len| {
417                                if token_ids.len() > max_len as usize {
418                                    token_ids.len() - max_len as usize
419                                } else {
420                                    0
421                                }
422                            })
423                            .unwrap_or(0)
424                    })
425                    .collect::<Vec<usize>>();
426
427                token_ids
428                    .into_iter()
429                    .zip(num_truncated_tokens)
430                    .map(|(tokens, num_truncated_tokens)| {
431                        truncate_sequences(
432                            TokenIdsWithOffsets {
433                                ids: tokens,
434                                offsets: vec![],
435                                reference_offsets: vec![],
436                                masks: vec![],
437                            },
438                            None,
439                            num_truncated_tokens,
440                            &TruncationStrategy::LongestFirst,
441                            0,
442                        )
443                        .unwrap()
444                        .0
445                        .ids
446                    })
447                    .collect::<Vec<Vec<i64>>>()
448            };
449
450            let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
451
452            let pad_token = match pad_token_id {
453                Some(value) => value,
454                None => self._get_tokenizer().get_unk_id(),
455            };
456
457            let token_ids = token_ids
458                .into_iter()
459                .map(|mut input| {
460                    let mut temp = vec![pad_token; max_len - input.len()];
461                    if self.is_encoder_decoder() {
462                        input.extend(temp);
463                        input
464                    } else {
465                        // Pad left for causal generation
466                        temp.extend(input);
467                        temp
468                    }
469                })
470                .map(|tokens| Tensor::from_slice(&tokens).to(self.get_device()))
471                .collect::<Vec<Tensor>>();
472
473            Tensor::stack(&token_ids, 0)
474        }
475
476        fn enforce_repetition_penalty(
477            &self,
478            next_token_logits: &mut Tensor,
479            batch_size: i64,
480            num_beams: i64,
481            prev_output_tokens: &Tensor,
482            repetition_penalty: f64,
483        ) {
484            for i in 0..(batch_size * num_beams) {
485                for token_position in 0..prev_output_tokens.get(i).size()[0] {
486                    let token = prev_output_tokens.get(i).int64_value(&[token_position]);
487                    let updated_value = &next_token_logits.double_value(&[i, token]);
488                    if updated_value < &0f64 {
489                        let _ = next_token_logits.get(i).index_fill_(
490                            0,
491                            &Tensor::from_slice(&[token])
492                                .to_kind(Kind::Int64)
493                                .to_device(next_token_logits.device()),
494                            updated_value * repetition_penalty,
495                        );
496                    } else {
497                        let _ = next_token_logits.get(i).index_fill_(
498                            0,
499                            &Tensor::from_slice(&[token])
500                                .to_kind(Kind::Int64)
501                                .to_device(next_token_logits.device()),
502                            updated_value / repetition_penalty,
503                        );
504                    }
505                }
506            }
507        }
508
509        fn get_banned_tokens(
510            &self,
511            input_ids: &Tensor,
512            no_repeat_ngram_size: i64,
513            cur_len: i64,
514        ) -> Vec<Vec<i64>> {
515            //        Ported from hugging face's transformers and fairseq (https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py)
516            if cur_len + 1 < no_repeat_ngram_size {
517                vec![vec![]]
518            } else {
519                let input_ids = input_ids.to(Device::Cpu);
520                let num_hypothesis = *input_ids.size().first().unwrap();
521                let mut banned_tokens: Vec<Vec<i64>> = Vec::with_capacity(num_hypothesis as usize);
522                for hypothesis_index in 0..num_hypothesis {
523                    let hypothesis_input_ids = input_ids.get(hypothesis_index);
524                    let mut generated_ngram: HashMap<Vec<i64>, Vec<i64>> = HashMap::new();
525                    let input: Vec<i64> = (0..hypothesis_input_ids.size1().unwrap()).collect();
526                    let hypothesis_input_ids = hypothesis_input_ids
527                        .iter::<i64>()
528                        .unwrap()
529                        .collect::<Vec<i64>>();
530                    let query = &hypothesis_input_ids
531                        [cur_len as usize + 1 - no_repeat_ngram_size as usize..]
532                        .to_vec();
533                    for ngram in input
534                        .windows(no_repeat_ngram_size as usize)
535                        .map(|win| (*win.first().unwrap(), *win.last().unwrap()))
536                    {
537                        let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
538                        let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
539                        let value = *ngram.last().unwrap();
540                        generated_ngram
541                            .entry(key)
542                            .or_insert_with(|| vec![value])
543                            .push(value);
544                    }
545                    let hypothesis_banned_tokens = match generated_ngram.get(query) {
546                        Some(banned_tokens) => banned_tokens.clone(),
547                        None => vec![],
548                    };
549                    banned_tokens.push(hypothesis_banned_tokens);
550                }
551                banned_tokens
552            }
553        }
554
555        fn top_k_top_p_filtering(
556            &self,
557            logits: &mut Tensor,
558            top_k: i64,
559            top_p: f64,
560            min_tokens_to_keep: i64,
561        ) {
562            //        Nucleus and top-k filtering introduced by Holtzman et al. (http://arxiv.org/abs/1904.09751)
563            //        Ported from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
564            let vocab_size = *logits.size().last().unwrap();
565            if top_k > 0 {
566                let top_k = vocab_size - min(max(top_k, min_tokens_to_keep), vocab_size);
567                let (_, indices_to_remove) = logits.topk(top_k, -1, false, false);
568                for index in 0..*logits.size().first().unwrap() {
569                    let _ = logits.get(index).index_fill_(
570                        0,
571                        &indices_to_remove.get(index),
572                        f64::NEG_INFINITY,
573                    );
574                }
575            }
576            if top_p < 1f64 {
577                let (sorted_logits, sorted_indices) = logits.sort(-1, true);
578                let cumulative_probabilities = sorted_logits
579                    .softmax(-1, sorted_logits.kind())
580                    .cumsum(-1, sorted_logits.kind());
581                let mut sorted_indices_to_remove =
582                    cumulative_probabilities.ge(top_p).to_kind(Kind::Int64);
583                if min_tokens_to_keep > 1 {
584                    let _ = sorted_indices_to_remove.index_fill_(
585                        1,
586                        &Tensor::arange_start(
587                            0,
588                            min_tokens_to_keep + 1,
589                            (Kind::Int64, logits.device()),
590                        ),
591                        0,
592                    );
593                }
594                let _ = sorted_indices_to_remove.index_copy_(
595                    1,
596                    &Tensor::arange_start(1, vocab_size, (Kind::Int64, logits.device())),
597                    &sorted_indices_to_remove
598                        .slice(1, 0, vocab_size - 1, 1)
599                        .copy(),
600                );
601                let _ = sorted_indices_to_remove.index_fill_(
602                    1,
603                    &Tensor::from_slice(&[0])
604                        .to_kind(Kind::Int64)
605                        .to_device(sorted_indices_to_remove.device()),
606                    0,
607                );
608                let indices_to_remove = sorted_indices_to_remove
609                    .scatter(1, &sorted_indices, &sorted_indices_to_remove)
610                    .to_kind(Kind::Bool);
611                let _ = logits.masked_fill_(&indices_to_remove, f64::NEG_INFINITY);
612            }
613        }
614
615        fn run_hamming_diversity_penalty(
616            &self,
617            scores: &mut Tensor,
618            current_tokens: &Tensor,
619            diversity_penalty: f64,
620            num_beams: i64,
621            batch_size: i64,
622            group_size: i64,
623            group_start_index: i64,
624        ) {
625            if group_start_index > 0 {
626                let vocab_size = *scores.size().last().unwrap();
627                for batch_index in 0..batch_size {
628                    let previous_group_tokens = current_tokens.slice(
629                        0,
630                        batch_index * num_beams,
631                        batch_index * num_beams + group_start_index,
632                        1,
633                    );
634                    let diversity_penalty = previous_group_tokens
635                        .bincount::<Tensor>(None, vocab_size)
636                        * diversity_penalty;
637                    let _ = scores
638                        .slice(
639                            0,
640                            batch_index * group_size,
641                            (batch_index + 1) * group_size,
642                            1,
643                        )
644                        .subtract_(&diversity_penalty);
645                }
646            }
647        }
648
649        fn apply_prefix_allowed_tokens_function(
650            &self,
651            prefix_allowed_tokens_fn: &dyn Fn(i64, &Tensor) -> Vec<i64>,
652            num_beams: i64,
653            input_ids: &Tensor,
654            scores: &mut Tensor,
655        ) {
656            let mask = scores.new_full(
657                scores.size().as_slice(),
658                get_positive_infinity(scores.kind()).unwrap(),
659                (scores.kind(), scores.device()),
660            );
661            for idx in 0..scores.size()[0] {
662                let batch_id = idx / num_beams;
663                let allowed_tokens: Vec<i64> =
664                    prefix_allowed_tokens_fn(batch_id, &input_ids.get(idx));
665                let _ = mask.get(idx).index_fill_(
666                    0,
667                    &Tensor::from_slice(allowed_tokens.as_slice()).to(scores.device()),
668                    0,
669                );
670            }
671            let _ = scores.subtract_(&mask);
672        }
673
674        fn split_bad_word_ids<'a>(
675            &self,
676            bad_word_ids: Option<&'a Vec<Vec<i64>>>,
677        ) -> (Option<Vec<i64>>, Option<Vec<&'a Vec<i64>>>) {
678            if let Some(bad_word_ids) = bad_word_ids {
679                let mut bad_word_ids_length_1 = vec![];
680                let mut bad_word_ids_length_greater_than_1 = vec![];
681                for bad_word in bad_word_ids {
682                    if bad_word.len() == 1 {
683                        bad_word_ids_length_1.push(bad_word[0]);
684                    } else {
685                        bad_word_ids_length_greater_than_1.push(bad_word);
686                    }
687                }
688                let bad_word_ids_length_1 = if !bad_word_ids_length_1.is_empty() {
689                    Some(bad_word_ids_length_1)
690                } else {
691                    None
692                };
693                let bad_word_ids_length_greater_than_1 =
694                    if !bad_word_ids_length_greater_than_1.is_empty() {
695                        Some(bad_word_ids_length_greater_than_1)
696                    } else {
697                        None
698                    };
699                (bad_word_ids_length_1, bad_word_ids_length_greater_than_1)
700            } else {
701                (None, None)
702            }
703        }
704
705        fn tokens_match(&self, prev_tokens: &[i64], tokens: &[i64]) -> bool {
706            if tokens.is_empty() {
707                true
708            } else if tokens.len() > prev_tokens.len() {
709                false
710            } else {
711                &prev_tokens[prev_tokens.len() - tokens.len()..] == tokens
712            }
713        }
714
715        fn calc_static_bad_word_mask(
716            &self,
717            scores: &Tensor,
718            bad_words_id_length_1: &[i64],
719        ) -> Tensor {
720            let mut static_bad_words_mask =
721                Tensor::zeros([scores.size()[1]], (Kind::Int8, scores.device()));
722            let _ = static_bad_words_mask.index_fill_(
723                0,
724                &Tensor::from_slice(bad_words_id_length_1).to_device(scores.device()),
725                1,
726            );
727            static_bad_words_mask.unsqueeze(0).totype(Kind::Bool)
728        }
729
730        fn get_dynamic_bad_word_ids(
731            &self,
732            prev_tokens: &[Vec<i64>],
733            bad_word_ids_length_greater_than_1: &[&Vec<i64>],
734        ) -> Vec<Vec<i64>> {
735            let mut banned_tokens = Vec::new();
736            for prev_token_sequence in prev_tokens {
737                let mut sequence_banned_tokens = Vec::new();
738                for bad_word_ids in bad_word_ids_length_greater_than_1 {
739                    if self
740                        .tokens_match(prev_token_sequence, &bad_word_ids[..bad_word_ids.len() - 1])
741                    {
742                        sequence_banned_tokens.push(*bad_word_ids.last().unwrap());
743                    }
744                }
745                banned_tokens.push(sequence_banned_tokens);
746            }
747
748            banned_tokens
749        }
750
751        fn ban_bad_words(
752            &self,
753            dynamic_bad_words: Option<&Vec<&Vec<i64>>>,
754            static_bad_words_mask: Option<&Tensor>,
755            token_ids: &Tensor,
756            scores: &mut Tensor,
757        ) {
758            let longest_bad_word = dynamic_bad_words
759                .iter()
760                .map(|bad_word| bad_word.len())
761                .max()
762                .unwrap() as i64;
763
764            let last_token_ids = token_ids.slice(1, -longest_bad_word, None, 1);
765            let mut prev_tokens = Vec::new();
766            for sequence_idx in 0..token_ids.size()[0] {
767                prev_tokens.push(
768                    last_token_ids
769                        .get(sequence_idx)
770                        .iter::<i64>()
771                        .unwrap()
772                        .collect::<Vec<i64>>(),
773                )
774            }
775
776            let dynamic_bad_words_mask = if let Some(dynamic_bad_words) = dynamic_bad_words {
777                let dynamic_banned_tokens =
778                    self.get_dynamic_bad_word_ids(&prev_tokens, dynamic_bad_words);
779                let dynamic_banned_mask =
780                    Tensor::zeros(scores.size().as_slice(), (Kind::Int, scores.device()));
781                for (sequence_index, sequence_ban_tokens) in
782                    dynamic_banned_tokens.iter().enumerate()
783                {
784                    if !sequence_ban_tokens.is_empty() {
785                        let _ = dynamic_banned_mask.get(sequence_index as i64).index_fill_(
786                            0,
787                            &Tensor::from_slice(sequence_ban_tokens).to_device(scores.device()),
788                            1,
789                        );
790                    }
791                }
792                Some(dynamic_banned_mask.to_kind(Kind::Bool))
793            } else {
794                None
795            };
796
797            let combined_bad_word_mask = {
798                if let (Some(static_mask), Some(dynamic_mask)) =
799                    (static_bad_words_mask, &dynamic_bad_words_mask)
800                {
801                    Some(static_mask.bitwise_or_tensor(dynamic_mask))
802                } else {
803                    None
804                }
805            };
806
807            let bad_word_mask = if combined_bad_word_mask.is_some() {
808                combined_bad_word_mask.as_ref()
809            } else if static_bad_words_mask.is_some() {
810                static_bad_words_mask
811            } else if dynamic_bad_words_mask.is_some() {
812                dynamic_bad_words_mask.as_ref()
813            } else {
814                None
815            };
816
817            if let Some(bad_word_mask) = bad_word_mask {
818                let _ = scores.masked_fill_(bad_word_mask, f64::NEG_INFINITY);
819            }
820        }
821
822        fn generate_no_beam_search(
823            &self,
824            input_ids: Tensor,
825            encoder_outputs: Option<Tensor>,
826            cur_len: i64,
827            batch_size: i64,
828            attention_mask: Tensor,
829            gen_opt: InternalGenerateOptions,
830            prefix_allowed_tokens_fn: Option<PrefixAllowedFunction>,
831            output_scores: bool,
832        ) -> GeneratedOutputWithScores {
833            let mut unfinished_sentences =
834                Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
835            let mut sentence_lengths: Tensor =
836                Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
837            let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
838                self.split_bad_word_ids(gen_opt.bad_word_ids);
839            let mut static_bad_words_mask: Option<Tensor> = None;
840            let mut attention_mask = attention_mask.copy();
841            let mut input_ids = input_ids.copy();
842            let mut past: Cache = Cache::None;
843            let mut outputs: Tensor;
844            let mut current_length = cur_len;
845            let mut token_scores_output: Option<Vec<Tensor>> =
846                if output_scores { Some(vec![]) } else { None };
847
848            loop {
849                let prepared_input = self.prepare_inputs_for_generation(
850                    input_ids.copy(),
851                    encoder_outputs.as_ref(),
852                    past,
853                    attention_mask.copy(),
854                );
855                let temp = self
856                    .forward_t(
857                        prepared_input.prepared_input.as_ref(),
858                        prepared_input.prepared_past,
859                        prepared_input.prepared_attention_mask.as_ref(),
860                        None,
861                        prepared_input.prepared_position_ids.as_ref(),
862                        None,
863                        prepared_input.prepared_encoder_output,
864                        prepared_input.prepared_decoder_input.as_ref(),
865                        false,
866                    )
867                    .unwrap();
868                outputs = temp.lm_logits;
869                past = temp.cache;
870
871                let mut next_token_logits = outputs.select(1, -1);
872                // Reduce probability for repeated inputs
873                if gen_opt.repetition_penalty > 1f64 {
874                    self.enforce_repetition_penalty(
875                        &mut next_token_logits,
876                        batch_size,
877                        1,
878                        &input_ids,
879                        gen_opt.repetition_penalty,
880                    )
881                }
882
883                // Get bad word_ids and set their probability to 0
884                if gen_opt.bad_word_ids.is_some() {
885                    // Calculate static bad words masks if not set yet
886                    if let Some(bad_word_ids_length_1) = &bad_word_ids_length_1 {
887                        if static_bad_words_mask.is_none() {
888                            static_bad_words_mask = Some(self.calc_static_bad_word_mask(
889                                &next_token_logits,
890                                bad_word_ids_length_1,
891                            ));
892                        }
893                    }
894                    self.ban_bad_words(
895                        bad_word_ids_length_greater_than_1.as_ref(),
896                        static_bad_words_mask.as_ref(),
897                        &input_ids,
898                        &mut next_token_logits,
899                    );
900                }
901
902                // Get banned tokens and set their probability to 0
903                if gen_opt.no_repeat_ngram_size > 0 {
904                    let banned_tokens = self.get_banned_tokens(
905                        &input_ids,
906                        gen_opt.no_repeat_ngram_size,
907                        current_length,
908                    );
909                    for (batch_index, index_banned_token) in
910                        (0..banned_tokens.len() as i64).zip(banned_tokens)
911                    {
912                        let _ = next_token_logits.get(batch_index).index_fill_(
913                            0,
914                            &Tensor::from_slice(&index_banned_token)
915                                .to_device(next_token_logits.device()),
916                            f64::NEG_INFINITY,
917                        );
918                    }
919                }
920
921                // Apply custom prefix constraint function
922                if let Some(prefix_allowed_tokens_function) = prefix_allowed_tokens_fn {
923                    self.apply_prefix_allowed_tokens_function(
924                        prefix_allowed_tokens_function,
925                        1,
926                        &input_ids,
927                        &mut next_token_logits,
928                    )
929                }
930
931                // Do not allow eos token if min length is not reached
932                if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
933                    let _ = next_token_logits.index_fill_(
934                        1,
935                        &Tensor::from_slice(gen_opt.eos_token_ids.as_ref().unwrap())
936                            .to(next_token_logits.device()),
937                        f64::NEG_INFINITY,
938                    );
939                }
940
941                self.prepare_scores_for_generation(
942                    &mut next_token_logits,
943                    current_length,
944                    gen_opt.max_length,
945                    gen_opt.forced_bos_token_id,
946                );
947
948                // Top-k and top-p sampling
949                let next_token = if gen_opt.do_sample {
950                    if gen_opt.temperature > 1f64 {
951                        next_token_logits /= gen_opt.temperature;
952                    }
953                    self.top_k_top_p_filtering(
954                        &mut next_token_logits,
955                        gen_opt.top_k,
956                        gen_opt.top_p,
957                        1,
958                    );
959                    let probabilities = next_token_logits.softmax(-1, next_token_logits.kind());
960                    probabilities.multinomial(1, false).squeeze_dim(1)
961                } else {
962                    next_token_logits.argmax(-1, false)
963                };
964
965                if let Some(prev_scores) = token_scores_output.as_mut() {
966                    let finished_mask = unfinished_sentences.eq(0);
967                    prev_scores.push(
968                        next_token_logits
969                            .log_softmax(-1, next_token_logits.kind())
970                            .gather(1, &next_token.reshape([-1, 1]), false)
971                            .squeeze()
972                            .masked_fill(&finished_mask, 0),
973                    );
974                };
975
976                // Add tokens to unfinished sentences
977                let tokens_to_add = match &gen_opt.eos_token_ids {
978                    Some(_) => {
979                        next_token * &unfinished_sentences
980                            - gen_opt.pad_token_id.unwrap() * (&unfinished_sentences - 1)
981                    }
982                    None => next_token,
983                };
984
985                input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
986                if gen_opt.eos_token_ids.is_some() {
987                    for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() {
988                        let sentence_with_eos =
989                            tokens_to_add.eq(*eos_token_id).to_kind(Kind::Int64);
990                        let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
991                        let _ = sentence_lengths.masked_fill_(
992                            &sentence_with_eos
993                                .to_kind(Kind::Bool)
994                                .to_device(sentence_lengths.device()),
995                            current_length + 1,
996                        );
997                        unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1);
998                    }
999                    if i64::try_from(unfinished_sentences.max()).unwrap() == 0 {
1000                        break;
1001                    }
1002                }
1003                if !self.is_encoder_decoder() {
1004                    attention_mask = Tensor::cat(
1005                        &[
1006                            attention_mask.as_ref(),
1007                            Tensor::ones(
1008                                [*attention_mask.size().first().unwrap(), 1],
1009                                (Kind::Int64, attention_mask.device()),
1010                            )
1011                            .as_ref(),
1012                        ],
1013                        -1,
1014                    );
1015                }
1016                current_length += 1;
1017                if let Some(max_length) = gen_opt.max_length {
1018                    if current_length >= max_length {
1019                        let _ = sentence_lengths.masked_fill_(
1020                            &unfinished_sentences
1021                                .to_kind(Kind::Bool)
1022                                .to_device(sentence_lengths.device()),
1023                            current_length,
1024                        );
1025                        break;
1026                    }
1027                }
1028            }
1029            let scores_output = token_scores_output.as_ref().map(|scores_tensor| {
1030                (Tensor::stack(scores_tensor, 1).sum_dim_intlist(
1031                    [1].as_slice(),
1032                    false,
1033                    Kind::Float,
1034                ) / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty))
1035                .iter::<f64>()
1036                .unwrap()
1037                .collect::<Vec<f64>>()
1038            });
1039            let token_scores_output = token_scores_output.map(|score_tensors| {
1040                Tensor::stack(&score_tensors, 1)
1041                    .split(1, 0)
1042                    .iter()
1043                    .map(|sequence_scores| {
1044                        sequence_scores
1045                            .squeeze_dim(0)
1046                            .iter::<f64>()
1047                            .unwrap()
1048                            .collect::<Vec<f64>>()
1049                    })
1050                    .collect()
1051            });
1052            GeneratedOutputWithScores {
1053                indices: input_ids,
1054                scores: scores_output,
1055                token_scores: token_scores_output,
1056            }
1057        }
1058
1059        fn generate_beam_search(
1060            &self,
1061            mut input_ids: Tensor,
1062            encoder_outputs: Option<Tensor>,
1063            cur_len: i64,
1064            batch_size: i64,
1065            mut attention_mask: Tensor,
1066            gen_opt: InternalGenerateOptions,
1067            prefix_allowed_tokens_fn: Option<PrefixAllowedFunction>,
1068            output_scores: bool,
1069        ) -> GeneratedOutputWithScores {
1070            let num_beam_groups = gen_opt.num_beam_groups.unwrap_or(1);
1071            let num_sub_beams = gen_opt.num_beams / num_beam_groups;
1072            let diversity_penalty = gen_opt.diversity_penalty.unwrap_or(5.5);
1073            let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
1074                self.split_bad_word_ids(gen_opt.bad_word_ids);
1075            let mut static_bad_words_mask: Option<Tensor> = None;
1076
1077            let mut hypotheses = (0..batch_size)
1078                .map(|_| {
1079                    BeamHypotheses::new(
1080                        gen_opt.num_beams,
1081                        gen_opt.max_length,
1082                        gen_opt.length_penalty,
1083                        gen_opt.early_stopping,
1084                    )
1085                })
1086                .collect::<Vec<BeamHypotheses>>();
1087
1088            let vocab_size = self.get_vocab_size();
1089            let beam_scores = Tensor::ones(
1090                [batch_size, gen_opt.num_beams],
1091                (Kind::Float, self.get_device()),
1092            ) * -1e9;
1093            let _ = beam_scores
1094                .slice(1, 0, *beam_scores.size().last().unwrap(), num_sub_beams)
1095                .fill_(0);
1096
1097            let mut beam_scores = beam_scores.view_([-1]);
1098            let mut beam_tokens = Tensor::zeros(
1099                [batch_size * gen_opt.num_beams],
1100                (Kind::Int64, self.get_device()),
1101            );
1102            let mut beam_indices = Tensor::zeros(
1103                [batch_size * gen_opt.num_beams],
1104                (Kind::Int64, self.get_device()),
1105            );
1106            let mut saved_beam_scores: Option<Vec<Tensor>> =
1107                if output_scores { Some(vec![]) } else { None };
1108            let mut current_tokens = Tensor::new();
1109
1110            let mut past: Cache = Cache::None;
1111            let mut done = vec![false; batch_size as usize];
1112
1113            let mut outputs: Tensor;
1114            let mut encoder_outputs = encoder_outputs;
1115            let mut current_length = cur_len;
1116
1117            loop {
1118                if num_beam_groups > 1 {
1119                    current_tokens = Tensor::zeros(
1120                        [batch_size * gen_opt.num_beams],
1121                        (input_ids.kind(), input_ids.device()),
1122                    );
1123                }
1124                let prepared_input = self.prepare_inputs_for_generation(
1125                    input_ids.copy(),
1126                    encoder_outputs.as_ref(),
1127                    past,
1128                    attention_mask.copy(),
1129                );
1130                let temp = self
1131                    .forward_t(
1132                        prepared_input.prepared_input.as_ref(),
1133                        prepared_input.prepared_past,
1134                        prepared_input.prepared_attention_mask.as_ref(),
1135                        None,
1136                        prepared_input.prepared_position_ids.as_ref(),
1137                        None,
1138                        prepared_input.prepared_encoder_output,
1139                        prepared_input.prepared_decoder_input.as_ref(),
1140                        false,
1141                    )
1142                    .unwrap();
1143                outputs = temp.lm_logits;
1144                past = temp.cache;
1145
1146                for beam_group_index in 0..num_beam_groups {
1147                    let group_start_index = beam_group_index * num_sub_beams;
1148                    let group_end_index = min(group_start_index + num_sub_beams, gen_opt.num_beams);
1149                    let group_size = group_end_index - group_start_index;
1150
1151                    let (group_input_ids, batch_group_indices) = if num_beam_groups > 1 {
1152                        let mut batch_group_indices: Vec<i64> =
1153                            Vec::with_capacity((batch_size * group_size) as usize);
1154                        for batch_index in 0..batch_size {
1155                            batch_group_indices.extend(
1156                                (group_start_index..group_end_index)
1157                                    .map(|value| value + batch_index * gen_opt.num_beams),
1158                            )
1159                        }
1160                        let batch_group_indices =
1161                            Tensor::from_slice(batch_group_indices.as_slice())
1162                                .to(input_ids.device());
1163                        (
1164                            Some(input_ids.index_select(0, &batch_group_indices)),
1165                            Some(batch_group_indices),
1166                        )
1167                    } else {
1168                        (None, None)
1169                    };
1170
1171                    let mut next_token_logits = if num_beam_groups <= 1 {
1172                        outputs.select(1, -1)
1173                    } else {
1174                        outputs
1175                            .select(1, -1)
1176                            .index_select(0, batch_group_indices.as_ref().unwrap())
1177                    };
1178                    // Reduce probability for repeated inputs
1179                    if gen_opt.repetition_penalty > 1f64 {
1180                        self.enforce_repetition_penalty(
1181                            &mut next_token_logits,
1182                            batch_size,
1183                            1,
1184                            group_input_ids.as_ref().unwrap_or(&input_ids),
1185                            gen_opt.repetition_penalty,
1186                        )
1187                    }
1188
1189                    if gen_opt.temperature > 1f64 {
1190                        next_token_logits /= gen_opt.temperature;
1191                    }
1192                    self.prepare_scores_for_generation(
1193                        &mut next_token_logits,
1194                        current_length,
1195                        gen_opt.max_length,
1196                        gen_opt.forced_bos_token_id,
1197                    );
1198
1199                    let mut scores = next_token_logits.log_softmax(-1, next_token_logits.kind());
1200
1201                    // Do not allow eos token if min length is not reached
1202                    if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
1203                        let _ = scores.index_fill_(
1204                            1,
1205                            &Tensor::from_slice(gen_opt.eos_token_ids.as_ref().unwrap())
1206                                .to(scores.device()),
1207                            f64::NEG_INFINITY,
1208                        );
1209                    }
1210
1211                    // Get bad word_ids and set their probability to 0
1212                    if gen_opt.bad_word_ids.is_some() {
1213                        // Calculate static bad words masks if not set yet
1214                        if let Some(bad_word_ids_length_1) = &bad_word_ids_length_1 {
1215                            if static_bad_words_mask.is_none() {
1216                                static_bad_words_mask = Some(
1217                                    self.calc_static_bad_word_mask(&scores, bad_word_ids_length_1),
1218                                );
1219                            }
1220                        }
1221                        self.ban_bad_words(
1222                            bad_word_ids_length_greater_than_1.as_ref(),
1223                            static_bad_words_mask.as_ref(),
1224                            group_input_ids.as_ref().unwrap_or(&input_ids),
1225                            &mut scores,
1226                        );
1227                    }
1228
1229                    // Get repeated tokens and set their probability to 0
1230                    if gen_opt.no_repeat_ngram_size > 0 {
1231                        let banned_tokens = self.get_banned_tokens(
1232                            group_input_ids.as_ref().unwrap_or(&input_ids),
1233                            gen_opt.no_repeat_ngram_size,
1234                            current_length,
1235                        );
1236                        for (batch_index, index_banned_token) in
1237                            (0..banned_tokens.len() as i64).zip(banned_tokens)
1238                        {
1239                            let _ = scores.get(batch_index).index_fill_(
1240                                0,
1241                                &Tensor::from_slice(&index_banned_token)
1242                                    .to_device(next_token_logits.device()),
1243                                f64::NEG_INFINITY,
1244                            );
1245                        }
1246                    }
1247
1248                    // Update scores with diversity penalty
1249                    if num_beam_groups > 1 {
1250                        self.run_hamming_diversity_penalty(
1251                            &mut scores,
1252                            &current_tokens,
1253                            diversity_penalty,
1254                            gen_opt.num_beams,
1255                            batch_size,
1256                            group_size,
1257                            group_start_index,
1258                        );
1259                    }
1260
1261                    // Apply custom prefix constraint function
1262                    if let Some(prefix_allowed_tokens_function) = prefix_allowed_tokens_fn {
1263                        self.apply_prefix_allowed_tokens_function(
1264                            prefix_allowed_tokens_function,
1265                            num_sub_beams,
1266                            &input_ids,
1267                            &mut scores,
1268                        )
1269                    }
1270
1271                    let mut next_scores: Tensor = &scores
1272                        + (if num_beam_groups > 1 {
1273                            beam_scores
1274                                .index_select(0, batch_group_indices.as_ref().unwrap())
1275                                .unsqueeze(-1)
1276                                .expand_as(&scores)
1277                        } else {
1278                            beam_scores.unsqueeze(-1).expand_as(&scores)
1279                        });
1280
1281                    let (next_scores, next_tokens) = if gen_opt.do_sample {
1282                        self.top_k_top_p_filtering(
1283                            &mut next_scores,
1284                            gen_opt.top_k,
1285                            gen_opt.top_p,
1286                            2,
1287                        );
1288                        let _scores = next_scores
1289                            .contiguous()
1290                            .view((batch_size, group_size * vocab_size));
1291
1292                        let probabilities = _scores.softmax(-1, _scores.kind());
1293                        let next_tokens = probabilities.multinomial(2 * group_size, false);
1294                        let _scores = _scores.gather(-1, &next_tokens, false);
1295                        let (_scores, next_scores_indices) = _scores.sort(1, true);
1296                        let next_tokens = next_tokens.gather(-1, &next_scores_indices, false);
1297                        (_scores, next_tokens)
1298                    } else {
1299                        let _scores = next_scores
1300                            .contiguous()
1301                            .view((batch_size, group_size * vocab_size));
1302                        _scores.topk(2 * group_size, 1, true, true)
1303                    };
1304
1305                    let eos_token_ids = gen_opt.eos_token_ids.as_ref();
1306                    let beam_ids_tensor = &next_tokens.divide_scalar_mode(vocab_size, "floor");
1307                    let effective_beam_ids_tensor =
1308                        (&next_tokens.ones_like().cumsum(0, Kind::Int64) - 1) * group_size
1309                            + beam_ids_tensor;
1310                    let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size;
1311                    let (max_scores, _) = next_scores.max_dim(1, false);
1312                    let mut eos_mask = token_id_tensor.ones_like();
1313                    if let Some(eos_token_id) = eos_token_ids {
1314                        eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Kind::Int64);
1315                    }
1316                    let eos_mask2 = eos_mask
1317                        .cumsum(1, Kind::Int64)
1318                        .le(group_size)
1319                        .to_kind(Kind::Bool)
1320                        .logical_and(&eos_mask);
1321
1322                    let group_beam_scores = next_scores.masked_select(&eos_mask2);
1323                    let group_beam_tokens = token_id_tensor.masked_select(&eos_mask2);
1324                    let group_beam_indices = effective_beam_ids_tensor.masked_select(&eos_mask2);
1325                    let eos_pos = (eos_mask.ones_like() - eos_mask).nonzero();
1326
1327                    for eos_idx in 0..eos_pos.size()[0] {
1328                        let eos_data = eos_pos.get(eos_idx);
1329                        let batch_index = eos_data.int64_value(&[0]);
1330                        if !done[batch_index as usize] {
1331                            let beam_index_pos = eos_data.int64_value(&[1]);
1332                            let is_beam_token_worse_than_top_num_beams =
1333                                beam_index_pos >= gen_opt.num_beams;
1334                            if is_beam_token_worse_than_top_num_beams {
1335                                continue;
1336                            }
1337                            let effective_beam_id = effective_beam_ids_tensor
1338                                .int64_value(&[batch_index, beam_index_pos]);
1339                            let beam_token_score =
1340                                next_scores.double_value(&[batch_index, beam_index_pos]);
1341                            let saved_beam_scores =
1342                                saved_beam_scores.as_ref().map(|step_wise_scores| {
1343                                    Tensor::stack(step_wise_scores, 1)
1344                                        .get(effective_beam_id)
1345                                        .copy()
1346                                });
1347                            hypotheses[batch_index as usize].add(
1348                                input_ids.get(effective_beam_id).copy(),
1349                                beam_token_score,
1350                                saved_beam_scores,
1351                            );
1352                        }
1353                    }
1354
1355                    for batch_index in 0..batch_size {
1356                        if done[batch_index as usize] {
1357                            let _ = group_beam_scores
1358                                .narrow(0, batch_index * gen_opt.num_beams, gen_opt.num_beams)
1359                                .fill_(0f64);
1360                            let _ = group_beam_tokens
1361                                .narrow(0, batch_index * gen_opt.num_beams, gen_opt.num_beams)
1362                                .fill_(gen_opt.pad_token_id.unwrap());
1363                            let _ = group_beam_indices
1364                                .narrow(0, batch_index * gen_opt.num_beams, gen_opt.num_beams)
1365                                .fill_(0);
1366                            continue;
1367                        } else {
1368                            done[batch_index as usize] |= hypotheses[batch_index as usize]
1369                                .is_done(max_scores.double_value(&[batch_index]), current_length);
1370                        }
1371                    }
1372
1373                    if num_beam_groups <= 1 {
1374                        beam_scores = group_beam_scores.view(-1);
1375                        beam_tokens = group_beam_tokens.view(-1);
1376                        beam_indices = group_beam_indices.view(-1);
1377                    } else {
1378                        let _ = beam_scores.index_copy_(
1379                            0,
1380                            batch_group_indices.as_ref().unwrap(),
1381                            &group_beam_scores,
1382                        );
1383                        let _ = beam_tokens.index_copy_(
1384                            0,
1385                            batch_group_indices.as_ref().unwrap(),
1386                            &group_beam_tokens,
1387                        );
1388                        let new_indices = gen_opt.num_beams
1389                            * group_beam_indices.divide_scalar_mode(group_size, "floor")
1390                            + group_start_index
1391                            + group_beam_indices.remainder(group_size);
1392                        let _ = beam_indices.index_copy_(
1393                            0,
1394                            batch_group_indices.as_ref().unwrap(),
1395                            &new_indices,
1396                        );
1397                        let _ = current_tokens.index_copy_(
1398                            0,
1399                            batch_group_indices.as_ref().unwrap(),
1400                            &group_beam_tokens,
1401                        );
1402                    }
1403                }
1404
1405                if let Some(scores_output) = saved_beam_scores.as_mut() {
1406                    scores_output.push(beam_scores.copy());
1407                }
1408                if done.iter().all(|&x| x) {
1409                    break;
1410                }
1411
1412                input_ids = Tensor::cat(
1413                    &[
1414                        input_ids.index_select(0, &beam_indices),
1415                        beam_tokens.unsqueeze(1),
1416                    ],
1417                    -1,
1418                );
1419
1420                current_length += 1;
1421                if let Some(max_length) = gen_opt.max_length {
1422                    if current_length >= max_length {
1423                        break;
1424                    }
1425                }
1426                encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices);
1427
1428                if !self.is_encoder_decoder() {
1429                    attention_mask = Tensor::cat(
1430                        &[
1431                            attention_mask.as_ref(),
1432                            Tensor::ones(
1433                                [*attention_mask.size().first().unwrap(), 1],
1434                                (Kind::Int64, attention_mask.device()),
1435                            )
1436                            .as_ref(),
1437                        ],
1438                        -1,
1439                    );
1440                }
1441            }
1442
1443            let mut batch_index = 0i64;
1444
1445            let mut saved_beam_scores = saved_beam_scores
1446                .map(|step_wise_scores| Tensor::stack(&step_wise_scores, 1).split(1, 0));
1447            loop {
1448                if batch_index == batch_size {
1449                    break;
1450                }
1451                if done[batch_index as usize] {
1452                    batch_index += 1;
1453                    continue;
1454                }
1455                for beam_index in 0..gen_opt.num_beams {
1456                    let effective_beam_id = batch_index * gen_opt.num_beams + beam_index;
1457                    let beam_saved_token_scores = saved_beam_scores.as_mut().map(|saved_tokens| {
1458                        mem::replace(&mut saved_tokens[effective_beam_id as usize], Tensor::new())
1459                    });
1460                    let final_score = f64::try_from(beam_scores.get(effective_beam_id)).unwrap();
1461                    let final_tokens = input_ids.get(effective_beam_id);
1462                    hypotheses[batch_index as usize].add(
1463                        final_tokens,
1464                        final_score,
1465                        beam_saved_token_scores,
1466                    );
1467                }
1468                batch_index += 1;
1469            }
1470            let (output_batch_size, output_num_return_sequences_per_batch) = if gen_opt.do_sample {
1471                (batch_size, 1)
1472            } else {
1473                (
1474                    batch_size * gen_opt.num_return_sequences,
1475                    gen_opt.num_return_sequences,
1476                )
1477            };
1478
1479            let mut sentence_lengths =
1480                Tensor::zeros([output_batch_size], (Kind::Int64, input_ids.device()));
1481            let mut best_ids = vec![];
1482
1483            let mut scores_output = if output_scores {
1484                Some(Vec::with_capacity(best_ids.len()))
1485            } else {
1486                None
1487            };
1488            let mut token_scores_output = if output_scores {
1489                Some(Vec::with_capacity(best_ids.len()))
1490            } else {
1491                None
1492            };
1493            for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
1494                let mut sorted_hypotheses = hypothesis.clone();
1495                sorted_hypotheses
1496                    .beams
1497                    .sort_by_key(|(score, _, _)| OrderedFloat(*score));
1498                for j in 0..output_num_return_sequences_per_batch {
1499                    let effective_batch_index =
1500                        output_num_return_sequences_per_batch * hypothesis_index as i64 + j;
1501
1502                    let (best_score, best_hyp, best_token_scores) =
1503                        sorted_hypotheses.beams.pop().unwrap();
1504                    let _ = sentence_lengths.index_fill_(
1505                        0,
1506                        &Tensor::from_slice(&[effective_batch_index]).to(sentence_lengths.device()),
1507                        *best_hyp.size().first().unwrap(),
1508                    );
1509                    best_ids.push(best_hyp);
1510                    if let Some(current_best_scores) = &mut scores_output {
1511                        current_best_scores.push(best_score);
1512                    }
1513                    if let Some(current_best_token_scores) = &mut token_scores_output {
1514                        current_best_token_scores.push(
1515                            best_token_scores
1516                                .unwrap()
1517                                .iter::<f64>()
1518                                .unwrap()
1519                                .collect::<Vec<f64>>(),
1520                        );
1521                    }
1522                }
1523            }
1524            let sentence_max_length = gen_opt
1525                .max_length
1526                .map(|max_length| {
1527                    min(
1528                        i64::try_from(sentence_lengths.max()).unwrap() + 1,
1529                        max_length,
1530                    )
1531                })
1532                .unwrap_or(i64::try_from(sentence_lengths.max()).unwrap() + 1);
1533
1534            let mut decoded = input_ids.new_empty(
1535                [output_batch_size, sentence_max_length],
1536                (Kind::Int64, input_ids.device()),
1537            );
1538            if i64::try_from(sentence_lengths.max()).unwrap()
1539                != i64::try_from(sentence_lengths.min()).unwrap()
1540            {
1541                let _ = decoded.fill_(
1542                    gen_opt
1543                        .pad_token_id
1544                        .unwrap_or_else(|| gen_opt.eos_token_ids.as_ref().unwrap()[0]),
1545                );
1546            }
1547            for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
1548                let _ = decoded.get(hypothesis_index as i64).index_copy_(
1549                    0,
1550                    &Tensor::arange_start(
1551                        0,
1552                        i64::try_from(sentence_lengths.get(hypothesis_index as i64)).unwrap(),
1553                        (Kind::Int64, input_ids.device()),
1554                    ),
1555                    best_id,
1556                );
1557                let sentence_length =
1558                    i64::try_from(sentence_lengths.get(hypothesis_index as i64)).unwrap();
1559                let sentence_length_max = gen_opt
1560                    .max_length
1561                    .unwrap_or_else(|| i64::try_from(sentence_lengths.max()).unwrap());
1562                if sentence_length < sentence_length_max {
1563                    let _ = decoded.get(hypothesis_index as i64).index_fill_(
1564                        0,
1565                        &Tensor::from_slice(&[sentence_length]).to_device(input_ids.device()),
1566                        gen_opt.eos_token_ids.as_ref().unwrap()[0],
1567                    );
1568                }
1569            }
1570            GeneratedOutputWithScores {
1571                indices: decoded,
1572                scores: scores_output,
1573                token_scores: token_scores_output,
1574            }
1575        }
1576
1577        fn reorder_cache(
1578            &self,
1579            past: &mut Cache,
1580            _encoder_outputs: Option<Tensor>,
1581            _beam_indices: &Tensor,
1582        ) -> Option<Tensor> {
1583            match past {
1584                Cache::None => None,
1585                _ => {
1586                    panic!("Not implemented");
1587                }
1588            }
1589        }
1590    }
1591
1592    pub fn force_token_id_generation(scores: &mut Tensor, token_ids: &[i64], vocab_size: i64) {
1593        let impossible_tokens: Vec<i64> = (0..vocab_size)
1594            .filter(|pos| !token_ids.contains(pos))
1595            .collect();
1596        let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
1597        let _ = scores.index_fill_(
1598            1,
1599            &impossible_tokens,
1600            get_negative_infinity(scores.kind()).unwrap(),
1601        );
1602    }
1603}
1604
1605#[derive(Debug, Clone)]
1606/// # Generated text output
1607/// Contains generated text and an optional log-likelihood score for the generated sequence
1608pub struct GeneratedTextOutput {
1609    pub text: String,
1610    pub score: Option<f64>,
1611}
1612
1613#[derive(Debug, Clone)]
1614/// # Generated indices output
1615/// Contains generated indices and an optional log-likelihood score for the generated sequence and individual tokens
1616pub struct GeneratedIndicesOutput {
1617    pub indices: Vec<i64>,
1618    pub score: Option<f64>,
1619    pub token_scores: Option<Vec<f64>>,
1620}
1621
1622pub type PrefixAllowedFunction<'a> = &'a dyn Fn(i64, &Tensor) -> Vec<i64>;
1623/// Type alias for a function defining allowed tokens based on current tokens generated.
1624/// This function should take a `batch_id` and associated tensor of already generated tokens and
1625/// should return a vector of allowed tokens. This is useful for controlled generation, i.e.
1626/// deterministic generation of a token continuation if a sequence of token occurs.
1627
1628#[derive(Clone, Copy, Default)]
1629/// # Generation options for text generation.
1630/// When provided to a `generate` method, these options will take priority over the `GenerateConfig` used to create the
1631/// `LanguageGenerator`. Some of these options may be left as `None`, options without a value will individually default
1632/// to the `GenerateConfig`.
1633pub struct GenerateOptions<'a> {
1634    /// Minimum sequence length
1635    pub min_length: Option<i64>,
1636    /// Maximum sequence length
1637    pub max_length: Option<i64>,
1638    /// Maximum number of new tokens to generate (useful for causal generation models).
1639    /// Only one of `max_length` and `max_new_tokens` should be provided.
1640    /// When both are given, `max_new_tokens` is ignored and the `max_length` setting is used.
1641    pub max_new_tokens: Option<i64>,
1642    /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated
1643    pub early_stopping: Option<bool>,
1644    /// Number of sequences to return for each prompt text
1645    pub num_return_sequences: Option<i64>,
1646    /// Number of beams for beam search
1647    pub num_beams: Option<i64>,
1648    pub num_beam_groups: Option<i64>,
1649    /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding
1650    pub do_sample: Option<bool>,
1651    /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance
1652    pub temperature: Option<f64>,
1653    /// Top_k values for sampling tokens. Value higher than 0 will enable the feature
1654    pub top_k: Option<i64>,
1655    /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p
1656    pub top_p: Option<f64>,
1657    /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated.
1658    pub repetition_penalty: Option<f64>,
1659    /// Exponential penalty based on the length of the hypotheses generated
1660    pub length_penalty: Option<f64>,
1661    /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature
1662    pub no_repeat_ngram_size: Option<i64>,
1663    /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups
1664    pub diversity_penalty: Option<f64>,
1665    /// Decoder start token id
1666    pub decoder_start_token_id: Option<i64>,
1667    /// Forced first token generated
1668    pub forced_bos_token_id: Option<i64>,
1669    /// Function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
1670    pub prefix_allowed_tokens_fn: Option<PrefixAllowedFunction<'a>>,
1671    /// List of bad word ids (may be a sequence of word ids) that will be banned during the generation
1672    pub bad_word_ids: Option<&'a Vec<Vec<i64>>>,
1673    /// Flag indicating if text generation scores should be returned
1674    pub output_scores: bool,
1675}
1676
1677macro_rules! unpack_config {
1678    ($field_name:ident, $generate_options: ident, $generate_config: ident) => {
1679        $generate_options.map_or($generate_config.$field_name, |opts| {
1680            opts.$field_name.unwrap_or($generate_config.$field_name)
1681        })
1682    };
1683}
1684
1685/// # Common trait for text generation models.
1686/// Main API for text generation
1687pub trait LanguageGenerator: PrivateLanguageGenerator {
1688    /// Generate text based on a vector of promp texts.
1689    ///
1690    /// # Arguments
1691    ///
1692    /// * `prompt_texts` - `Option<Vec<&str>>` Optional vector of text prompts. An empty prompt to the model may be passed if the model implement a `bos_id`.
1693    /// * `generate_options` - `Option<GenerateOptions>` Optional set of generate options. If not (or partially) provided, will use the settings provided when creating the generator
1694    ///
1695    /// # Returns
1696    /// * `Vec<TextOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing TextOutput with the generated texts and the generation score if `output_scores` is true.
1697    ///
1698    /// # Example
1699    ///
1700    /// ```no_run
1701    /// # use std::path::PathBuf;
1702    /// # use tch::Device;
1703    /// # fn main() -> anyhow::Result<()> {
1704    /// use rust_bert::gpt2::GPT2Generator;
1705    /// use rust_bert::pipelines::generation_utils::{
1706    ///     GenerateConfig, GenerateOptions, LanguageGenerator,
1707    /// };
1708    /// use tch::Tensor;
1709    /// # let mut home: PathBuf = dirs::home_dir().unwrap();
1710    /// # home.push("rustbert");
1711    /// # home.push("gpt2");
1712    /// # let config_path = &home.as_path().join("config.json");
1713    /// # let vocab_path = &home.as_path().join("vocab.txt");
1714    /// # let merges_path = &home.as_path().join("merges.txt");
1715    /// # let weights_path = &home.as_path().join("model.ot");
1716    /// let device = Device::cuda_if_available();
1717    /// let generate_config = GenerateConfig {
1718    ///     max_length: Some(30),
1719    ///     do_sample: true,
1720    ///     num_beams: 5,
1721    ///     temperature: 1.1,
1722    ///     num_return_sequences: 3,
1723    ///     ..Default::default()
1724    /// };
1725    /// let gpt2_generator = GPT2Generator::new(generate_config)?;
1726    /// let input_context = "The dog";
1727    /// let second_input_context = "The cat was";
1728    ///
1729    /// //Example custom function for fine-grained generation control
1730    /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
1731    ///     let paragraph_tokens = [198, 628];
1732    ///
1733    ///     for paragraph_token in paragraph_tokens.iter() {
1734    ///         if previous_token_ids
1735    ///             .iter::<i64>()
1736    ///             .unwrap()
1737    ///             .collect::<Vec<i64>>()
1738    ///             .contains(paragraph_token)
1739    ///         {
1740    ///             return vec![50256];
1741    ///         }
1742    ///     }
1743    ///     (0..50255).collect()
1744    /// }
1745    ///
1746    /// let generate_options = GenerateOptions {
1747    ///     min_length: Some(32),
1748    ///     max_length: Some(128),
1749    ///     output_scores: true,
1750    ///     prefix_allowed_tokens_fn: Some(&force_one_paragraph),
1751    ///     ..Default::default()
1752    /// };
1753    ///
1754    /// let output = gpt2_generator.generate(
1755    ///     Some(&[input_context, second_input_context]),
1756    ///     Some(generate_options),
1757    /// );
1758    /// # Ok(())
1759    /// # }
1760    /// ```
1761    /// Example output: \
1762    /// ```no_run
1763    /// # let output =
1764    /// [
1765    ///     "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
1766    ///     "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
1767    ///     "The dog has been able to stay in the home for more than three months now. \"It's a very good dog. She's",
1768    ///     "The cat was discovered earlier this month in the home of a relative of the deceased. The cat\'s owner, who wished to remain anonymous,",
1769    ///     "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
1770    ///     "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
1771    /// ]
1772    /// # ;
1773    /// ```
1774    fn generate<S>(
1775        &self,
1776        prompt_texts: Option<&[S]>,
1777        generate_options: Option<GenerateOptions>,
1778    ) -> Result<Vec<GeneratedTextOutput>, RustBertError>
1779    where
1780        S: AsRef<str> + Send + Sync,
1781    {
1782        let indices_outputs = self.generate_indices(prompt_texts, generate_options)?;
1783        let mut output = Vec::with_capacity(indices_outputs.len());
1784        for generated_sequence in indices_outputs {
1785            output.push(GeneratedTextOutput {
1786                text: self
1787                    ._get_tokenizer()
1788                    .decode(&generated_sequence.indices, true, true),
1789                score: generated_sequence.score,
1790            });
1791        }
1792        Ok(output)
1793    }
1794
1795    /// Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).
1796    ///
1797    /// # Arguments
1798    ///
1799    /// * `prompt_texts` - `Option<Vec<&str>>` Optional vector of text prompts. An empty prompt to the model may be passed if the model implement a `bos_id`.
1800    /// * `generate_options` - `Option<GenerateOptions>` Optional set of generate options. If not (or partially) provided, will use the settings provided when creating the generator
1801    ///
1802    /// # Returns
1803    /// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
1804    ///
1805    /// # Example
1806    ///
1807    /// ```no_run
1808    /// # use std::path::PathBuf;
1809    /// # use tch::Device;
1810    /// # fn main() -> anyhow::Result<()> {
1811    /// use rust_bert::gpt2::GPT2Generator;
1812    /// use rust_bert::pipelines::generation_utils::{
1813    ///     GenerateConfig, GenerateOptions, LanguageGenerator,
1814    /// };
1815    /// use tch::Tensor;
1816    /// # let mut home: PathBuf = dirs::home_dir().unwrap();
1817    /// # home.push("rustbert");
1818    /// # home.push("gpt2");
1819    /// # let config_path = &home.as_path().join("config.json");
1820    /// # let vocab_path = &home.as_path().join("vocab.txt");
1821    /// # let merges_path = &home.as_path().join("merges.txt");
1822    /// # let weights_path = &home.as_path().join("model.ot");
1823    /// let device = Device::cuda_if_available();
1824    /// let generate_config = GenerateConfig {
1825    ///     max_length: Some(30),
1826    ///     do_sample: true,
1827    ///     num_beams: 5,
1828    ///     temperature: 1.1,
1829    ///     num_return_sequences: 3,
1830    ///     ..Default::default()
1831    /// };
1832    /// let gpt2_generator = GPT2Generator::new(generate_config)?;
1833    /// let input_context = "The dog";
1834    /// let second_input_context = "The cat was";
1835    ///
1836    /// //Example custom function for fine-grained generation control
1837    /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
1838    ///     let paragraph_tokens = [198, 628];
1839    ///
1840    ///     for paragraph_token in paragraph_tokens.iter() {
1841    ///         if previous_token_ids
1842    ///             .iter::<i64>()
1843    ///             .unwrap()
1844    ///             .collect::<Vec<i64>>()
1845    ///             .contains(paragraph_token)
1846    ///         {
1847    ///             return vec![50256];
1848    ///         }
1849    ///     }
1850    ///     (0..50255).collect()
1851    /// }
1852    ///
1853    /// let generate_options = GenerateOptions {
1854    ///     min_length: Some(32),
1855    ///     max_length: Some(128),
1856    ///     output_scores: true,
1857    ///     prefix_allowed_tokens_fn: Some(&force_one_paragraph),
1858    ///     ..Default::default()
1859    /// };
1860    ///
1861    /// let output = gpt2_generator.generate_indices(
1862    ///     Some(&[input_context, second_input_context]),
1863    ///     Some(generate_options),
1864    /// );
1865    /// # Ok(())
1866    /// # }
1867    /// ```
1868    fn generate_indices<S>(
1869        &self,
1870        prompt_texts: Option<&[S]>,
1871        generate_options: Option<GenerateOptions>,
1872    ) -> Result<Vec<GeneratedIndicesOutput>, RustBertError>
1873    where
1874        S: AsRef<str> + Send + Sync,
1875    {
1876        let eos_token_ids = self.get_eos_ids();
1877
1878        let config = self.get_config();
1879
1880        let max_length = generate_options.map_or(config.max_length, |generate_options| {
1881            generate_options.max_length
1882        });
1883        let encoding_max_len = if self.is_encoder_decoder() {
1884            self.get_max_positions_embeddings()
1885        } else {
1886            max_length
1887        };
1888        let pad_token_id = match self.get_pad_id() {
1889            Some(value) => Some(value),
1890            None => eos_token_ids.as_ref().map(|eos_ids| eos_ids[0]),
1891        };
1892
1893        let input_ids = match prompt_texts {
1894            Some(prompts) if !prompts.is_empty() => {
1895                self.encode_prompt_text(prompts, encoding_max_len, pad_token_id)
1896            }
1897            None => match self.get_bos_id() {
1898                Some(bos_id) => Tensor::ones([1, 1], (Int64, self.get_device())) * bos_id,
1899                None => return Err(RustBertError::ValueError(
1900                    "A model with a BOS token must be used to start generation with an empty input"
1901                        .to_string(),
1902                )),
1903            },
1904            _ => return Ok(Vec::new()),
1905        };
1906        self.generate_from_ids_and_past(input_ids, None, generate_options)
1907    }
1908
1909    /// Generate token indices given a list of indices (useful when the input has been pre-tokenized).
1910    /// Returns a list of output tokens that need to be decoded using a tokenizer.
1911    ///
1912    /// # Arguments
1913    ///
1914    /// * `input_ids` - `Tensor` pre-tokenized and encoded input for generation.
1915    /// * `generate_options` - `Option<GenerateOptions>` Optional set of generate options. If not (or partially) provided, will use the settings provided when creating the generator
1916    ///
1917    /// # Returns
1918    /// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
1919    ///
1920    /// # Example
1921    ///
1922    /// ```no_run
1923    /// # use std::path::PathBuf;
1924    /// # use tch::Device;
1925    /// # fn main() -> anyhow::Result<()> {
1926    /// use rust_bert::gpt2::GPT2Generator;
1927    /// use rust_bert::pipelines::generation_utils::{
1928    ///     GenerateConfig, GenerateOptions, LanguageGenerator,
1929    /// };
1930    /// use tch::{Kind, Tensor};
1931    /// # let mut home: PathBuf = dirs::home_dir().unwrap();
1932    /// # home.push("rustbert");
1933    /// # home.push("gpt2");
1934    /// # let config_path = &home.as_path().join("config.json");
1935    /// # let vocab_path = &home.as_path().join("vocab.txt");
1936    /// # let merges_path = &home.as_path().join("merges.txt");
1937    /// # let weights_path = &home.as_path().join("model.ot");
1938    /// let device = Device::cuda_if_available();
1939    ///
1940    /// let gpt2_generator = GPT2Generator::new(Default::default())?;
1941    /// let input_tensor = Tensor::randn(&[32, 128], (Kind::Int64, Device::Cpu));
1942    /// let input_mask = Tensor::ones(&[32, 128], (Kind::Int64, Device::Cpu));
1943    ///
1944    /// let generate_options = GenerateOptions {
1945    ///     min_length: Some(32),
1946    ///     max_length: Some(128),
1947    ///     output_scores: true,
1948    ///     ..Default::default()
1949    /// };
1950    ///
1951    /// let output = gpt2_generator.generate_from_ids_and_past(
1952    ///     input_tensor,
1953    ///     Some(input_mask),
1954    ///     Some(generate_options),
1955    /// );
1956    /// # Ok(())
1957    /// # }
1958    /// ```
1959    fn generate_from_ids_and_past(
1960        &self,
1961        mut input_ids: Tensor,
1962        mut attention_mask: Option<Tensor>,
1963        generate_options: Option<GenerateOptions>,
1964    ) -> Result<Vec<GeneratedIndicesOutput>, RustBertError> {
1965        let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).cloned();
1966
1967        let config = PrivateLanguageGenerator::get_config(self);
1968
1969        // Set generation options. Priority goes to options provided to the `generate` method, then
1970        // model configuration, then default values.
1971        let do_sample = unpack_config!(do_sample, generate_options, config);
1972        let num_return_sequences = unpack_config!(num_return_sequences, generate_options, config);
1973        let num_beams = unpack_config!(num_beams, generate_options, config);
1974        let min_length = unpack_config!(min_length, generate_options, config);
1975        let early_stopping = unpack_config!(early_stopping, generate_options, config);
1976        let temperature = unpack_config!(temperature, generate_options, config);
1977        let top_k = unpack_config!(top_k, generate_options, config);
1978        let top_p = unpack_config!(top_p, generate_options, config);
1979        let repetition_penalty = unpack_config!(repetition_penalty, generate_options, config);
1980        let length_penalty = unpack_config!(length_penalty, generate_options, config);
1981        let no_repeat_ngram_size = unpack_config!(no_repeat_ngram_size, generate_options, config);
1982        let num_beam_groups = generate_options.map_or(config.num_beam_groups, |opts| {
1983            opts.num_beam_groups.or(config.num_beam_groups)
1984        });
1985        let diversity_penalty = generate_options.map_or(config.diversity_penalty, |opts| {
1986            opts.diversity_penalty.or(config.diversity_penalty)
1987        });
1988        let decoder_start_token_id = generate_options.and_then(|opts| opts.decoder_start_token_id);
1989        let forced_bos_token_id = generate_options.and_then(|opts| opts.forced_bos_token_id);
1990        let bad_word_ids = generate_options.and_then(|opts| opts.bad_word_ids);
1991        let prefix_allowed_tokens_fn =
1992            generate_options.and_then(|opts| opts.prefix_allowed_tokens_fn);
1993        let output_scores = generate_options.map_or(false, |opts| opts.output_scores);
1994
1995        let pad_token_id = match self.get_pad_id() {
1996            Some(value) => Some(value),
1997            None => eos_token_ids.as_ref().map(|eos_ids| eos_ids[0]),
1998        };
1999
2000        let input_id_size = input_ids.size();
2001        let mut input_ids_len = *input_id_size.last().unwrap();
2002        if input_ids_len == 0 {
2003            input_ids = Tensor::ones(
2004                [*input_id_size.first().unwrap(), 1],
2005                (Int64, input_ids.device()),
2006            ) * self
2007                .get_bos_id()
2008                .expect("`bos_token_id` has to be defined when no `input_ids` are provided.");
2009            attention_mask = Some(Tensor::ones(
2010                [*input_id_size.first().unwrap(), 1],
2011                (Int64, input_ids.device()),
2012            ));
2013            input_ids_len += 1;
2014        }
2015
2016        let cur_len = if !self.is_encoder_decoder() {
2017            *input_ids.size().last().unwrap()
2018        } else {
2019            1
2020        };
2021        let batch_size = *input_ids.size().first().unwrap();
2022
2023        let (effective_batch_size, effective_batch_mult) = match do_sample {
2024            true => (batch_size * num_return_sequences, num_return_sequences),
2025            false => (batch_size, 1),
2026        };
2027
2028        let attention_mask = match attention_mask {
2029            Some(value) => value,
2030            None => match pad_token_id {
2031                Some(pad_id) => input_ids.ne(pad_id).to_kind(Int64),
2032                None => input_ids.ones_like().to_kind(Int64),
2033            },
2034        };
2035
2036        let encoder_outputs = if self.is_encoder_decoder() {
2037            let encoder_outputs = self
2038                .encode(&input_ids, Some(&attention_mask))
2039                .ok_or(RustBertError::UnsupportedError)?;
2040            let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
2041                .view((-1, 1))
2042                .repeat([1, num_beams * effective_batch_mult])
2043                .view(-1);
2044            Some(encoder_outputs.index_select(0, &expanded_batch_indices))
2045        } else {
2046            None
2047        };
2048
2049        let (input_ids, attention_mask) = if !self.is_encoder_decoder() {
2050            if (num_return_sequences > 1) | (num_beams > 1) {
2051                (
2052                    input_ids
2053                        .unsqueeze(1)
2054                        .expand(
2055                            [batch_size, effective_batch_mult * num_beams, cur_len],
2056                            true,
2057                        )
2058                        .contiguous()
2059                        .view((effective_batch_size * num_beams, cur_len)),
2060                    attention_mask
2061                        .unsqueeze(1)
2062                        .expand(
2063                            [batch_size, effective_batch_mult * num_beams, cur_len],
2064                            true,
2065                        )
2066                        .contiguous()
2067                        .view((effective_batch_size * num_beams, cur_len)),
2068                )
2069            } else {
2070                (input_ids, attention_mask)
2071            }
2072        } else {
2073            let decoder_start_token_id = decoder_start_token_id
2074                .or(self.get_decoder_start_id())
2075                .ok_or(RustBertError::ValueError(
2076                    "decoder start id must be specified for encoder decoders".to_string(),
2077                ))?;
2078            let input_ids = Tensor::full(
2079                [effective_batch_size * num_beams, 1],
2080                decoder_start_token_id,
2081                (Int64, input_ids.device()),
2082            );
2083            let attention_mask = if (num_return_sequences > 1) | (num_beams > 1) {
2084                attention_mask
2085                    .unsqueeze(1)
2086                    .expand(
2087                        [batch_size, effective_batch_mult * num_beams, input_ids_len],
2088                        true,
2089                    )
2090                    .contiguous()
2091                    .view((effective_batch_size * num_beams, input_ids_len))
2092            } else {
2093                attention_mask
2094            };
2095            (input_ids, attention_mask)
2096        };
2097
2098        let max_length = if let Some(generate_options) = generate_options {
2099            match (generate_options.max_length, generate_options.max_new_tokens) {
2100                (Some(max_length), _) => Some(max_length),
2101                (None, Some(max_new_tokens)) => {
2102                    Some(max_new_tokens + input_ids.size().last().unwrap())
2103                }
2104                (None, None) => config.max_length,
2105            }
2106        } else {
2107            config.max_length
2108        };
2109
2110        if let Some(max_length) = max_length {
2111            if input_ids.size2()?.1 > max_length {
2112                return Err(RustBertError::ValueError("The input ids exceeds the maximum length for generation.\
2113                 Reduce the size of the provided input ids or increase the allowable maximum generation length.".to_string()));
2114            }
2115        }
2116
2117        if max_length.is_none() & eos_token_ids.is_none() {
2118            return Err(RustBertError::InvalidConfigurationError("No maximum length given for a model without an EOS token. \
2119            This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`".to_string()));
2120        }
2121
2122        let gen_opt = InternalGenerateOptions {
2123            min_length,
2124            max_length,
2125            do_sample,
2126            temperature,
2127            top_k,
2128            top_p,
2129            repetition_penalty,
2130            no_repeat_ngram_size,
2131            pad_token_id,
2132            eos_token_ids,
2133            num_return_sequences,
2134            early_stopping,
2135            num_beams,
2136            length_penalty,
2137            num_beam_groups,
2138            diversity_penalty,
2139            forced_bos_token_id,
2140            bad_word_ids,
2141        };
2142
2143        let generated_output_with_scores = no_grad(|| {
2144            if num_beams > 1 {
2145                self.generate_beam_search(
2146                    input_ids,
2147                    encoder_outputs,
2148                    cur_len,
2149                    effective_batch_size,
2150                    attention_mask,
2151                    gen_opt,
2152                    prefix_allowed_tokens_fn,
2153                    output_scores,
2154                )
2155            } else {
2156                self.generate_no_beam_search(
2157                    input_ids,
2158                    encoder_outputs,
2159                    cur_len,
2160                    effective_batch_size,
2161                    attention_mask,
2162                    gen_opt,
2163                    prefix_allowed_tokens_fn,
2164                    output_scores,
2165                )
2166            }
2167        });
2168        let (decoded, scores, mut token_scores) = (
2169            generated_output_with_scores.indices,
2170            generated_output_with_scores.scores,
2171            generated_output_with_scores.token_scores,
2172        );
2173        let num_sequences = *decoded.size().first().unwrap();
2174        let mut output = Vec::with_capacity(num_sequences as usize);
2175        for sequence_index in 0..num_sequences {
2176            let indices = decoded
2177                .as_ref()
2178                .get(sequence_index)
2179                .iter::<i64>()
2180                .unwrap()
2181                .collect::<Vec<i64>>();
2182            let score = scores
2183                .as_ref()
2184                .map(|scores_value| scores_value[sequence_index as usize]);
2185
2186            let token_scores = token_scores
2187                .as_mut()
2188                .map(|token_scores| std::mem::take(&mut token_scores[sequence_index as usize]));
2189
2190            output.push(GeneratedIndicesOutput {
2191                indices,
2192                score,
2193                token_scores,
2194            });
2195        }
2196        Ok(output)
2197    }
2198
2199    /// Returns a reference to the text generator's tokenizer
2200    ///
2201    /// # Returns
2202    /// * `&TokenizerOption` Reference to the generator's tokenizer.
2203    ///
2204    /// # Example
2205    ///
2206    /// ```no_run
2207    /// # use std::path::PathBuf;
2208    /// # use tch::Device;
2209    /// # fn main() -> anyhow::Result<()> {
2210    /// use rust_bert::gpt2::GPT2Generator;
2211    /// use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
2212    /// use tch::Tensor;
2213    /// # let mut home: PathBuf = dirs::home_dir().unwrap();
2214    /// # home.push("rustbert");
2215    /// # home.push("gpt2");
2216    /// # let config_path = &home.as_path().join("config.json");
2217    /// # let vocab_path = &home.as_path().join("vocab.txt");
2218    /// # let merges_path = &home.as_path().join("merges.txt");
2219    /// # let weights_path = &home.as_path().join("model.ot");
2220    /// let device = Device::cuda_if_available();
2221    /// let generate_config = GenerateConfig {
2222    ///     max_length: Some(30),
2223    ///     do_sample: true,
2224    ///     num_beams: 5,
2225    ///     temperature: 1.1,
2226    ///     num_return_sequences: 3,
2227    ///     ..Default::default()
2228    /// };
2229    /// let gpt2_generator = GPT2Generator::new(generate_config)?;
2230    /// let tokenizer = gpt2_generator.get_tokenizer();
2231    /// tokenizer.tokenize("Hello, world!");
2232    /// # Ok(())
2233    /// # }
2234    /// ```
2235    fn get_tokenizer(&self) -> &TokenizerOption {
2236        self._get_tokenizer()
2237    }
2238
2239    fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
2240        self._get_tokenizer_mut()
2241    }
2242
2243    fn half(&mut self) -> Result<(), RustBertError> {
2244        self.get_var_store_mut()?.half();
2245        Ok(())
2246    }
2247
2248    fn float(&mut self) -> Result<(), RustBertError> {
2249        self.get_var_store_mut()?.float();
2250        Ok(())
2251    }
2252
2253    fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
2254        self.get_var_store_mut()?.set_device(device);
2255        Ok(())
2256    }
2257}
2258
2259#[derive(Debug)]
2260struct BeamHypotheses {
2261    max_length: Option<i64>,
2262    length_penalty: f64,
2263    early_stopping: bool,
2264    num_beams: i64,
2265    beams: Vec<(f64, Tensor, Option<Tensor>)>,
2266    worst_score: f64,
2267}
2268
2269impl Clone for BeamHypotheses {
2270    fn clone(&self) -> Self {
2271        BeamHypotheses {
2272            max_length: self.max_length,
2273            length_penalty: self.length_penalty,
2274            early_stopping: self.early_stopping,
2275            num_beams: self.num_beams,
2276            beams: self
2277                .beams
2278                .iter()
2279                .map(|(score, tensor, scores_tensor)| {
2280                    (
2281                        *score,
2282                        tensor.copy(),
2283                        scores_tensor
2284                            .as_ref()
2285                            .map(|scores_tensor| scores_tensor.copy()),
2286                    )
2287                })
2288                .collect::<Vec<(f64, Tensor, Option<Tensor>)>>(),
2289            worst_score: self.worst_score,
2290        }
2291    }
2292}
2293
2294impl BeamHypotheses {
2295    fn new(
2296        num_beams: i64,
2297        max_length: Option<i64>,
2298        length_penalty: f64,
2299        early_stopping: bool,
2300    ) -> BeamHypotheses {
2301        BeamHypotheses {
2302            max_length: max_length.map(|max_length| max_length - 1),
2303            length_penalty,
2304            early_stopping,
2305            num_beams,
2306            beams: Vec::with_capacity(num_beams as usize + 1),
2307            worst_score: 1e9f64,
2308        }
2309    }
2310
2311    fn len(&self) -> i64 {
2312        self.beams.len() as i64
2313    }
2314
2315    fn add(
2316        &mut self,
2317        hypothesis: Tensor,
2318        sum_log_probabilities: f64,
2319        token_scores: Option<Tensor>,
2320    ) {
2321        let score =
2322            sum_log_probabilities / ((hypothesis.size()[0] as f64).powf(self.length_penalty));
2323        if (self.len() < self.num_beams) | (score > self.worst_score) {
2324            let token_scores = token_scores.map(|scores_tensor| {
2325                scores_tensor.squeeze_dim(0).diff::<Tensor>(
2326                    1,
2327                    0,
2328                    Some(Tensor::zeros(
2329                        [1],
2330                        (scores_tensor.kind(), scores_tensor.device()),
2331                    )),
2332                    None,
2333                )
2334            });
2335            self.beams.push((score, hypothesis, token_scores));
2336            if self.len() > self.num_beams {
2337                let (worst_score_position, _) = self
2338                    .beams
2339                    .iter()
2340                    .enumerate()
2341                    .min_by_key(|(_, (score, _, _))| OrderedFloat(*score))
2342                    .unwrap();
2343                let _ = self.beams.remove(worst_score_position);
2344            }
2345            self.worst_score = self
2346                .beams
2347                .iter()
2348                .min_by_key(|(score, _, _)| OrderedFloat(*score))
2349                .unwrap()
2350                .0;
2351        }
2352    }
2353
2354    fn is_done(&self, best_sum_log_probabilities: f64, current_length: i64) -> bool {
2355        if self.len() < self.num_beams {
2356            false
2357        } else if self.early_stopping {
2358            true
2359        } else {
2360            self.worst_score
2361                >= best_sum_log_probabilities / (current_length as f64).powf(self.length_penalty)
2362        }
2363    }
2364}
2365
2366/// Container holding a language model output for generation tasks
2367pub struct LMModelOutput {
2368    /// Logits for each vocab item and position
2369    pub lm_logits: Tensor,
2370    /// cached state for improved efficiency during decoding
2371    pub cache: Cache,
2372}