rust_bert/models/prophetnet/prophetnet_model.rs
1// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use std::borrow::Borrow;
14use std::collections::HashMap;
15
16use serde::{Deserialize, Serialize};
17use tch::{nn, Device, Kind, Tensor};
18
19use crate::pipelines::common::{ModelType, TokenizerOption};
20use crate::pipelines::generation_utils::private_generation_utils::{
21 PreparedInput, PrivateLanguageGenerator,
22};
23use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
24use crate::prophetnet::attention::LayerState;
25use crate::prophetnet::decoder::ProphetNetDecoder;
26use crate::prophetnet::encoder::ProphetNetEncoder;
27use crate::{Activation, Config, RustBertError};
28
29/// # ProphetNet Pretrained model weight files
30pub struct ProphetNetModelResources;
31
32/// # ProphetNet Pretrained model config files
33pub struct ProphetNetConfigResources;
34
35/// # ProphetNet Pretrained model vocab files
36pub struct ProphetNetVocabResources;
37
38impl ProphetNetModelResources {
39 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
40 pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
41 "prophetnet-large-uncased/model",
42 "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/rust_model.ot",
43 );
44 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
45 pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
46 "prophetnet-large-uncased-cnndm/model",
47 "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/rust_model.ot",
48 );
49}
50
51impl ProphetNetConfigResources {
52 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
53 pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
54 "prophetnet-large-uncased/config",
55 "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json",
56 );
57 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
58 pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
59 "prophetnet-large-uncased-cnndm/config",
60 "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/config.json",
61 );
62}
63
64impl ProphetNetVocabResources {
65 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
66 pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
67 "prophetnet-large-uncased/vocab",
68 "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
69 );
70 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/ProphetNet>. Modified with conversion to C-array format.
71 pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
72 "prophetnet-large-uncased-cnndm/vocab",
73 "https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/prophetnet.tokenizer",
74 );
75}
76
77#[derive(Debug, Serialize, Deserialize, Clone)]
78/// # ProphetNet model configuration
79/// Defines the ProphetNet model architecture (e.g. number of layers, hidden layer size, label mapping...)
80pub struct ProphetNetConfig {
81 pub activation_function: Activation,
82 pub activation_dropout: f64,
83 pub attention_dropout: f64,
84 pub decoder_ffn_dim: i64,
85 pub decoder_start_token_id: Option<i64>,
86 pub disable_ngram_loss: bool,
87 pub dropout: f64,
88 pub encoder_ffn_dim: i64,
89 pub eps: f64,
90 pub hidden_size: i64,
91 pub init_std: f64,
92 pub is_encoder_decoder: bool,
93 pub max_position_embeddings: i64,
94 pub bos_token_id: i64,
95 pub eos_token_id: i64,
96 pub forced_bos_token_id: Option<i64>,
97 pub forced_eos_token_id: Option<i64>,
98 pub ngram: i64,
99 pub id2label: Option<HashMap<i64, String>>,
100 pub label2id: Option<HashMap<String, i64>>,
101 pub num_buckets: i64,
102 pub num_decoder_attention_heads: i64,
103 pub num_decoder_layers: i64,
104 pub num_encoder_attention_heads: i64,
105 pub num_encoder_layers: i64,
106 pub output_past: Option<bool>,
107 pub pad_token_id: i64,
108 pub relative_max_distance: i64,
109 pub vocab_size: i64,
110 pub output_attentions: Option<bool>,
111 pub output_hidden_states: Option<bool>,
112 pub add_cross_attention: Option<bool>,
113}
114
115impl Config for ProphetNetConfig {}
116
117impl Default for ProphetNetConfig {
118 fn default() -> Self {
119 ProphetNetConfig {
120 activation_function: Activation::gelu,
121 activation_dropout: 0.1,
122 attention_dropout: 0.1,
123 decoder_ffn_dim: 4096,
124 decoder_start_token_id: Some(0),
125 disable_ngram_loss: false,
126 dropout: 0.1,
127 encoder_ffn_dim: 4096,
128 eps: 0.0,
129 hidden_size: 1024,
130 init_std: 0.02,
131 is_encoder_decoder: false,
132 max_position_embeddings: 512,
133 bos_token_id: 1,
134 eos_token_id: 2,
135 forced_bos_token_id: None,
136 forced_eos_token_id: None,
137 ngram: 2,
138 id2label: None,
139 label2id: None,
140 num_buckets: 32,
141 num_decoder_attention_heads: 16,
142 num_decoder_layers: 12,
143 num_encoder_attention_heads: 16,
144 num_encoder_layers: 12,
145 output_past: None,
146 pad_token_id: 0,
147 relative_max_distance: 128,
148 vocab_size: 30522,
149 output_attentions: None,
150 output_hidden_states: None,
151 add_cross_attention: Some(true),
152 }
153 }
154}
155
156/// # ProphetNet Base model
157/// Base architecture for ProphetNet models. Task-specific models will be built from this common base model
158/// It is made of the following blocks:
159/// - `word_embeddings`: Word embeddings
160/// - `encoder`: ProphetNetEncoder
161/// - `decoder`: ProphetNetDecoder
162pub struct ProphetNetModel {
163 pub(crate) word_embeddings: nn::Embedding,
164 pub(crate) encoder: ProphetNetEncoder,
165 decoder: ProphetNetDecoder,
166}
167
168impl ProphetNetModel {
169 /// Build a new `ProphetNetModel`
170 ///
171 /// # Arguments
172 ///
173 /// * `p` - Variable store path for the root of the ProphetNet model
174 /// * `config` - `ProphetNetConfig` object defining the model architecture
175 ///
176 /// # Example
177 ///
178 /// ```no_run
179 /// use rust_bert::prophetnet::{ProphetNetConfig, ProphetNetModel};
180 /// use rust_bert::Config;
181 /// use std::path::Path;
182 /// use tch::{nn, Device};
183 ///
184 /// let config_path = Path::new("path/to/config.json");
185 /// let device = Device::Cpu;
186 /// let p = nn::VarStore::new(device);
187 /// let config = ProphetNetConfig::from_file(config_path);
188 /// let prophetnet_model = ProphetNetModel::new(&p.root(), &config);
189 /// ```
190 pub fn new<'p, P>(p: P, config: &ProphetNetConfig) -> Result<ProphetNetModel, RustBertError>
191 where
192 P: Borrow<nn::Path<'p>>,
193 {
194 let p = p.borrow();
195
196 let word_embeddings_config = nn::EmbeddingConfig {
197 padding_idx: config.pad_token_id,
198 ..Default::default()
199 };
200 let word_embeddings = nn::embedding(
201 p / "word_embeddings",
202 config.vocab_size,
203 config.hidden_size,
204 word_embeddings_config,
205 );
206
207 let encoder = ProphetNetEncoder::new(p / "encoder", config)?;
208 let decoder = ProphetNetDecoder::new(p / "decoder", config)?;
209
210 Ok(ProphetNetModel {
211 word_embeddings,
212 encoder,
213 decoder,
214 })
215 }
216
217 /// Forward pass through the model
218 ///
219 /// # Arguments
220 ///
221 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
222 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
223 /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
224 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
225 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
226 /// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation)
227 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
228 /// * `old_layer_states` - Optional Vector `Option<Vec<Option<&LayerState>, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
229 /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
230 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
231 ///
232 /// # Returns
233 ///
234 /// * `ProphetNetOutput` containing:
235 /// - `last_hidden_states` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last hidden state for the decoder
236 /// - `ngram_hidden_states` - `Tensor` of shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last hidden state for the decoder ngram stream
237 /// - `next_decoder_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers with shape (*past_sequence_length*, *batch size*, *hidden_size*)
238 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
239 /// - `all_ngram_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
240 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
241 /// - `all_ngram_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
242 /// - `all_cross_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
243 ///
244 /// # Example
245 ///
246 /// ```no_run
247 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
248 /// # use rust_bert::Config;
249 /// # use std::path::Path;
250 /// # use tch::kind::Kind::{Int64, Double};
251 /// use rust_bert::prophetnet::{ProphetNetModel, ProphetNetConfig};
252 /// # let config_path = Path::new("path/to/config.json");
253 /// # let vocab_path = Path::new("path/to/vocab.txt");
254 /// # let device = Device::Cpu;
255 /// # let vs = nn::VarStore::new(device);
256 /// # let config = ProphetNetConfig::from_file(config_path);
257 /// # let prophetnet_model: ProphetNetModel = ProphetNetModel::new(&vs.root(), &config).unwrap();
258 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
259 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
260 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
261 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
262 /// let decoder_input_ids = Tensor::ones(&[batch_size, target_sequence_length], (Kind::Float, device));
263 ///
264 /// let model_output = no_grad(|| {
265 /// prophetnet_model.forward_t(
266 /// Some(&input_tensor),
267 /// Some(&attention_mask),
268 /// None,
269 /// Some(&decoder_input_ids),
270 /// None,
271 /// None,
272 /// None,
273 /// None,
274 /// false
275 /// )
276 /// });
277 /// ```
278 pub fn forward_t(
279 &self,
280 input_ids: Option<&Tensor>,
281 attention_mask: Option<&Tensor>,
282 input_embeds: Option<&Tensor>,
283 decoder_input_ids: Option<&Tensor>,
284 decoder_attention_mask: Option<&Tensor>,
285 encoder_hidden_states: Option<&Tensor>,
286 old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
287 decoder_input_embeds: Option<&Tensor>,
288 train: bool,
289 ) -> Result<ProphetNetOutput, RustBertError> {
290 let calc_encoder_hidden_states = if encoder_hidden_states.is_none() {
291 Some(
292 self.encoder
293 .forward_t(
294 input_ids,
295 attention_mask,
296 input_embeds,
297 Some(&self.word_embeddings),
298 train,
299 )?
300 .hidden_states,
301 )
302 } else {
303 None
304 };
305 let encoder_hidden_states =
306 encoder_hidden_states.unwrap_or_else(|| calc_encoder_hidden_states.as_ref().unwrap());
307
308 let decoder_output = self.decoder.forward_t(
309 decoder_input_ids,
310 decoder_attention_mask,
311 encoder_hidden_states.into(),
312 decoder_attention_mask,
313 old_layer_states,
314 decoder_input_embeds,
315 Some(&self.word_embeddings),
316 train,
317 )?;
318
319 Ok(ProphetNetOutput {
320 last_hidden_states: decoder_output.hidden_states,
321 ngram_hidden_states: decoder_output.ngram_hidden_states,
322 all_decoder_hidden_states: decoder_output.all_hidden_states,
323 all_ngram_hidden_states: decoder_output.all_ngram_hidden_states,
324 all_attentions: decoder_output.all_attentions,
325 all_ngram_attentions: decoder_output.all_ngram_attentions,
326 all_cross_attentions: decoder_output.all_cross_attentions,
327 next_decoder_cache: decoder_output.next_decoder_cache,
328 })
329 }
330}
331
332/// # ProphetNet Model for conditional generation
333/// ProphetNet model with a vocabulary decoding head
334/// It is made of the following blocks:
335/// - `base_model`: `ProphetNetModel` Base ProphetNet model
336/// - `lm_head`: Linear layer without bias to project the hidden states to the vocabulary
337pub struct ProphetNetForConditionalGeneration {
338 base_model: ProphetNetModel,
339 lm_head: nn::Linear,
340 decoder_start_token_id: i64,
341 pad_token_id: i64,
342 ngram: i64,
343}
344
345impl ProphetNetForConditionalGeneration {
346 /// Build a new `ProphetNetForConditionalGeneration`
347 ///
348 /// # Arguments
349 ///
350 /// * `p` - Variable store path for the root of the ProphetNet model
351 /// * `config` - `ProphetNetConfig` object defining the model architecture
352 ///
353 /// # Example
354 ///
355 /// ```no_run
356 /// use rust_bert::prophetnet::{ProphetNetConfig, ProphetNetForConditionalGeneration};
357 /// use rust_bert::Config;
358 /// use std::path::Path;
359 /// use tch::{nn, Device};
360 ///
361 /// let config_path = Path::new("path/to/config.json");
362 /// let device = Device::Cpu;
363 /// let p = nn::VarStore::new(device);
364 /// let config = ProphetNetConfig::from_file(config_path);
365 /// let prophetnet_model = ProphetNetForConditionalGeneration::new(&p.root(), &config);
366 /// ```
367 pub fn new<'p, P>(
368 p: P,
369 config: &ProphetNetConfig,
370 ) -> Result<ProphetNetForConditionalGeneration, RustBertError>
371 where
372 P: Borrow<nn::Path<'p>>,
373 {
374 let p = p.borrow();
375 let base_model = ProphetNetModel::new(p / "prophetnet", config)?;
376 let linear_config = nn::LinearConfig {
377 bias: false,
378 ..Default::default()
379 };
380 let lm_head = nn::linear(
381 p / "lm_head",
382 config.hidden_size,
383 config.vocab_size,
384 linear_config,
385 );
386
387 let decoder_start_token_id = config.decoder_start_token_id.ok_or_else(|| {
388 RustBertError::InvalidConfigurationError(
389 "`decoder_start_token_id` must be provided for ProphetNet models".to_string(),
390 )
391 })?;
392 let pad_token_id = config.pad_token_id;
393 let ngram = config.ngram;
394
395 Ok(ProphetNetForConditionalGeneration {
396 base_model,
397 lm_head,
398 decoder_start_token_id,
399 pad_token_id,
400 ngram,
401 })
402 }
403
404 fn shift_right(&self, input_ids: &Tensor) -> Tensor {
405 let shifted_input_ids = Tensor::zeros(
406 input_ids.size().as_slice(),
407 (Kind::Int64, input_ids.device()),
408 );
409
410 shifted_input_ids
411 .slice(-1, 1, *shifted_input_ids.size().last().unwrap(), 1)
412 .copy_(&input_ids.slice(-1, 0, *input_ids.size().last().unwrap() - 1, 1));
413
414 let _ = shifted_input_ids
415 .get(-1)
416 .get(0)
417 .fill_(self.decoder_start_token_id);
418
419 let _ = shifted_input_ids.masked_fill(&shifted_input_ids.eq(-100), self.pad_token_id);
420
421 shifted_input_ids
422 }
423
424 /// Forward pass through the model
425 ///
426 /// # Arguments
427 ///
428 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
429 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
430 /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
431 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
432 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
433 /// * `encoder_hidden_states` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) corresponding to pre-calculated encoder hidden states (useful for conditional generation)
434 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
435 /// * `old_layer_states` - Optional Vector `Option<Vec<Option<&LayerState>, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
436 /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
437 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
438 ///
439 /// # Returns
440 ///
441 /// * `ProphetNetGenerationOutput` containing:
442 /// - `logits` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder
443 /// - `ngram_logits` - `Tensor` of shape (*ngram*, *batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder ngram stream
444 /// - `next_decoder_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers with shape (*past_sequence_length*, *batch size*, *hidden_size*)
445 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
446 /// - `all_ngram_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
447 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
448 /// - `all_ngram_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
449 /// - `all_cross_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
450 ///
451 /// # Example
452 ///
453 /// ```no_run
454 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
455 /// # use rust_bert::Config;
456 /// # use std::path::Path;
457 /// # use tch::kind::Kind::{Int64, Double};
458 /// use rust_bert::prophetnet::{ProphetNetModel, ProphetNetConfig, ProphetNetForConditionalGeneration};
459 /// # let config_path = Path::new("path/to/config.json");
460 /// # let vocab_path = Path::new("path/to/vocab.txt");
461 /// # let device = Device::Cpu;
462 /// # let vs = nn::VarStore::new(device);
463 /// # let config = ProphetNetConfig::from_file(config_path);
464 /// # let prophetnet_model: ProphetNetForConditionalGeneration = ProphetNetForConditionalGeneration::new(&vs.root(), &config).unwrap();
465 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
466 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
467 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
468 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
469 /// let decoder_input_ids = Tensor::ones(&[batch_size, target_sequence_length], (Kind::Float, device));
470 ///
471 /// let model_output = no_grad(|| {
472 /// prophetnet_model.forward_t(
473 /// Some(&input_tensor),
474 /// Some(&attention_mask),
475 /// None,
476 /// Some(&decoder_input_ids),
477 /// None,
478 /// None,
479 /// None,
480 /// None,
481 /// false
482 /// )
483 /// });
484 /// ```
485 pub fn forward_t(
486 &self,
487 input_ids: Option<&Tensor>,
488 attention_mask: Option<&Tensor>,
489 input_embeds: Option<&Tensor>,
490 decoder_input_ids: Option<&Tensor>,
491 decoder_attention_mask: Option<&Tensor>,
492 encoder_hidden_states: Option<&Tensor>,
493 old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
494 decoder_input_embeds: Option<&Tensor>,
495 train: bool,
496 ) -> Result<ProphetNetGenerationOutput, RustBertError> {
497 let calc_decoder_input_ids = if decoder_input_ids.is_none() & decoder_input_embeds.is_none()
498 {
499 if let Some(input_ids) = input_ids {
500 Some(self.shift_right(input_ids))
501 } else {
502 return Err(RustBertError::ValueError("input_ids must be provided if decoder_input_ids and decoder_input_embeds are not given.".into()));
503 }
504 } else {
505 None
506 };
507
508 let decoder_input_ids = if decoder_input_ids.is_some() {
509 decoder_input_ids
510 } else {
511 Some(calc_decoder_input_ids.as_ref().unwrap())
512 };
513
514 let base_model_output = self.base_model.forward_t(
515 input_ids,
516 attention_mask,
517 input_embeds,
518 decoder_input_ids,
519 decoder_attention_mask,
520 encoder_hidden_states,
521 old_layer_states,
522 decoder_input_embeds,
523 train,
524 )?;
525
526 let (batch_size, sequence_length) = if let Some(decoder_input_ids) = decoder_input_ids {
527 let shape = decoder_input_ids.size();
528 (shape[0], shape[1])
529 } else if let Some(decoder_input_embeds) = decoder_input_embeds {
530 let shape = decoder_input_embeds.size();
531 (shape[0], shape[1])
532 } else {
533 return Err(RustBertError::ValueError(
534 "At least one of decoder_input_ids or decoder_input_embeds must be set".into(),
535 ));
536 };
537
538 if base_model_output.ngram_hidden_states.is_none() {
539 return Err(RustBertError::InvalidConfigurationError(
540 "ngram must be set > 0 in the configuration for conditional generation".into(),
541 ));
542 }
543
544 let predict_logits = base_model_output
545 .ngram_hidden_states
546 .as_ref()
547 .unwrap()
548 .view([batch_size, self.ngram, sequence_length, -1])
549 .apply(&self.lm_head);
550
551 let logits = predict_logits.select(1, 0).contiguous();
552
553 let ngram_logits = if self.ngram > 1 {
554 Some(predict_logits.slice(1, 1, predict_logits.size()[1], 1))
555 } else {
556 None
557 };
558
559 Ok(ProphetNetGenerationOutput {
560 logits,
561 ngram_logits,
562 ngram_hidden_states: base_model_output.ngram_hidden_states,
563 all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
564 all_ngram_hidden_states: base_model_output.all_ngram_hidden_states,
565 all_attentions: base_model_output.all_attentions,
566 all_ngram_attentions: base_model_output.all_ngram_attentions,
567 all_cross_attentions: base_model_output.all_cross_attentions,
568 next_decoder_cache: base_model_output.next_decoder_cache,
569 })
570 }
571
572 pub fn encode(
573 &self,
574 input_ids: Option<&Tensor>,
575 attention_mask: Option<&Tensor>,
576 input_embeds: Option<&Tensor>,
577 ) -> Result<Tensor, RustBertError> {
578 Ok(self
579 .base_model
580 .encoder
581 .forward_t(
582 input_ids,
583 attention_mask,
584 input_embeds,
585 Some(&self.base_model.word_embeddings),
586 false,
587 )?
588 .hidden_states)
589 }
590}
591
592/// # ProphetNet Model for causal generation
593/// ProphetNet decoder with a vocabulary decoding head
594/// It is made of the following blocks:
595/// - `base_model`: `ProphetNetDecoder` Base ProphetNet decoder
596/// - `word_embeddings`: word embeddings used by the decoder
597/// - `lm_head`: Linear layer without bias to project the hidden states to the vocabulary
598pub struct ProphetNetForCausalGeneration {
599 decoder: ProphetNetDecoder,
600 word_embeddings: nn::Embedding,
601 lm_head: nn::Linear,
602 ngram: i64,
603}
604
605impl ProphetNetForCausalGeneration {
606 /// Build a new `ProphetNetForCausalGeneration`
607 ///
608 /// # Arguments
609 ///
610 /// * `p` - Variable store path for the root of the ProphetNet model
611 /// * `config` - `ProphetNetConfig` object defining the model architecture
612 ///
613 /// # Example
614 ///
615 /// ```no_run
616 /// use rust_bert::prophetnet::{ProphetNetConfig, ProphetNetForCausalGeneration};
617 /// use rust_bert::Config;
618 /// use std::path::Path;
619 /// use tch::{nn, Device};
620 ///
621 /// let config_path = Path::new("path/to/config.json");
622 /// let device = Device::Cpu;
623 /// let p = nn::VarStore::new(device);
624 /// let config = ProphetNetConfig::from_file(config_path);
625 /// let prophetnet_model = ProphetNetForCausalGeneration::new(&p.root(), &config);
626 /// ```
627 pub fn new<'p, P>(
628 p: P,
629 config: &ProphetNetConfig,
630 ) -> Result<ProphetNetForCausalGeneration, RustBertError>
631 where
632 P: Borrow<nn::Path<'p>>,
633 {
634 let p = p.borrow();
635 let mut updated_config = config.clone();
636 updated_config.is_encoder_decoder = false;
637
638 let p_prophetnet = p / "prophetnet";
639 let decoder = ProphetNetDecoder::new(&p_prophetnet / "decoder", &updated_config)?;
640 let linear_config = nn::LinearConfig {
641 bias: false,
642 ..Default::default()
643 };
644
645 let word_embeddings_config = nn::EmbeddingConfig {
646 padding_idx: config.pad_token_id,
647 ..Default::default()
648 };
649 let p_decoder = &p_prophetnet / "decoder";
650 let word_embeddings = nn::embedding(
651 &p_decoder / "word_embeddings",
652 config.vocab_size,
653 config.hidden_size,
654 word_embeddings_config,
655 );
656
657 let lm_head = nn::linear(
658 p / "lm_head",
659 config.hidden_size,
660 config.vocab_size,
661 linear_config,
662 );
663
664 let ngram = config.ngram;
665
666 Ok(ProphetNetForCausalGeneration {
667 decoder,
668 word_embeddings,
669 lm_head,
670 ngram,
671 })
672 }
673
674 /// Forward pass through the model
675 ///
676 /// # Arguments
677 ///
678 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
679 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
680 /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
681 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
682 /// * `old_layer_states` - Optional Vector `Option<Vec<Option<&LayerState>, Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
683 /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
684 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
685 ///
686 /// # Returns
687 ///
688 /// * `ProphetNetGenerationOutput` containing:
689 /// - `logits` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder
690 /// - `ngram_logits` - `Tensor` of shape (*ngram*, *batch size*, *target_sequence_length*, *vocabulary_size*) representing the activations of the last hidden state for the decoder ngram stream
691 /// - `next_decoder_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers with shape (*past_sequence_length*, *batch size*, *hidden_size*)
692 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
693 /// - `all_ngram_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
694 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
695 /// - `all_ngram_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*ngram*, *batch size*, *target_sequence_length*, *hidden_size*)
696 /// - `all_cross_attentions` - `Option<Vec<Tensor>>` of length *n_layer* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
697 ///
698 /// # Example
699 ///
700 /// ```no_run
701 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
702 /// # use rust_bert::Config;
703 /// # use std::path::Path;
704 /// # use tch::kind::Kind::{Int64, Double};
705 /// use rust_bert::prophetnet::{ProphetNetModel, ProphetNetConfig, ProphetNetForCausalGeneration};
706 /// # let config_path = Path::new("path/to/config.json");
707 /// # let vocab_path = Path::new("path/to/vocab.txt");
708 /// # let device = Device::Cpu;
709 /// # let vs = nn::VarStore::new(device);
710 /// # let config = ProphetNetConfig::from_file(config_path);
711 /// # let prophetnet_model: ProphetNetForCausalGeneration = ProphetNetForCausalGeneration::new(&vs.root(), &config).unwrap();
712 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
713 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
714 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
715 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
716 /// let decoder_input_ids = Tensor::ones(&[batch_size, target_sequence_length], (Kind::Float, device));
717 ///
718 /// let model_output = no_grad(|| {
719 /// prophetnet_model.forward_t(
720 /// Some(&input_tensor),
721 /// Some(&attention_mask),
722 /// None,
723 /// Some(&decoder_input_ids),
724 /// None,
725 /// None,
726 /// false
727 /// )
728 /// });
729 /// ```
730 pub fn forward_t(
731 &self,
732 input_ids: Option<&Tensor>,
733 attention_mask: Option<&Tensor>,
734 input_embeds: Option<&Tensor>,
735 encoder_hidden_states: Option<&Tensor>,
736 encoder_attention_mask: Option<&Tensor>,
737 old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
738 train: bool,
739 ) -> Result<ProphetNetGenerationOutput, RustBertError> {
740 let base_model_output = self.decoder.forward_t(
741 input_ids,
742 attention_mask,
743 encoder_hidden_states,
744 encoder_attention_mask,
745 old_layer_states,
746 input_embeds,
747 Some(&self.word_embeddings),
748 train,
749 )?;
750
751 let (batch_size, sequence_length) = if let Some(input_ids) = input_ids {
752 let shape = input_ids.size();
753 (shape[0], shape[1])
754 } else if let Some(input_embeds) = input_embeds {
755 let shape = input_embeds.size();
756 (shape[0], shape[1])
757 } else {
758 return Err(RustBertError::ValueError(
759 "At least one of input_ids or input_embeds must be set".into(),
760 ));
761 };
762
763 if base_model_output.ngram_hidden_states.is_none() {
764 return Err(RustBertError::InvalidConfigurationError(
765 "ngram must be set > 0 in the configuration for conditional generation".into(),
766 ));
767 }
768
769 let predict_logits = base_model_output
770 .ngram_hidden_states
771 .as_ref()
772 .unwrap()
773 .view([batch_size, self.ngram, sequence_length, -1])
774 .apply(&self.lm_head);
775
776 let logits = predict_logits.select(1, 0).contiguous();
777
778 let ngram_logits = if self.ngram > 1 {
779 Some(predict_logits.slice(1, 1, predict_logits.size()[1], 1))
780 } else {
781 None
782 };
783
784 Ok(ProphetNetGenerationOutput {
785 logits,
786 ngram_logits,
787 ngram_hidden_states: base_model_output.ngram_hidden_states,
788 all_decoder_hidden_states: base_model_output.all_hidden_states,
789 all_ngram_hidden_states: base_model_output.all_ngram_hidden_states,
790 all_attentions: base_model_output.all_attentions,
791 all_ngram_attentions: base_model_output.all_ngram_attentions,
792 all_cross_attentions: base_model_output.all_cross_attentions,
793 next_decoder_cache: base_model_output.next_decoder_cache,
794 })
795 }
796}
797
798///Container holding a ProphetNet model output
799pub struct ProphetNetOutput {
800 /// last decoder layer hidden state
801 pub last_hidden_states: Tensor,
802 /// last decoder layer ngram hidden state
803 pub ngram_hidden_states: Option<Tensor>,
804 /// Hidden states for all intermediate layers
805 pub all_decoder_hidden_states: Option<Vec<Tensor>>,
806 /// Hidden states (ngram) for all intermediate layers
807 pub all_ngram_hidden_states: Option<Vec<Tensor>>,
808 /// Attention weights for all intermediate layers
809 pub all_attentions: Option<Vec<Tensor>>,
810 /// Ngram attention weights for all intermediate layers
811 pub all_ngram_attentions: Option<Vec<Tensor>>,
812 /// Cross attention weights for all intermediate layers
813 pub all_cross_attentions: Option<Vec<Tensor>>,
814 /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
815 pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
816}
817
818///Container holding a ProphetNet model generation output
819pub struct ProphetNetGenerationOutput {
820 /// Prediction logits
821 pub logits: Tensor,
822 /// Ngram prediction logits
823 pub ngram_logits: Option<Tensor>,
824 /// last decoder layer ngram hidden state
825 pub ngram_hidden_states: Option<Tensor>,
826 /// Hidden states for all intermediate layers
827 pub all_decoder_hidden_states: Option<Vec<Tensor>>,
828 /// Hidden states (ngram) for all intermediate layers
829 pub all_ngram_hidden_states: Option<Vec<Tensor>>,
830 /// Attention weights for all intermediate layers
831 pub all_attentions: Option<Vec<Tensor>>,
832 /// Ngram attention weights for all intermediate layers
833 pub all_ngram_attentions: Option<Vec<Tensor>>,
834 /// Cross attention weights for all intermediate layers
835 pub all_cross_attentions: Option<Vec<Tensor>>,
836 /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
837 pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
838}
839
840/// # Language generation model based on the ProphetNet architecture
841pub struct ProphetNetConditionalGenerator {
842 model: ProphetNetForConditionalGeneration,
843 tokenizer: TokenizerOption,
844 var_store: nn::VarStore,
845 generate_config: GenerateConfig,
846 bos_token_id: Option<i64>,
847 eos_token_ids: Option<Vec<i64>>,
848 pad_token_id: Option<i64>,
849 is_encoder_decoder: bool,
850 vocab_size: i64,
851 decoder_start_id: Option<i64>,
852 max_position_embeddings: i64,
853}
854
855impl ProphetNetConditionalGenerator {
856 /// Build a new `ProphetNetConditionalGenerator`
857 ///
858 /// # Arguments
859 ///
860 /// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
861 /// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
862 /// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
863 /// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
864 /// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
865 ///
866 /// # Example
867 ///
868 /// ```no_run
869 /// # use std::path::PathBuf;
870 /// # use tch::Device;
871 /// # fn main() -> anyhow::Result<()> {
872 /// use rust_bert::pipelines::generation_utils::GenerateConfig;
873 /// use rust_bert::prophetnet::ProphetNetConditionalGenerator;
874 /// # let mut home: PathBuf = dirs::home_dir().unwrap();
875 /// # home.push("rustbert");
876 /// # home.push("prophetnet");
877 /// # let config_path = &home.as_path().join("config.json");
878 /// # let vocab_path = &home.as_path().join("vocab.txt");
879 /// # let merges_path = &home.as_path().join("merges.txt");
880 /// # let weights_path = &home.as_path().join("model.ot");
881 /// let device = Device::cuda_if_available();
882 /// let generate_config = GenerateConfig {
883 /// max_length: Some(30),
884 /// do_sample: true,
885 /// num_beams: 5,
886 /// temperature: 1.1,
887 /// num_return_sequences: 3,
888 /// ..Default::default()
889 /// };
890 /// let prophetnet_generator = ProphetNetConditionalGenerator::new(generate_config)?;
891 /// # Ok(())
892 /// # }
893 /// ```
894 pub fn new(
895 generate_config: GenerateConfig,
896 ) -> Result<ProphetNetConditionalGenerator, RustBertError> {
897 let vocab_path = generate_config.vocab_resource.get_local_path()?;
898
899 let tokenizer = TokenizerOption::from_file(
900 ModelType::ProphetNet,
901 vocab_path.to_str().unwrap(),
902 None,
903 true,
904 true,
905 None,
906 )?;
907
908 Self::new_with_tokenizer(generate_config, tokenizer)
909 }
910
911 pub fn new_with_tokenizer(
912 generate_config: GenerateConfig,
913 tokenizer: TokenizerOption,
914 ) -> Result<ProphetNetConditionalGenerator, RustBertError> {
915 let config_path = generate_config.config_resource.get_local_path()?;
916 let device = generate_config.device;
917
918 generate_config.validate();
919 let mut var_store = nn::VarStore::new(device);
920 let config = ProphetNetConfig::from_file(config_path);
921 let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
922 crate::resources::load_weights(
923 &generate_config.model_resource,
924 &mut var_store,
925 generate_config.kind,
926 device,
927 )?;
928
929 let bos_token_id = Some(config.bos_token_id);
930 let eos_token_ids = Some(vec![config.eos_token_id]);
931 let pad_token_id = Some(config.pad_token_id);
932 let vocab_size = config.vocab_size;
933 let is_encoder_decoder = true;
934 let decoder_start_id = config.decoder_start_token_id;
935 let max_position_embeddings = config.max_position_embeddings;
936
937 Ok(ProphetNetConditionalGenerator {
938 model,
939 tokenizer,
940 var_store,
941 generate_config,
942 bos_token_id,
943 eos_token_ids,
944 pad_token_id,
945 is_encoder_decoder,
946 vocab_size,
947 decoder_start_id,
948 max_position_embeddings,
949 })
950 }
951}
952
953impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
954 fn _get_tokenizer(&self) -> &TokenizerOption {
955 &self.tokenizer
956 }
957 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
958 &mut self.tokenizer
959 }
960 fn get_device(&self) -> Device {
961 self.var_store.device()
962 }
963 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
964 Ok(&mut self.var_store)
965 }
966 fn get_config(&self) -> &GenerateConfig {
967 &self.generate_config
968 }
969 fn get_bos_id(&self) -> Option<i64> {
970 self.bos_token_id
971 }
972 fn get_eos_ids(&self) -> Option<&Vec<i64>> {
973 self.eos_token_ids.as_ref()
974 }
975 fn get_pad_id(&self) -> Option<i64> {
976 self.pad_token_id
977 }
978 fn is_encoder_decoder(&self) -> bool {
979 self.is_encoder_decoder
980 }
981 fn get_vocab_size(&self) -> i64 {
982 self.vocab_size
983 }
984 fn get_decoder_start_id(&self) -> Option<i64> {
985 self.decoder_start_id
986 }
987 fn get_max_positions_embeddings(&self) -> Option<i64> {
988 Some(self.max_position_embeddings)
989 }
990
991 fn forward_t(
992 &self,
993 input_ids: Option<&Tensor>,
994 cache: Cache,
995 attention_mask: Option<&Tensor>,
996 _token_type_ids: Option<&Tensor>,
997 _position_ids: Option<&Tensor>,
998 input_embeds: Option<&Tensor>,
999 encoder_outputs: Option<&Tensor>,
1000 decoder_input_ids: Option<&Tensor>,
1001 train: bool,
1002 ) -> Result<LMModelOutput, RustBertError> {
1003 let base_model_output = match cache {
1004 Cache::ProphetNetCache(cached_layer_states) => self.model.forward_t(
1005 input_ids,
1006 attention_mask,
1007 input_embeds,
1008 decoder_input_ids,
1009 None,
1010 encoder_outputs,
1011 cached_layer_states,
1012 None,
1013 train,
1014 )?,
1015 Cache::None => self.model.forward_t(
1016 input_ids,
1017 attention_mask,
1018 input_embeds,
1019 decoder_input_ids,
1020 None,
1021 encoder_outputs,
1022 None,
1023 None,
1024 train,
1025 )?,
1026 _ => {
1027 return Err(RustBertError::ValueError(
1028 "Cache not compatible with ProphetNet Model".into(),
1029 ));
1030 }
1031 };
1032
1033 Ok(LMModelOutput {
1034 lm_logits: base_model_output.logits,
1035 cache: Cache::ProphetNetCache(base_model_output.next_decoder_cache),
1036 })
1037 }
1038
1039 fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
1040 Some(
1041 self.model
1042 .encode(Some(input_ids), attention_mask, None)
1043 .unwrap(),
1044 )
1045 }
1046
1047 fn prepare_inputs_for_generation<'a>(
1048 &self,
1049 input_ids: Tensor,
1050 encoder_outputs: Option<&'a Tensor>,
1051 past: Cache,
1052 attention_mask: Tensor,
1053 ) -> PreparedInput<'a> {
1054 match past {
1055 Cache::ProphetNetCache(past) => PreparedInput {
1056 prepared_input: None,
1057 prepared_attention_mask: Some(attention_mask),
1058 prepared_encoder_output: encoder_outputs,
1059 prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
1060 prepared_position_ids: None,
1061 prepared_past: Cache::ProphetNetCache(past),
1062 },
1063 Cache::None => PreparedInput {
1064 prepared_input: None,
1065 prepared_attention_mask: Some(attention_mask),
1066 prepared_encoder_output: encoder_outputs,
1067 prepared_decoder_input: Some(input_ids),
1068 prepared_position_ids: None,
1069 prepared_past: Cache::ProphetNetCache(None),
1070 },
1071 _ => panic!("Cache type incompatible with ProphetNet"),
1072 }
1073 }
1074
1075 fn reorder_cache(
1076 &self,
1077 past: &mut Cache,
1078 encoder_outputs: Option<Tensor>,
1079 beam_indices: &Tensor,
1080 ) -> Option<Tensor> {
1081 let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
1082 match past {
1083 Cache::ProphetNetCache(old_cache_option) => match old_cache_option {
1084 Some(old_cache) => {
1085 for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
1086 if self_layer_state.is_some() {
1087 self_layer_state
1088 .as_mut()
1089 .unwrap()
1090 .reorder_cache(beam_indices)
1091 };
1092 if encoder_layer_state.is_some() {
1093 encoder_layer_state
1094 .as_mut()
1095 .unwrap()
1096 .reorder_cache(beam_indices)
1097 };
1098 }
1099 }
1100 None => {}
1101 },
1102 Cache::None => {}
1103 _ => {
1104 panic!("Invalid cache for ProphetNet model");
1105 }
1106 };
1107 encoder_outputs
1108 }
1109}
1110
1111impl LanguageGenerator for ProphetNetConditionalGenerator {}