rust_bert/models/longformer/
longformer_model.rs

1// Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
2// Copyright 2021 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::activations::{TensorFunction, _tanh};
14use crate::common::dropout::Dropout;
15use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
16use crate::longformer::embeddings::LongformerEmbeddings;
17use crate::longformer::encoder::LongformerEncoder;
18use crate::{Activation, Config, RustBertError};
19use serde::{Deserialize, Serialize};
20use std::borrow::Borrow;
21use std::collections::HashMap;
22use tch::nn::{Init, Module, ModuleT};
23use tch::{nn, Kind, Tensor};
24
25/// # Longformer Pretrained model weight files
26pub struct LongformerModelResources;
27
28/// # Longformer Pretrained model config files
29pub struct LongformerConfigResources;
30
31/// # Longformer Pretrained model vocab files
32pub struct LongformerVocabResources;
33
34/// # Longformer Pretrained model merges files
35pub struct LongformerMergesResources;
36
37impl LongformerModelResources {
38    /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
39    pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
40        "longformer-base-4096/model",
41        "https://huggingface.co/allenai/longformer-base-4096/resolve/main/rust_model.ot",
42    );
43    /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
44    pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
45        "longformer-base-4096/model",
46        "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/rust_model.ot",
47    );
48}
49
50impl LongformerConfigResources {
51    /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
52    pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
53        "longformer-base-4096/config",
54        "https://huggingface.co/allenai/longformer-base-4096/resolve/main/config.json",
55    );
56    /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
57    pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
58        "longformer-base-4096/config",
59        "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/config.json",
60    );
61}
62
63impl LongformerVocabResources {
64    /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
65    pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
66        "longformer-base-4096/vocab",
67        "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
68    );
69    /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
70    pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
71        "longformer-base-4096/vocab",
72        "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/vocab.json",
73    );
74}
75
76impl LongformerMergesResources {
77    /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
78    pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
79        "longformer-base-4096/merges",
80        "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
81    );
82    /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
83    pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
84        "longformer-base-4096/merges",
85        "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/merges.txt",
86    );
87}
88
89#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
90#[serde(rename_all = "camelCase")]
91/// # Longformer Position embeddings type
92pub enum PositionEmbeddingType {
93    Absolute,
94    RelativeKey,
95}
96
97#[derive(Debug, Serialize, Deserialize, Clone)]
98/// # Longformer model configuration
99/// Defines the Longformer model architecture (e.g. number of layers, hidden layer size, label mapping...)
100pub struct LongformerConfig {
101    pub hidden_act: Activation,
102    pub attention_window: Vec<i64>,
103    pub attention_probs_dropout_prob: f64,
104    pub hidden_dropout_prob: f64,
105    pub hidden_size: i64,
106    pub initializer_range: f32,
107    pub intermediate_size: i64,
108    pub max_position_embeddings: i64,
109    pub num_attention_heads: i64,
110    pub num_hidden_layers: i64,
111    pub type_vocab_size: i64,
112    pub vocab_size: i64,
113    pub sep_token_id: i64,
114    pub pad_token_id: Option<i64>,
115    pub layer_norm_eps: Option<f64>,
116    pub output_attentions: Option<bool>,
117    pub output_hidden_states: Option<bool>,
118    pub position_embedding_type: Option<PositionEmbeddingType>,
119    pub is_decoder: Option<bool>,
120    pub id2label: Option<HashMap<i64, String>>,
121    pub label2id: Option<HashMap<String, i64>>,
122}
123
124impl Config for LongformerConfig {}
125
126impl Default for LongformerConfig {
127    fn default() -> Self {
128        LongformerConfig {
129            hidden_act: Activation::gelu,
130            attention_window: vec![512],
131            attention_probs_dropout_prob: 0.1,
132            hidden_dropout_prob: 0.1,
133            hidden_size: 768,
134            initializer_range: 0.02,
135            intermediate_size: 3072,
136            max_position_embeddings: 512,
137            num_attention_heads: 12,
138            num_hidden_layers: 12,
139            type_vocab_size: 2,
140            vocab_size: 30522,
141            sep_token_id: 2,
142            pad_token_id: None,
143            layer_norm_eps: None,
144            output_attentions: None,
145            output_hidden_states: None,
146            position_embedding_type: None,
147            is_decoder: None,
148            id2label: None,
149            label2id: None,
150        }
151    }
152}
153
154fn get_question_end_index(input_ids: &Tensor, sep_token_id: i64) -> Tensor {
155    input_ids
156        .eq(sep_token_id)
157        .nonzero()
158        .view([input_ids.size()[0], 3, 2])
159        .select(2, 1)
160        .select(1, 0)
161}
162
163fn compute_global_attention_mask(
164    input_ids: &Tensor,
165    sep_token_id: i64,
166    before_sep_token: bool,
167) -> Tensor {
168    let question_end_index = get_question_end_index(input_ids, sep_token_id).unsqueeze(1);
169    let attention_mask = Tensor::arange(input_ids.size()[1], (Kind::Int64, input_ids.device()));
170
171    if before_sep_token {
172        attention_mask
173            .expand_as(input_ids)
174            .lt_tensor(&question_end_index)
175    } else {
176        attention_mask
177            .expand_as(input_ids)
178            .gt_tensor(&(question_end_index + 1))
179            * attention_mask
180                .expand_as(input_ids)
181                .lt(*input_ids.size().last().unwrap())
182    }
183}
184
185#[derive(Debug)]
186pub struct LongformerPooler {
187    dense: nn::Linear,
188    activation: TensorFunction,
189}
190
191impl LongformerPooler {
192    pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerPooler
193    where
194        P: Borrow<nn::Path<'p>>,
195    {
196        let p = p.borrow();
197
198        let dense = nn::linear(
199            p / "dense",
200            config.hidden_size,
201            config.hidden_size,
202            Default::default(),
203        );
204
205        let activation = TensorFunction::new(Box::new(_tanh));
206
207        LongformerPooler { dense, activation }
208    }
209}
210
211impl Module for LongformerPooler {
212    fn forward(&self, hidden_states: &Tensor) -> Tensor {
213        self.activation.get_fn()(&hidden_states.select(1, 0).apply(&self.dense))
214    }
215}
216
217#[derive(Debug)]
218pub struct LongformerLMHead {
219    dense: nn::Linear,
220    layer_norm: nn::LayerNorm,
221    decoder: nn::Linear,
222    bias: Tensor,
223}
224
225impl LongformerLMHead {
226    pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerLMHead
227    where
228        P: Borrow<nn::Path<'p>>,
229    {
230        let p = p.borrow();
231
232        let dense = nn::linear(
233            p / "dense",
234            config.hidden_size,
235            config.hidden_size,
236            Default::default(),
237        );
238
239        let layer_norm_config = nn::LayerNormConfig {
240            eps: config.layer_norm_eps.unwrap_or(1e-12),
241            ..Default::default()
242        };
243
244        let layer_norm = nn::layer_norm(
245            p / "layer_norm",
246            vec![config.hidden_size],
247            layer_norm_config,
248        );
249
250        let linear_config = nn::LinearConfig {
251            bias: false,
252            ..Default::default()
253        };
254
255        let decoder = nn::linear(
256            p / "decoder",
257            config.hidden_size,
258            config.vocab_size,
259            linear_config,
260        );
261
262        let bias = p.var("bias", &[config.vocab_size], Init::Const(0f64));
263
264        LongformerLMHead {
265            dense,
266            layer_norm,
267            decoder,
268            bias,
269        }
270    }
271}
272
273impl Module for LongformerLMHead {
274    fn forward(&self, hidden_states: &Tensor) -> Tensor {
275        hidden_states
276            .apply(&self.dense)
277            .gelu("none")
278            .apply(&self.layer_norm)
279            .apply(&self.decoder)
280            + &self.bias
281    }
282}
283
284struct PaddedInput {
285    input_ids: Option<Tensor>,
286    position_ids: Option<Tensor>,
287    inputs_embeds: Option<Tensor>,
288    attention_mask: Option<Tensor>,
289    token_type_ids: Option<Tensor>,
290}
291
292/// # LongformerModel Base model
293/// Base architecture for LongformerModel models. Task-specific models will be built from this common base model
294/// It is made of the following blocks:
295/// - `embeddings`: LongformerEmbeddings containing word, position and segment id embeddings
296/// - `encoder`: LongformerEncoder
297/// - `pooler`: Optional pooling layer extracting the representation of the first token for each batch item
298pub struct LongformerModel {
299    embeddings: LongformerEmbeddings,
300    encoder: LongformerEncoder,
301    pooler: Option<LongformerPooler>,
302    max_attention_window: i64,
303    pad_token_id: i64,
304    is_decoder: bool,
305}
306
307impl LongformerModel {
308    /// Build a new `LongformerModel`
309    ///
310    /// # Arguments
311    ///
312    /// * `p` - Variable store path for the root of the Longformer model
313    /// * `config` - `LongformerConfig` object defining the model architecture
314    ///
315    /// # Example
316    ///
317    /// ```no_run
318    /// use rust_bert::longformer::{LongformerConfig, LongformerModel};
319    /// use rust_bert::Config;
320    /// use std::path::Path;
321    /// use tch::{nn, Device};
322    ///
323    /// let config_path = Path::new("path/to/config.json");
324    /// let device = Device::Cpu;
325    /// let p = nn::VarStore::new(device);
326    /// let config = LongformerConfig::from_file(config_path);
327    /// let add_pooling_layer = false;
328    /// let longformer_model = LongformerModel::new(&p.root(), &config, add_pooling_layer);
329    /// ```
330    pub fn new<'p, P>(p: P, config: &LongformerConfig, add_pooling_layer: bool) -> LongformerModel
331    where
332        P: Borrow<nn::Path<'p>>,
333    {
334        let p = p.borrow();
335
336        let embeddings = LongformerEmbeddings::new(p / "embeddings", config);
337        let encoder = LongformerEncoder::new(p / "encoder", config);
338        let pooler = if add_pooling_layer {
339            Some(LongformerPooler::new(p / "pooler", config))
340        } else {
341            None
342        };
343
344        let max_attention_window = *config.attention_window.iter().max().unwrap();
345        let pad_token_id = config.pad_token_id.unwrap_or(1);
346        let is_decoder = config.is_decoder.unwrap_or(false);
347
348        LongformerModel {
349            embeddings,
350            encoder,
351            pooler,
352            max_attention_window,
353            pad_token_id,
354            is_decoder,
355        }
356    }
357
358    fn pad_with_nonzero_value(
359        &self,
360        tensor: &Tensor,
361        padding: &[i64],
362        padding_value: i64,
363    ) -> Tensor {
364        (tensor - padding_value).constant_pad_nd(padding) + padding_value
365    }
366
367    fn pad_with_boolean(&self, tensor: &Tensor, padding: &[i64], padding_value: bool) -> Tensor {
368        if !padding_value {
369            tensor.constant_pad_nd(padding)
370        } else {
371            ((tensor.logical_not()).constant_pad_nd(padding)).logical_not()
372        }
373    }
374
375    fn pad_to_window_size(
376        &self,
377        input_ids: Option<&Tensor>,
378        attention_mask: Option<&Tensor>,
379        token_type_ids: Option<&Tensor>,
380        position_ids: Option<&Tensor>,
381        input_embeds: Option<&Tensor>,
382        pad_token_id: i64,
383        padding_length: i64,
384        train: bool,
385    ) -> Result<PaddedInput, RustBertError> {
386        let (input_shape, _) =
387            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
388        let batch_size = input_shape[0];
389
390        let input_ids = input_ids
391            .map(|value| self.pad_with_nonzero_value(value, &[0, padding_length], pad_token_id));
392        let position_ids = position_ids
393            .map(|value| self.pad_with_nonzero_value(value, &[0, padding_length], pad_token_id));
394        let inputs_embeds = input_embeds.map(|value| {
395            let input_ids_padding = Tensor::full(
396                [batch_size, padding_length],
397                pad_token_id,
398                (Kind::Int64, value.device()),
399            );
400            let input_embeds_padding = self
401                .embeddings
402                .forward_t(Some(&input_ids_padding), None, None, None, train)
403                .unwrap();
404
405            Tensor::cat(&[value, &input_embeds_padding], -2)
406        });
407
408        let attention_mask =
409            attention_mask.map(|value| self.pad_with_boolean(value, &[0, padding_length], false));
410        let token_type_ids = token_type_ids.map(|value| value.constant_pad_nd([0, padding_length]));
411        Ok(PaddedInput {
412            input_ids,
413            position_ids,
414            inputs_embeds,
415            attention_mask,
416            token_type_ids,
417        })
418    }
419
420    /// Forward pass through the model
421    ///
422    /// # Arguments
423    ///
424    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
425    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
426    /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
427    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
428    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
429    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
430    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
431    ///
432    /// # Returns
433    ///
434    /// * `LongformerModelOutput` containing:
435    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
436    ///   - `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
437    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
438    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
439    ///   - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*)  where x is the number of tokens with global attention
440    ///
441    /// # Example
442    ///
443    /// ```no_run
444    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
445    /// # use rust_bert::Config;
446    /// # use std::path::Path;
447    /// # use tch::kind::Kind::{Int64, Double};
448    /// use rust_bert::longformer::{LongformerConfig, LongformerModel};
449    /// # let config_path = Path::new("path/to/config.json");
450    /// # let vocab_path = Path::new("path/to/vocab.txt");
451    /// # let device = Device::Cpu;
452    /// # let vs = nn::VarStore::new(device);
453    /// # let config = LongformerConfig::from_file(config_path);
454    /// let longformer_model = LongformerModel::new(&vs.root(), &config, false);
455    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
456    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
457    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
458    /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
459    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
460    ///
461    /// let model_output = no_grad(|| {
462    ///     longformer_model
463    ///         .forward_t(
464    ///             Some(&input_tensor),
465    ///             Some(&attention_mask),
466    ///             Some(&global_attention_mask),
467    ///             None,
468    ///             None,
469    ///             None,
470    ///             false,
471    ///         )
472    ///         .unwrap()
473    /// });
474    /// ```
475    pub fn forward_t(
476        &self,
477        input_ids: Option<&Tensor>,
478        attention_mask: Option<&Tensor>,
479        global_attention_mask: Option<&Tensor>,
480        token_type_ids: Option<&Tensor>,
481        position_ids: Option<&Tensor>,
482        input_embeds: Option<&Tensor>,
483        train: bool,
484    ) -> Result<LongformerModelOutput, RustBertError> {
485        let (input_shape, device) =
486            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
487
488        let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
489
490        let calc_attention_mask = if attention_mask.is_none() {
491            Some(Tensor::ones(input_shape.as_slice(), (Kind::Int, device)))
492        } else {
493            None
494        };
495        let calc_token_type_ids = if token_type_ids.is_none() {
496            Some(Tensor::zeros(input_shape.as_slice(), (Kind::Int64, device)))
497        } else {
498            None
499        };
500        let attention_mask = if attention_mask.is_some() {
501            attention_mask
502        } else {
503            calc_attention_mask.as_ref()
504        };
505        let token_type_ids = if token_type_ids.is_some() {
506            token_type_ids
507        } else {
508            calc_token_type_ids.as_ref()
509        };
510
511        let merged_attention_mask = if let Some(global_attention_mask) = global_attention_mask {
512            attention_mask.map(|tensor| tensor * (global_attention_mask + 1))
513        } else {
514            None
515        };
516        let attention_mask = if merged_attention_mask.is_some() {
517            merged_attention_mask.as_ref()
518        } else {
519            attention_mask
520        };
521
522        let padding_length = (self.max_attention_window
523            - sequence_length % self.max_attention_window)
524            % self.max_attention_window;
525        let (
526            calc_padded_input_ids,
527            calc_padded_position_ids,
528            calc_padded_inputs_embeds,
529            calc_padded_attention_mask,
530            calc_padded_token_type_ids,
531        ) = if padding_length > 0 {
532            let padded_input = self.pad_to_window_size(
533                input_ids,
534                attention_mask,
535                token_type_ids,
536                position_ids,
537                input_embeds,
538                self.pad_token_id,
539                padding_length,
540                train,
541            )?;
542            (
543                padded_input.input_ids,
544                padded_input.position_ids,
545                padded_input.inputs_embeds,
546                padded_input.attention_mask,
547                padded_input.token_type_ids,
548            )
549        } else {
550            (None, None, None, None, None)
551        };
552        let padded_input_ids = if calc_padded_input_ids.is_some() {
553            calc_padded_input_ids.as_ref()
554        } else {
555            input_ids
556        };
557        let padded_position_ids = if calc_padded_position_ids.is_some() {
558            calc_padded_position_ids.as_ref()
559        } else {
560            position_ids
561        };
562        let padded_inputs_embeds = if calc_padded_inputs_embeds.is_some() {
563            calc_padded_inputs_embeds.as_ref()
564        } else {
565            input_embeds
566        };
567        let padded_attention_mask = calc_padded_attention_mask
568            .as_ref()
569            .unwrap_or_else(|| attention_mask.as_ref().unwrap());
570        let padded_token_type_ids = if calc_padded_token_type_ids.is_some() {
571            calc_padded_token_type_ids.as_ref()
572        } else {
573            token_type_ids
574        };
575
576        let extended_attention_mask = match padded_attention_mask.dim() {
577            3 => padded_attention_mask.unsqueeze(1),
578            2 => {
579                if !self.is_decoder {
580                    padded_attention_mask.unsqueeze(1).unsqueeze(1)
581                } else {
582                    let sequence_ids = Tensor::arange(sequence_length, (Kind::Int64, device));
583                    let mut causal_mask = sequence_ids
584                        .unsqueeze(0)
585                        .unsqueeze(0)
586                        .repeat([batch_size, sequence_length, 1])
587                        .le_tensor(&sequence_ids.unsqueeze(-1).unsqueeze(0))
588                        .totype(Kind::Int);
589                    if causal_mask.size()[1] < padded_attention_mask.size()[1] {
590                        let prefix_sequence_length =
591                            padded_attention_mask.size()[1] - causal_mask.size()[1];
592                        causal_mask = Tensor::cat(
593                            &[
594                                Tensor::ones(
595                                    [batch_size, sequence_length, prefix_sequence_length],
596                                    (Kind::Int, device),
597                                ),
598                                causal_mask,
599                            ],
600                            -1,
601                        );
602                    }
603                    causal_mask.unsqueeze(1) * padded_attention_mask.unsqueeze(1).unsqueeze(1)
604                }
605            }
606            _ => {
607                return Err(RustBertError::ValueError(
608                    "Invalid attention mask dimension, must be 2 or 3".into(),
609                ));
610            }
611        }
612        .select(2, 0)
613        .select(1, 0);
614        let extended_attention_mask =
615            (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
616
617        let embedding_output = self.embeddings.forward_t(
618            padded_input_ids,
619            padded_token_type_ids,
620            padded_position_ids,
621            padded_inputs_embeds,
622            train,
623        )?;
624
625        let encoder_outputs =
626            self.encoder
627                .forward_t(&embedding_output, &extended_attention_mask, train);
628
629        let pooled_output = self
630            .pooler
631            .as_ref()
632            .map(|pooler| pooler.forward(&encoder_outputs.hidden_states));
633
634        let sequence_output = if padding_length > 0 {
635            encoder_outputs
636                .hidden_states
637                .slice(1, 0, -padding_length, 1)
638        } else {
639            encoder_outputs.hidden_states
640        };
641
642        Ok(LongformerModelOutput {
643            hidden_state: sequence_output,
644            pooled_output,
645            all_hidden_states: encoder_outputs.all_hidden_states,
646            all_attentions: encoder_outputs.all_attentions,
647            all_global_attentions: encoder_outputs.all_global_attentions,
648        })
649    }
650}
651
652/// # Longformer for masked language model
653/// Base Longformer model with a masked language model head to predict missing tokens, for example `"Looks like one <mask> is missing" -> "person"`
654/// It is made of the following blocks:
655/// - `longformer`: Base LongformerModel
656/// - `lm_head`: Longformer LM prediction head
657pub struct LongformerForMaskedLM {
658    longformer: LongformerModel,
659    lm_head: LongformerLMHead,
660}
661
662impl LongformerForMaskedLM {
663    /// Build a new `LongformerForMaskedLM`
664    ///
665    /// # Arguments
666    ///
667    /// * `p` - Variable store path for the root of the Longformer model
668    /// * `config` - `LongformerConfig` object defining the model architecture
669    ///
670    /// # Example
671    ///
672    /// ```no_run
673    /// use rust_bert::longformer::{LongformerConfig, LongformerForMaskedLM};
674    /// use rust_bert::Config;
675    /// use std::path::Path;
676    /// use tch::{nn, Device};
677    ///
678    /// let config_path = Path::new("path/to/config.json");
679    /// let device = Device::Cpu;
680    /// let p = nn::VarStore::new(device);
681    /// let config = LongformerConfig::from_file(config_path);
682    /// let longformer_model = LongformerForMaskedLM::new(&p.root(), &config);
683    /// ```
684    pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForMaskedLM
685    where
686        P: Borrow<nn::Path<'p>>,
687    {
688        let p = p.borrow();
689
690        let longformer = LongformerModel::new(p / "longformer", config, false);
691        let lm_head = LongformerLMHead::new(p / "lm_head", config);
692
693        LongformerForMaskedLM {
694            longformer,
695            lm_head,
696        }
697    }
698
699    /// Forward pass through the model
700    ///
701    /// # Arguments
702    ///
703    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
704    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
705    /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
706    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
707    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
708    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
709    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
710    ///
711    /// # Returns
712    ///
713    /// * `LongformerMaskedLMOutput` containing:
714    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
715    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
716    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
717    ///   - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*)  where x is the number of tokens with global attention
718    ///
719    /// # Example
720    ///
721    /// ```no_run
722    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
723    /// # use rust_bert::Config;
724    /// # use std::path::Path;
725    /// # use tch::kind::Kind::{Int64, Double};
726    /// use rust_bert::longformer::{LongformerConfig, LongformerForMaskedLM};
727    /// # let config_path = Path::new("path/to/config.json");
728    /// # let vocab_path = Path::new("path/to/vocab.txt");
729    /// # let device = Device::Cpu;
730    /// # let vs = nn::VarStore::new(device);
731    /// # let config = LongformerConfig::from_file(config_path);
732    /// let longformer_model = LongformerForMaskedLM::new(&vs.root(), &config);
733    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
734    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
735    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
736    /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
737    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
738    ///
739    /// let model_output = no_grad(|| {
740    ///     longformer_model
741    ///         .forward_t(
742    ///             Some(&input_tensor),
743    ///             Some(&attention_mask),
744    ///             Some(&global_attention_mask),
745    ///             None,
746    ///             None,
747    ///             None,
748    ///             false,
749    ///         )
750    ///         .unwrap()
751    /// });
752    /// ```
753    pub fn forward_t(
754        &self,
755        input_ids: Option<&Tensor>,
756        attention_mask: Option<&Tensor>,
757        global_attention_mask: Option<&Tensor>,
758        token_type_ids: Option<&Tensor>,
759        position_ids: Option<&Tensor>,
760        input_embeds: Option<&Tensor>,
761        train: bool,
762    ) -> Result<LongformerMaskedLMOutput, RustBertError> {
763        let longformer_outputs = self.longformer.forward_t(
764            input_ids,
765            attention_mask,
766            global_attention_mask,
767            token_type_ids,
768            position_ids,
769            input_embeds,
770            train,
771        )?;
772
773        let prediction_scores = self
774            .lm_head
775            .forward_t(&longformer_outputs.hidden_state, train);
776
777        Ok(LongformerMaskedLMOutput {
778            prediction_scores,
779            all_hidden_states: longformer_outputs.all_hidden_states,
780            all_attentions: longformer_outputs.all_attentions,
781            all_global_attentions: longformer_outputs.all_global_attentions,
782        })
783    }
784}
785
786pub struct LongformerClassificationHead {
787    dense: nn::Linear,
788    dropout: Dropout,
789    out_proj: nn::Linear,
790}
791
792impl LongformerClassificationHead {
793    pub fn new<'p, P>(
794        p: P,
795        config: &LongformerConfig,
796    ) -> Result<LongformerClassificationHead, RustBertError>
797    where
798        P: Borrow<nn::Path<'p>>,
799    {
800        let p = p.borrow();
801
802        let dense = nn::linear(
803            p / "dense",
804            config.hidden_size,
805            config.hidden_size,
806            Default::default(),
807        );
808        let dropout = Dropout::new(config.hidden_dropout_prob);
809
810        let num_labels = config
811            .id2label
812            .as_ref()
813            .ok_or_else(|| {
814                RustBertError::InvalidConfigurationError(
815                    "num_labels not provided in configuration".to_string(),
816                )
817            })?
818            .len() as i64;
819        let out_proj = nn::linear(
820            p / "out_proj",
821            config.hidden_size,
822            num_labels,
823            Default::default(),
824        );
825
826        Ok(LongformerClassificationHead {
827            dense,
828            dropout,
829            out_proj,
830        })
831    }
832
833    pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
834        hidden_states
835            .select(1, 0)
836            .apply_t(&self.dropout, train)
837            .apply(&self.dense)
838            .tanh()
839            .apply_t(&self.dropout, train)
840            .apply(&self.out_proj)
841    }
842}
843
844/// # Longformer for sequence classification
845/// Base Longformer model with a classifier head to perform sentence or document-level classification
846/// It is made of the following blocks:
847/// - `longformer`: Base Longformer
848/// - `classifier`: Longformer classification head
849pub struct LongformerForSequenceClassification {
850    longformer: LongformerModel,
851    classifier: LongformerClassificationHead,
852}
853
854impl LongformerForSequenceClassification {
855    /// Build a new `LongformerForSequenceClassification`
856    ///
857    /// # Arguments
858    ///
859    /// * `p` - Variable store path for the root of the Longformer model
860    /// * `config` - `LongformerConfig` object defining the model architecture
861    ///
862    /// # Example
863    ///
864    /// ```no_run
865    /// use rust_bert::longformer::{LongformerConfig, LongformerForSequenceClassification};
866    /// use rust_bert::Config;
867    /// use std::path::Path;
868    /// use tch::{nn, Device};
869    ///
870    /// let config_path = Path::new("path/to/config.json");
871    /// let device = Device::Cpu;
872    /// let p = nn::VarStore::new(device);
873    /// let config = LongformerConfig::from_file(config_path);
874    /// let longformer_model = LongformerForSequenceClassification::new(&p.root(), &config).unwrap();
875    /// ```
876    pub fn new<'p, P>(
877        p: P,
878        config: &LongformerConfig,
879    ) -> Result<LongformerForSequenceClassification, RustBertError>
880    where
881        P: Borrow<nn::Path<'p>>,
882    {
883        let p = p.borrow();
884
885        let longformer = LongformerModel::new(p / "longformer", config, false);
886        let classifier = LongformerClassificationHead::new(p / "classifier", config)?;
887
888        Ok(LongformerForSequenceClassification {
889            longformer,
890            classifier,
891        })
892    }
893
894    /// Forward pass through the model
895    ///
896    /// # Arguments
897    ///
898    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
899    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
900    /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
901    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
902    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
903    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
904    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
905    ///
906    /// # Returns
907    ///
908    /// * `LongformerSequenceClassificationOutput` containing:
909    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_classes*)
910    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
911    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
912    ///   - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*)  where x is the number of tokens with global attention
913    ///
914    /// # Example
915    ///
916    /// ```no_run
917    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
918    /// # use rust_bert::Config;
919    /// # use std::path::Path;
920    /// # use tch::kind::Kind::{Int64, Double};
921    /// use rust_bert::longformer::{LongformerConfig, LongformerForSequenceClassification};
922    /// # let config_path = Path::new("path/to/config.json");
923    /// # let vocab_path = Path::new("path/to/vocab.txt");
924    /// # let device = Device::Cpu;
925    /// # let vs = nn::VarStore::new(device);
926    /// # let config = LongformerConfig::from_file(config_path);
927    /// let longformer_model = LongformerForSequenceClassification::new(&vs.root(), &config).unwrap();
928    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
929    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
930    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
931    /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
932    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
933    ///
934    /// let model_output = no_grad(|| {
935    ///     longformer_model
936    ///         .forward_t(
937    ///             Some(&input_tensor),
938    ///             Some(&attention_mask),
939    ///             Some(&global_attention_mask),
940    ///             None,
941    ///             None,
942    ///             None,
943    ///             false,
944    ///         )
945    ///         .unwrap()
946    /// });
947    /// ```
948    pub fn forward_t(
949        &self,
950        input_ids: Option<&Tensor>,
951        attention_mask: Option<&Tensor>,
952        global_attention_mask: Option<&Tensor>,
953        token_type_ids: Option<&Tensor>,
954        position_ids: Option<&Tensor>,
955        input_embeds: Option<&Tensor>,
956        train: bool,
957    ) -> Result<LongformerSequenceClassificationOutput, RustBertError> {
958        let calc_global_attention_mask = if global_attention_mask.is_none() {
959            let (input_shape, device) = if let Some(input_ids) = input_ids {
960                if input_embeds.is_none() {
961                    (input_ids.size(), input_ids.device())
962                } else {
963                    return Err(RustBertError::ValueError(
964                        "Only one of input ids or input embeddings may be set".into(),
965                    ));
966                }
967            } else if let Some(input_embeds) = input_embeds {
968                (input_embeds.size()[..2].to_vec(), input_embeds.device())
969            } else {
970                return Err(RustBertError::ValueError(
971                    "At least one of input ids or input embeddings must be set".into(),
972                ));
973            };
974
975            let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
976            let global_attention_mask =
977                Tensor::zeros([batch_size, sequence_length], (Kind::Int, device));
978            let _ = global_attention_mask.select(1, 0).fill_(1);
979            Some(global_attention_mask)
980        } else {
981            None
982        };
983
984        let global_attention_mask = if global_attention_mask.is_some() {
985            global_attention_mask
986        } else {
987            calc_global_attention_mask.as_ref()
988        };
989
990        let base_model_output = self.longformer.forward_t(
991            input_ids,
992            attention_mask,
993            global_attention_mask,
994            token_type_ids,
995            position_ids,
996            input_embeds,
997            train,
998        )?;
999
1000        let logits = self
1001            .classifier
1002            .forward_t(&base_model_output.hidden_state, train);
1003        Ok(LongformerSequenceClassificationOutput {
1004            logits,
1005            all_hidden_states: base_model_output.all_hidden_states,
1006            all_attentions: base_model_output.all_attentions,
1007            all_global_attentions: base_model_output.all_global_attentions,
1008        })
1009    }
1010}
1011
1012/// # Longformer for question answering
1013/// Extractive question-answering model based on a Longformer language model. Identifies the segment of a context that answers a provided question.
1014/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
1015/// See the question answering pipeline (also provided in this crate) for more details.
1016/// It is made of the following blocks:
1017/// - `longformer`: Base Longformer
1018/// - `qa_outputs`: Linear layer for question answering
1019pub struct LongformerForQuestionAnswering {
1020    longformer: LongformerModel,
1021    qa_outputs: nn::Linear,
1022    sep_token_id: i64,
1023}
1024
1025impl LongformerForQuestionAnswering {
1026    /// Build a new `LongformerForQuestionAnswering`
1027    ///
1028    /// # Arguments
1029    ///
1030    /// * `p` - Variable store path for the root of the Longformer model
1031    /// * `config` - `LongformerConfig` object defining the model architecture
1032    ///
1033    /// # Example
1034    ///
1035    /// ```no_run
1036    /// use rust_bert::longformer::{LongformerConfig, LongformerForQuestionAnswering};
1037    /// use rust_bert::Config;
1038    /// use std::path::Path;
1039    /// use tch::{nn, Device};
1040    ///
1041    /// let config_path = Path::new("path/to/config.json");
1042    /// let device = Device::Cpu;
1043    /// let p = nn::VarStore::new(device);
1044    /// let config = LongformerConfig::from_file(config_path);
1045    /// let longformer_model = LongformerForQuestionAnswering::new(&p.root(), &config);
1046    /// ```
1047    pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForQuestionAnswering
1048    where
1049        P: Borrow<nn::Path<'p>>,
1050    {
1051        let p = p.borrow();
1052
1053        let longformer = LongformerModel::new(p / "longformer", config, false);
1054        let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, 2, Default::default());
1055        let sep_token_id = config.sep_token_id;
1056
1057        LongformerForQuestionAnswering {
1058            longformer,
1059            qa_outputs,
1060            sep_token_id,
1061        }
1062    }
1063
1064    /// Forward pass through the model
1065    ///
1066    /// # Arguments
1067    ///
1068    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
1069    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
1070    /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
1071    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1072    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1073    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1074    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1075    ///
1076    /// # Returns
1077    ///
1078    /// * `LongformerForQuestionAnsweringOutput` containing:
1079    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
1080    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
1081    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1082    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
1083    ///   - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*)  where x is the number of tokens with global attention
1084    ///
1085    /// # Example
1086    ///
1087    /// ```no_run
1088    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
1089    /// # use rust_bert::Config;
1090    /// # use std::path::Path;
1091    /// # use tch::kind::Kind::{Int64, Double};
1092    /// use rust_bert::longformer::{LongformerConfig, LongformerForQuestionAnswering};
1093    /// # let config_path = Path::new("path/to/config.json");
1094    /// # let vocab_path = Path::new("path/to/vocab.txt");
1095    /// # let device = Device::Cpu;
1096    /// # let vs = nn::VarStore::new(device);
1097    /// # let config = LongformerConfig::from_file(config_path);
1098    /// let longformer_model = LongformerForQuestionAnswering::new(&vs.root(), &config);
1099    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
1100    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1101    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1102    /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1103    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1104    ///
1105    /// let model_output = no_grad(|| {
1106    ///     longformer_model
1107    ///         .forward_t(
1108    ///             Some(&input_tensor),
1109    ///             Some(&attention_mask),
1110    ///             Some(&global_attention_mask),
1111    ///             None,
1112    ///             None,
1113    ///             None,
1114    ///             false,
1115    ///         )
1116    ///         .unwrap()
1117    /// });
1118    /// ```
1119    pub fn forward_t(
1120        &self,
1121        input_ids: Option<&Tensor>,
1122        attention_mask: Option<&Tensor>,
1123        global_attention_mask: Option<&Tensor>,
1124        token_type_ids: Option<&Tensor>,
1125        position_ids: Option<&Tensor>,
1126        input_embeds: Option<&Tensor>,
1127        train: bool,
1128    ) -> Result<LongformerQuestionAnsweringOutput, RustBertError> {
1129        let calc_global_attention_mask = if global_attention_mask.is_none() {
1130            if let Some(input_ids) = input_ids {
1131                Some(compute_global_attention_mask(
1132                    input_ids,
1133                    self.sep_token_id,
1134                    true,
1135                ))
1136            } else {
1137                return Err(RustBertError::ValueError(
1138                    "Inputs ids must be provided to LongformerQuestionAnsweringOutput if the global_attention_mask is not given".into(),
1139                ));
1140            }
1141        } else {
1142            None
1143        };
1144
1145        let global_attention_mask = if global_attention_mask.is_some() {
1146            global_attention_mask
1147        } else {
1148            calc_global_attention_mask.as_ref()
1149        };
1150
1151        let base_model_output = self.longformer.forward_t(
1152            input_ids,
1153            attention_mask,
1154            global_attention_mask,
1155            token_type_ids,
1156            position_ids,
1157            input_embeds,
1158            train,
1159        )?;
1160
1161        let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1162        let logits = sequence_output.split(1, -1);
1163        let (start_logits, end_logits) = (&logits[0], &logits[1]);
1164        let start_logits = start_logits.squeeze_dim(-1);
1165        let end_logits = end_logits.squeeze_dim(-1);
1166
1167        Ok(LongformerQuestionAnsweringOutput {
1168            start_logits,
1169            end_logits,
1170            all_hidden_states: base_model_output.all_hidden_states,
1171            all_attentions: base_model_output.all_attentions,
1172            all_global_attentions: base_model_output.all_global_attentions,
1173        })
1174    }
1175}
1176
1177/// # Longformer for token classification (e.g. NER, POS)
1178/// Token-level classifier predicting a label for each token provided.
1179/// It is made of the following blocks:
1180/// - `longformer`: Base Longformer model
1181/// - `classifier`: Linear layer for token classification
1182pub struct LongformerForTokenClassification {
1183    longformer: LongformerModel,
1184    dropout: Dropout,
1185    classifier: nn::Linear,
1186}
1187
1188impl LongformerForTokenClassification {
1189    /// Build a new `LongformerForTokenClassification`
1190    ///
1191    /// # Arguments
1192    ///
1193    /// * `p` - Variable store path for the root of the Longformer model
1194    /// * `config` - `LongformerConfig` object defining the model architecture
1195    ///
1196    /// # Example
1197    ///
1198    /// ```no_run
1199    /// use rust_bert::longformer::{LongformerConfig, LongformerForTokenClassification};
1200    /// use rust_bert::Config;
1201    /// use std::path::Path;
1202    /// use tch::{nn, Device};
1203    ///
1204    /// let config_path = Path::new("path/to/config.json");
1205    /// let device = Device::Cpu;
1206    /// let p = nn::VarStore::new(device);
1207    /// let config = LongformerConfig::from_file(config_path);
1208    /// let longformer_model = LongformerForTokenClassification::new(&p.root(), &config).unwrap();
1209    /// ```
1210    pub fn new<'p, P>(
1211        p: P,
1212        config: &LongformerConfig,
1213    ) -> Result<LongformerForTokenClassification, RustBertError>
1214    where
1215        P: Borrow<nn::Path<'p>>,
1216    {
1217        let p = p.borrow();
1218
1219        let longformer = LongformerModel::new(p / "longformer", config, false);
1220        let dropout = Dropout::new(config.hidden_dropout_prob);
1221
1222        let num_labels = config
1223            .id2label
1224            .as_ref()
1225            .ok_or_else(|| {
1226                RustBertError::InvalidConfigurationError(
1227                    "num_labels not provided in configuration".to_string(),
1228                )
1229            })?
1230            .len() as i64;
1231
1232        let classifier = nn::linear(
1233            p / "classifier",
1234            config.hidden_size,
1235            num_labels,
1236            Default::default(),
1237        );
1238
1239        Ok(LongformerForTokenClassification {
1240            longformer,
1241            dropout,
1242            classifier,
1243        })
1244    }
1245
1246    /// Forward pass through the model
1247    ///
1248    /// # Arguments
1249    ///
1250    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
1251    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
1252    /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
1253    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1254    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1255    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1256    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1257    ///
1258    /// # Returns
1259    ///
1260    /// * `LongformerTokenClassificationOutput` containing:
1261    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
1262    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1263    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
1264    ///   - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*)  where x is the number of tokens with global attention
1265    ///
1266    /// # Example
1267    ///
1268    /// ```no_run
1269    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
1270    /// # use rust_bert::Config;
1271    /// # use std::path::Path;
1272    /// # use tch::kind::Kind::{Int64, Double};
1273    /// use rust_bert::longformer::{LongformerConfig, LongformerForTokenClassification};
1274    /// # let config_path = Path::new("path/to/config.json");
1275    /// # let vocab_path = Path::new("path/to/vocab.txt");
1276    /// # let device = Device::Cpu;
1277    /// # let vs = nn::VarStore::new(device);
1278    /// # let config = LongformerConfig::from_file(config_path);
1279    /// let longformer_model = LongformerForTokenClassification::new(&vs.root(), &config).unwrap();
1280    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
1281    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1282    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1283    /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1284    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1285    ///
1286    /// let model_output = no_grad(|| {
1287    ///     longformer_model
1288    ///         .forward_t(
1289    ///             Some(&input_tensor),
1290    ///             Some(&attention_mask),
1291    ///             Some(&global_attention_mask),
1292    ///             None,
1293    ///             None,
1294    ///             None,
1295    ///             false,
1296    ///         )
1297    ///         .unwrap()
1298    /// });
1299    /// ```
1300    pub fn forward_t(
1301        &self,
1302        input_ids: Option<&Tensor>,
1303        attention_mask: Option<&Tensor>,
1304        global_attention_mask: Option<&Tensor>,
1305        token_type_ids: Option<&Tensor>,
1306        position_ids: Option<&Tensor>,
1307        input_embeds: Option<&Tensor>,
1308        train: bool,
1309    ) -> Result<LongformerTokenClassificationOutput, RustBertError> {
1310        let base_model_output = self.longformer.forward_t(
1311            input_ids,
1312            attention_mask,
1313            global_attention_mask,
1314            token_type_ids,
1315            position_ids,
1316            input_embeds,
1317            train,
1318        )?;
1319
1320        let logits = base_model_output
1321            .hidden_state
1322            .apply_t(&self.dropout, train)
1323            .apply(&self.classifier);
1324
1325        Ok(LongformerTokenClassificationOutput {
1326            logits,
1327            all_hidden_states: base_model_output.all_hidden_states,
1328            all_attentions: base_model_output.all_attentions,
1329            all_global_attentions: base_model_output.all_global_attentions,
1330        })
1331    }
1332}
1333
1334/// # Longformer for multiple choices
1335/// Multiple choices model using a Longformer base model and a linear classifier.
1336/// Input should be in the form `<cls> Context <sep><sep> Possible choice <sep>`. The choice is made along the batch axis,
1337/// assuming all elements of the batch are alternatives to be chosen from for a given context.
1338/// It is made of the following blocks:
1339/// - `longformer`: Base LongformerModel model
1340/// - `classifier`: Linear layer for multiple choices
1341pub struct LongformerForMultipleChoice {
1342    longformer: LongformerModel,
1343    dropout: Dropout,
1344    classifier: nn::Linear,
1345    sep_token_id: i64,
1346}
1347
1348impl LongformerForMultipleChoice {
1349    /// Build a new `LongformerForMultipleChoice`
1350    ///
1351    /// # Arguments
1352    ///
1353    /// * `p` - Variable store path for the root of the Longformer model
1354    /// * `config` - `LongformerConfig` object defining the model architecture
1355    ///
1356    /// # Example
1357    ///
1358    /// ```no_run
1359    /// use rust_bert::longformer::{LongformerConfig, LongformerForMultipleChoice};
1360    /// use rust_bert::Config;
1361    /// use std::path::Path;
1362    /// use tch::{nn, Device};
1363    ///
1364    /// let config_path = Path::new("path/to/config.json");
1365    /// let device = Device::Cpu;
1366    /// let p = nn::VarStore::new(device);
1367    /// let config = LongformerConfig::from_file(config_path);
1368    /// let longformer_model = LongformerForMultipleChoice::new(&p.root(), &config);
1369    /// ```
1370    pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForMultipleChoice
1371    where
1372        P: Borrow<nn::Path<'p>>,
1373    {
1374        let p = p.borrow();
1375
1376        let longformer = LongformerModel::new(p / "longformer", config, true);
1377        let dropout = Dropout::new(config.hidden_dropout_prob);
1378        let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
1379        let sep_token_id = config.sep_token_id;
1380
1381        LongformerForMultipleChoice {
1382            longformer,
1383            dropout,
1384            classifier,
1385            sep_token_id,
1386        }
1387    }
1388
1389    /// Forward pass through the model
1390    ///
1391    /// # Arguments
1392    ///
1393    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
1394    /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
1395    /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
1396    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1397    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1398    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1399    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1400    ///
1401    /// # Returns
1402    ///
1403    /// * `LongformerSequenceClassificationOutput` containing:
1404    ///   - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
1405    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1406    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
1407    ///   - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*)  where x is the number of tokens with global attention
1408    ///
1409    /// # Example
1410    ///
1411    /// ```no_run
1412    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
1413    /// # use rust_bert::Config;
1414    /// # use std::path::Path;
1415    /// # use tch::kind::Kind::{Int64, Double};
1416    /// use rust_bert::longformer::{LongformerConfig, LongformerForMultipleChoice};
1417    /// # let config_path = Path::new("path/to/config.json");
1418    /// # let vocab_path = Path::new("path/to/vocab.txt");
1419    /// # let device = Device::Cpu;
1420    /// # let vs = nn::VarStore::new(device);
1421    /// # let config = LongformerConfig::from_file(config_path);
1422    /// let longformer_model = LongformerForMultipleChoice::new(&vs.root(), &config);
1423    /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
1424    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1425    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1426    /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1427    /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1428    ///
1429    /// let model_output = no_grad(|| {
1430    ///     longformer_model
1431    ///         .forward_t(
1432    ///             Some(&input_tensor),
1433    ///             Some(&attention_mask),
1434    ///             Some(&global_attention_mask),
1435    ///             None,
1436    ///             None,
1437    ///             None,
1438    ///             false,
1439    ///         )
1440    ///         .unwrap()
1441    /// });
1442    /// ```
1443    pub fn forward_t(
1444        &self,
1445        input_ids: Option<&Tensor>,
1446        attention_mask: Option<&Tensor>,
1447        global_attention_mask: Option<&Tensor>,
1448        token_type_ids: Option<&Tensor>,
1449        position_ids: Option<&Tensor>,
1450        input_embeds: Option<&Tensor>,
1451        train: bool,
1452    ) -> Result<LongformerSequenceClassificationOutput, RustBertError> {
1453        let num_choices = match (input_ids, input_embeds) {
1454            (Some(input_ids_value), None) => input_ids_value.size()[1],
1455            (None, Some(input_embeds_value)) => input_embeds_value.size()[1],
1456            (Some(_), Some(_)) => {
1457                return Err(RustBertError::ValueError(
1458                    "Only one of input ids or input embeddings may be set".into(),
1459                ));
1460            }
1461            (None, None) => {
1462                return Err(RustBertError::ValueError(
1463                    "At least one of input ids or input embeddings must be set".into(),
1464                ));
1465            }
1466        };
1467
1468        let calc_global_attention_mask = if global_attention_mask.is_none() {
1469            if let Some(input_ids) = input_ids {
1470                let mut masks = Vec::with_capacity(num_choices as usize);
1471                for i in 0..num_choices {
1472                    masks.push(compute_global_attention_mask(
1473                        &input_ids.select(1, i),
1474                        self.sep_token_id,
1475                        false,
1476                    ));
1477                }
1478                Some(Tensor::stack(masks.as_slice(), 1))
1479            } else {
1480                return Err(RustBertError::ValueError(
1481                    "Inputs ids must be provided to LongformerQuestionAnsweringOutput if the global_attention_mask is not given".into(),
1482                ));
1483            }
1484        } else {
1485            None
1486        };
1487
1488        let flat_input_ids =
1489            input_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1490        let flat_attention_mask =
1491            attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1492        let flat_token_type_ids =
1493            token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1494        let flat_position_ids =
1495            position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1496        let flat_input_embeds =
1497            input_embeds.map(|tensor| tensor.view((-1, tensor.size()[1], tensor.size()[2])));
1498
1499        let global_attention_mask = if global_attention_mask.is_some() {
1500            global_attention_mask
1501        } else {
1502            calc_global_attention_mask.as_ref()
1503        };
1504        let flat_global_attention_mask =
1505            global_attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1506
1507        let base_model_output = self.longformer.forward_t(
1508            flat_input_ids.as_ref(),
1509            flat_attention_mask.as_ref(),
1510            flat_global_attention_mask.as_ref(),
1511            flat_token_type_ids.as_ref(),
1512            flat_position_ids.as_ref(),
1513            flat_input_embeds.as_ref(),
1514            train,
1515        )?;
1516
1517        let logits = base_model_output
1518            .pooled_output
1519            .unwrap()
1520            .apply_t(&self.dropout, train)
1521            .apply(&self.classifier)
1522            .view((-1, num_choices));
1523
1524        Ok(LongformerSequenceClassificationOutput {
1525            logits,
1526            all_hidden_states: base_model_output.all_hidden_states,
1527            all_attentions: base_model_output.all_attentions,
1528            all_global_attentions: base_model_output.all_global_attentions,
1529        })
1530    }
1531}
1532
1533/// Container for the Longformer model output.
1534pub struct LongformerModelOutput {
1535    /// Last hidden states from the model
1536    pub hidden_state: Tensor,
1537    /// Pooled output (hidden state for the first token)
1538    pub pooled_output: Option<Tensor>,
1539    /// Hidden states for all intermediate layers
1540    pub all_hidden_states: Option<Vec<Tensor>>,
1541    /// Attention weights for all intermediate layers
1542    pub all_attentions: Option<Vec<Tensor>>,
1543    /// Global attention weights for all intermediate layers
1544    pub all_global_attentions: Option<Vec<Tensor>>,
1545}
1546
1547/// Container for the Longformer masked LM model output.
1548pub struct LongformerMaskedLMOutput {
1549    /// Logits for the vocabulary items at each sequence position
1550    pub prediction_scores: Tensor,
1551    /// Hidden states for all intermediate layers
1552    pub all_hidden_states: Option<Vec<Tensor>>,
1553    /// Attention weights for all intermediate layers
1554    pub all_attentions: Option<Vec<Tensor>>,
1555    /// Global attention weights for all intermediate layers
1556    pub all_global_attentions: Option<Vec<Tensor>>,
1557}
1558
1559/// Container for the Longformer sequence classification model output.
1560pub struct LongformerSequenceClassificationOutput {
1561    /// Logits for each sequence item (token) for each target class
1562    pub logits: Tensor,
1563    /// Hidden states for all intermediate layers
1564    pub all_hidden_states: Option<Vec<Tensor>>,
1565    /// Attention weights for all intermediate layers
1566    pub all_attentions: Option<Vec<Tensor>>,
1567    /// Global attention weights for all intermediate layers
1568    pub all_global_attentions: Option<Vec<Tensor>>,
1569}
1570
1571/// Container for the Longformer token classification model output.
1572pub struct LongformerTokenClassificationOutput {
1573    /// Logits for each sequence item (token) for each target class
1574    pub logits: Tensor,
1575    /// Hidden states for all intermediate layers
1576    pub all_hidden_states: Option<Vec<Tensor>>,
1577    /// Attention weights for all intermediate layers
1578    pub all_attentions: Option<Vec<Tensor>>,
1579    /// Global attention weights for all intermediate layers
1580    pub all_global_attentions: Option<Vec<Tensor>>,
1581}
1582
1583/// Container for the Longformer question answering model output.
1584pub struct LongformerQuestionAnsweringOutput {
1585    /// Logits for the start position for token of each input sequence
1586    pub start_logits: Tensor,
1587    /// Logits for the end position for token of each input sequence
1588    pub end_logits: Tensor,
1589    /// Hidden states for all intermediate layers
1590    pub all_hidden_states: Option<Vec<Tensor>>,
1591    /// Attention weights for all intermediate layers
1592    pub all_attentions: Option<Vec<Tensor>>,
1593    /// Global attention weights for all intermediate layers
1594    pub all_global_attentions: Option<Vec<Tensor>>,
1595}