rust_bert/models/bart/
bart_model.rs

1// Copyright 2020 The Facebook AI Research Team Authors
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bart::attention::LayerState;
15use crate::bart::decoder::BartDecoder;
16use crate::bart::encoder::BartEncoder;
17use crate::common::activations::Activation;
18use crate::common::dropout::Dropout;
19use crate::common::kind::get_min;
20use crate::pipelines::common::{ModelType, TokenizerOption};
21use crate::pipelines::generation_utils::private_generation_utils::{
22    PreparedInput, PrivateLanguageGenerator,
23};
24use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
25use crate::{Config, RustBertError};
26
27use serde::{Deserialize, Serialize};
28use std::borrow::Borrow;
29use std::collections::HashMap;
30use tch::nn::{embedding, EmbeddingConfig};
31use tch::{nn, Device, Kind, Tensor};
32
33/// # BART Pretrained model weight files
34pub struct BartModelResources;
35
36/// # BART Pretrained model config files
37pub struct BartConfigResources;
38
39/// # BART Pretrained model vocab files
40pub struct BartVocabResources;
41
42/// # BART Pretrained model merges files
43pub struct BartMergesResources;
44
45impl BartModelResources {
46    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
47    pub const BART: (&'static str, &'static str) = (
48        "bart/model",
49        "https://huggingface.co/facebook/bart-large/resolve/main/rust_model.ot",
50    );
51    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
52    pub const BART_CNN: (&'static str, &'static str) = (
53        "bart-cnn/model",
54        "https://huggingface.co/facebook/bart-large-cnn/resolve/main/rust_model.ot",
55    );
56    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
57    pub const BART_XSUM: (&'static str, &'static str) = (
58        "bart-xsum/model",
59        "https://huggingface.co/facebook/bart-large-xsum/resolve/main/rust_model.ot",
60    );
61    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
62    pub const BART_MNLI: (&'static str, &'static str) = (
63        "bart-large-mnli/model",
64        "https://huggingface.co/facebook/bart-large-mnli/resolve/main/rust_model.ot",
65    );
66    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
67    pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
68        "distilbart-cnn-6-6/model",
69        "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/rust_model.ot",
70    );
71    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
72    pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
73        "distilbart-cnn-12-6/model",
74        "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/rust_model.ot",
75    );
76}
77
78impl BartConfigResources {
79    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
80    pub const BART: (&'static str, &'static str) = (
81        "bart/config",
82        "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
83    );
84    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
85    pub const BART_CNN: (&'static str, &'static str) = (
86        "bart-cnn/config",
87        "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json",
88    );
89    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
90    pub const BART_XSUM: (&'static str, &'static str) = (
91        "bart-xsum/config",
92        "https://huggingface.co/facebook/bart-large-xsum/resolve/main/config.json",
93    );
94    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
95    pub const BART_MNLI: (&'static str, &'static str) = (
96        "bart-large-mnli/config",
97        "https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json",
98    );
99    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
100    pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
101        "distilbart-cnn-6-6/config",
102        "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json",
103    );
104    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
105    pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
106        "distilbart-cnn-12-6/config",
107        "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/config.json",
108    );
109}
110
111impl BartVocabResources {
112    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
113    pub const BART: (&'static str, &'static str) = (
114        "bart/vocab",
115        "https://huggingface.co/roberta-large/resolve/main/vocab.json",
116    );
117    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
118    pub const BART_CNN: (&'static str, &'static str) = (
119        "bart-cnn/vocab",
120        "https://huggingface.co/roberta-large/resolve/main/vocab.json",
121    );
122    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
123    pub const BART_XSUM: (&'static str, &'static str) = (
124        "bart-xsum/vocab",
125        "https://huggingface.co/roberta-large/resolve/main/vocab.json",
126    );
127    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
128    pub const BART_MNLI: (&'static str, &'static str) = (
129        "bart-large-mnli/vocab",
130        "https://huggingface.co/roberta-large/resolve/main/vocab.json",
131    );
132    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
133    pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
134        "distilbart-cnn-6-6/vocab",
135        "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/vocab.json",
136    );
137    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
138    pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
139        "distilbart-cnn-12-6/vocab",
140        "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/vocab.json",
141    );
142}
143
144impl BartMergesResources {
145    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
146    pub const BART: (&'static str, &'static str) = (
147        "bart/merges",
148        "https://huggingface.co/roberta-large/resolve/main/merges.txt",
149    );
150    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
151    pub const BART_CNN: (&'static str, &'static str) = (
152        "bart-cnn/merges",
153        "https://huggingface.co/roberta-large/resolve/main/merges.txt",
154    );
155    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
156    pub const BART_XSUM: (&'static str, &'static str) = (
157        "bart-xsum/merges",
158        "https://huggingface.co/roberta-large/resolve/main/merges.txt",
159    );
160    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
161    pub const BART_MNLI: (&'static str, &'static str) = (
162        "bart-large-mnli/merges",
163        "https://huggingface.co/roberta-large/resolve/main/merges.txt",
164    );
165    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
166    pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
167        "distilbart-cnn-6-6/merges",
168        "https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/merges.txt",
169    );
170    /// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
171    pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
172        "distilbart-cnn-12-6/merges",
173        "https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/merges.txt",
174    );
175}
176
177#[derive(Debug, Serialize, Deserialize, Clone)]
178/// # BART model configuration
179/// Defines the BART model architecture (e.g. number of layers, hidden layer size, label mapping...)
180pub struct BartConfig {
181    pub num_labels: Option<i64>,
182    pub activation_function: Option<Activation>,
183    pub activation_dropout: f64,
184    pub attention_dropout: f64,
185    pub classif_dropout: Option<f64>,
186    pub d_model: i64,
187    pub decoder_attention_heads: i64,
188    pub decoder_ffn_dim: i64,
189    pub decoder_layerdrop: f64,
190    pub decoder_layers: i64,
191    pub decoder_start_token_id: Option<i64>,
192    pub dropout: f64,
193    pub encoder_attention_heads: i64,
194    pub encoder_ffn_dim: i64,
195    pub encoder_layerdrop: f64,
196    pub encoder_layers: i64,
197    pub bos_token_id: Option<i64>,
198    pub eos_token_id: Option<i64>,
199    pub forced_bos_token_id: Option<i64>,
200    pub forced_eos_token_id: Option<i64>,
201    pub pad_token_id: Option<i64>,
202    pub id2label: Option<HashMap<i64, String>>,
203    pub label2id: Option<HashMap<String, i64>>,
204    pub init_std: f64,
205    pub is_decoder: Option<bool>,
206    pub is_encoder_decoder: Option<bool>,
207    pub max_position_embeddings: i64,
208    pub min_length: Option<i64>,
209    pub no_repeat_ngram_size: Option<i64>,
210    pub normalize_embedding: Option<bool>,
211    pub num_hidden_layers: i64,
212    pub output_attentions: Option<bool>,
213    pub output_hidden_states: Option<bool>,
214    pub output_past: Option<bool>,
215    pub static_position_embeddings: Option<bool>,
216    pub scale_embedding: Option<bool>,
217    pub vocab_size: i64,
218}
219
220impl Config for BartConfig {}
221
222impl Default for BartConfig {
223    fn default() -> Self {
224        BartConfig {
225            num_labels: Some(3),
226            activation_function: Some(Activation::gelu),
227            activation_dropout: 0.0,
228            attention_dropout: 0.0,
229            classif_dropout: Some(0.0),
230            d_model: 1024,
231            decoder_attention_heads: 16,
232            decoder_ffn_dim: 4096,
233            decoder_layerdrop: 0.0,
234            decoder_layers: 12,
235            decoder_start_token_id: Some(2),
236            dropout: 0.1,
237            encoder_attention_heads: 16,
238            encoder_ffn_dim: 4096,
239            encoder_layerdrop: 0.0,
240            encoder_layers: 12,
241            bos_token_id: Some(0),
242            eos_token_id: Some(2),
243            pad_token_id: Some(1),
244            forced_bos_token_id: Some(0),
245            forced_eos_token_id: Some(2),
246            id2label: None,
247            label2id: None,
248            init_std: 0.02,
249            is_decoder: None,
250            is_encoder_decoder: Some(true),
251            max_position_embeddings: 1024,
252            min_length: None,
253            no_repeat_ngram_size: None,
254            normalize_embedding: Some(true),
255            num_hidden_layers: 12,
256            output_attentions: None,
257            output_hidden_states: None,
258            output_past: None,
259            static_position_embeddings: None,
260            scale_embedding: Some(false),
261            vocab_size: 50265,
262        }
263    }
264}
265
266pub(crate) fn _make_causal_mask(
267    input_ids_shape: &[i64],
268    dtype: Kind,
269    device: Device,
270    past_key_values_length: i64,
271) -> Tensor {
272    let batch_size = input_ids_shape[0];
273    let target_length = input_ids_shape[1];
274
275    let mut mask = Tensor::full(
276        [target_length, target_length],
277        get_min(dtype).unwrap(),
278        (dtype, device),
279    );
280    let mask_cond = Tensor::arange(target_length, (dtype, device));
281    let _ = mask.masked_fill_(
282        &mask_cond.lt_tensor(&(&mask_cond + 1).view([target_length, 1])),
283        0,
284    );
285
286    if past_key_values_length > 0 {
287        mask = Tensor::cat(
288            &[
289                Tensor::zeros([target_length, past_key_values_length], (dtype, device)),
290                mask,
291            ],
292            -1,
293        );
294    }
295    mask.unsqueeze(0).unsqueeze(0).expand(
296        [
297            batch_size,
298            1,
299            target_length,
300            target_length + past_key_values_length,
301        ],
302        true,
303    )
304}
305
306pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option<i64>, dtype: Kind) -> Tensor {
307    let (batch_size, source_length) = mask.size2().unwrap();
308    let target_length = target_length.unwrap_or(source_length);
309    let expanded_mask = mask
310        .unsqueeze(1)
311        .unsqueeze(1)
312        .expand([batch_size, 1, target_length, source_length], true)
313        .totype(dtype);
314    let inverted_mask: Tensor = 1 - expanded_mask;
315    inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), get_min(dtype).unwrap())
316}
317
318pub(crate) fn _prepare_decoder_attention_mask(
319    attention_mask: Option<&Tensor>,
320    input_shape: &[i64],
321    input_embeds: &Tensor,
322    past_key_values_length: i64,
323) -> Option<Tensor> {
324    let last_input_shape_dim = *input_shape.last().unwrap();
325    let mut combined_attention_mask = if last_input_shape_dim > 1 {
326        Some(_make_causal_mask(
327            input_shape,
328            input_embeds.kind(),
329            input_embeds.device(),
330            past_key_values_length,
331        ))
332    } else {
333        None
334    };
335
336    if let Some(attention_mask) = attention_mask {
337        let expanded_attention_mask = _expand_mask(
338            attention_mask,
339            Some(last_input_shape_dim),
340            input_embeds.kind(),
341        );
342        combined_attention_mask = match combined_attention_mask {
343            Some(value) => Some(value + expanded_attention_mask),
344            None => Some(expanded_attention_mask),
345        };
346    }
347
348    combined_attention_mask
349}
350
351fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
352    let index_eos: Tensor =
353        input_ids
354            .ne(pad_token_id)
355            .sum_dim_intlist([-1].as_slice(), true, Kind::Int64)
356            - 1;
357    let output = input_ids.empty_like().to_kind(Kind::Int64);
358    output
359        .select(1, 0)
360        .copy_(&input_ids.gather(1, &index_eos, false).squeeze());
361    output
362        .slice(1, 1, *output.size().last().unwrap(), 1)
363        .copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));
364    output
365}
366
367/// # BART Base model
368/// Base architecture for BART model. Usually complemented with a task-specific head, such as a language model head.
369/// It is made of the following blocks:
370/// - `encoder`: `BartEncoder` (transformer) made of a vector of encoding layers
371/// - `decoder`: `BartDecoder` (transformer)  made of a vector of decoding layers with self attention and encoder cross-attention.
372///     caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
373/// - `pad_token_id`: padding token id
374pub struct BartModel {
375    pub(crate) encoder: BartEncoder,
376    decoder: BartDecoder,
377    pub(crate) embeddings: nn::Embedding,
378    pad_token_id: i64,
379}
380
381impl BartModel {
382    /// Build a new `BartModel`
383    ///
384    /// # Arguments
385    ///
386    /// * `p` - Variable store path for the root of the BART model
387    /// * `config` - `BartConfig` object defining the model architecture
388    ///
389    /// # Example
390    ///
391    /// ```no_run
392    /// use rust_bert::bart::{BartConfig, BartModel};
393    /// use rust_bert::Config;
394    /// use std::path::Path;
395    /// use tch::{nn, Device};
396    ///
397    /// let config_path = Path::new("path/to/config.json");
398    /// let device = Device::Cpu;
399    /// let p = nn::VarStore::new(device);
400    /// let config = BartConfig::from_file(config_path);
401    /// let bart: BartModel = BartModel::new(&p.root() / "bart", &config);
402    /// ```
403    pub fn new<'p, P>(p: P, config: &BartConfig) -> BartModel
404    where
405        P: Borrow<nn::Path<'p>>,
406    {
407        let p = p.borrow();
408
409        let pad_token_id = config.pad_token_id.unwrap_or(1);
410        let embedding_config = EmbeddingConfig {
411            padding_idx: pad_token_id,
412            ..Default::default()
413        };
414        let embeddings: nn::Embedding = embedding(
415            p / "shared",
416            config.vocab_size,
417            config.d_model,
418            embedding_config,
419        );
420
421        let encoder = BartEncoder::new(p / "encoder", config);
422        let decoder = BartDecoder::new(p / "decoder", config);
423
424        BartModel {
425            encoder,
426            decoder,
427            embeddings,
428            pad_token_id,
429        }
430    }
431
432    /// Forward pass through the model
433    ///
434    /// # Arguments
435    ///
436    /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
437    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
438    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
439    /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
440    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
441    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
442    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
443    ///
444    /// # Returns
445    ///
446    /// * `BartModelOutput` containing:
447    ///   - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
448    ///   - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None
449    ///   - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
450    ///   - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
451    ///   - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
452    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
453    ///   - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
454    ///
455    /// # Example
456    ///
457    /// ```no_run
458    /// # use tch::{nn, Device, Tensor, no_grad};
459    /// # use rust_bert::Config;
460    /// # use std::path::Path;
461    /// # use tch::kind::Kind::{Int64, Double};
462    /// use rust_bert::bart::{BartConfig, BartModel};
463    /// # let config_path = Path::new("path/to/config.json");
464    /// # let vocab_path = Path::new("path/to/vocab.txt");
465    /// # let device = Device::Cpu;
466    /// # let vs = nn::VarStore::new(device);
467    /// # let config = BartConfig::from_file(config_path);
468    /// # let bart_model: BartModel = BartModel::new(&vs.root(), &config);
469    /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
470    /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
471    /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
472    /// let encoder_attention_mask =
473    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
474    /// let decoder_attention_mask =
475    ///     Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
476    ///
477    /// let model_output = no_grad(|| {
478    ///     bart_model.forward_t(
479    ///         Some(&input_tensor),
480    ///         Some(&encoder_attention_mask),
481    ///         Some(&target_tensor),
482    ///         None,
483    ///         Some(&decoder_attention_mask),
484    ///         None,
485    ///         false,
486    ///     )
487    /// });
488    /// ```
489    pub fn forward_t(
490        &self,
491        input_ids: Option<&Tensor>,
492        attention_mask: Option<&Tensor>,
493        decoder_input_ids: Option<&Tensor>,
494        encoder_output: Option<&Tensor>,
495        decoder_attention_mask: Option<&Tensor>,
496        layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
497        train: bool,
498    ) -> BartModelOutput {
499        let calc_decoder_input_ids = if decoder_input_ids.is_none() {
500            Some(_shift_tokens_right(input_ids.unwrap(), self.pad_token_id))
501        } else {
502            None
503        };
504
505        let decoder_input_ids =
506            decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap());
507
508        let calc_encoder_output = if encoder_output.is_none() {
509            Some(self.encoder.forward_t(
510                input_ids.unwrap(),
511                attention_mask,
512                &self.embeddings,
513                train,
514            ))
515        } else {
516            None
517        };
518
519        let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
520            if let Some(calc_encoder_output) = calc_encoder_output {
521                (
522                    Some(calc_encoder_output.hidden_state),
523                    calc_encoder_output.all_hidden_states,
524                    calc_encoder_output.all_attentions,
525                )
526            } else {
527                (None, None, None)
528            };
529
530        let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
531
532        let decoder_output = self.decoder.forward_t(
533            decoder_input_ids,
534            encoder_output,
535            attention_mask,
536            decoder_attention_mask,
537            &self.embeddings,
538            layer_states,
539            train,
540        );
541        BartModelOutput {
542            decoder_output: decoder_output.hidden_state,
543            encoder_hidden_state: calc_hidden_states,
544            cache: decoder_output.next_decoder_cache,
545            all_decoder_hidden_states: decoder_output.all_hidden_states,
546            all_decoder_attentions: decoder_output.all_attentions,
547            all_encoder_hidden_states,
548            all_encoder_attentions,
549        }
550    }
551}
552
553/// # BART Model for conditional generation
554/// BART model with a vocabulary decoding head
555/// It is made of the following blocks:
556/// - `base_model`: `BartModel` Base BART model
557/// - `linear`: Linear layer without bias tied to the weights of the token id embeddings
558pub struct BartForConditionalGeneration {
559    base_model: BartModel,
560}
561
562impl BartForConditionalGeneration {
563    /// Build a new `BartForConditionalGeneration`
564    ///
565    /// # Arguments
566    ///
567    /// * `p` - Variable store path for the root of the BART model
568    /// * `config` - `BartConfig` object defining the model architecture
569    ///
570    /// # Example
571    ///
572    /// ```no_run
573    /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
574    /// use rust_bert::Config;
575    /// use std::path::Path;
576    /// use tch::{nn, Device};
577    ///
578    /// let config_path = Path::new("path/to/config.json");
579    /// let device = Device::Cpu;
580    /// let p = nn::VarStore::new(device);
581    /// let config = BartConfig::from_file(config_path);
582    /// let bart: BartForConditionalGeneration =
583    ///     BartForConditionalGeneration::new(&p.root() / "bart", &config);
584    /// ```
585    pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForConditionalGeneration
586    where
587        P: Borrow<nn::Path<'p>>,
588    {
589        let base_model = BartModel::new(p.borrow() / "model", config);
590        BartForConditionalGeneration { base_model }
591    }
592
593    /// Forward pass through the model
594    ///
595    /// # Arguments
596    ///
597    /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
598    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
599    /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
600    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
601    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
602    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
603    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
604    ///
605    /// # Returns
606    ///
607    /// * `BartModelOutput` containing:
608    ///   - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
609    ///   - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
610    ///   - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
611    ///   - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
612    ///   - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
613    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
614    ///   - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
615    ///
616    /// # Example
617    ///
618    /// ```no_run
619    /// # use tch::{nn, Device, Tensor, no_grad};
620    /// # use rust_bert::Config;
621    /// # use std::path::Path;
622    /// # use tch::kind::Kind::{Int64, Double};
623    /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
624    /// # let config_path = Path::new("path/to/config.json");
625    /// # let vocab_path = Path::new("path/to/vocab.txt");
626    /// # let device = Device::Cpu;
627    /// # let vs = nn::VarStore::new(device);
628    /// # let config = BartConfig::from_file(config_path);
629    /// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
630    ///  let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
631    ///  let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
632    ///  let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
633    ///  let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
634    ///  let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
635    ///
636    ///  let model_output = no_grad(|| {
637    ///    bart_model
638    ///         .forward_t(Some(&input_tensor),
639    ///                    Some(&encoder_attention_mask),
640    ///                    None,
641    ///                    Some(&target_tensor),
642    ///                    Some(&decoder_attention_mask),
643    ///                    None,
644    ///                    false)
645    ///    });
646    /// ```
647    pub fn forward_t(
648        &self,
649        input_ids: Option<&Tensor>,
650        attention_mask: Option<&Tensor>,
651        encoder_output: Option<&Tensor>,
652        decoder_input_ids: Option<&Tensor>,
653        decoder_attention_mask: Option<&Tensor>,
654        old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
655        train: bool,
656    ) -> BartModelOutput {
657        let base_model_output = self.base_model.forward_t(
658            input_ids,
659            attention_mask,
660            decoder_input_ids,
661            encoder_output,
662            decoder_attention_mask,
663            old_layer_states,
664            train,
665        );
666
667        let lm_logits = base_model_output
668            .decoder_output
669            .linear::<Tensor>(&self.base_model.embeddings.ws, None);
670        BartModelOutput {
671            decoder_output: lm_logits,
672            ..base_model_output
673        }
674    }
675
676    pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
677        self.base_model
678            .encoder
679            .forward_t(
680                input_ids,
681                attention_mask,
682                &self.base_model.embeddings,
683                false,
684            )
685            .hidden_state
686    }
687}
688
689pub struct BartClassificationHead {
690    dense: nn::Linear,
691    dropout: Dropout,
692    out_proj: nn::Linear,
693}
694
695impl BartClassificationHead {
696    pub fn new<'p, P>(p: P, config: &BartConfig) -> Result<BartClassificationHead, RustBertError>
697    where
698        P: Borrow<nn::Path<'p>>,
699    {
700        let p = p.borrow();
701        let num_labels = config
702            .id2label
703            .as_ref()
704            .ok_or_else(|| {
705                RustBertError::InvalidConfigurationError(
706                    "num_labels not provided in configuration".to_string(),
707                )
708            })?
709            .len() as i64;
710        let dense = nn::linear(
711            p / "dense",
712            config.d_model,
713            config.d_model,
714            Default::default(),
715        );
716        let dropout = Dropout::new(config.classif_dropout.unwrap_or(0.0));
717        let out_proj = nn::linear(
718            p / "out_proj",
719            config.d_model,
720            num_labels,
721            Default::default(),
722        );
723
724        Ok(BartClassificationHead {
725            dense,
726            dropout,
727            out_proj,
728        })
729    }
730
731    pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
732        x.apply_t(&self.dropout, train)
733            .apply(&self.dense)
734            .tanh()
735            .apply_t(&self.dropout, train)
736            .apply(&self.out_proj)
737    }
738}
739
740/// # BART Model for sequence classification
741/// BART model with a classification head
742/// It is made of the following blocks:
743/// - `base_model`: `BartModel` Base BART model
744/// - `classification_head`: `BartClassificationHead` made of 2 linear layers mapping hidden states to a target class
745/// - `eos_token_id`: token id for the EOS token carrying the pooled representation for classification
746pub struct BartForSequenceClassification {
747    base_model: BartModel,
748    classification_head: BartClassificationHead,
749    eos_token_id: i64,
750}
751
752impl BartForSequenceClassification {
753    /// Build a new `BartForSequenceClassification`
754    ///
755    /// # Arguments
756    ///
757    /// * `p` - Variable store path for the root of the BART model
758    /// * `config` - `BartConfig` object defining the model architecture
759    ///
760    /// # Example
761    ///
762    /// ```no_run
763    /// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
764    /// use rust_bert::Config;
765    /// use std::path::Path;
766    /// use tch::{nn, Device};
767    ///
768    /// let config_path = Path::new("path/to/config.json");
769    /// let device = Device::Cpu;
770    /// let p = nn::VarStore::new(device);
771    /// let config = BartConfig::from_file(config_path);
772    /// let bart: BartForSequenceClassification =
773    ///     BartForSequenceClassification::new(&p.root() / "bart", &config).unwrap();
774    /// ```
775    pub fn new<'p, P>(
776        p: P,
777        config: &BartConfig,
778    ) -> Result<BartForSequenceClassification, RustBertError>
779    where
780        P: Borrow<nn::Path<'p>>,
781    {
782        let p = p.borrow();
783
784        let base_model = BartModel::new(p / "model", config);
785        let classification_head = BartClassificationHead::new(p / "classification_head", config)?;
786        let eos_token_id = config.eos_token_id.unwrap_or(3);
787        Ok(BartForSequenceClassification {
788            base_model,
789            classification_head,
790            eos_token_id,
791        })
792    }
793
794    /// Forward pass through the model
795    ///
796    /// # Arguments
797    ///
798    /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
799    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
800    /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
801    ///     These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
802    /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
803    /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
804    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
805    ///
806    /// # Returns
807    ///
808    /// * `BartModelOutput` containing:
809    ///   - `decoder_output` - `Tensor` of shape (*batch size*, *num_classes*) representing the activations for each class and batch item
810    ///   - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None.
811    ///   - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
812    ///   - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
813    ///   - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
814    ///   - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
815    ///   - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
816    ///
817    /// # Example
818    ///
819    /// ```no_run
820    /// # use tch::{nn, Device, Tensor, no_grad};
821    /// # use rust_bert::Config;
822    /// # use std::path::Path;
823    /// # use tch::kind::Kind::{Int64, Double};
824    /// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
825    /// # let config_path = Path::new("path/to/config.json");
826    /// # let vocab_path = Path::new("path/to/vocab.txt");
827    /// # let device = Device::Cpu;
828    /// # let vs = nn::VarStore::new(device);
829    /// # let config = BartConfig::from_file(config_path);
830    /// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap();
831    ///  let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
832    ///  let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
833    ///  let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
834    ///  let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
835    ///  let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
836    ///
837    ///  let model_output = no_grad(|| {
838    ///    bart_model
839    ///         .forward_t(&input_tensor,
840    ///                    Some(&encoder_attention_mask),
841    ///                    None,
842    ///                    Some(&target_tensor),
843    ///                    Some(&decoder_attention_mask),
844    ///                    false)
845    ///    });
846    /// ```
847    pub fn forward_t(
848        &self,
849        input_ids: &Tensor,
850        attention_mask: Option<&Tensor>,
851        encoder_output: Option<&Tensor>,
852        decoder_input_ids: Option<&Tensor>,
853        decoder_attention_mask: Option<&Tensor>,
854        train: bool,
855    ) -> BartModelOutput {
856        let base_model_output = self.base_model.forward_t(
857            Some(input_ids),
858            attention_mask,
859            decoder_input_ids,
860            encoder_output,
861            decoder_attention_mask,
862            None,
863            train,
864        );
865        let eos_mask = input_ids.eq(self.eos_token_id);
866        let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, input_ids.kind());
867        let sentence_representation = base_model_output
868            .decoder_output
869            .permute([2, 0, 1])
870            .masked_select(&eos_mask)
871            .view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
872            .transpose(0, 1)
873            .view((
874                base_model_output.decoder_output.size()[0],
875                -1,
876                *base_model_output.decoder_output.size().last().unwrap(),
877            ))
878            .select(1, -1);
879
880        let logits = self
881            .classification_head
882            .forward_t(&sentence_representation, train);
883        BartModelOutput {
884            decoder_output: logits,
885            encoder_hidden_state: base_model_output.encoder_hidden_state,
886            cache: None,
887            all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
888            all_decoder_attentions: base_model_output.all_decoder_attentions,
889            all_encoder_hidden_states: base_model_output.all_encoder_hidden_states,
890            all_encoder_attentions: base_model_output.all_encoder_attentions,
891        }
892    }
893}
894
895/// Container holding a BART model output. The decoder output may hold the hidden state of
896/// the last layer of the decoder, or may hold logits for a custom head module after the
897/// decoder (e.g. for classification or language modeling tasks)
898pub struct BartModelOutput {
899    /// Hidden state of the last layer of the decoder, or logits for a custom head
900    /// module after the decoder (e.g. for classification or language modeling tasks)
901    pub decoder_output: Tensor,
902    /// Hidden state for the last layer of the encoder if they are calculated (not provided), otherwise None
903    pub encoder_hidden_state: Option<Tensor>,
904    /// Cached outputs of the model (attention layers keys and values) if the model is used for generation
905    pub cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
906    /// Hidden states for all layers of the decoder
907    pub all_decoder_hidden_states: Option<Vec<Tensor>>,
908    /// Attention weights for all layers of the decoder
909    pub all_decoder_attentions: Option<Vec<Tensor>>,
910    /// Hidden states for all layers of the encoder
911    pub all_encoder_hidden_states: Option<Vec<Tensor>>,
912    /// Attention weights for all layers of the encoder
913    pub all_encoder_attentions: Option<Vec<Tensor>>,
914}
915
916/// # Language generation model based on the Bart architecture
917pub struct BartGenerator {
918    model: BartForConditionalGeneration,
919    tokenizer: TokenizerOption,
920    var_store: nn::VarStore,
921    generate_config: GenerateConfig,
922    bos_token_id: Option<i64>,
923    eos_token_ids: Option<Vec<i64>>,
924    forced_bos_token_id: Option<i64>,
925    forced_eos_token_id: Option<i64>,
926    pad_token_id: Option<i64>,
927    is_encoder_decoder: bool,
928    vocab_size: i64,
929    decoder_start_id: Option<i64>,
930    max_position_embeddings: i64,
931}
932
933impl BartGenerator {
934    /// Build a new `BartGenerator`
935    ///
936    /// # Arguments
937    ///
938    /// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
939    /// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
940    /// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
941    /// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
942    /// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
943    ///
944    /// # Example
945    ///
946    /// ```no_run
947    /// # use std::path::PathBuf;
948    /// # use tch::Device;
949    /// # fn main() -> anyhow::Result<()> {
950    /// use rust_bert::bart::BartGenerator;
951    /// use rust_bert::pipelines::generation_utils::GenerateConfig;
952    /// # let mut home: PathBuf = dirs::home_dir().unwrap();
953    /// # home.push("rustbert");
954    /// # home.push("openai-gpt");
955    /// # let config_path = &home.as_path().join("config.json");
956    /// # let vocab_path = &home.as_path().join("vocab.txt");
957    /// # let merges_path = &home.as_path().join("merges.txt");
958    /// # let weights_path = &home.as_path().join("model.ot");
959    /// let device = Device::cuda_if_available();
960    /// let generate_config = GenerateConfig {
961    ///     max_length: Some(30),
962    ///     do_sample: true,
963    ///     num_beams: 5,
964    ///     temperature: 1.1,
965    ///     num_return_sequences: 3,
966    ///     ..Default::default()
967    /// };
968    /// let bart_generator = BartGenerator::new(generate_config)?;
969    /// # Ok(())
970    /// # }
971    /// ```
972    pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
973        let vocab_path = generate_config.vocab_resource.get_local_path()?;
974        let merges_path = generate_config
975            .merges_resource
976            .as_ref()
977            .ok_or_else(|| {
978                RustBertError::InvalidConfigurationError(
979                    "BART expects a merges resources to be provided".to_string(),
980                )
981            })?
982            .get_local_path()?;
983
984        let tokenizer = TokenizerOption::from_file(
985            ModelType::Bart,
986            vocab_path.to_str().unwrap(),
987            Some(merges_path.to_str().unwrap()),
988            false,
989            None,
990            false,
991        )?;
992
993        Self::new_with_tokenizer(generate_config, tokenizer)
994    }
995
996    pub fn new_with_tokenizer(
997        generate_config: GenerateConfig,
998        tokenizer: TokenizerOption,
999    ) -> Result<BartGenerator, RustBertError> {
1000        let config_path = generate_config.config_resource.get_local_path()?;
1001        let device = generate_config.device;
1002
1003        generate_config.validate();
1004        let mut var_store = nn::VarStore::new(device);
1005        let config = BartConfig::from_file(config_path);
1006        let model = BartForConditionalGeneration::new(var_store.root(), &config);
1007        crate::resources::load_weights(
1008            &generate_config.model_resource,
1009            &mut var_store,
1010            generate_config.kind,
1011            device,
1012        )?;
1013
1014        let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
1015        let eos_token_ids = Some(match config.eos_token_id {
1016            Some(value) => vec![value],
1017            None => vec![2],
1018        });
1019        let forced_bos_token_id = config.forced_bos_token_id;
1020        let forced_eos_token_id = config.forced_eos_token_id;
1021        let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
1022        let vocab_size = config.vocab_size;
1023        let is_encoder_decoder = true;
1024        let decoder_start_id = config.decoder_start_token_id;
1025        let max_position_embeddings = config.max_position_embeddings;
1026
1027        Ok(BartGenerator {
1028            model,
1029            tokenizer,
1030            var_store,
1031            generate_config,
1032            bos_token_id,
1033            eos_token_ids,
1034            forced_bos_token_id,
1035            forced_eos_token_id,
1036            pad_token_id,
1037            is_encoder_decoder,
1038            vocab_size,
1039            decoder_start_id,
1040            max_position_embeddings,
1041        })
1042    }
1043}
1044
1045impl PrivateLanguageGenerator for BartGenerator {
1046    fn _get_tokenizer(&self) -> &TokenizerOption {
1047        &self.tokenizer
1048    }
1049    fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
1050        &mut self.tokenizer
1051    }
1052    fn get_device(&self) -> Device {
1053        self.var_store.device()
1054    }
1055    fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
1056        Ok(&mut self.var_store)
1057    }
1058    fn get_config(&self) -> &GenerateConfig {
1059        &self.generate_config
1060    }
1061    fn get_bos_id(&self) -> Option<i64> {
1062        self.bos_token_id
1063    }
1064    fn get_eos_ids(&self) -> Option<&Vec<i64>> {
1065        self.eos_token_ids.as_ref()
1066    }
1067    fn get_forced_bos_token_id(&self) -> Option<i64> {
1068        self.forced_bos_token_id
1069    }
1070    fn get_forced_eos_token_id(&self) -> Option<i64> {
1071        self.forced_eos_token_id
1072    }
1073    fn get_pad_id(&self) -> Option<i64> {
1074        self.pad_token_id
1075    }
1076    fn is_encoder_decoder(&self) -> bool {
1077        self.is_encoder_decoder
1078    }
1079    fn get_vocab_size(&self) -> i64 {
1080        self.vocab_size
1081    }
1082    fn get_decoder_start_id(&self) -> Option<i64> {
1083        self.decoder_start_id
1084    }
1085    fn get_max_positions_embeddings(&self) -> Option<i64> {
1086        Some(self.max_position_embeddings)
1087    }
1088
1089    fn forward_t(
1090        &self,
1091        input_ids: Option<&Tensor>,
1092        cache: Cache,
1093        attention_mask: Option<&Tensor>,
1094        _token_type_ids: Option<&Tensor>,
1095        _position_ids: Option<&Tensor>,
1096        _input_embeds: Option<&Tensor>,
1097        encoder_outputs: Option<&Tensor>,
1098        decoder_input_ids: Option<&Tensor>,
1099        train: bool,
1100    ) -> Result<LMModelOutput, RustBertError> {
1101        let base_model_output = match cache {
1102            Cache::BARTCache(cached_layer_states) => self.model.forward_t(
1103                input_ids,
1104                attention_mask,
1105                encoder_outputs,
1106                decoder_input_ids,
1107                None,
1108                cached_layer_states,
1109                train,
1110            ),
1111
1112            Cache::None => self.model.forward_t(
1113                input_ids,
1114                attention_mask,
1115                encoder_outputs,
1116                decoder_input_ids,
1117                None,
1118                None,
1119                train,
1120            ),
1121            _ => {
1122                return Err(RustBertError::ValueError(
1123                    "Cache not compatible with BART Model".into(),
1124                ));
1125            }
1126        };
1127
1128        Ok(LMModelOutput {
1129            lm_logits: base_model_output.decoder_output,
1130            cache: Cache::BARTCache(base_model_output.cache),
1131        })
1132    }
1133
1134    fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
1135        Some(self.model.encode(input_ids, attention_mask))
1136    }
1137
1138    fn prepare_inputs_for_generation<'a>(
1139        &self,
1140        input_ids: Tensor,
1141        encoder_outputs: Option<&'a Tensor>,
1142        past: Cache,
1143        attention_mask: Tensor,
1144    ) -> PreparedInput<'a> {
1145        match past {
1146            Cache::BARTCache(past) => PreparedInput {
1147                prepared_input: None,
1148                prepared_attention_mask: Some(attention_mask),
1149                prepared_encoder_output: encoder_outputs,
1150                prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
1151                prepared_position_ids: None,
1152                prepared_past: Cache::BARTCache(past),
1153            },
1154            Cache::None => PreparedInput {
1155                prepared_input: None,
1156                prepared_attention_mask: Some(attention_mask),
1157                prepared_encoder_output: encoder_outputs,
1158                prepared_decoder_input: Some(input_ids),
1159                prepared_position_ids: None,
1160                prepared_past: Cache::BARTCache(None),
1161            },
1162            _ => panic!("Cache type incompatible with BART"),
1163        }
1164    }
1165
1166    fn reorder_cache(
1167        &self,
1168        past: &mut Cache,
1169        encoder_outputs: Option<Tensor>,
1170        beam_indices: &Tensor,
1171    ) -> Option<Tensor> {
1172        let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
1173        match past {
1174            Cache::BARTCache(old_cache_option) => match old_cache_option {
1175                Some(old_cache) => {
1176                    for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
1177                        if self_layer_state.is_some() {
1178                            self_layer_state
1179                                .as_mut()
1180                                .unwrap()
1181                                .reorder_cache(beam_indices)
1182                        };
1183                        if encoder_layer_state.is_some() {
1184                            encoder_layer_state
1185                                .as_mut()
1186                                .unwrap()
1187                                .reorder_cache(beam_indices)
1188                        };
1189                    }
1190                }
1191                None => {}
1192            },
1193            Cache::None => {}
1194            _ => {
1195                panic!("Invalid cache for BART model");
1196            }
1197        };
1198        encoder_outputs
1199    }
1200}
1201
1202impl LanguageGenerator for BartGenerator {}
1203
1204#[cfg(test)]
1205mod test {
1206    use tch::Device;
1207
1208    use crate::{
1209        resources::{RemoteResource, ResourceProvider},
1210        Config,
1211    };
1212
1213    use super::{BartConfig, BartConfigResources, BartModel};
1214
1215    #[test]
1216    #[ignore] // compilation is enough, no need to run
1217    fn bart_model_send() {
1218        let config_resource = Box::new(RemoteResource::from_pretrained(BartConfigResources::BART));
1219        let config_path = config_resource.get_local_path().expect("");
1220
1221        //    Set-up masked LM model
1222        let device = Device::cuda_if_available();
1223        let vs = tch::nn::VarStore::new(device);
1224        let config = BartConfig::from_file(config_path);
1225
1226        let _: Box<dyn Send> = Box::new(BartModel::new(vs.root(), &config));
1227    }
1228}