rust_bert/models/prophetnet/
prophetnet_model.rs

1// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use std::borrow::Borrow;
14use std::collections::HashMap;
15
16use serde::{Deserialize, Serialize};
17use tch::{nn, Device, Kind, Tensor};
18
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21    PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::prophetnet::attention::LayerState;
25use crate::prophetnet::decoder::ProphetNetDecoder;
26use crate::prophetnet::encoder::ProphetNetEncoder;
27use crate::{Activation, Config, RustBertError};
28
29/// # ProphetNet Pretrained model weight files
30pub struct ProphetNetModelResources;
31
32/// # ProphetNet Pretrained model config files
33pub struct ProphetNetConfigResources;
34
35/// # ProphetNet Pretrained model vocab files
36pub struct ProphetNetVocabResources;
37
38impl ProphetNetModelResources {
39    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
40    pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
41        "prophetnet-large-uncased/model",
42        "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/rust_model.ot",
43    );
44    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
45    pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
46        "prophetnet-large-uncased-cnndm/model",
47        "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/rust_model.ot",
48    );
49}
50
51impl ProphetNetConfigResources {
52    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
53    pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
54        "prophetnet-large-uncased/config",
55        "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json",
56    );
57    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
58    pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
59        "prophetnet-large-uncased-cnndm/config",
60        "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/config.json",
61    );
62}
63
64impl ProphetNetVocabResources {
65    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
66    pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
67        "prophetnet-large-uncased/vocab",
68        "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
69    );
70    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
71    pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
72        "prophetnet-large-uncased-cnndm/vocab",
73        "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/prophetnet.tokenizer",
74    );
75}
76
77#[derive(Debug, Serialize, Deserialize, Clone)]
78/// # ProphetNet model configuration
79/// Defines the ProphetNet model architecture (e.g. number of layers, hidden layer size, label mapping...)
80pub struct ProphetNetConfig {
81    pub activation_function: Activation,
82    pub activation_dropout: f64,
83    pub attention_dropout: f64,
84    pub decoder_ffn_dim: i64,
85    pub decoder_start_token_id: Option<i64>,
86    pub disable_ngram_loss: bool,
87    pub dropout: f64,
88    pub encoder_ffn_dim: i64,
89    pub eps: f64,
90    pub hidden_size: i64,
91    pub init_std: f64,
92    pub is_encoder_decoder: bool,
93    pub max_position_embeddings: i64,
94    pub bos_token_id: i64,
95    pub eos_token_id: i64,
96    pub forced_bos_token_id: Option<i64>,
97    pub forced_eos_token_id: Option<i64>,
98    pub ngram: i64,
99    pub id2label: Option<HashMap<i64, String>>,
100    pub label2id: Option<HashMap<String, i64>>,
101    pub num_buckets: i64,
102    pub num_decoder_attention_heads: i64,
103    pub num_decoder_layers: i64,
104    pub num_encoder_attention_heads: i64,
105    pub num_encoder_layers: i64,
106    pub output_past: Option<bool>,
107    pub pad_token_id: i64,
108    pub relative_max_distance: i64,
109    pub vocab_size: i64,
110    pub output_attentions: Option<bool>,
111    pub output_hidden_states: Option<bool>,
112    pub add_cross_attention: Option<bool>,
113}
114
115impl Config for ProphetNetConfig {}
116
117impl Default for ProphetNetConfig {
118    fn default() -> Self {
119        ProphetNetConfig {
120            activation_function: Activation::gelu,
121            activation_dropout: 0.1,
122            attention_dropout: 0.1,
123            decoder_ffn_dim: 4096,
124            decoder_start_token_id: Some(0),
125            disable_ngram_loss: false,
126            dropout: 0.1,
127            encoder_ffn_dim: 4096,
128            eps: 0.0,
129            hidden_size: 1024,
130            init_std: 0.02,
131            is_encoder_decoder: false,
132            max_position_embeddings: 512,
133            bos_token_id: 1,
134            eos_token_id: 2,
135            forced_bos_token_id: None,
136            forced_eos_token_id: None,
137            ngram: 2,
138            id2label: None,
139            label2id: None,
140            num_buckets: 32,
141            num_decoder_attention_heads: 16,
142            num_decoder_layers: 12,
143            num_encoder_attention_heads: 16,
144            num_encoder_layers: 12,
145            output_past: None,
146            pad_token_id: 0,
147            relative_max_distance: 128,
148            vocab_size: 30522,
149            output_attentions: None,
150            output_hidden_states: None,
151            add_cross_attention: Some(true),
152        }
153    }
154}
155
156/// # ProphetNet Base model
157/// Base architecture for ProphetNet models. Task-specific models will be built from this common base model
158/// It is made of the following blocks:
159/// - `word_embeddings`: Word embeddings
160/// - `encoder`: ProphetNetEncoder
161/// - `decoder`: ProphetNetDecoder
162pub struct ProphetNetModel {
163    pub(crate) word_embeddings: nn::Embedding,
164    pub(crate) encoder: ProphetNetEncoder,
165    decoder: ProphetNetDecoder,
166}
167
168impl ProphetNetModel {
169    /// Build a new `ProphetNetModel`
170    ///
171    /// # Arguments
172    ///
173    /// * `p` - Variable store path for the root of the ProphetNet model
174    /// * `config` - `ProphetNetConfig` object defining the model architecture
175    ///
176    /// # Example
177    ///
178    /// ```no_run
179    /// use rust_bert::prophetnet::{ProphetNetConfig, ProphetNetModel};
180    /// use rust_bert::Config;
181    /// use std::path::Path;
182    /// use tch::{nn, Device};
183    ///
184    /// let config_path = Path::new("path/to/config.json");
185    /// let device = Device::Cpu;
186    /// let p = nn::VarStore::new(device);
187    /// let config = ProphetNetConfig::from_file(config_path);
188    /// let prophetnet_model = ProphetNetModel::new(&p.root(), &config);
189    /// ```
190    pub fn new<'p, P>(p: P, config: &ProphetNetConfig) -> Result<ProphetNetModel, RustBertError>
191    where
192        P: Borrow<nn::Path<'p>>,
193    {
194        let p = p.borrow();
195
196        let word_embeddings_config = nn::EmbeddingConfig {
197            padding_idx: config.pad_token_id,
198            ..Default::default()
199        };
200        let word_embeddings = nn::embedding(
201            p / "word_embeddings",
202            config.vocab_size,
203            config.hidden_size,
204            word_embeddings_config,
205        );
206
207        let encoder = ProphetNetEncoder::new(p / "encoder", config)?;
208        let decoder = ProphetNetDecoder::new(p / "decoder", config)?;
209
210        Ok(ProphetNetModel {
211            word_embeddings,
212            encoder,
213            decoder,
214        })
215    }
216
217    /// Forward pass through the model
218    ///
219    /// # Arguments
220    ///
221    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
222    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
223    /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
224    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
225    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
226    /// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation)
227    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
228    /// * `old_layer_states` - Optional Vector `Option<Vec<Option<&LayerState>, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
229    /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
230    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
231    ///
232    /// # Returns
233    ///
234    /// * `ProphetNetOutput` containing:
235    ///   - `last_hidden_states` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last hidden state for the decoder
236    ///   - `ngram_hidden_states` - `Tensor` of shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last hidden state for the decoder ngram stream
237    ///   - `next_decoder_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers with shape (*past_sequence_length*, *batch size*, *hidden_size*)
238    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
239    ///   - `all_ngram_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
240    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
241    ///   - `all_ngram_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
242    ///   - `all_cross_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
243    ///
244    /// # Example
245    ///
246    /// ```no_run
247    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
248    /// # use rust_bert::Config;
249    /// # use std::path::Path;
250    /// # use tch::kind::Kind::{Int64, Double};
251    /// use rust_bert::prophetnet::{ProphetNetModel, ProphetNetConfig};
252    /// # let config_path = Path::new("path/to/config.json");
253    /// # let vocab_path = Path::new("path/to/vocab.txt");
254    /// # let device = Device::Cpu;
255    /// # let vs = nn::VarStore::new(device);
256    /// # let config = ProphetNetConfig::from_file(config_path);
257    /// # let prophetnet_model: ProphetNetModel = ProphetNetModel::new(&vs.root(), &config).unwrap();
258    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
259    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
260    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
261    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
262    /// let decoder_input_ids = Tensor::ones(&[batch_size, target_sequence_length], (Kind::Float, device));
263    ///
264    /// let model_output = no_grad(|| {
265    ///     prophetnet_model.forward_t(
266    ///         Some(&input_tensor),
267    ///         Some(&attention_mask),
268    ///         None,
269    ///         Some(&decoder_input_ids),
270    ///         None,
271    ///         None,
272    ///         None,
273    ///         None,
274    ///         false
275    ///     )
276    /// });
277    /// ```
278    pub fn forward_t(
279        &self,
280        input_ids: Option<&Tensor>,
281        attention_mask: Option<&Tensor>,
282        input_embeds: Option<&Tensor>,
283        decoder_input_ids: Option<&Tensor>,
284        decoder_attention_mask: Option<&Tensor>,
285        encoder_hidden_states: Option<&Tensor>,
286        old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
287        decoder_input_embeds: Option<&Tensor>,
288        train: bool,
289    ) -> Result<ProphetNetOutput, RustBertError> {
290        let calc_encoder_hidden_states = if encoder_hidden_states.is_none() {
291            Some(
292                self.encoder
293                    .forward_t(
294                        input_ids,
295                        attention_mask,
296                        input_embeds,
297                        Some(&self.word_embeddings),
298                        train,
299                    )?
300                    .hidden_states,
301            )
302        } else {
303            None
304        };
305        let encoder_hidden_states =
306            encoder_hidden_states.unwrap_or_else(|| calc_encoder_hidden_states.as_ref().unwrap());
307
308        let decoder_output = self.decoder.forward_t(
309            decoder_input_ids,
310            decoder_attention_mask,
311            encoder_hidden_states.into(),
312            decoder_attention_mask,
313            old_layer_states,
314            decoder_input_embeds,
315            Some(&self.word_embeddings),
316            train,
317        )?;
318
319        Ok(ProphetNetOutput {
320            last_hidden_states: decoder_output.hidden_states,
321            ngram_hidden_states: decoder_output.ngram_hidden_states,
322            all_decoder_hidden_states: decoder_output.all_hidden_states,
323            all_ngram_hidden_states: decoder_output.all_ngram_hidden_states,
324            all_attentions: decoder_output.all_attentions,
325            all_ngram_attentions: decoder_output.all_ngram_attentions,
326            all_cross_attentions: decoder_output.all_cross_attentions,
327            next_decoder_cache: decoder_output.next_decoder_cache,
328        })
329    }
330}
331
332/// # ProphetNet Model for conditional generation
333/// ProphetNet model with a vocabulary decoding head
334/// It is made of the following blocks:
335/// - `base_model`: `ProphetNetModel` Base ProphetNet model
336/// - `lm_head`: Linear layer without bias to project the hidden states to the vocabulary
337pub struct ProphetNetForConditionalGeneration {
338    base_model: ProphetNetModel,
339    lm_head: nn::Linear,
340    decoder_start_token_id: i64,
341    pad_token_id: i64,
342    ngram: i64,
343}
344
345impl ProphetNetForConditionalGeneration {
346    /// Build a new `ProphetNetForConditionalGeneration`
347    ///
348    /// # Arguments
349    ///
350    /// * `p` - Variable store path for the root of the ProphetNet model
351    /// * `config` - `ProphetNetConfig` object defining the model architecture
352    ///
353    /// # Example
354    ///
355    /// ```no_run
356    /// use rust_bert::prophetnet::{ProphetNetConfig, ProphetNetForConditionalGeneration};
357    /// use rust_bert::Config;
358    /// use std::path::Path;
359    /// use tch::{nn, Device};
360    ///
361    /// let config_path = Path::new("path/to/config.json");
362    /// let device = Device::Cpu;
363    /// let p = nn::VarStore::new(device);
364    /// let config = ProphetNetConfig::from_file(config_path);
365    /// let prophetnet_model = ProphetNetForConditionalGeneration::new(&p.root(), &config);
366    /// ```
367    pub fn new<'p, P>(
368        p: P,
369        config: &ProphetNetConfig,
370    ) -> Result<ProphetNetForConditionalGeneration, RustBertError>
371    where
372        P: Borrow<nn::Path<'p>>,
373    {
374        let p = p.borrow();
375        let base_model = ProphetNetModel::new(p / "prophetnet", config)?;
376        let linear_config = nn::LinearConfig {
377            bias: false,
378            ..Default::default()
379        };
380        let lm_head = nn::linear(
381            p / "lm_head",
382            config.hidden_size,
383            config.vocab_size,
384            linear_config,
385        );
386
387        let decoder_start_token_id = config.decoder_start_token_id.ok_or_else(|| {
388            RustBertError::InvalidConfigurationError(
389                "`decoder_start_token_id` must be provided for ProphetNet models".to_string(),
390            )
391        })?;
392        let pad_token_id = config.pad_token_id;
393        let ngram = config.ngram;
394
395        Ok(ProphetNetForConditionalGeneration {
396            base_model,
397            lm_head,
398            decoder_start_token_id,
399            pad_token_id,
400            ngram,
401        })
402    }
403
404    fn shift_right(&self, input_ids: &Tensor) -> Tensor {
405        let shifted_input_ids = Tensor::zeros(
406            input_ids.size().as_slice(),
407            (Kind::Int64, input_ids.device()),
408        );
409
410        shifted_input_ids
411            .slice(-1, 1, *shifted_input_ids.size().last().unwrap(), 1)
412            .copy_(&input_ids.slice(-1, 0, *input_ids.size().last().unwrap() - 1, 1));
413
414        let _ = shifted_input_ids
415            .get(-1)
416            .get(0)
417            .fill_(self.decoder_start_token_id);
418
419        let _ = shifted_input_ids.masked_fill(&shifted_input_ids.eq(-100), self.pad_token_id);
420
421        shifted_input_ids
422    }
423
424    /// Forward pass through the model
425    ///
426    /// # Arguments
427    ///
428    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
429    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
430    /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
431    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
432    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
433    /// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation)
434    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
435    /// * `old_layer_states` - Optional Vector `Option<Vec<Option<&LayerState>, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
436    /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
437    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
438    ///
439    /// # Returns
440    ///
441    /// * `ProphetNetGenerationOutput` containing:
442    ///   - `logits` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder
443    ///   - `ngram_logits` - `Tensor` of shape (*ngram*, *batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder ngram stream
444    ///   - `next_decoder_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers with shape (*past_sequence_length*, *batch size*, *hidden_size*)
445    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
446    ///   - `all_ngram_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
447    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
448    ///   - `all_ngram_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
449    ///   - `all_cross_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
450    ///
451    /// # Example
452    ///
453    /// ```no_run
454    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
455    /// # use rust_bert::Config;
456    /// # use std::path::Path;
457    /// # use tch::kind::Kind::{Int64, Double};
458    /// use rust_bert::prophetnet::{ProphetNetModel, ProphetNetConfig, ProphetNetForConditionalGeneration};
459    /// # let config_path = Path::new("path/to/config.json");
460    /// # let vocab_path = Path::new("path/to/vocab.txt");
461    /// # let device = Device::Cpu;
462    /// # let vs = nn::VarStore::new(device);
463    /// # let config = ProphetNetConfig::from_file(config_path);
464    /// # let prophetnet_model: ProphetNetForConditionalGeneration = ProphetNetForConditionalGeneration::new(&vs.root(), &config).unwrap();
465    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
466    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
467    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
468    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
469    /// let decoder_input_ids = Tensor::ones(&[batch_size, target_sequence_length], (Kind::Float, device));
470    ///
471    /// let model_output = no_grad(|| {
472    ///     prophetnet_model.forward_t(
473    ///         Some(&input_tensor),
474    ///         Some(&attention_mask),
475    ///         None,
476    ///         Some(&decoder_input_ids),
477    ///         None,
478    ///         None,
479    ///         None,
480    ///         None,
481    ///         false
482    ///     )
483    /// });
484    /// ```
485    pub fn forward_t(
486        &self,
487        input_ids: Option<&Tensor>,
488        attention_mask: Option<&Tensor>,
489        input_embeds: Option<&Tensor>,
490        decoder_input_ids: Option<&Tensor>,
491        decoder_attention_mask: Option<&Tensor>,
492        encoder_hidden_states: Option<&Tensor>,
493        old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
494        decoder_input_embeds: Option<&Tensor>,
495        train: bool,
496    ) -> Result<ProphetNetGenerationOutput, RustBertError> {
497        let calc_decoder_input_ids = if decoder_input_ids.is_none() & decoder_input_embeds.is_none()
498        {
499            if let Some(input_ids) = input_ids {
500                Some(self.shift_right(input_ids))
501            } else {
502                return Err(RustBertError::ValueError("input_ids must be provided if decoder_input_ids and decoder_input_embeds are not given.".into()));
503            }
504        } else {
505            None
506        };
507
508        let decoder_input_ids = if decoder_input_ids.is_some() {
509            decoder_input_ids
510        } else {
511            Some(calc_decoder_input_ids.as_ref().unwrap())
512        };
513
514        let base_model_output = self.base_model.forward_t(
515            input_ids,
516            attention_mask,
517            input_embeds,
518            decoder_input_ids,
519            decoder_attention_mask,
520            encoder_hidden_states,
521            old_layer_states,
522            decoder_input_embeds,
523            train,
524        )?;
525
526        let (batch_size, sequence_length) = if let Some(decoder_input_ids) = decoder_input_ids {
527            let shape = decoder_input_ids.size();
528            (shape[0], shape[1])
529        } else if let Some(decoder_input_embeds) = decoder_input_embeds {
530            let shape = decoder_input_embeds.size();
531            (shape[0], shape[1])
532        } else {
533            return Err(RustBertError::ValueError(
534                "At least one of decoder_input_ids or decoder_input_embeds must be set".into(),
535            ));
536        };
537
538        if base_model_output.ngram_hidden_states.is_none() {
539            return Err(RustBertError::InvalidConfigurationError(
540                "ngram must be set > 0 in the configuration for conditional generation".into(),
541            ));
542        }
543
544        let predict_logits = base_model_output
545            .ngram_hidden_states
546            .as_ref()
547            .unwrap()
548            .view([batch_size, self.ngram, sequence_length, -1])
549            .apply(&self.lm_head);
550
551        let logits = predict_logits.select(1, 0).contiguous();
552
553        let ngram_logits = if self.ngram > 1 {
554            Some(predict_logits.slice(1, 1, predict_logits.size()[1], 1))
555        } else {
556            None
557        };
558
559        Ok(ProphetNetGenerationOutput {
560            logits,
561            ngram_logits,
562            ngram_hidden_states: base_model_output.ngram_hidden_states,
563            all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
564            all_ngram_hidden_states: base_model_output.all_ngram_hidden_states,
565            all_attentions: base_model_output.all_attentions,
566            all_ngram_attentions: base_model_output.all_ngram_attentions,
567            all_cross_attentions: base_model_output.all_cross_attentions,
568            next_decoder_cache: base_model_output.next_decoder_cache,
569        })
570    }
571
572    pub fn encode(
573        &self,
574        input_ids: Option<&Tensor>,
575        attention_mask: Option<&Tensor>,
576        input_embeds: Option<&Tensor>,
577    ) -> Result<Tensor, RustBertError> {
578        Ok(self
579            .base_model
580            .encoder
581            .forward_t(
582                input_ids,
583                attention_mask,
584                input_embeds,
585                Some(&self.base_model.word_embeddings),
586                false,
587            )?
588            .hidden_states)
589    }
590}
591
592/// # ProphetNet Model for causal generation
593/// ProphetNet decoder with a vocabulary decoding head
594/// It is made of the following blocks:
595/// - `base_model`: `ProphetNetDecoder` Base ProphetNet decoder
596/// - `word_embeddings`: word embeddings used by the decoder
597/// - `lm_head`: Linear layer without bias to project the hidden states to the vocabulary
598pub struct ProphetNetForCausalGeneration {
599    decoder: ProphetNetDecoder,
600    word_embeddings: nn::Embedding,
601    lm_head: nn::Linear,
602    ngram: i64,
603}
604
605impl ProphetNetForCausalGeneration {
606    /// Build a new `ProphetNetForCausalGeneration`
607    ///
608    /// # Arguments
609    ///
610    /// * `p` - Variable store path for the root of the ProphetNet model
611    /// * `config` - `ProphetNetConfig` object defining the model architecture
612    ///
613    /// # Example
614    ///
615    /// ```no_run
616    /// use rust_bert::prophetnet::{ProphetNetConfig, ProphetNetForCausalGeneration};
617    /// use rust_bert::Config;
618    /// use std::path::Path;
619    /// use tch::{nn, Device};
620    ///
621    /// let config_path = Path::new("path/to/config.json");
622    /// let device = Device::Cpu;
623    /// let p = nn::VarStore::new(device);
624    /// let config = ProphetNetConfig::from_file(config_path);
625    /// let prophetnet_model = ProphetNetForCausalGeneration::new(&p.root(), &config);
626    /// ```
627    pub fn new<'p, P>(
628        p: P,
629        config: &ProphetNetConfig,
630    ) -> Result<ProphetNetForCausalGeneration, RustBertError>
631    where
632        P: Borrow<nn::Path<'p>>,
633    {
634        let p = p.borrow();
635        let mut updated_config = config.clone();
636        updated_config.is_encoder_decoder = false;
637
638        let p_prophetnet = p / "prophetnet";
639        let decoder = ProphetNetDecoder::new(&p_prophetnet / "decoder", &updated_config)?;
640        let linear_config = nn::LinearConfig {
641            bias: false,
642            ..Default::default()
643        };
644
645        let word_embeddings_config = nn::EmbeddingConfig {
646            padding_idx: config.pad_token_id,
647            ..Default::default()
648        };
649        let p_decoder = &p_prophetnet / "decoder";
650        let word_embeddings = nn::embedding(
651            &p_decoder / "word_embeddings",
652            config.vocab_size,
653            config.hidden_size,
654            word_embeddings_config,
655        );
656
657        let lm_head = nn::linear(
658            p / "lm_head",
659            config.hidden_size,
660            config.vocab_size,
661            linear_config,
662        );
663
664        let ngram = config.ngram;
665
666        Ok(ProphetNetForCausalGeneration {
667            decoder,
668            word_embeddings,
669            lm_head,
670            ngram,
671        })
672    }
673
674    /// Forward pass through the model
675    ///
676    /// # Arguments
677    ///
678    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
679    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
680    /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
681    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
682    /// * `old_layer_states` - Optional Vector `Option<Vec<Option<&LayerState>, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
683    /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
684    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
685    ///
686    /// # Returns
687    ///
688    /// * `ProphetNetGenerationOutput` containing:
689    ///   - `logits` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder
690    ///   - `ngram_logits` - `Tensor` of shape (*ngram*, *batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder ngram stream
691    ///   - `next_decoder_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers with shape (*past_sequence_length*, *batch size*, *hidden_size*)
692    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
693    ///   - `all_ngram_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
694    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
695    ///   - `all_ngram_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
696    ///   - `all_cross_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
697    ///
698    /// # Example
699    ///
700    /// ```no_run
701    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
702    /// # use rust_bert::Config;
703    /// # use std::path::Path;
704    /// # use tch::kind::Kind::{Int64, Double};
705    /// use rust_bert::prophetnet::{ProphetNetModel, ProphetNetConfig, ProphetNetForCausalGeneration};
706    /// # let config_path = Path::new("path/to/config.json");
707    /// # let vocab_path = Path::new("path/to/vocab.txt");
708    /// # let device = Device::Cpu;
709    /// # let vs = nn::VarStore::new(device);
710    /// # let config = ProphetNetConfig::from_file(config_path);
711    /// # let prophetnet_model: ProphetNetForCausalGeneration = ProphetNetForCausalGeneration::new(&vs.root(), &config).unwrap();
712    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
713    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
714    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
715    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
716    /// let decoder_input_ids = Tensor::ones(&[batch_size, target_sequence_length], (Kind::Float, device));
717    ///
718    /// let model_output = no_grad(|| {
719    ///     prophetnet_model.forward_t(
720    ///         Some(&input_tensor),
721    ///         Some(&attention_mask),
722    ///         None,
723    ///         Some(&decoder_input_ids),
724    ///         None,
725    ///         None,
726    ///         false
727    ///     )
728    /// });
729    /// ```
730    pub fn forward_t(
731        &self,
732        input_ids: Option<&Tensor>,
733        attention_mask: Option<&Tensor>,
734        input_embeds: Option<&Tensor>,
735        encoder_hidden_states: Option<&Tensor>,
736        encoder_attention_mask: Option<&Tensor>,
737        old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
738        train: bool,
739    ) -> Result<ProphetNetGenerationOutput, RustBertError> {
740        let base_model_output = self.decoder.forward_t(
741            input_ids,
742            attention_mask,
743            encoder_hidden_states,
744            encoder_attention_mask,
745            old_layer_states,
746            input_embeds,
747            Some(&self.word_embeddings),
748            train,
749        )?;
750
751        let (batch_size, sequence_length) = if let Some(input_ids) = input_ids {
752            let shape = input_ids.size();
753            (shape[0], shape[1])
754        } else if let Some(input_embeds) = input_embeds {
755            let shape = input_embeds.size();
756            (shape[0], shape[1])
757        } else {
758            return Err(RustBertError::ValueError(
759                "At least one of input_ids or input_embeds must be set".into(),
760            ));
761        };
762
763        if base_model_output.ngram_hidden_states.is_none() {
764            return Err(RustBertError::InvalidConfigurationError(
765                "ngram must be set > 0 in the configuration for conditional generation".into(),
766            ));
767        }
768
769        let predict_logits = base_model_output
770            .ngram_hidden_states
771            .as_ref()
772            .unwrap()
773            .view([batch_size, self.ngram, sequence_length, -1])
774            .apply(&self.lm_head);
775
776        let logits = predict_logits.select(1, 0).contiguous();
777
778        let ngram_logits = if self.ngram > 1 {
779            Some(predict_logits.slice(1, 1, predict_logits.size()[1], 1))
780        } else {
781            None
782        };
783
784        Ok(ProphetNetGenerationOutput {
785            logits,
786            ngram_logits,
787            ngram_hidden_states: base_model_output.ngram_hidden_states,
788            all_decoder_hidden_states: base_model_output.all_hidden_states,
789            all_ngram_hidden_states: base_model_output.all_ngram_hidden_states,
790            all_attentions: base_model_output.all_attentions,
791            all_ngram_attentions: base_model_output.all_ngram_attentions,
792            all_cross_attentions: base_model_output.all_cross_attentions,
793            next_decoder_cache: base_model_output.next_decoder_cache,
794        })
795    }
796}
797
798///Container holding a ProphetNet model output
799pub struct ProphetNetOutput {
800    /// last decoder layer hidden state
801    pub last_hidden_states: Tensor,
802    /// last decoder layer ngram hidden state
803    pub ngram_hidden_states: Option<Tensor>,
804    /// Hidden states for all intermediate layers
805    pub all_decoder_hidden_states: Option<Vec<Tensor>>,
806    /// Hidden states (ngram) for all intermediate layers
807    pub all_ngram_hidden_states: Option<Vec<Tensor>>,
808    /// Attention weights for all intermediate layers
809    pub all_attentions: Option<Vec<Tensor>>,
810    /// Ngram attention weights for all intermediate layers
811    pub all_ngram_attentions: Option<Vec<Tensor>>,
812    /// Cross attention weights for all intermediate layers
813    pub all_cross_attentions: Option<Vec<Tensor>>,
814    /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
815    pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
816}
817
818///Container holding a ProphetNet model generation output
819pub struct ProphetNetGenerationOutput {
820    /// Prediction logits
821    pub logits: Tensor,
822    /// Ngram prediction logits
823    pub ngram_logits: Option<Tensor>,
824    /// last decoder layer ngram hidden state
825    pub ngram_hidden_states: Option<Tensor>,
826    /// Hidden states for all intermediate layers
827    pub all_decoder_hidden_states: Option<Vec<Tensor>>,
828    /// Hidden states (ngram) for all intermediate layers
829    pub all_ngram_hidden_states: Option<Vec<Tensor>>,
830    /// Attention weights for all intermediate layers
831    pub all_attentions: Option<Vec<Tensor>>,
832    /// Ngram attention weights for all intermediate layers
833    pub all_ngram_attentions: Option<Vec<Tensor>>,
834    /// Cross attention weights for all intermediate layers
835    pub all_cross_attentions: Option<Vec<Tensor>>,
836    /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
837    pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
838}
839
840/// # Language generation model based on the ProphetNet architecture
841pub struct ProphetNetConditionalGenerator {
842    model: ProphetNetForConditionalGeneration,
843    tokenizer: TokenizerOption,
844    var_store: nn::VarStore,
845    generate_config: GenerateConfig,
846    bos_token_id: Option<i64>,
847    eos_token_ids: Option<Vec<i64>>,
848    pad_token_id: Option<i64>,
849    is_encoder_decoder: bool,
850    vocab_size: i64,
851    decoder_start_id: Option<i64>,
852    max_position_embeddings: i64,
853}
854
855impl ProphetNetConditionalGenerator {
856    /// Build a new `ProphetNetConditionalGenerator`
857    ///
858    /// # Arguments
859    ///
860    /// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
861    /// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
862    /// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
863    /// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
864    /// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
865    ///
866    /// # Example
867    ///
868    /// ```no_run
869    /// # use std::path::PathBuf;
870    /// # use tch::Device;
871    /// # fn main() -> anyhow::Result<()> {
872    /// use rust_bert::pipelines::generation_utils::GenerateConfig;
873    /// use rust_bert::prophetnet::ProphetNetConditionalGenerator;
874    /// # let mut home: PathBuf = dirs::home_dir().unwrap();
875    /// # home.push("rustbert");
876    /// # home.push("prophetnet");
877    /// # let config_path = &home.as_path().join("config.json");
878    /// # let vocab_path = &home.as_path().join("vocab.txt");
879    /// # let merges_path = &home.as_path().join("merges.txt");
880    /// # let weights_path = &home.as_path().join("model.ot");
881    /// let device = Device::cuda_if_available();
882    /// let generate_config = GenerateConfig {
883    ///     max_length: Some(30),
884    ///     do_sample: true,
885    ///     num_beams: 5,
886    ///     temperature: 1.1,
887    ///     num_return_sequences: 3,
888    ///     ..Default::default()
889    /// };
890    /// let prophetnet_generator = ProphetNetConditionalGenerator::new(generate_config)?;
891    /// # Ok(())
892    /// # }
893    /// ```
894    pub fn new(
895        generate_config: GenerateConfig,
896    ) -> Result<ProphetNetConditionalGenerator, RustBertError> {
897        let vocab_path = generate_config.vocab_resource.get_local_path()?;
898
899        let tokenizer = TokenizerOption::from_file(
900            ModelType::ProphetNet,
901            vocab_path.to_str().unwrap(),
902            None,
903            true,
904            true,
905            None,
906        )?;
907
908        Self::new_with_tokenizer(generate_config, tokenizer)
909    }
910
911    pub fn new_with_tokenizer(
912        generate_config: GenerateConfig,
913        tokenizer: TokenizerOption,
914    ) -> Result<ProphetNetConditionalGenerator, RustBertError> {
915        let config_path = generate_config.config_resource.get_local_path()?;
916        let device = generate_config.device;
917
918        generate_config.validate();
919        let mut var_store = nn::VarStore::new(device);
920        let config = ProphetNetConfig::from_file(config_path);
921        let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
922        crate::resources::load_weights(
923            &generate_config.model_resource,
924            &mut var_store,
925            generate_config.kind,
926            device,
927        )?;
928
929        let bos_token_id = Some(config.bos_token_id);
930        let eos_token_ids = Some(vec![config.eos_token_id]);
931        let pad_token_id = Some(config.pad_token_id);
932        let vocab_size = config.vocab_size;
933        let is_encoder_decoder = true;
934        let decoder_start_id = config.decoder_start_token_id;
935        let max_position_embeddings = config.max_position_embeddings;
936
937        Ok(ProphetNetConditionalGenerator {
938            model,
939            tokenizer,
940            var_store,
941            generate_config,
942            bos_token_id,
943            eos_token_ids,
944            pad_token_id,
945            is_encoder_decoder,
946            vocab_size,
947            decoder_start_id,
948            max_position_embeddings,
949        })
950    }
951}
952
953impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
954    fn _get_tokenizer(&self) -> &TokenizerOption {
955        &self.tokenizer
956    }
957    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
958        &mut self.tokenizer
959    }
960    fn get_device(&self) -> Device {
961        self.var_store.device()
962    }
963    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
964        Ok(&mut self.var_store)
965    }
966    fn get_config(&self) -> &GenerateConfig {
967        &self.generate_config
968    }
969    fn get_bos_id(&self) -> Option<i64> {
970        self.bos_token_id
971    }
972    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
973        self.eos_token_ids.as_ref()
974    }
975    fn get_pad_id(&self) -> Option<i64> {
976        self.pad_token_id
977    }
978    fn is_encoder_decoder(&self) -> bool {
979        self.is_encoder_decoder
980    }
981    fn get_vocab_size(&self) -> i64 {
982        self.vocab_size
983    }
984    fn get_decoder_start_id(&self) -> Option<i64> {
985        self.decoder_start_id
986    }
987    fn get_max_positions_embeddings(&self) -> Option<i64> {
988        Some(self.max_position_embeddings)
989    }
990
991    fn forward_t(
992        &self,
993        input_ids: Option<&Tensor>,
994        cache: Cache,
995        attention_mask: Option<&Tensor>,
996        _token_type_ids: Option<&Tensor>,
997        _position_ids: Option<&Tensor>,
998        input_embeds: Option<&Tensor>,
999        encoder_outputs: Option<&Tensor>,
1000        decoder_input_ids: Option<&Tensor>,
1001        train: bool,
1002    ) -> Result<LMModelOutput, RustBertError> {
1003        let base_model_output = match cache {
1004            Cache::ProphetNetCache(cached_layer_states) => self.model.forward_t(
1005                input_ids,
1006                attention_mask,
1007                input_embeds,
1008                decoder_input_ids,
1009                None,
1010                encoder_outputs,
1011                cached_layer_states,
1012                None,
1013                train,
1014            )?,
1015            Cache::None => self.model.forward_t(
1016                input_ids,
1017                attention_mask,
1018                input_embeds,
1019                decoder_input_ids,
1020                None,
1021                encoder_outputs,
1022                None,
1023                None,
1024                train,
1025            )?,
1026            _ => {
1027                return Err(RustBertError::ValueError(
1028                    "Cache not compatible with ProphetNet Model".into(),
1029                ));
1030            }
1031        };
1032
1033        Ok(LMModelOutput {
1034            lm_logits: base_model_output.logits,
1035            cache: Cache::ProphetNetCache(base_model_output.next_decoder_cache),
1036        })
1037    }
1038
1039    fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
1040        Some(
1041            self.model
1042                .encode(Some(input_ids), attention_mask, None)
1043                .unwrap(),
1044        )
1045    }
1046
1047    fn prepare_inputs_for_generation<'a>(
1048        &self,
1049        input_ids: Tensor,
1050        encoder_outputs: Option<&'a Tensor>,
1051        past: Cache,
1052        attention_mask: Tensor,
1053    ) -> PreparedInput<'a> {
1054        match past {
1055            Cache::ProphetNetCache(past) => PreparedInput {
1056                prepared_input: None,
1057                prepared_attention_mask: Some(attention_mask),
1058                prepared_encoder_output: encoder_outputs,
1059                prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
1060                prepared_position_ids: None,
1061                prepared_past: Cache::ProphetNetCache(past),
1062            },
1063            Cache::None => PreparedInput {
1064                prepared_input: None,
1065                prepared_attention_mask: Some(attention_mask),
1066                prepared_encoder_output: encoder_outputs,
1067                prepared_decoder_input: Some(input_ids),
1068                prepared_position_ids: None,
1069                prepared_past: Cache::ProphetNetCache(None),
1070            },
1071            _ => panic!("Cache type incompatible with ProphetNet"),
1072        }
1073    }
1074
1075    fn reorder_cache(
1076        &self,
1077        past: &mut Cache,
1078        encoder_outputs: Option<Tensor>,
1079        beam_indices: &Tensor,
1080    ) -> Option<Tensor> {
1081        let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
1082        match past {
1083            Cache::ProphetNetCache(old_cache_option) => match old_cache_option {
1084                Some(old_cache) => {
1085                    for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
1086                        if self_layer_state.is_some() {
1087                            self_layer_state
1088                                .as_mut()
1089                                .unwrap()
1090                                .reorder_cache(beam_indices)
1091                        };
1092                        if encoder_layer_state.is_some() {
1093                            encoder_layer_state
1094                                .as_mut()
1095                                .unwrap()
1096                                .reorder_cache(beam_indices)
1097                        };
1098                    }
1099                }
1100                None => {}
1101            },
1102            Cache::None => {}
1103            _ => {
1104                panic!("Invalid cache for ProphetNet model");
1105            }
1106        };
1107        encoder_outputs
1108    }
1109}
1110
1111impl LanguageGenerator for ProphetNetConditionalGenerator {}