rust_bert/models/pegasus/pegasus_model.rs
1// Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
2// Copyright 2021 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::bart::BartModelOutput;
14use crate::mbart::MBartConfig;
15use crate::pegasus::decoder::PegasusDecoder;
16use crate::pegasus::encoder::PegasusEncoder;
17use crate::pegasus::LayerState;
18use crate::pipelines::common::{ModelType, TokenizerOption};
19use crate::pipelines::generation_utils::private_generation_utils::{
20 PreparedInput, PrivateLanguageGenerator,
21};
22use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
23use crate::{Config, RustBertError};
24use std::borrow::Borrow;
25use tch::nn::{embedding, EmbeddingConfig, Init};
26use tch::{nn, Device, Tensor};
27
28/// # Pegasus Pretrained model weight files
29pub struct PegasusModelResources;
30
31/// # Pegasus Pretrained model config files
32pub struct PegasusConfigResources;
33
34/// # Pegasus Pretrained model vocab files
35pub struct PegasusVocabResources;
36
37impl PegasusModelResources {
38 /// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>. Modified with conversion to C-array format.
39 pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
40 "pegasus-cnn_dailymail/model",
41 "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/rust_model.ot",
42 );
43}
44
45impl PegasusConfigResources {
46 /// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>.
47 pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
48 "pegasus-cnn_dailymail/config",
49 "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/config.json",
50 );
51}
52
53impl PegasusVocabResources {
54 /// Shared under Apache 2.0 license by the Pegasus team at <https://huggingface.co/google/pegasus-cnn_dailymail>.
55 pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
56 "pegasus-cnn_dailymail/spiece",
57 "https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/spiece.model",
58 );
59}
60
61/// # Pegasus model configuration
62/// Defines the Pegasus model architecture (e.g. number of layers, hidden layer size, label mapping...)
63pub type PegasusConfig = MBartConfig;
64
65fn _shift_tokens_right(
66 input_ids: &Tensor,
67 pad_token_id: i64,
68 decoder_start_token_id: i64,
69) -> Tensor {
70 let input_ids_length = input_ids.size()[1];
71 let mut shifted_input_ids = Tensor::zeros(
72 input_ids.size().as_slice(),
73 (input_ids.kind(), input_ids.device()),
74 );
75 shifted_input_ids
76 .slice(1, 1, input_ids_length, 1)
77 .copy_(&input_ids.slice(1, 0, input_ids_length - 1, 1));
78
79 let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id);
80 let _ = shifted_input_ids.masked_fill_(&shifted_input_ids.eq(-100), pad_token_id);
81
82 shifted_input_ids
83}
84
85/// # Pegasus Base model
86/// Base architecture for Pegasus model. Usually complemented with a task-specific head, such as a language model head.
87/// It is made of the following blocks:
88/// - `encoder`: `PegasusEncoder` (transformer) made of a vector of encoding layers
89/// - `decoder`: `PegasusDecoder` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention.
90/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
91pub struct PegasusModel {
92 pub(crate) encoder: PegasusEncoder,
93 decoder: PegasusDecoder,
94 pub(crate) embeddings: nn::Embedding,
95}
96
97impl PegasusModel {
98 /// Build a new `PegasusModel`
99 ///
100 /// # Arguments
101 ///
102 /// * `p` - Variable store path for the root of the Pegasus model
103 /// * `config` - `PegasusConfig` object defining the model architecture
104 ///
105 /// # Example
106 ///
107 /// ```no_run
108 /// use rust_bert::pegasus::{PegasusConfig, PegasusModel};
109 /// use rust_bert::Config;
110 /// use std::path::Path;
111 /// use tch::{nn, Device};
112 ///
113 /// let config_path = Path::new("path/to/config.json");
114 /// let device = Device::Cpu;
115 /// let p = nn::VarStore::new(device);
116 /// let config = PegasusConfig::from_file(config_path);
117 /// let pegasus: PegasusModel = PegasusModel::new(&p.root() / "pegasus", &config);
118 /// ```
119 pub fn new<'p, P>(p: P, config: &PegasusConfig) -> PegasusModel
120 where
121 P: Borrow<nn::Path<'p>>,
122 {
123 let p = p.borrow();
124
125 let pad_token_id = config.pad_token_id.unwrap_or(0);
126 let embedding_config = EmbeddingConfig {
127 padding_idx: pad_token_id,
128 ..Default::default()
129 };
130 let embeddings: nn::Embedding = embedding(
131 p / "shared",
132 config.vocab_size,
133 config.d_model,
134 embedding_config,
135 );
136
137 let encoder = PegasusEncoder::new(p / "encoder", config);
138 let decoder = PegasusDecoder::new(p / "decoder", config);
139
140 PegasusModel {
141 encoder,
142 decoder,
143 embeddings,
144 }
145 }
146
147 /// Forward pass through the model
148 ///
149 /// # Arguments
150 ///
151 /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
152 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
153 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
154 /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
155 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
156 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
157 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
158 ///
159 /// # Returns
160 ///
161 /// * `PegasusModelOutput` containing:
162 /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
163 /// - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None
164 /// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
165 /// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
166 /// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
167 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
168 /// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
169 ///
170 /// # Example
171 ///
172 /// ```no_run
173 /// # use tch::{nn, Device, Tensor, no_grad};
174 /// # use rust_bert::Config;
175 /// # use std::path::Path;
176 /// # use tch::kind::Kind::{Int64, Double};
177 /// use rust_bert::pegasus::{PegasusConfig, PegasusModel};
178 /// # let config_path = Path::new("path/to/config.json");
179 /// # let vocab_path = Path::new("path/to/vocab.txt");
180 /// # let device = Device::Cpu;
181 /// # let vs = nn::VarStore::new(device);
182 /// # let config = PegasusConfig::from_file(config_path);
183 /// # let pegasus_model: PegasusModel = PegasusModel::new(&vs.root(), &config);
184 /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
185 /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
186 /// let decoder_input_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
187 /// let encoder_attention_mask =
188 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
189 /// let decoder_attention_mask =
190 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
191 ///
192 /// let model_output = no_grad(|| {
193 /// pegasus_model.forward_t(
194 /// Some(&input_tensor),
195 /// Some(&encoder_attention_mask),
196 /// &decoder_input_tensor,
197 /// None,
198 /// Some(&decoder_attention_mask),
199 /// None,
200 /// false,
201 /// )
202 /// });
203 /// ```
204 pub fn forward_t(
205 &self,
206 input_ids: Option<&Tensor>,
207 attention_mask: Option<&Tensor>,
208 decoder_input_ids: &Tensor,
209 encoder_output: Option<&Tensor>,
210 decoder_attention_mask: Option<&Tensor>,
211 layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
212 train: bool,
213 ) -> PegasusModelOutput {
214 let calc_encoder_output = if encoder_output.is_none() {
215 Some(self.encoder.forward_t(
216 input_ids.unwrap(),
217 attention_mask,
218 &self.embeddings,
219 train,
220 ))
221 } else {
222 None
223 };
224
225 let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
226 if let Some(calc_encoder_output) = calc_encoder_output {
227 (
228 Some(calc_encoder_output.hidden_state),
229 calc_encoder_output.all_hidden_states,
230 calc_encoder_output.all_attentions,
231 )
232 } else {
233 (None, None, None)
234 };
235
236 let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
237
238 let decoder_output = self.decoder.forward_t(
239 decoder_input_ids,
240 encoder_output,
241 attention_mask,
242 decoder_attention_mask,
243 &self.embeddings,
244 layer_states,
245 train,
246 );
247 PegasusModelOutput {
248 decoder_output: decoder_output.hidden_state,
249 encoder_hidden_state: calc_hidden_states,
250 cache: decoder_output.next_decoder_cache,
251 all_decoder_hidden_states: decoder_output.all_hidden_states,
252 all_decoder_attentions: decoder_output.all_attentions,
253 all_encoder_hidden_states,
254 all_encoder_attentions,
255 }
256 }
257}
258
259/// # Pegasus Model for conditional generation
260/// Pegasus model with a vocabulary decoding head
261/// It is made of the following blocks:
262/// - `base_model`: `PegasusModel` Base Pegasus model
263pub struct PegasusForConditionalGeneration {
264 base_model: PegasusModel,
265 final_logits_bias: Tensor,
266 pad_token_id: i64,
267 decoder_start_token_id: i64,
268}
269
270impl PegasusForConditionalGeneration {
271 /// Build a new `PegasusForConditionalGeneration`
272 ///
273 /// # Arguments
274 ///
275 /// * `p` - Variable store path for the root of the BART model
276 /// * `config` - `PegasusConfig` object defining the model architecture
277 ///
278 /// # Example
279 ///
280 /// ```no_run
281 /// use rust_bert::pegasus::{PegasusConfig, PegasusForConditionalGeneration};
282 /// use rust_bert::Config;
283 /// use std::path::Path;
284 /// use tch::{nn, Device};
285 ///
286 /// let config_path = Path::new("path/to/config.json");
287 /// let device = Device::Cpu;
288 /// let p = nn::VarStore::new(device);
289 /// let config = PegasusConfig::from_file(config_path);
290 /// let pegasus: PegasusForConditionalGeneration =
291 /// PegasusForConditionalGeneration::new(&p.root(), &config);
292 /// ```
293 pub fn new<'p, P>(p: P, config: &PegasusConfig) -> PegasusForConditionalGeneration
294 where
295 P: Borrow<nn::Path<'p>>,
296 {
297 let p = p.borrow();
298
299 let base_model = PegasusModel::new(p / "model", config);
300
301 let final_logits_bias = p.var(
302 "final_logits_bias",
303 &[1, config.vocab_size],
304 Init::Const(0.0),
305 );
306
307 let pad_token_id = config.pad_token_id.unwrap_or(0);
308 let decoder_start_token_id = config.decoder_start_token_id.unwrap_or(0);
309
310 PegasusForConditionalGeneration {
311 base_model,
312 final_logits_bias,
313 pad_token_id,
314 decoder_start_token_id,
315 }
316 }
317
318 /// Forward pass through the model
319 ///
320 /// # Arguments
321 ///
322 /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
323 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
324 /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
325 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
326 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
327 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
328 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
329 ///
330 /// # Returns
331 ///
332 /// * `PegasusModelOutput` containing:
333 /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
334 /// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
335 /// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
336 /// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
337 /// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
338 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
339 /// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
340 ///
341 /// # Example
342 ///
343 /// ```no_run
344 /// # use tch::{nn, Device, Tensor, no_grad};
345 /// # use rust_bert::Config;
346 /// # use std::path::Path;
347 /// # use tch::kind::Kind::{Int64, Double};
348 /// use rust_bert::pegasus::{PegasusConfig, PegasusForConditionalGeneration};
349 /// # let config_path = Path::new("path/to/config.json");
350 /// # let vocab_path = Path::new("path/to/vocab.txt");
351 /// # let device = Device::Cpu;
352 /// # let vs = nn::VarStore::new(device);
353 /// # let config = PegasusConfig::from_file(config_path);
354 /// # let pegasus_model: PegasusForConditionalGeneration = PegasusForConditionalGeneration::new(&vs.root(), &config);
355 /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
356 /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
357 /// let decoder_input_ids = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
358 /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
359 /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
360 ///
361 /// let model_output = no_grad(|| {
362 /// pegasus_model
363 /// .forward_t(Some(&input_tensor),
364 /// Some(&encoder_attention_mask),
365 /// None,
366 /// Some(&decoder_input_ids),
367 /// Some(&decoder_attention_mask),
368 /// None,
369 /// false)
370 /// });
371 /// ```
372 pub fn forward_t(
373 &self,
374 input_ids: Option<&Tensor>,
375 attention_mask: Option<&Tensor>,
376 encoder_output: Option<&Tensor>,
377 decoder_input_ids: Option<&Tensor>,
378 decoder_attention_mask: Option<&Tensor>,
379 old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
380 train: bool,
381 ) -> PegasusModelOutput {
382 let calc_decoder_input_ids = if decoder_input_ids.is_none() {
383 Some(_shift_tokens_right(
384 input_ids.unwrap(),
385 self.pad_token_id,
386 self.decoder_start_token_id,
387 ))
388 } else {
389 None
390 };
391
392 let decoder_input_ids =
393 decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap());
394
395 let base_model_output = self.base_model.forward_t(
396 input_ids,
397 attention_mask,
398 decoder_input_ids,
399 encoder_output,
400 decoder_attention_mask,
401 old_layer_states,
402 train,
403 );
404
405 let lm_logits = base_model_output
406 .decoder_output
407 .linear::<Tensor>(&self.base_model.embeddings.ws, None)
408 + &self.final_logits_bias;
409 PegasusModelOutput {
410 decoder_output: lm_logits,
411 ..base_model_output
412 }
413 }
414
415 pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
416 self.base_model
417 .encoder
418 .forward_t(
419 input_ids,
420 attention_mask,
421 &self.base_model.embeddings,
422 false,
423 )
424 .hidden_state
425 }
426}
427
428/// # Language generation model based on the Pegasus architecture
429pub struct PegasusConditionalGenerator {
430 model: PegasusForConditionalGeneration,
431 tokenizer: TokenizerOption,
432 var_store: nn::VarStore,
433 generate_config: GenerateConfig,
434 bos_token_id: Option<i64>,
435 eos_token_ids: Option<Vec<i64>>,
436 pad_token_id: Option<i64>,
437 is_encoder_decoder: bool,
438 vocab_size: i64,
439 decoder_start_id: Option<i64>,
440 max_position_embeddings: i64,
441}
442
443impl PegasusConditionalGenerator {
444 /// Build a new `PegasusGenerator`
445 ///
446 /// # Arguments
447 ///
448 /// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
449 /// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
450 /// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
451 /// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
452 ///
453 /// # Example
454 ///
455 /// ```no_run
456 /// # use std::path::PathBuf;
457 /// # use tch::Device;
458 /// # fn main() -> anyhow::Result<()> {
459 /// use rust_bert::pegasus::PegasusConditionalGenerator;
460 /// use rust_bert::pipelines::generation_utils::GenerateConfig;
461 /// # let mut home: PathBuf = dirs::home_dir().unwrap();
462 /// # home.push("rustbert");
463 /// # home.push("pegasus-cnn_dailymail");
464 /// # let config_path = &home.as_path().join("config.json");
465 /// # let vocab_path = &home.as_path().join("spiece.model");
466 /// # let weights_path = &home.as_path().join("model.ot");
467 /// let device = Device::cuda_if_available();
468 /// let generate_config = GenerateConfig {
469 /// max_length: Some(30),
470 /// do_sample: true,
471 /// num_beams: 5,
472 /// temperature: 1.1,
473 /// num_return_sequences: 3,
474 /// ..Default::default()
475 /// };
476 /// let pegasus_generator = PegasusConditionalGenerator::new(generate_config)?;
477 /// # Ok(())
478 /// # }
479 /// ```
480 pub fn new(
481 generate_config: GenerateConfig,
482 ) -> Result<PegasusConditionalGenerator, RustBertError> {
483 let vocab_path = generate_config.vocab_resource.get_local_path()?;
484
485 let tokenizer = TokenizerOption::from_file(
486 ModelType::Pegasus,
487 vocab_path.to_str().unwrap(),
488 None,
489 false,
490 None,
491 None,
492 )?;
493
494 Self::new_with_tokenizer(generate_config, tokenizer)
495 }
496
497 pub fn new_with_tokenizer(
498 generate_config: GenerateConfig,
499 tokenizer: TokenizerOption,
500 ) -> Result<PegasusConditionalGenerator, RustBertError> {
501 let config_path = generate_config.config_resource.get_local_path()?;
502 let device = generate_config.device;
503
504 generate_config.validate();
505 let mut var_store = nn::VarStore::new(device);
506 let config = PegasusConfig::from_file(config_path);
507 let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
508 crate::resources::load_weights(
509 &generate_config.model_resource,
510 &mut var_store,
511 generate_config.kind,
512 device,
513 )?;
514
515 let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
516 let eos_token_ids = config
517 .eos_token_id
518 .map_or(Some(vec![1]), |value| Some(vec![value]));
519 let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
520 let vocab_size = config.vocab_size;
521 let is_encoder_decoder = true;
522 let decoder_start_id = config.decoder_start_token_id.or(Some(0));
523 let max_position_embeddings = config.max_position_embeddings;
524
525 Ok(PegasusConditionalGenerator {
526 model,
527 tokenizer,
528 var_store,
529 generate_config,
530 bos_token_id,
531 eos_token_ids,
532 pad_token_id,
533 is_encoder_decoder,
534 vocab_size,
535 decoder_start_id,
536 max_position_embeddings,
537 })
538 }
539}
540
541impl PrivateLanguageGenerator for PegasusConditionalGenerator {
542 fn _get_tokenizer(&self) -> &TokenizerOption {
543 &self.tokenizer
544 }
545 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
546 &mut self.tokenizer
547 }
548 fn get_device(&self) -> Device {
549 self.var_store.device()
550 }
551 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
552 Ok(&mut self.var_store)
553 }
554 fn get_config(&self) -> &GenerateConfig {
555 &self.generate_config
556 }
557 fn get_bos_id(&self) -> Option<i64> {
558 self.bos_token_id
559 }
560 fn get_eos_ids(&self) -> Option<&Vec<i64>> {
561 self.eos_token_ids.as_ref()
562 }
563 fn get_pad_id(&self) -> Option<i64> {
564 self.pad_token_id
565 }
566 fn is_encoder_decoder(&self) -> bool {
567 self.is_encoder_decoder
568 }
569 fn get_vocab_size(&self) -> i64 {
570 self.vocab_size
571 }
572 fn get_decoder_start_id(&self) -> Option<i64> {
573 self.decoder_start_id
574 }
575 fn get_max_positions_embeddings(&self) -> Option<i64> {
576 Some(self.max_position_embeddings)
577 }
578
579 fn forward_t(
580 &self,
581 input_ids: Option<&Tensor>,
582 cache: Cache,
583 attention_mask: Option<&Tensor>,
584 _token_type_ids: Option<&Tensor>,
585 _position_ids: Option<&Tensor>,
586 _input_embeds: Option<&Tensor>,
587 encoder_outputs: Option<&Tensor>,
588 decoder_input_ids: Option<&Tensor>,
589 train: bool,
590 ) -> Result<LMModelOutput, RustBertError> {
591 let base_model_output = match cache {
592 Cache::BARTCache(cached_layer_states) => self.model.forward_t(
593 input_ids,
594 attention_mask,
595 encoder_outputs,
596 decoder_input_ids,
597 None,
598 cached_layer_states,
599 train,
600 ),
601 Cache::None => self.model.forward_t(
602 input_ids,
603 attention_mask,
604 encoder_outputs,
605 decoder_input_ids,
606 None,
607 None,
608 train,
609 ),
610 _ => {
611 return Err(RustBertError::ValueError(
612 "Cache not compatible with Pegasus Model".into(),
613 ));
614 }
615 };
616
617 Ok(LMModelOutput {
618 lm_logits: base_model_output.decoder_output,
619 cache: Cache::BARTCache(base_model_output.cache),
620 })
621 }
622
623 fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
624 Some(self.model.encode(input_ids, attention_mask))
625 }
626
627 fn prepare_inputs_for_generation<'a>(
628 &self,
629 input_ids: Tensor,
630 encoder_outputs: Option<&'a Tensor>,
631 past: Cache,
632 attention_mask: Tensor,
633 ) -> PreparedInput<'a> {
634 match past {
635 Cache::BARTCache(past) => PreparedInput {
636 prepared_input: None,
637 prepared_attention_mask: Some(attention_mask),
638 prepared_encoder_output: encoder_outputs,
639 prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
640 prepared_position_ids: None,
641 prepared_past: Cache::BARTCache(past),
642 },
643 Cache::None => PreparedInput {
644 prepared_input: None,
645 prepared_attention_mask: Some(attention_mask),
646 prepared_encoder_output: encoder_outputs,
647 prepared_decoder_input: Some(input_ids),
648 prepared_position_ids: None,
649 prepared_past: Cache::BARTCache(None),
650 },
651 _ => panic!("Cache type incompatible with Pegasus"),
652 }
653 }
654
655 fn reorder_cache(
656 &self,
657 past: &mut Cache,
658 encoder_outputs: Option<Tensor>,
659 beam_indices: &Tensor,
660 ) -> Option<Tensor> {
661 let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
662 match past {
663 Cache::BARTCache(old_cache_option) => match old_cache_option {
664 Some(old_cache) => {
665 for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
666 if self_layer_state.is_some() {
667 self_layer_state
668 .as_mut()
669 .unwrap()
670 .reorder_cache(beam_indices)
671 };
672 if encoder_layer_state.is_some() {
673 encoder_layer_state
674 .as_mut()
675 .unwrap()
676 .reorder_cache(beam_indices)
677 };
678 }
679 }
680 None => {}
681 },
682 Cache::None => {}
683 _ => {
684 panic!("Invalid cache for Pegasus model");
685 }
686 };
687 encoder_outputs
688 }
689}
690
691impl LanguageGenerator for PegasusConditionalGenerator {}
692
693/// Container holding a Pegasus model output. The decoder output may hold the hidden state of
694/// the last layer of the decoder, or may hold logits for a custom head module after the
695/// decoder (e.g. for classification or language modeling tasks)
696pub type PegasusModelOutput = BartModelOutput;