rust_bert/models/pegasus/
pegasus_model.rs

1// Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
2// Copyright 2021 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::bart::BartModelOutput;
14use crate::mbart::MBartConfig;
15use crate::pegasus::decoder::PegasusDecoder;
16use crate::pegasus::encoder::PegasusEncoder;
17use crate::pegasus::LayerState;
18use crate::pipelines::common::{ModelType, TokenizerOption};
19use crate::pipelines::generation_utils::private_generation_utils::{
20    PreparedInput, PrivateLanguageGenerator,
21};
22use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
23use crate::{Config, RustBertError};
24use std::borrow::Borrow;
25use tch::nn::{embedding, EmbeddingConfig, Init};
26use tch::{nn, Device, Tensor};
27
28/// # Pegasus Pretrained model weight files
29pub struct PegasusModelResources;
30
31/// # Pegasus Pretrained model config files
32pub struct PegasusConfigResources;
33
34/// # Pegasus Pretrained model vocab files
35pub struct PegasusVocabResources;
36
37impl PegasusModelResources {
38    /// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>. Modified with conversion to C-array format.
39    pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
40        "pegasus-cnn_dailymail/model",
41        "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/rust_model.ot",
42    );
43}
44
45impl PegasusConfigResources {
46    /// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>.
47    pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
48        "pegasus-cnn_dailymail/config",
49        "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/config.json",
50    );
51}
52
53impl PegasusVocabResources {
54    /// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>.
55    pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
56        "pegasus-cnn_dailymail/spiece",
57        "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/spiece.model",
58    );
59}
60
61/// # Pegasus model configuration
62/// Defines the Pegasus model architecture (e.g. number of layers, hidden layer size, label mapping...)
63pub type PegasusConfig = MBartConfig;
64
65fn _shift_tokens_right(
66    input_ids: &Tensor,
67    pad_token_id: i64,
68    decoder_start_token_id: i64,
69) -> Tensor {
70    let input_ids_length = input_ids.size()[1];
71    let mut shifted_input_ids = Tensor::zeros(
72        input_ids.size().as_slice(),
73        (input_ids.kind(), input_ids.device()),
74    );
75    shifted_input_ids
76        .slice(1, 1, input_ids_length, 1)
77        .copy_(&input_ids.slice(1, 0, input_ids_length - 1, 1));
78
79    let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id);
80    let _ = shifted_input_ids.masked_fill_(&shifted_input_ids.eq(-100), pad_token_id);
81
82    shifted_input_ids
83}
84
85/// # Pegasus Base model
86/// Base architecture for Pegasus model. Usually complemented with a task-specific head, such as a language model head.
87/// It is made of the following blocks:
88/// - `encoder`: `PegasusEncoder` (transformer) made of a vector of encoding layers
89/// - `decoder`: `PegasusDecoder` (transformer)  made of a vector of decoding layers with self attention and encoder cross-attention.
90///     caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
91pub struct PegasusModel {
92    pub(crate) encoder: PegasusEncoder,
93    decoder: PegasusDecoder,
94    pub(crate) embeddings: nn::Embedding,
95}
96
97impl PegasusModel {
98    /// Build a new `PegasusModel`
99    ///
100    /// # Arguments
101    ///
102    /// * `p` - Variable store path for the root of the Pegasus model
103    /// * `config` - `PegasusConfig` object defining the model architecture
104    ///
105    /// # Example
106    ///
107    /// ```no_run
108    /// use rust_bert::pegasus::{PegasusConfig, PegasusModel};
109    /// use rust_bert::Config;
110    /// use std::path::Path;
111    /// use tch::{nn, Device};
112    ///
113    /// let config_path = Path::new("path/to/config.json");
114    /// let device = Device::Cpu;
115    /// let p = nn::VarStore::new(device);
116    /// let config = PegasusConfig::from_file(config_path);
117    /// let pegasus: PegasusModel = PegasusModel::new(&p.root() / "pegasus", &config);
118    /// ```
119    pub fn new<'p, P>(p: P, config: &PegasusConfig) -> PegasusModel
120    where
121        P: Borrow<nn::Path<'p>>,
122    {
123        let p = p.borrow();
124
125        let pad_token_id = config.pad_token_id.unwrap_or(0);
126        let embedding_config = EmbeddingConfig {
127            padding_idx: pad_token_id,
128            ..Default::default()
129        };
130        let embeddings: nn::Embedding = embedding(
131            p / "shared",
132            config.vocab_size,
133            config.d_model,
134            embedding_config,
135        );
136
137        let encoder = PegasusEncoder::new(p / "encoder", config);
138        let decoder = PegasusDecoder::new(p / "decoder", config);
139
140        PegasusModel {
141            encoder,
142            decoder,
143            embeddings,
144        }
145    }
146
147    /// Forward pass through the model
148    ///
149    /// # Arguments
150    ///
151    /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
152    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
153    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
154    /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
155    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
156    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
157    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
158    ///
159    /// # Returns
160    ///
161    /// * `PegasusModelOutput` containing:
162    ///   - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
163    ///   - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None
164    ///   - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
165    ///   - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
166    ///   - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
167    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
168    ///   - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
169    ///
170    /// # Example
171    ///
172    /// ```no_run
173    /// # use tch::{nn, Device, Tensor, no_grad};
174    /// # use rust_bert::Config;
175    /// # use std::path::Path;
176    /// # use tch::kind::Kind::{Int64, Double};
177    /// use rust_bert::pegasus::{PegasusConfig, PegasusModel};
178    /// # let config_path = Path::new("path/to/config.json");
179    /// # let vocab_path = Path::new("path/to/vocab.txt");
180    /// # let device = Device::Cpu;
181    /// # let vs = nn::VarStore::new(device);
182    /// # let config = PegasusConfig::from_file(config_path);
183    /// # let pegasus_model: PegasusModel = PegasusModel::new(&vs.root(), &config);
184    /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
185    /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
186    /// let decoder_input_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
187    /// let encoder_attention_mask =
188    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
189    /// let decoder_attention_mask =
190    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
191    ///
192    /// let model_output = no_grad(|| {
193    ///     pegasus_model.forward_t(
194    ///         Some(&input_tensor),
195    ///         Some(&encoder_attention_mask),
196    ///         &decoder_input_tensor,
197    ///         None,
198    ///         Some(&decoder_attention_mask),
199    ///         None,
200    ///         false,
201    ///     )
202    /// });
203    /// ```
204    pub fn forward_t(
205        &self,
206        input_ids: Option<&Tensor>,
207        attention_mask: Option<&Tensor>,
208        decoder_input_ids: &Tensor,
209        encoder_output: Option<&Tensor>,
210        decoder_attention_mask: Option<&Tensor>,
211        layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
212        train: bool,
213    ) -> PegasusModelOutput {
214        let calc_encoder_output = if encoder_output.is_none() {
215            Some(self.encoder.forward_t(
216                input_ids.unwrap(),
217                attention_mask,
218                &self.embeddings,
219                train,
220            ))
221        } else {
222            None
223        };
224
225        let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
226            if let Some(calc_encoder_output) = calc_encoder_output {
227                (
228                    Some(calc_encoder_output.hidden_state),
229                    calc_encoder_output.all_hidden_states,
230                    calc_encoder_output.all_attentions,
231                )
232            } else {
233                (None, None, None)
234            };
235
236        let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
237
238        let decoder_output = self.decoder.forward_t(
239            decoder_input_ids,
240            encoder_output,
241            attention_mask,
242            decoder_attention_mask,
243            &self.embeddings,
244            layer_states,
245            train,
246        );
247        PegasusModelOutput {
248            decoder_output: decoder_output.hidden_state,
249            encoder_hidden_state: calc_hidden_states,
250            cache: decoder_output.next_decoder_cache,
251            all_decoder_hidden_states: decoder_output.all_hidden_states,
252            all_decoder_attentions: decoder_output.all_attentions,
253            all_encoder_hidden_states,
254            all_encoder_attentions,
255        }
256    }
257}
258
259/// # Pegasus Model for conditional generation
260/// Pegasus model with a vocabulary decoding head
261/// It is made of the following blocks:
262/// - `base_model`: `PegasusModel` Base Pegasus model
263pub struct PegasusForConditionalGeneration {
264    base_model: PegasusModel,
265    final_logits_bias: Tensor,
266    pad_token_id: i64,
267    decoder_start_token_id: i64,
268}
269
270impl PegasusForConditionalGeneration {
271    /// Build a new `PegasusForConditionalGeneration`
272    ///
273    /// # Arguments
274    ///
275    /// * `p` - Variable store path for the root of the BART model
276    /// * `config` - `PegasusConfig` object defining the model architecture
277    ///
278    /// # Example
279    ///
280    /// ```no_run
281    /// use rust_bert::pegasus::{PegasusConfig, PegasusForConditionalGeneration};
282    /// use rust_bert::Config;
283    /// use std::path::Path;
284    /// use tch::{nn, Device};
285    ///
286    /// let config_path = Path::new("path/to/config.json");
287    /// let device = Device::Cpu;
288    /// let p = nn::VarStore::new(device);
289    /// let config = PegasusConfig::from_file(config_path);
290    /// let pegasus: PegasusForConditionalGeneration =
291    ///     PegasusForConditionalGeneration::new(&p.root(), &config);
292    /// ```
293    pub fn new<'p, P>(p: P, config: &PegasusConfig) -> PegasusForConditionalGeneration
294    where
295        P: Borrow<nn::Path<'p>>,
296    {
297        let p = p.borrow();
298
299        let base_model = PegasusModel::new(p / "model", config);
300
301        let final_logits_bias = p.var(
302            "final_logits_bias",
303            &[1, config.vocab_size],
304            Init::Const(0.0),
305        );
306
307        let pad_token_id = config.pad_token_id.unwrap_or(0);
308        let decoder_start_token_id = config.decoder_start_token_id.unwrap_or(0);
309
310        PegasusForConditionalGeneration {
311            base_model,
312            final_logits_bias,
313            pad_token_id,
314            decoder_start_token_id,
315        }
316    }
317
318    /// Forward pass through the model
319    ///
320    /// # Arguments
321    ///
322    /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
323    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
324    /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
325    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
326    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
327    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
328    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
329    ///
330    /// # Returns
331    ///
332    /// * `PegasusModelOutput` containing:
333    ///   - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
334    ///   - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
335    ///   - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
336    ///   - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
337    ///   - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
338    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
339    ///   - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
340    ///
341    /// # Example
342    ///
343    /// ```no_run
344    /// # use tch::{nn, Device, Tensor, no_grad};
345    /// # use rust_bert::Config;
346    /// # use std::path::Path;
347    /// # use tch::kind::Kind::{Int64, Double};
348    /// use rust_bert::pegasus::{PegasusConfig, PegasusForConditionalGeneration};
349    /// # let config_path = Path::new("path/to/config.json");
350    /// # let vocab_path = Path::new("path/to/vocab.txt");
351    /// # let device = Device::Cpu;
352    /// # let vs = nn::VarStore::new(device);
353    /// # let config = PegasusConfig::from_file(config_path);
354    /// # let pegasus_model: PegasusForConditionalGeneration = PegasusForConditionalGeneration::new(&vs.root(), &config);
355    ///  let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
356    ///  let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
357    ///  let decoder_input_ids = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
358    ///  let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
359    ///  let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
360    ///
361    ///  let model_output = no_grad(|| {
362    ///    pegasus_model
363    ///         .forward_t(Some(&input_tensor),
364    ///                    Some(&encoder_attention_mask),
365    ///                    None,
366    ///                    Some(&decoder_input_ids),
367    ///                    Some(&decoder_attention_mask),
368    ///                    None,
369    ///                    false)
370    ///    });
371    /// ```
372    pub fn forward_t(
373        &self,
374        input_ids: Option<&Tensor>,
375        attention_mask: Option<&Tensor>,
376        encoder_output: Option<&Tensor>,
377        decoder_input_ids: Option<&Tensor>,
378        decoder_attention_mask: Option<&Tensor>,
379        old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
380        train: bool,
381    ) -> PegasusModelOutput {
382        let calc_decoder_input_ids = if decoder_input_ids.is_none() {
383            Some(_shift_tokens_right(
384                input_ids.unwrap(),
385                self.pad_token_id,
386                self.decoder_start_token_id,
387            ))
388        } else {
389            None
390        };
391
392        let decoder_input_ids =
393            decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap());
394
395        let base_model_output = self.base_model.forward_t(
396            input_ids,
397            attention_mask,
398            decoder_input_ids,
399            encoder_output,
400            decoder_attention_mask,
401            old_layer_states,
402            train,
403        );
404
405        let lm_logits = base_model_output
406            .decoder_output
407            .linear::<Tensor>(&self.base_model.embeddings.ws, None)
408            + &self.final_logits_bias;
409        PegasusModelOutput {
410            decoder_output: lm_logits,
411            ..base_model_output
412        }
413    }
414
415    pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
416        self.base_model
417            .encoder
418            .forward_t(
419                input_ids,
420                attention_mask,
421                &self.base_model.embeddings,
422                false,
423            )
424            .hidden_state
425    }
426}
427
428/// # Language generation model based on the Pegasus architecture
429pub struct PegasusConditionalGenerator {
430    model: PegasusForConditionalGeneration,
431    tokenizer: TokenizerOption,
432    var_store: nn::VarStore,
433    generate_config: GenerateConfig,
434    bos_token_id: Option<i64>,
435    eos_token_ids: Option<Vec<i64>>,
436    pad_token_id: Option<i64>,
437    is_encoder_decoder: bool,
438    vocab_size: i64,
439    decoder_start_id: Option<i64>,
440    max_position_embeddings: i64,
441}
442
443impl PegasusConditionalGenerator {
444    /// Build a new `PegasusGenerator`
445    ///
446    /// # Arguments
447    ///
448    /// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
449    /// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
450    /// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
451    /// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
452    ///
453    /// # Example
454    ///
455    /// ```no_run
456    /// # use std::path::PathBuf;
457    /// # use tch::Device;
458    /// # fn main() -> anyhow::Result<()> {
459    /// use rust_bert::pegasus::PegasusConditionalGenerator;
460    /// use rust_bert::pipelines::generation_utils::GenerateConfig;
461    /// # let mut home: PathBuf = dirs::home_dir().unwrap();
462    /// # home.push("rustbert");
463    /// # home.push("pegasus-cnn_dailymail");
464    /// # let config_path = &home.as_path().join("config.json");
465    /// # let vocab_path = &home.as_path().join("spiece.model");
466    /// # let weights_path = &home.as_path().join("model.ot");
467    /// let device = Device::cuda_if_available();
468    /// let generate_config = GenerateConfig {
469    ///     max_length: Some(30),
470    ///     do_sample: true,
471    ///     num_beams: 5,
472    ///     temperature: 1.1,
473    ///     num_return_sequences: 3,
474    ///     ..Default::default()
475    /// };
476    /// let pegasus_generator = PegasusConditionalGenerator::new(generate_config)?;
477    /// # Ok(())
478    /// # }
479    /// ```
480    pub fn new(
481        generate_config: GenerateConfig,
482    ) -> Result<PegasusConditionalGenerator, RustBertError> {
483        let vocab_path = generate_config.vocab_resource.get_local_path()?;
484
485        let tokenizer = TokenizerOption::from_file(
486            ModelType::Pegasus,
487            vocab_path.to_str().unwrap(),
488            None,
489            false,
490            None,
491            None,
492        )?;
493
494        Self::new_with_tokenizer(generate_config, tokenizer)
495    }
496
497    pub fn new_with_tokenizer(
498        generate_config: GenerateConfig,
499        tokenizer: TokenizerOption,
500    ) -> Result<PegasusConditionalGenerator, RustBertError> {
501        let config_path = generate_config.config_resource.get_local_path()?;
502        let device = generate_config.device;
503
504        generate_config.validate();
505        let mut var_store = nn::VarStore::new(device);
506        let config = PegasusConfig::from_file(config_path);
507        let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
508        crate::resources::load_weights(
509            &generate_config.model_resource,
510            &mut var_store,
511            generate_config.kind,
512            device,
513        )?;
514
515        let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
516        let eos_token_ids = config
517            .eos_token_id
518            .map_or(Some(vec![1]), |value| Some(vec![value]));
519        let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
520        let vocab_size = config.vocab_size;
521        let is_encoder_decoder = true;
522        let decoder_start_id = config.decoder_start_token_id.or(Some(0));
523        let max_position_embeddings = config.max_position_embeddings;
524
525        Ok(PegasusConditionalGenerator {
526            model,
527            tokenizer,
528            var_store,
529            generate_config,
530            bos_token_id,
531            eos_token_ids,
532            pad_token_id,
533            is_encoder_decoder,
534            vocab_size,
535            decoder_start_id,
536            max_position_embeddings,
537        })
538    }
539}
540
541impl PrivateLanguageGenerator for PegasusConditionalGenerator {
542    fn _get_tokenizer(&self) -> &TokenizerOption {
543        &self.tokenizer
544    }
545    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
546        &mut self.tokenizer
547    }
548    fn get_device(&self) -> Device {
549        self.var_store.device()
550    }
551    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
552        Ok(&mut self.var_store)
553    }
554    fn get_config(&self) -> &GenerateConfig {
555        &self.generate_config
556    }
557    fn get_bos_id(&self) -> Option<i64> {
558        self.bos_token_id
559    }
560    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
561        self.eos_token_ids.as_ref()
562    }
563    fn get_pad_id(&self) -> Option<i64> {
564        self.pad_token_id
565    }
566    fn is_encoder_decoder(&self) -> bool {
567        self.is_encoder_decoder
568    }
569    fn get_vocab_size(&self) -> i64 {
570        self.vocab_size
571    }
572    fn get_decoder_start_id(&self) -> Option<i64> {
573        self.decoder_start_id
574    }
575    fn get_max_positions_embeddings(&self) -> Option<i64> {
576        Some(self.max_position_embeddings)
577    }
578
579    fn forward_t(
580        &self,
581        input_ids: Option<&Tensor>,
582        cache: Cache,
583        attention_mask: Option<&Tensor>,
584        _token_type_ids: Option<&Tensor>,
585        _position_ids: Option<&Tensor>,
586        _input_embeds: Option<&Tensor>,
587        encoder_outputs: Option<&Tensor>,
588        decoder_input_ids: Option<&Tensor>,
589        train: bool,
590    ) -> Result<LMModelOutput, RustBertError> {
591        let base_model_output = match cache {
592            Cache::BARTCache(cached_layer_states) => self.model.forward_t(
593                input_ids,
594                attention_mask,
595                encoder_outputs,
596                decoder_input_ids,
597                None,
598                cached_layer_states,
599                train,
600            ),
601            Cache::None => self.model.forward_t(
602                input_ids,
603                attention_mask,
604                encoder_outputs,
605                decoder_input_ids,
606                None,
607                None,
608                train,
609            ),
610            _ => {
611                return Err(RustBertError::ValueError(
612                    "Cache not compatible with Pegasus Model".into(),
613                ));
614            }
615        };
616
617        Ok(LMModelOutput {
618            lm_logits: base_model_output.decoder_output,
619            cache: Cache::BARTCache(base_model_output.cache),
620        })
621    }
622
623    fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
624        Some(self.model.encode(input_ids, attention_mask))
625    }
626
627    fn prepare_inputs_for_generation<'a>(
628        &self,
629        input_ids: Tensor,
630        encoder_outputs: Option<&'a Tensor>,
631        past: Cache,
632        attention_mask: Tensor,
633    ) -> PreparedInput<'a> {
634        match past {
635            Cache::BARTCache(past) => PreparedInput {
636                prepared_input: None,
637                prepared_attention_mask: Some(attention_mask),
638                prepared_encoder_output: encoder_outputs,
639                prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
640                prepared_position_ids: None,
641                prepared_past: Cache::BARTCache(past),
642            },
643            Cache::None => PreparedInput {
644                prepared_input: None,
645                prepared_attention_mask: Some(attention_mask),
646                prepared_encoder_output: encoder_outputs,
647                prepared_decoder_input: Some(input_ids),
648                prepared_position_ids: None,
649                prepared_past: Cache::BARTCache(None),
650            },
651            _ => panic!("Cache type incompatible with Pegasus"),
652        }
653    }
654
655    fn reorder_cache(
656        &self,
657        past: &mut Cache,
658        encoder_outputs: Option<Tensor>,
659        beam_indices: &Tensor,
660    ) -> Option<Tensor> {
661        let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
662        match past {
663            Cache::BARTCache(old_cache_option) => match old_cache_option {
664                Some(old_cache) => {
665                    for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
666                        if self_layer_state.is_some() {
667                            self_layer_state
668                                .as_mut()
669                                .unwrap()
670                                .reorder_cache(beam_indices)
671                        };
672                        if encoder_layer_state.is_some() {
673                            encoder_layer_state
674                                .as_mut()
675                                .unwrap()
676                                .reorder_cache(beam_indices)
677                        };
678                    }
679                }
680                None => {}
681            },
682            Cache::None => {}
683            _ => {
684                panic!("Invalid cache for Pegasus model");
685            }
686        };
687        encoder_outputs
688    }
689}
690
691impl LanguageGenerator for PegasusConditionalGenerator {}
692
693/// Container holding a Pegasus model output. The decoder output may hold the hidden state of
694/// the last layer of the decoder, or may hold logits for a custom head module after the
695/// decoder (e.g. for classification or language modeling tasks)
696pub type PegasusModelOutput = BartModelOutput;