rust_bert/models/bart/bart_model.rs
1// Copyright 2020 The Facebook AI Research Team Authors
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bart::attention::LayerState;
15use crate::bart::decoder::BartDecoder;
16use crate::bart::encoder::BartEncoder;
17use crate::common::activations::Activation;
18use crate::common::dropout::Dropout;
19use crate::common::kind::get_min;
20use crate::pipelines::common::{ModelType, TokenizerOption};
21use crate::pipelines::generation_utils::private_generation_utils::{
22 PreparedInput, PrivateLanguageGenerator,
23};
24use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
25use crate::{Config, RustBertError};
26
27use serde::{Deserialize, Serialize};
28use std::borrow::Borrow;
29use std::collections::HashMap;
30use tch::nn::{embedding, EmbeddingConfig};
31use tch::{nn, Device, Kind, Tensor};
32
33/// # BART Pretrained model weight files
34pub struct BartModelResources;
35
36/// # BART Pretrained model config files
37pub struct BartConfigResources;
38
39/// # BART Pretrained model vocab files
40pub struct BartVocabResources;
41
42/// # BART Pretrained model merges files
43pub struct BartMergesResources;
44
45impl BartModelResources {
46 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
47 pub const BART: (&'static str, &'static str) = (
48 "bart/model",
49 "https://huggingface.co/facebook/bart-large/resolve/main/rust_model.ot",
50 );
51 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
52 pub const BART_CNN: (&'static str, &'static str) = (
53 "bart-cnn/model",
54 "https://huggingface.co/facebook/bart-large-cnn/resolve/main/rust_model.ot",
55 );
56 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
57 pub const BART_XSUM: (&'static str, &'static str) = (
58 "bart-xsum/model",
59 "https://huggingface.co/facebook/bart-large-xsum/resolve/main/rust_model.ot",
60 );
61 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
62 pub const BART_MNLI: (&'static str, &'static str) = (
63 "bart-large-mnli/model",
64 "https://huggingface.co/facebook/bart-large-mnli/resolve/main/rust_model.ot",
65 );
66 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
67 pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
68 "distilbart-cnn-6-6/model",
69 "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/rust_model.ot",
70 );
71 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
72 pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
73 "distilbart-cnn-12-6/model",
74 "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/rust_model.ot",
75 );
76}
77
78impl BartConfigResources {
79 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
80 pub const BART: (&'static str, &'static str) = (
81 "bart/config",
82 "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
83 );
84 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
85 pub const BART_CNN: (&'static str, &'static str) = (
86 "bart-cnn/config",
87 "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json",
88 );
89 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
90 pub const BART_XSUM: (&'static str, &'static str) = (
91 "bart-xsum/config",
92 "https://huggingface.co/facebook/bart-large-xsum/resolve/main/config.json",
93 );
94 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
95 pub const BART_MNLI: (&'static str, &'static str) = (
96 "bart-large-mnli/config",
97 "https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json",
98 );
99 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
100 pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
101 "distilbart-cnn-6-6/config",
102 "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json",
103 );
104 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
105 pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
106 "distilbart-cnn-12-6/config",
107 "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/config.json",
108 );
109}
110
111impl BartVocabResources {
112 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
113 pub const BART: (&'static str, &'static str) = (
114 "bart/vocab",
115 "https://huggingface.co/roberta-large/resolve/main/vocab.json",
116 );
117 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
118 pub const BART_CNN: (&'static str, &'static str) = (
119 "bart-cnn/vocab",
120 "https://huggingface.co/roberta-large/resolve/main/vocab.json",
121 );
122 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
123 pub const BART_XSUM: (&'static str, &'static str) = (
124 "bart-xsum/vocab",
125 "https://huggingface.co/roberta-large/resolve/main/vocab.json",
126 );
127 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
128 pub const BART_MNLI: (&'static str, &'static str) = (
129 "bart-large-mnli/vocab",
130 "https://huggingface.co/roberta-large/resolve/main/vocab.json",
131 );
132 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
133 pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
134 "distilbart-cnn-6-6/vocab",
135 "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/vocab.json",
136 );
137 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
138 pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
139 "distilbart-cnn-12-6/vocab",
140 "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/vocab.json",
141 );
142}
143
144impl BartMergesResources {
145 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
146 pub const BART: (&'static str, &'static str) = (
147 "bart/merges",
148 "https://huggingface.co/roberta-large/resolve/main/merges.txt",
149 );
150 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
151 pub const BART_CNN: (&'static str, &'static str) = (
152 "bart-cnn/merges",
153 "https://huggingface.co/roberta-large/resolve/main/merges.txt",
154 );
155 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
156 pub const BART_XSUM: (&'static str, &'static str) = (
157 "bart-xsum/merges",
158 "https://huggingface.co/roberta-large/resolve/main/merges.txt",
159 );
160 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
161 pub const BART_MNLI: (&'static str, &'static str) = (
162 "bart-large-mnli/merges",
163 "https://huggingface.co/roberta-large/resolve/main/merges.txt",
164 );
165 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
166 pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
167 "distilbart-cnn-6-6/merges",
168 "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/merges.txt",
169 );
170 /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
171 pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
172 "distilbart-cnn-12-6/merges",
173 "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/merges.txt",
174 );
175}
176
177#[derive(Debug, Serialize, Deserialize, Clone)]
178/// # BART model configuration
179/// Defines the BART model architecture (e.g. number of layers, hidden layer size, label mapping...)
180pub struct BartConfig {
181 pub num_labels: Option<i64>,
182 pub activation_function: Option<Activation>,
183 pub activation_dropout: f64,
184 pub attention_dropout: f64,
185 pub classif_dropout: Option<f64>,
186 pub d_model: i64,
187 pub decoder_attention_heads: i64,
188 pub decoder_ffn_dim: i64,
189 pub decoder_layerdrop: f64,
190 pub decoder_layers: i64,
191 pub decoder_start_token_id: Option<i64>,
192 pub dropout: f64,
193 pub encoder_attention_heads: i64,
194 pub encoder_ffn_dim: i64,
195 pub encoder_layerdrop: f64,
196 pub encoder_layers: i64,
197 pub bos_token_id: Option<i64>,
198 pub eos_token_id: Option<i64>,
199 pub forced_bos_token_id: Option<i64>,
200 pub forced_eos_token_id: Option<i64>,
201 pub pad_token_id: Option<i64>,
202 pub id2label: Option<HashMap<i64, String>>,
203 pub label2id: Option<HashMap<String, i64>>,
204 pub init_std: f64,
205 pub is_decoder: Option<bool>,
206 pub is_encoder_decoder: Option<bool>,
207 pub max_position_embeddings: i64,
208 pub min_length: Option<i64>,
209 pub no_repeat_ngram_size: Option<i64>,
210 pub normalize_embedding: Option<bool>,
211 pub num_hidden_layers: i64,
212 pub output_attentions: Option<bool>,
213 pub output_hidden_states: Option<bool>,
214 pub output_past: Option<bool>,
215 pub static_position_embeddings: Option<bool>,
216 pub scale_embedding: Option<bool>,
217 pub vocab_size: i64,
218}
219
220impl Config for BartConfig {}
221
222impl Default for BartConfig {
223 fn default() -> Self {
224 BartConfig {
225 num_labels: Some(3),
226 activation_function: Some(Activation::gelu),
227 activation_dropout: 0.0,
228 attention_dropout: 0.0,
229 classif_dropout: Some(0.0),
230 d_model: 1024,
231 decoder_attention_heads: 16,
232 decoder_ffn_dim: 4096,
233 decoder_layerdrop: 0.0,
234 decoder_layers: 12,
235 decoder_start_token_id: Some(2),
236 dropout: 0.1,
237 encoder_attention_heads: 16,
238 encoder_ffn_dim: 4096,
239 encoder_layerdrop: 0.0,
240 encoder_layers: 12,
241 bos_token_id: Some(0),
242 eos_token_id: Some(2),
243 pad_token_id: Some(1),
244 forced_bos_token_id: Some(0),
245 forced_eos_token_id: Some(2),
246 id2label: None,
247 label2id: None,
248 init_std: 0.02,
249 is_decoder: None,
250 is_encoder_decoder: Some(true),
251 max_position_embeddings: 1024,
252 min_length: None,
253 no_repeat_ngram_size: None,
254 normalize_embedding: Some(true),
255 num_hidden_layers: 12,
256 output_attentions: None,
257 output_hidden_states: None,
258 output_past: None,
259 static_position_embeddings: None,
260 scale_embedding: Some(false),
261 vocab_size: 50265,
262 }
263 }
264}
265
266pub(crate) fn _make_causal_mask(
267 input_ids_shape: &[i64],
268 dtype: Kind,
269 device: Device,
270 past_key_values_length: i64,
271) -> Tensor {
272 let batch_size = input_ids_shape[0];
273 let target_length = input_ids_shape[1];
274
275 let mut mask = Tensor::full(
276 [target_length, target_length],
277 get_min(dtype).unwrap(),
278 (dtype, device),
279 );
280 let mask_cond = Tensor::arange(target_length, (dtype, device));
281 let _ = mask.masked_fill_(
282 &mask_cond.lt_tensor(&(&mask_cond + 1).view([target_length, 1])),
283 0,
284 );
285
286 if past_key_values_length > 0 {
287 mask = Tensor::cat(
288 &[
289 Tensor::zeros([target_length, past_key_values_length], (dtype, device)),
290 mask,
291 ],
292 -1,
293 );
294 }
295 mask.unsqueeze(0).unsqueeze(0).expand(
296 [
297 batch_size,
298 1,
299 target_length,
300 target_length + past_key_values_length,
301 ],
302 true,
303 )
304}
305
306pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option<i64>, dtype: Kind) -> Tensor {
307 let (batch_size, source_length) = mask.size2().unwrap();
308 let target_length = target_length.unwrap_or(source_length);
309 let expanded_mask = mask
310 .unsqueeze(1)
311 .unsqueeze(1)
312 .expand([batch_size, 1, target_length, source_length], true)
313 .totype(dtype);
314 let inverted_mask: Tensor = 1 - expanded_mask;
315 inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), get_min(dtype).unwrap())
316}
317
318pub(crate) fn _prepare_decoder_attention_mask(
319 attention_mask: Option<&Tensor>,
320 input_shape: &[i64],
321 input_embeds: &Tensor,
322 past_key_values_length: i64,
323) -> Option<Tensor> {
324 let last_input_shape_dim = *input_shape.last().unwrap();
325 let mut combined_attention_mask = if last_input_shape_dim > 1 {
326 Some(_make_causal_mask(
327 input_shape,
328 input_embeds.kind(),
329 input_embeds.device(),
330 past_key_values_length,
331 ))
332 } else {
333 None
334 };
335
336 if let Some(attention_mask) = attention_mask {
337 let expanded_attention_mask = _expand_mask(
338 attention_mask,
339 Some(last_input_shape_dim),
340 input_embeds.kind(),
341 );
342 combined_attention_mask = match combined_attention_mask {
343 Some(value) => Some(value + expanded_attention_mask),
344 None => Some(expanded_attention_mask),
345 };
346 }
347
348 combined_attention_mask
349}
350
351fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
352 let index_eos: Tensor =
353 input_ids
354 .ne(pad_token_id)
355 .sum_dim_intlist([-1].as_slice(), true, Kind::Int64)
356 - 1;
357 let output = input_ids.empty_like().to_kind(Kind::Int64);
358 output
359 .select(1, 0)
360 .copy_(&input_ids.gather(1, &index_eos, false).squeeze());
361 output
362 .slice(1, 1, *output.size().last().unwrap(), 1)
363 .copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));
364 output
365}
366
367/// # BART Base model
368/// Base architecture for BART model. Usually complemented with a task-specific head, such as a language model head.
369/// It is made of the following blocks:
370/// - `encoder`: `BartEncoder` (transformer) made of a vector of encoding layers
371/// - `decoder`: `BartDecoder` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention.
372/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
373/// - `pad_token_id`: padding token id
374pub struct BartModel {
375 pub(crate) encoder: BartEncoder,
376 decoder: BartDecoder,
377 pub(crate) embeddings: nn::Embedding,
378 pad_token_id: i64,
379}
380
381impl BartModel {
382 /// Build a new `BartModel`
383 ///
384 /// # Arguments
385 ///
386 /// * `p` - Variable store path for the root of the BART model
387 /// * `config` - `BartConfig` object defining the model architecture
388 ///
389 /// # Example
390 ///
391 /// ```no_run
392 /// use rust_bert::bart::{BartConfig, BartModel};
393 /// use rust_bert::Config;
394 /// use std::path::Path;
395 /// use tch::{nn, Device};
396 ///
397 /// let config_path = Path::new("path/to/config.json");
398 /// let device = Device::Cpu;
399 /// let p = nn::VarStore::new(device);
400 /// let config = BartConfig::from_file(config_path);
401 /// let bart: BartModel = BartModel::new(&p.root() / "bart", &config);
402 /// ```
403 pub fn new<'p, P>(p: P, config: &BartConfig) -> BartModel
404 where
405 P: Borrow<nn::Path<'p>>,
406 {
407 let p = p.borrow();
408
409 let pad_token_id = config.pad_token_id.unwrap_or(1);
410 let embedding_config = EmbeddingConfig {
411 padding_idx: pad_token_id,
412 ..Default::default()
413 };
414 let embeddings: nn::Embedding = embedding(
415 p / "shared",
416 config.vocab_size,
417 config.d_model,
418 embedding_config,
419 );
420
421 let encoder = BartEncoder::new(p / "encoder", config);
422 let decoder = BartDecoder::new(p / "decoder", config);
423
424 BartModel {
425 encoder,
426 decoder,
427 embeddings,
428 pad_token_id,
429 }
430 }
431
432 /// Forward pass through the model
433 ///
434 /// # Arguments
435 ///
436 /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
437 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
438 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
439 /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
440 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
441 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
442 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
443 ///
444 /// # Returns
445 ///
446 /// * `BartModelOutput` containing:
447 /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
448 /// - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None
449 /// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
450 /// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
451 /// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
452 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
453 /// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
454 ///
455 /// # Example
456 ///
457 /// ```no_run
458 /// # use tch::{nn, Device, Tensor, no_grad};
459 /// # use rust_bert::Config;
460 /// # use std::path::Path;
461 /// # use tch::kind::Kind::{Int64, Double};
462 /// use rust_bert::bart::{BartConfig, BartModel};
463 /// # let config_path = Path::new("path/to/config.json");
464 /// # let vocab_path = Path::new("path/to/vocab.txt");
465 /// # let device = Device::Cpu;
466 /// # let vs = nn::VarStore::new(device);
467 /// # let config = BartConfig::from_file(config_path);
468 /// # let bart_model: BartModel = BartModel::new(&vs.root(), &config);
469 /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
470 /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
471 /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
472 /// let encoder_attention_mask =
473 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
474 /// let decoder_attention_mask =
475 /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
476 ///
477 /// let model_output = no_grad(|| {
478 /// bart_model.forward_t(
479 /// Some(&input_tensor),
480 /// Some(&encoder_attention_mask),
481 /// Some(&target_tensor),
482 /// None,
483 /// Some(&decoder_attention_mask),
484 /// None,
485 /// false,
486 /// )
487 /// });
488 /// ```
489 pub fn forward_t(
490 &self,
491 input_ids: Option<&Tensor>,
492 attention_mask: Option<&Tensor>,
493 decoder_input_ids: Option<&Tensor>,
494 encoder_output: Option<&Tensor>,
495 decoder_attention_mask: Option<&Tensor>,
496 layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
497 train: bool,
498 ) -> BartModelOutput {
499 let calc_decoder_input_ids = if decoder_input_ids.is_none() {
500 Some(_shift_tokens_right(input_ids.unwrap(), self.pad_token_id))
501 } else {
502 None
503 };
504
505 let decoder_input_ids =
506 decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap());
507
508 let calc_encoder_output = if encoder_output.is_none() {
509 Some(self.encoder.forward_t(
510 input_ids.unwrap(),
511 attention_mask,
512 &self.embeddings,
513 train,
514 ))
515 } else {
516 None
517 };
518
519 let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
520 if let Some(calc_encoder_output) = calc_encoder_output {
521 (
522 Some(calc_encoder_output.hidden_state),
523 calc_encoder_output.all_hidden_states,
524 calc_encoder_output.all_attentions,
525 )
526 } else {
527 (None, None, None)
528 };
529
530 let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
531
532 let decoder_output = self.decoder.forward_t(
533 decoder_input_ids,
534 encoder_output,
535 attention_mask,
536 decoder_attention_mask,
537 &self.embeddings,
538 layer_states,
539 train,
540 );
541 BartModelOutput {
542 decoder_output: decoder_output.hidden_state,
543 encoder_hidden_state: calc_hidden_states,
544 cache: decoder_output.next_decoder_cache,
545 all_decoder_hidden_states: decoder_output.all_hidden_states,
546 all_decoder_attentions: decoder_output.all_attentions,
547 all_encoder_hidden_states,
548 all_encoder_attentions,
549 }
550 }
551}
552
553/// # BART Model for conditional generation
554/// BART model with a vocabulary decoding head
555/// It is made of the following blocks:
556/// - `base_model`: `BartModel` Base BART model
557/// - `linear`: Linear layer without bias tied to the weights of the token id embeddings
558pub struct BartForConditionalGeneration {
559 base_model: BartModel,
560}
561
562impl BartForConditionalGeneration {
563 /// Build a new `BartForConditionalGeneration`
564 ///
565 /// # Arguments
566 ///
567 /// * `p` - Variable store path for the root of the BART model
568 /// * `config` - `BartConfig` object defining the model architecture
569 ///
570 /// # Example
571 ///
572 /// ```no_run
573 /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
574 /// use rust_bert::Config;
575 /// use std::path::Path;
576 /// use tch::{nn, Device};
577 ///
578 /// let config_path = Path::new("path/to/config.json");
579 /// let device = Device::Cpu;
580 /// let p = nn::VarStore::new(device);
581 /// let config = BartConfig::from_file(config_path);
582 /// let bart: BartForConditionalGeneration =
583 /// BartForConditionalGeneration::new(&p.root() / "bart", &config);
584 /// ```
585 pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForConditionalGeneration
586 where
587 P: Borrow<nn::Path<'p>>,
588 {
589 let base_model = BartModel::new(p.borrow() / "model", config);
590 BartForConditionalGeneration { base_model }
591 }
592
593 /// Forward pass through the model
594 ///
595 /// # Arguments
596 ///
597 /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
598 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
599 /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
600 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
601 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
602 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
603 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
604 ///
605 /// # Returns
606 ///
607 /// * `BartModelOutput` containing:
608 /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
609 /// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
610 /// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
611 /// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
612 /// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
613 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
614 /// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
615 ///
616 /// # Example
617 ///
618 /// ```no_run
619 /// # use tch::{nn, Device, Tensor, no_grad};
620 /// # use rust_bert::Config;
621 /// # use std::path::Path;
622 /// # use tch::kind::Kind::{Int64, Double};
623 /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
624 /// # let config_path = Path::new("path/to/config.json");
625 /// # let vocab_path = Path::new("path/to/vocab.txt");
626 /// # let device = Device::Cpu;
627 /// # let vs = nn::VarStore::new(device);
628 /// # let config = BartConfig::from_file(config_path);
629 /// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
630 /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
631 /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
632 /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
633 /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
634 /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
635 ///
636 /// let model_output = no_grad(|| {
637 /// bart_model
638 /// .forward_t(Some(&input_tensor),
639 /// Some(&encoder_attention_mask),
640 /// None,
641 /// Some(&target_tensor),
642 /// Some(&decoder_attention_mask),
643 /// None,
644 /// false)
645 /// });
646 /// ```
647 pub fn forward_t(
648 &self,
649 input_ids: Option<&Tensor>,
650 attention_mask: Option<&Tensor>,
651 encoder_output: Option<&Tensor>,
652 decoder_input_ids: Option<&Tensor>,
653 decoder_attention_mask: Option<&Tensor>,
654 old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
655 train: bool,
656 ) -> BartModelOutput {
657 let base_model_output = self.base_model.forward_t(
658 input_ids,
659 attention_mask,
660 decoder_input_ids,
661 encoder_output,
662 decoder_attention_mask,
663 old_layer_states,
664 train,
665 );
666
667 let lm_logits = base_model_output
668 .decoder_output
669 .linear::<Tensor>(&self.base_model.embeddings.ws, None);
670 BartModelOutput {
671 decoder_output: lm_logits,
672 ..base_model_output
673 }
674 }
675
676 pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
677 self.base_model
678 .encoder
679 .forward_t(
680 input_ids,
681 attention_mask,
682 &self.base_model.embeddings,
683 false,
684 )
685 .hidden_state
686 }
687}
688
689pub struct BartClassificationHead {
690 dense: nn::Linear,
691 dropout: Dropout,
692 out_proj: nn::Linear,
693}
694
695impl BartClassificationHead {
696 pub fn new<'p, P>(p: P, config: &BartConfig) -> Result<BartClassificationHead, RustBertError>
697 where
698 P: Borrow<nn::Path<'p>>,
699 {
700 let p = p.borrow();
701 let num_labels = config
702 .id2label
703 .as_ref()
704 .ok_or_else(|| {
705 RustBertError::InvalidConfigurationError(
706 "num_labels not provided in configuration".to_string(),
707 )
708 })?
709 .len() as i64;
710 let dense = nn::linear(
711 p / "dense",
712 config.d_model,
713 config.d_model,
714 Default::default(),
715 );
716 let dropout = Dropout::new(config.classif_dropout.unwrap_or(0.0));
717 let out_proj = nn::linear(
718 p / "out_proj",
719 config.d_model,
720 num_labels,
721 Default::default(),
722 );
723
724 Ok(BartClassificationHead {
725 dense,
726 dropout,
727 out_proj,
728 })
729 }
730
731 pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
732 x.apply_t(&self.dropout, train)
733 .apply(&self.dense)
734 .tanh()
735 .apply_t(&self.dropout, train)
736 .apply(&self.out_proj)
737 }
738}
739
740/// # BART Model for sequence classification
741/// BART model with a classification head
742/// It is made of the following blocks:
743/// - `base_model`: `BartModel` Base BART model
744/// - `classification_head`: `BartClassificationHead` made of 2 linear layers mapping hidden states to a target class
745/// - `eos_token_id`: token id for the EOS token carrying the pooled representation for classification
746pub struct BartForSequenceClassification {
747 base_model: BartModel,
748 classification_head: BartClassificationHead,
749 eos_token_id: i64,
750}
751
752impl BartForSequenceClassification {
753 /// Build a new `BartForSequenceClassification`
754 ///
755 /// # Arguments
756 ///
757 /// * `p` - Variable store path for the root of the BART model
758 /// * `config` - `BartConfig` object defining the model architecture
759 ///
760 /// # Example
761 ///
762 /// ```no_run
763 /// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
764 /// use rust_bert::Config;
765 /// use std::path::Path;
766 /// use tch::{nn, Device};
767 ///
768 /// let config_path = Path::new("path/to/config.json");
769 /// let device = Device::Cpu;
770 /// let p = nn::VarStore::new(device);
771 /// let config = BartConfig::from_file(config_path);
772 /// let bart: BartForSequenceClassification =
773 /// BartForSequenceClassification::new(&p.root() / "bart", &config).unwrap();
774 /// ```
775 pub fn new<'p, P>(
776 p: P,
777 config: &BartConfig,
778 ) -> Result<BartForSequenceClassification, RustBertError>
779 where
780 P: Borrow<nn::Path<'p>>,
781 {
782 let p = p.borrow();
783
784 let base_model = BartModel::new(p / "model", config);
785 let classification_head = BartClassificationHead::new(p / "classification_head", config)?;
786 let eos_token_id = config.eos_token_id.unwrap_or(3);
787 Ok(BartForSequenceClassification {
788 base_model,
789 classification_head,
790 eos_token_id,
791 })
792 }
793
794 /// Forward pass through the model
795 ///
796 /// # Arguments
797 ///
798 /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
799 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
800 /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
801 /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
802 /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
803 /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
804 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
805 ///
806 /// # Returns
807 ///
808 /// * `BartModelOutput` containing:
809 /// - `decoder_output` - `Tensor` of shape (*batch size*, *num_classes*) representing the activations for each class and batch item
810 /// - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None.
811 /// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
812 /// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
813 /// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
814 /// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
815 /// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
816 ///
817 /// # Example
818 ///
819 /// ```no_run
820 /// # use tch::{nn, Device, Tensor, no_grad};
821 /// # use rust_bert::Config;
822 /// # use std::path::Path;
823 /// # use tch::kind::Kind::{Int64, Double};
824 /// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
825 /// # let config_path = Path::new("path/to/config.json");
826 /// # let vocab_path = Path::new("path/to/vocab.txt");
827 /// # let device = Device::Cpu;
828 /// # let vs = nn::VarStore::new(device);
829 /// # let config = BartConfig::from_file(config_path);
830 /// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap();
831 /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
832 /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
833 /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
834 /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
835 /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
836 ///
837 /// let model_output = no_grad(|| {
838 /// bart_model
839 /// .forward_t(&input_tensor,
840 /// Some(&encoder_attention_mask),
841 /// None,
842 /// Some(&target_tensor),
843 /// Some(&decoder_attention_mask),
844 /// false)
845 /// });
846 /// ```
847 pub fn forward_t(
848 &self,
849 input_ids: &Tensor,
850 attention_mask: Option<&Tensor>,
851 encoder_output: Option<&Tensor>,
852 decoder_input_ids: Option<&Tensor>,
853 decoder_attention_mask: Option<&Tensor>,
854 train: bool,
855 ) -> BartModelOutput {
856 let base_model_output = self.base_model.forward_t(
857 Some(input_ids),
858 attention_mask,
859 decoder_input_ids,
860 encoder_output,
861 decoder_attention_mask,
862 None,
863 train,
864 );
865 let eos_mask = input_ids.eq(self.eos_token_id);
866 let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, input_ids.kind());
867 let sentence_representation = base_model_output
868 .decoder_output
869 .permute([2, 0, 1])
870 .masked_select(&eos_mask)
871 .view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
872 .transpose(0, 1)
873 .view((
874 base_model_output.decoder_output.size()[0],
875 -1,
876 *base_model_output.decoder_output.size().last().unwrap(),
877 ))
878 .select(1, -1);
879
880 let logits = self
881 .classification_head
882 .forward_t(&sentence_representation, train);
883 BartModelOutput {
884 decoder_output: logits,
885 encoder_hidden_state: base_model_output.encoder_hidden_state,
886 cache: None,
887 all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
888 all_decoder_attentions: base_model_output.all_decoder_attentions,
889 all_encoder_hidden_states: base_model_output.all_encoder_hidden_states,
890 all_encoder_attentions: base_model_output.all_encoder_attentions,
891 }
892 }
893}
894
895/// Container holding a BART model output. The decoder output may hold the hidden state of
896/// the last layer of the decoder, or may hold logits for a custom head module after the
897/// decoder (e.g. for classification or language modeling tasks)
898pub struct BartModelOutput {
899 /// Hidden state of the last layer of the decoder, or logits for a custom head
900 /// module after the decoder (e.g. for classification or language modeling tasks)
901 pub decoder_output: Tensor,
902 /// Hidden state for the last layer of the encoder if they are calculated (not provided), otherwise None
903 pub encoder_hidden_state: Option<Tensor>,
904 /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
905 pub cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
906 /// Hidden states for all layers of the decoder
907 pub all_decoder_hidden_states: Option<Vec<Tensor>>,
908 /// Attention weights for all layers of the decoder
909 pub all_decoder_attentions: Option<Vec<Tensor>>,
910 /// Hidden states for all layers of the encoder
911 pub all_encoder_hidden_states: Option<Vec<Tensor>>,
912 /// Attention weights for all layers of the encoder
913 pub all_encoder_attentions: Option<Vec<Tensor>>,
914}
915
916/// # Language generation model based on the Bart architecture
917pub struct BartGenerator {
918 model: BartForConditionalGeneration,
919 tokenizer: TokenizerOption,
920 var_store: nn::VarStore,
921 generate_config: GenerateConfig,
922 bos_token_id: Option<i64>,
923 eos_token_ids: Option<Vec<i64>>,
924 forced_bos_token_id: Option<i64>,
925 forced_eos_token_id: Option<i64>,
926 pad_token_id: Option<i64>,
927 is_encoder_decoder: bool,
928 vocab_size: i64,
929 decoder_start_id: Option<i64>,
930 max_position_embeddings: i64,
931}
932
933impl BartGenerator {
934 /// Build a new `BartGenerator`
935 ///
936 /// # Arguments
937 ///
938 /// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
939 /// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
940 /// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
941 /// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
942 /// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
943 ///
944 /// # Example
945 ///
946 /// ```no_run
947 /// # use std::path::PathBuf;
948 /// # use tch::Device;
949 /// # fn main() -> anyhow::Result<()> {
950 /// use rust_bert::bart::BartGenerator;
951 /// use rust_bert::pipelines::generation_utils::GenerateConfig;
952 /// # let mut home: PathBuf = dirs::home_dir().unwrap();
953 /// # home.push("rustbert");
954 /// # home.push("openai-gpt");
955 /// # let config_path = &home.as_path().join("config.json");
956 /// # let vocab_path = &home.as_path().join("vocab.txt");
957 /// # let merges_path = &home.as_path().join("merges.txt");
958 /// # let weights_path = &home.as_path().join("model.ot");
959 /// let device = Device::cuda_if_available();
960 /// let generate_config = GenerateConfig {
961 /// max_length: Some(30),
962 /// do_sample: true,
963 /// num_beams: 5,
964 /// temperature: 1.1,
965 /// num_return_sequences: 3,
966 /// ..Default::default()
967 /// };
968 /// let bart_generator = BartGenerator::new(generate_config)?;
969 /// # Ok(())
970 /// # }
971 /// ```
972 pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
973 let vocab_path = generate_config.vocab_resource.get_local_path()?;
974 let merges_path = generate_config
975 .merges_resource
976 .as_ref()
977 .ok_or_else(|| {
978 RustBertError::InvalidConfigurationError(
979 "BART expects a merges resources to be provided".to_string(),
980 )
981 })?
982 .get_local_path()?;
983
984 let tokenizer = TokenizerOption::from_file(
985 ModelType::Bart,
986 vocab_path.to_str().unwrap(),
987 Some(merges_path.to_str().unwrap()),
988 false,
989 None,
990 false,
991 )?;
992
993 Self::new_with_tokenizer(generate_config, tokenizer)
994 }
995
996 pub fn new_with_tokenizer(
997 generate_config: GenerateConfig,
998 tokenizer: TokenizerOption,
999 ) -> Result<BartGenerator, RustBertError> {
1000 let config_path = generate_config.config_resource.get_local_path()?;
1001 let device = generate_config.device;
1002
1003 generate_config.validate();
1004 let mut var_store = nn::VarStore::new(device);
1005 let config = BartConfig::from_file(config_path);
1006 let model = BartForConditionalGeneration::new(var_store.root(), &config);
1007 crate::resources::load_weights(
1008 &generate_config.model_resource,
1009 &mut var_store,
1010 generate_config.kind,
1011 device,
1012 )?;
1013
1014 let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
1015 let eos_token_ids = Some(match config.eos_token_id {
1016 Some(value) => vec![value],
1017 None => vec![2],
1018 });
1019 let forced_bos_token_id = config.forced_bos_token_id;
1020 let forced_eos_token_id = config.forced_eos_token_id;
1021 let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
1022 let vocab_size = config.vocab_size;
1023 let is_encoder_decoder = true;
1024 let decoder_start_id = config.decoder_start_token_id;
1025 let max_position_embeddings = config.max_position_embeddings;
1026
1027 Ok(BartGenerator {
1028 model,
1029 tokenizer,
1030 var_store,
1031 generate_config,
1032 bos_token_id,
1033 eos_token_ids,
1034 forced_bos_token_id,
1035 forced_eos_token_id,
1036 pad_token_id,
1037 is_encoder_decoder,
1038 vocab_size,
1039 decoder_start_id,
1040 max_position_embeddings,
1041 })
1042 }
1043}
1044
1045impl PrivateLanguageGenerator for BartGenerator {
1046 fn _get_tokenizer(&self) -> &TokenizerOption {
1047 &self.tokenizer
1048 }
1049 fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
1050 &mut self.tokenizer
1051 }
1052 fn get_device(&self) -> Device {
1053 self.var_store.device()
1054 }
1055 fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
1056 Ok(&mut self.var_store)
1057 }
1058 fn get_config(&self) -> &GenerateConfig {
1059 &self.generate_config
1060 }
1061 fn get_bos_id(&self) -> Option<i64> {
1062 self.bos_token_id
1063 }
1064 fn get_eos_ids(&self) -> Option<&Vec<i64>> {
1065 self.eos_token_ids.as_ref()
1066 }
1067 fn get_forced_bos_token_id(&self) -> Option<i64> {
1068 self.forced_bos_token_id
1069 }
1070 fn get_forced_eos_token_id(&self) -> Option<i64> {
1071 self.forced_eos_token_id
1072 }
1073 fn get_pad_id(&self) -> Option<i64> {
1074 self.pad_token_id
1075 }
1076 fn is_encoder_decoder(&self) -> bool {
1077 self.is_encoder_decoder
1078 }
1079 fn get_vocab_size(&self) -> i64 {
1080 self.vocab_size
1081 }
1082 fn get_decoder_start_id(&self) -> Option<i64> {
1083 self.decoder_start_id
1084 }
1085 fn get_max_positions_embeddings(&self) -> Option<i64> {
1086 Some(self.max_position_embeddings)
1087 }
1088
1089 fn forward_t(
1090 &self,
1091 input_ids: Option<&Tensor>,
1092 cache: Cache,
1093 attention_mask: Option<&Tensor>,
1094 _token_type_ids: Option<&Tensor>,
1095 _position_ids: Option<&Tensor>,
1096 _input_embeds: Option<&Tensor>,
1097 encoder_outputs: Option<&Tensor>,
1098 decoder_input_ids: Option<&Tensor>,
1099 train: bool,
1100 ) -> Result<LMModelOutput, RustBertError> {
1101 let base_model_output = match cache {
1102 Cache::BARTCache(cached_layer_states) => self.model.forward_t(
1103 input_ids,
1104 attention_mask,
1105 encoder_outputs,
1106 decoder_input_ids,
1107 None,
1108 cached_layer_states,
1109 train,
1110 ),
1111
1112 Cache::None => self.model.forward_t(
1113 input_ids,
1114 attention_mask,
1115 encoder_outputs,
1116 decoder_input_ids,
1117 None,
1118 None,
1119 train,
1120 ),
1121 _ => {
1122 return Err(RustBertError::ValueError(
1123 "Cache not compatible with BART Model".into(),
1124 ));
1125 }
1126 };
1127
1128 Ok(LMModelOutput {
1129 lm_logits: base_model_output.decoder_output,
1130 cache: Cache::BARTCache(base_model_output.cache),
1131 })
1132 }
1133
1134 fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
1135 Some(self.model.encode(input_ids, attention_mask))
1136 }
1137
1138 fn prepare_inputs_for_generation<'a>(
1139 &self,
1140 input_ids: Tensor,
1141 encoder_outputs: Option<&'a Tensor>,
1142 past: Cache,
1143 attention_mask: Tensor,
1144 ) -> PreparedInput<'a> {
1145 match past {
1146 Cache::BARTCache(past) => PreparedInput {
1147 prepared_input: None,
1148 prepared_attention_mask: Some(attention_mask),
1149 prepared_encoder_output: encoder_outputs,
1150 prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
1151 prepared_position_ids: None,
1152 prepared_past: Cache::BARTCache(past),
1153 },
1154 Cache::None => PreparedInput {
1155 prepared_input: None,
1156 prepared_attention_mask: Some(attention_mask),
1157 prepared_encoder_output: encoder_outputs,
1158 prepared_decoder_input: Some(input_ids),
1159 prepared_position_ids: None,
1160 prepared_past: Cache::BARTCache(None),
1161 },
1162 _ => panic!("Cache type incompatible with BART"),
1163 }
1164 }
1165
1166 fn reorder_cache(
1167 &self,
1168 past: &mut Cache,
1169 encoder_outputs: Option<Tensor>,
1170 beam_indices: &Tensor,
1171 ) -> Option<Tensor> {
1172 let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
1173 match past {
1174 Cache::BARTCache(old_cache_option) => match old_cache_option {
1175 Some(old_cache) => {
1176 for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
1177 if self_layer_state.is_some() {
1178 self_layer_state
1179 .as_mut()
1180 .unwrap()
1181 .reorder_cache(beam_indices)
1182 };
1183 if encoder_layer_state.is_some() {
1184 encoder_layer_state
1185 .as_mut()
1186 .unwrap()
1187 .reorder_cache(beam_indices)
1188 };
1189 }
1190 }
1191 None => {}
1192 },
1193 Cache::None => {}
1194 _ => {
1195 panic!("Invalid cache for BART model");
1196 }
1197 };
1198 encoder_outputs
1199 }
1200}
1201
1202impl LanguageGenerator for BartGenerator {}
1203
1204#[cfg(test)]
1205mod test {
1206 use tch::Device;
1207
1208 use crate::{
1209 resources::{RemoteResource, ResourceProvider},
1210 Config,
1211 };
1212
1213 use super::{BartConfig, BartConfigResources, BartModel};
1214
1215 #[test]
1216 #[ignore] // compilation is enough, no need to run
1217 fn bart_model_send() {
1218 let config_resource = Box::new(RemoteResource::from_pretrained(BartConfigResources::BART));
1219 let config_path = config_resource.get_local_path().expect("");
1220
1221 // Set-up masked LM model
1222 let device = Device::cuda_if_available();
1223 let vs = tch::nn::VarStore::new(device);
1224 let config = BartConfig::from_file(config_path);
1225
1226 let _: Box<dyn Send> = Box::new(BartModel::new(vs.root(), &config));
1227 }
1228}