rust_bert/models/longformer/longformer_model.rs
1// Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
2// Copyright 2021 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::activations::{TensorFunction, _tanh};
14use crate::common::dropout::Dropout;
15use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
16use crate::longformer::embeddings::LongformerEmbeddings;
17use crate::longformer::encoder::LongformerEncoder;
18use crate::{Activation, Config, RustBertError};
19use serde::{Deserialize, Serialize};
20use std::borrow::Borrow;
21use std::collections::HashMap;
22use tch::nn::{Init, Module, ModuleT};
23use tch::{nn, Kind, Tensor};
24
25/// # Longformer Pretrained model weight files
26pub struct LongformerModelResources;
27
28/// # Longformer Pretrained model config files
29pub struct LongformerConfigResources;
30
31/// # Longformer Pretrained model vocab files
32pub struct LongformerVocabResources;
33
34/// # Longformer Pretrained model merges files
35pub struct LongformerMergesResources;
36
37impl LongformerModelResources {
38 /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
39 pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
40 "longformer-base-4096/model",
41 "https://huggingface.co/allenai/longformer-base-4096/resolve/main/rust_model.ot",
42 );
43 /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
44 pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
45 "longformer-base-4096/model",
46 "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/rust_model.ot",
47 );
48}
49
50impl LongformerConfigResources {
51 /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
52 pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
53 "longformer-base-4096/config",
54 "https://huggingface.co/allenai/longformer-base-4096/resolve/main/config.json",
55 );
56 /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
57 pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
58 "longformer-base-4096/config",
59 "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/config.json",
60 );
61}
62
63impl LongformerVocabResources {
64 /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
65 pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
66 "longformer-base-4096/vocab",
67 "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
68 );
69 /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
70 pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
71 "longformer-base-4096/vocab",
72 "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/vocab.json",
73 );
74}
75
76impl LongformerMergesResources {
77 /// Shared under Apache 2.0 license by the AllenAI team at <https://github.com/allenai/longformer>. Modified with conversion to C-array format.
78 pub const LONGFORMER_BASE_4096: (&'static str, &'static str) = (
79 "longformer-base-4096/merges",
80 "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
81 );
82 /// Shared under MIT license at <https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1>. Modified with conversion to C-array format.
83 pub const LONGFORMER_BASE_SQUAD1: (&'static str, &'static str) = (
84 "longformer-base-4096/merges",
85 "https://huggingface.co/valhalla/longformer-base-4096-finetuned-squadv1/resolve/main/merges.txt",
86 );
87}
88
89#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
90#[serde(rename_all = "camelCase")]
91/// # Longformer Position embeddings type
92pub enum PositionEmbeddingType {
93 Absolute,
94 RelativeKey,
95}
96
97#[derive(Debug, Serialize, Deserialize, Clone)]
98/// # Longformer model configuration
99/// Defines the Longformer model architecture (e.g. number of layers, hidden layer size, label mapping...)
100pub struct LongformerConfig {
101 pub hidden_act: Activation,
102 pub attention_window: Vec<i64>,
103 pub attention_probs_dropout_prob: f64,
104 pub hidden_dropout_prob: f64,
105 pub hidden_size: i64,
106 pub initializer_range: f32,
107 pub intermediate_size: i64,
108 pub max_position_embeddings: i64,
109 pub num_attention_heads: i64,
110 pub num_hidden_layers: i64,
111 pub type_vocab_size: i64,
112 pub vocab_size: i64,
113 pub sep_token_id: i64,
114 pub pad_token_id: Option<i64>,
115 pub layer_norm_eps: Option<f64>,
116 pub output_attentions: Option<bool>,
117 pub output_hidden_states: Option<bool>,
118 pub position_embedding_type: Option<PositionEmbeddingType>,
119 pub is_decoder: Option<bool>,
120 pub id2label: Option<HashMap<i64, String>>,
121 pub label2id: Option<HashMap<String, i64>>,
122}
123
124impl Config for LongformerConfig {}
125
126impl Default for LongformerConfig {
127 fn default() -> Self {
128 LongformerConfig {
129 hidden_act: Activation::gelu,
130 attention_window: vec![512],
131 attention_probs_dropout_prob: 0.1,
132 hidden_dropout_prob: 0.1,
133 hidden_size: 768,
134 initializer_range: 0.02,
135 intermediate_size: 3072,
136 max_position_embeddings: 512,
137 num_attention_heads: 12,
138 num_hidden_layers: 12,
139 type_vocab_size: 2,
140 vocab_size: 30522,
141 sep_token_id: 2,
142 pad_token_id: None,
143 layer_norm_eps: None,
144 output_attentions: None,
145 output_hidden_states: None,
146 position_embedding_type: None,
147 is_decoder: None,
148 id2label: None,
149 label2id: None,
150 }
151 }
152}
153
154fn get_question_end_index(input_ids: &Tensor, sep_token_id: i64) -> Tensor {
155 input_ids
156 .eq(sep_token_id)
157 .nonzero()
158 .view([input_ids.size()[0], 3, 2])
159 .select(2, 1)
160 .select(1, 0)
161}
162
163fn compute_global_attention_mask(
164 input_ids: &Tensor,
165 sep_token_id: i64,
166 before_sep_token: bool,
167) -> Tensor {
168 let question_end_index = get_question_end_index(input_ids, sep_token_id).unsqueeze(1);
169 let attention_mask = Tensor::arange(input_ids.size()[1], (Kind::Int64, input_ids.device()));
170
171 if before_sep_token {
172 attention_mask
173 .expand_as(input_ids)
174 .lt_tensor(&question_end_index)
175 } else {
176 attention_mask
177 .expand_as(input_ids)
178 .gt_tensor(&(question_end_index + 1))
179 * attention_mask
180 .expand_as(input_ids)
181 .lt(*input_ids.size().last().unwrap())
182 }
183}
184
185#[derive(Debug)]
186pub struct LongformerPooler {
187 dense: nn::Linear,
188 activation: TensorFunction,
189}
190
191impl LongformerPooler {
192 pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerPooler
193 where
194 P: Borrow<nn::Path<'p>>,
195 {
196 let p = p.borrow();
197
198 let dense = nn::linear(
199 p / "dense",
200 config.hidden_size,
201 config.hidden_size,
202 Default::default(),
203 );
204
205 let activation = TensorFunction::new(Box::new(_tanh));
206
207 LongformerPooler { dense, activation }
208 }
209}
210
211impl Module for LongformerPooler {
212 fn forward(&self, hidden_states: &Tensor) -> Tensor {
213 self.activation.get_fn()(&hidden_states.select(1, 0).apply(&self.dense))
214 }
215}
216
217#[derive(Debug)]
218pub struct LongformerLMHead {
219 dense: nn::Linear,
220 layer_norm: nn::LayerNorm,
221 decoder: nn::Linear,
222 bias: Tensor,
223}
224
225impl LongformerLMHead {
226 pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerLMHead
227 where
228 P: Borrow<nn::Path<'p>>,
229 {
230 let p = p.borrow();
231
232 let dense = nn::linear(
233 p / "dense",
234 config.hidden_size,
235 config.hidden_size,
236 Default::default(),
237 );
238
239 let layer_norm_config = nn::LayerNormConfig {
240 eps: config.layer_norm_eps.unwrap_or(1e-12),
241 ..Default::default()
242 };
243
244 let layer_norm = nn::layer_norm(
245 p / "layer_norm",
246 vec![config.hidden_size],
247 layer_norm_config,
248 );
249
250 let linear_config = nn::LinearConfig {
251 bias: false,
252 ..Default::default()
253 };
254
255 let decoder = nn::linear(
256 p / "decoder",
257 config.hidden_size,
258 config.vocab_size,
259 linear_config,
260 );
261
262 let bias = p.var("bias", &[config.vocab_size], Init::Const(0f64));
263
264 LongformerLMHead {
265 dense,
266 layer_norm,
267 decoder,
268 bias,
269 }
270 }
271}
272
273impl Module for LongformerLMHead {
274 fn forward(&self, hidden_states: &Tensor) -> Tensor {
275 hidden_states
276 .apply(&self.dense)
277 .gelu("none")
278 .apply(&self.layer_norm)
279 .apply(&self.decoder)
280 + &self.bias
281 }
282}
283
284struct PaddedInput {
285 input_ids: Option<Tensor>,
286 position_ids: Option<Tensor>,
287 inputs_embeds: Option<Tensor>,
288 attention_mask: Option<Tensor>,
289 token_type_ids: Option<Tensor>,
290}
291
292/// # LongformerModel Base model
293/// Base architecture for LongformerModel models. Task-specific models will be built from this common base model
294/// It is made of the following blocks:
295/// - `embeddings`: LongformerEmbeddings containing word, position and segment id embeddings
296/// - `encoder`: LongformerEncoder
297/// - `pooler`: Optional pooling layer extracting the representation of the first token for each batch item
298pub struct LongformerModel {
299 embeddings: LongformerEmbeddings,
300 encoder: LongformerEncoder,
301 pooler: Option<LongformerPooler>,
302 max_attention_window: i64,
303 pad_token_id: i64,
304 is_decoder: bool,
305}
306
307impl LongformerModel {
308 /// Build a new `LongformerModel`
309 ///
310 /// # Arguments
311 ///
312 /// * `p` - Variable store path for the root of the Longformer model
313 /// * `config` - `LongformerConfig` object defining the model architecture
314 ///
315 /// # Example
316 ///
317 /// ```no_run
318 /// use rust_bert::longformer::{LongformerConfig, LongformerModel};
319 /// use rust_bert::Config;
320 /// use std::path::Path;
321 /// use tch::{nn, Device};
322 ///
323 /// let config_path = Path::new("path/to/config.json");
324 /// let device = Device::Cpu;
325 /// let p = nn::VarStore::new(device);
326 /// let config = LongformerConfig::from_file(config_path);
327 /// let add_pooling_layer = false;
328 /// let longformer_model = LongformerModel::new(&p.root(), &config, add_pooling_layer);
329 /// ```
330 pub fn new<'p, P>(p: P, config: &LongformerConfig, add_pooling_layer: bool) -> LongformerModel
331 where
332 P: Borrow<nn::Path<'p>>,
333 {
334 let p = p.borrow();
335
336 let embeddings = LongformerEmbeddings::new(p / "embeddings", config);
337 let encoder = LongformerEncoder::new(p / "encoder", config);
338 let pooler = if add_pooling_layer {
339 Some(LongformerPooler::new(p / "pooler", config))
340 } else {
341 None
342 };
343
344 let max_attention_window = *config.attention_window.iter().max().unwrap();
345 let pad_token_id = config.pad_token_id.unwrap_or(1);
346 let is_decoder = config.is_decoder.unwrap_or(false);
347
348 LongformerModel {
349 embeddings,
350 encoder,
351 pooler,
352 max_attention_window,
353 pad_token_id,
354 is_decoder,
355 }
356 }
357
358 fn pad_with_nonzero_value(
359 &self,
360 tensor: &Tensor,
361 padding: &[i64],
362 padding_value: i64,
363 ) -> Tensor {
364 (tensor - padding_value).constant_pad_nd(padding) + padding_value
365 }
366
367 fn pad_with_boolean(&self, tensor: &Tensor, padding: &[i64], padding_value: bool) -> Tensor {
368 if !padding_value {
369 tensor.constant_pad_nd(padding)
370 } else {
371 ((tensor.logical_not()).constant_pad_nd(padding)).logical_not()
372 }
373 }
374
375 fn pad_to_window_size(
376 &self,
377 input_ids: Option<&Tensor>,
378 attention_mask: Option<&Tensor>,
379 token_type_ids: Option<&Tensor>,
380 position_ids: Option<&Tensor>,
381 input_embeds: Option<&Tensor>,
382 pad_token_id: i64,
383 padding_length: i64,
384 train: bool,
385 ) -> Result<PaddedInput, RustBertError> {
386 let (input_shape, _) =
387 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
388 let batch_size = input_shape[0];
389
390 let input_ids = input_ids
391 .map(|value| self.pad_with_nonzero_value(value, &[0, padding_length], pad_token_id));
392 let position_ids = position_ids
393 .map(|value| self.pad_with_nonzero_value(value, &[0, padding_length], pad_token_id));
394 let inputs_embeds = input_embeds.map(|value| {
395 let input_ids_padding = Tensor::full(
396 [batch_size, padding_length],
397 pad_token_id,
398 (Kind::Int64, value.device()),
399 );
400 let input_embeds_padding = self
401 .embeddings
402 .forward_t(Some(&input_ids_padding), None, None, None, train)
403 .unwrap();
404
405 Tensor::cat(&[value, &input_embeds_padding], -2)
406 });
407
408 let attention_mask =
409 attention_mask.map(|value| self.pad_with_boolean(value, &[0, padding_length], false));
410 let token_type_ids = token_type_ids.map(|value| value.constant_pad_nd([0, padding_length]));
411 Ok(PaddedInput {
412 input_ids,
413 position_ids,
414 inputs_embeds,
415 attention_mask,
416 token_type_ids,
417 })
418 }
419
420 /// Forward pass through the model
421 ///
422 /// # Arguments
423 ///
424 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
425 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
426 /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
427 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
428 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
429 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
430 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
431 ///
432 /// # Returns
433 ///
434 /// * `LongformerModelOutput` containing:
435 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
436 /// - `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
437 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
438 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
439 /// - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*) where x is the number of tokens with global attention
440 ///
441 /// # Example
442 ///
443 /// ```no_run
444 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
445 /// # use rust_bert::Config;
446 /// # use std::path::Path;
447 /// # use tch::kind::Kind::{Int64, Double};
448 /// use rust_bert::longformer::{LongformerConfig, LongformerModel};
449 /// # let config_path = Path::new("path/to/config.json");
450 /// # let vocab_path = Path::new("path/to/vocab.txt");
451 /// # let device = Device::Cpu;
452 /// # let vs = nn::VarStore::new(device);
453 /// # let config = LongformerConfig::from_file(config_path);
454 /// let longformer_model = LongformerModel::new(&vs.root(), &config, false);
455 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
456 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
457 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
458 /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
459 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
460 ///
461 /// let model_output = no_grad(|| {
462 /// longformer_model
463 /// .forward_t(
464 /// Some(&input_tensor),
465 /// Some(&attention_mask),
466 /// Some(&global_attention_mask),
467 /// None,
468 /// None,
469 /// None,
470 /// false,
471 /// )
472 /// .unwrap()
473 /// });
474 /// ```
475 pub fn forward_t(
476 &self,
477 input_ids: Option<&Tensor>,
478 attention_mask: Option<&Tensor>,
479 global_attention_mask: Option<&Tensor>,
480 token_type_ids: Option<&Tensor>,
481 position_ids: Option<&Tensor>,
482 input_embeds: Option<&Tensor>,
483 train: bool,
484 ) -> Result<LongformerModelOutput, RustBertError> {
485 let (input_shape, device) =
486 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
487
488 let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
489
490 let calc_attention_mask = if attention_mask.is_none() {
491 Some(Tensor::ones(input_shape.as_slice(), (Kind::Int, device)))
492 } else {
493 None
494 };
495 let calc_token_type_ids = if token_type_ids.is_none() {
496 Some(Tensor::zeros(input_shape.as_slice(), (Kind::Int64, device)))
497 } else {
498 None
499 };
500 let attention_mask = if attention_mask.is_some() {
501 attention_mask
502 } else {
503 calc_attention_mask.as_ref()
504 };
505 let token_type_ids = if token_type_ids.is_some() {
506 token_type_ids
507 } else {
508 calc_token_type_ids.as_ref()
509 };
510
511 let merged_attention_mask = if let Some(global_attention_mask) = global_attention_mask {
512 attention_mask.map(|tensor| tensor * (global_attention_mask + 1))
513 } else {
514 None
515 };
516 let attention_mask = if merged_attention_mask.is_some() {
517 merged_attention_mask.as_ref()
518 } else {
519 attention_mask
520 };
521
522 let padding_length = (self.max_attention_window
523 - sequence_length % self.max_attention_window)
524 % self.max_attention_window;
525 let (
526 calc_padded_input_ids,
527 calc_padded_position_ids,
528 calc_padded_inputs_embeds,
529 calc_padded_attention_mask,
530 calc_padded_token_type_ids,
531 ) = if padding_length > 0 {
532 let padded_input = self.pad_to_window_size(
533 input_ids,
534 attention_mask,
535 token_type_ids,
536 position_ids,
537 input_embeds,
538 self.pad_token_id,
539 padding_length,
540 train,
541 )?;
542 (
543 padded_input.input_ids,
544 padded_input.position_ids,
545 padded_input.inputs_embeds,
546 padded_input.attention_mask,
547 padded_input.token_type_ids,
548 )
549 } else {
550 (None, None, None, None, None)
551 };
552 let padded_input_ids = if calc_padded_input_ids.is_some() {
553 calc_padded_input_ids.as_ref()
554 } else {
555 input_ids
556 };
557 let padded_position_ids = if calc_padded_position_ids.is_some() {
558 calc_padded_position_ids.as_ref()
559 } else {
560 position_ids
561 };
562 let padded_inputs_embeds = if calc_padded_inputs_embeds.is_some() {
563 calc_padded_inputs_embeds.as_ref()
564 } else {
565 input_embeds
566 };
567 let padded_attention_mask = calc_padded_attention_mask
568 .as_ref()
569 .unwrap_or_else(|| attention_mask.as_ref().unwrap());
570 let padded_token_type_ids = if calc_padded_token_type_ids.is_some() {
571 calc_padded_token_type_ids.as_ref()
572 } else {
573 token_type_ids
574 };
575
576 let extended_attention_mask = match padded_attention_mask.dim() {
577 3 => padded_attention_mask.unsqueeze(1),
578 2 => {
579 if !self.is_decoder {
580 padded_attention_mask.unsqueeze(1).unsqueeze(1)
581 } else {
582 let sequence_ids = Tensor::arange(sequence_length, (Kind::Int64, device));
583 let mut causal_mask = sequence_ids
584 .unsqueeze(0)
585 .unsqueeze(0)
586 .repeat([batch_size, sequence_length, 1])
587 .le_tensor(&sequence_ids.unsqueeze(-1).unsqueeze(0))
588 .totype(Kind::Int);
589 if causal_mask.size()[1] < padded_attention_mask.size()[1] {
590 let prefix_sequence_length =
591 padded_attention_mask.size()[1] - causal_mask.size()[1];
592 causal_mask = Tensor::cat(
593 &[
594 Tensor::ones(
595 [batch_size, sequence_length, prefix_sequence_length],
596 (Kind::Int, device),
597 ),
598 causal_mask,
599 ],
600 -1,
601 );
602 }
603 causal_mask.unsqueeze(1) * padded_attention_mask.unsqueeze(1).unsqueeze(1)
604 }
605 }
606 _ => {
607 return Err(RustBertError::ValueError(
608 "Invalid attention mask dimension, must be 2 or 3".into(),
609 ));
610 }
611 }
612 .select(2, 0)
613 .select(1, 0);
614 let extended_attention_mask =
615 (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
616
617 let embedding_output = self.embeddings.forward_t(
618 padded_input_ids,
619 padded_token_type_ids,
620 padded_position_ids,
621 padded_inputs_embeds,
622 train,
623 )?;
624
625 let encoder_outputs =
626 self.encoder
627 .forward_t(&embedding_output, &extended_attention_mask, train);
628
629 let pooled_output = self
630 .pooler
631 .as_ref()
632 .map(|pooler| pooler.forward(&encoder_outputs.hidden_states));
633
634 let sequence_output = if padding_length > 0 {
635 encoder_outputs
636 .hidden_states
637 .slice(1, 0, -padding_length, 1)
638 } else {
639 encoder_outputs.hidden_states
640 };
641
642 Ok(LongformerModelOutput {
643 hidden_state: sequence_output,
644 pooled_output,
645 all_hidden_states: encoder_outputs.all_hidden_states,
646 all_attentions: encoder_outputs.all_attentions,
647 all_global_attentions: encoder_outputs.all_global_attentions,
648 })
649 }
650}
651
652/// # Longformer for masked language model
653/// Base Longformer model with a masked language model head to predict missing tokens, for example `"Looks like one <mask> is missing" -> "person"`
654/// It is made of the following blocks:
655/// - `longformer`: Base LongformerModel
656/// - `lm_head`: Longformer LM prediction head
657pub struct LongformerForMaskedLM {
658 longformer: LongformerModel,
659 lm_head: LongformerLMHead,
660}
661
662impl LongformerForMaskedLM {
663 /// Build a new `LongformerForMaskedLM`
664 ///
665 /// # Arguments
666 ///
667 /// * `p` - Variable store path for the root of the Longformer model
668 /// * `config` - `LongformerConfig` object defining the model architecture
669 ///
670 /// # Example
671 ///
672 /// ```no_run
673 /// use rust_bert::longformer::{LongformerConfig, LongformerForMaskedLM};
674 /// use rust_bert::Config;
675 /// use std::path::Path;
676 /// use tch::{nn, Device};
677 ///
678 /// let config_path = Path::new("path/to/config.json");
679 /// let device = Device::Cpu;
680 /// let p = nn::VarStore::new(device);
681 /// let config = LongformerConfig::from_file(config_path);
682 /// let longformer_model = LongformerForMaskedLM::new(&p.root(), &config);
683 /// ```
684 pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForMaskedLM
685 where
686 P: Borrow<nn::Path<'p>>,
687 {
688 let p = p.borrow();
689
690 let longformer = LongformerModel::new(p / "longformer", config, false);
691 let lm_head = LongformerLMHead::new(p / "lm_head", config);
692
693 LongformerForMaskedLM {
694 longformer,
695 lm_head,
696 }
697 }
698
699 /// Forward pass through the model
700 ///
701 /// # Arguments
702 ///
703 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
704 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
705 /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
706 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
707 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
708 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
709 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
710 ///
711 /// # Returns
712 ///
713 /// * `LongformerMaskedLMOutput` containing:
714 /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
715 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
716 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
717 /// - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*) where x is the number of tokens with global attention
718 ///
719 /// # Example
720 ///
721 /// ```no_run
722 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
723 /// # use rust_bert::Config;
724 /// # use std::path::Path;
725 /// # use tch::kind::Kind::{Int64, Double};
726 /// use rust_bert::longformer::{LongformerConfig, LongformerForMaskedLM};
727 /// # let config_path = Path::new("path/to/config.json");
728 /// # let vocab_path = Path::new("path/to/vocab.txt");
729 /// # let device = Device::Cpu;
730 /// # let vs = nn::VarStore::new(device);
731 /// # let config = LongformerConfig::from_file(config_path);
732 /// let longformer_model = LongformerForMaskedLM::new(&vs.root(), &config);
733 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
734 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
735 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
736 /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
737 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
738 ///
739 /// let model_output = no_grad(|| {
740 /// longformer_model
741 /// .forward_t(
742 /// Some(&input_tensor),
743 /// Some(&attention_mask),
744 /// Some(&global_attention_mask),
745 /// None,
746 /// None,
747 /// None,
748 /// false,
749 /// )
750 /// .unwrap()
751 /// });
752 /// ```
753 pub fn forward_t(
754 &self,
755 input_ids: Option<&Tensor>,
756 attention_mask: Option<&Tensor>,
757 global_attention_mask: Option<&Tensor>,
758 token_type_ids: Option<&Tensor>,
759 position_ids: Option<&Tensor>,
760 input_embeds: Option<&Tensor>,
761 train: bool,
762 ) -> Result<LongformerMaskedLMOutput, RustBertError> {
763 let longformer_outputs = self.longformer.forward_t(
764 input_ids,
765 attention_mask,
766 global_attention_mask,
767 token_type_ids,
768 position_ids,
769 input_embeds,
770 train,
771 )?;
772
773 let prediction_scores = self
774 .lm_head
775 .forward_t(&longformer_outputs.hidden_state, train);
776
777 Ok(LongformerMaskedLMOutput {
778 prediction_scores,
779 all_hidden_states: longformer_outputs.all_hidden_states,
780 all_attentions: longformer_outputs.all_attentions,
781 all_global_attentions: longformer_outputs.all_global_attentions,
782 })
783 }
784}
785
786pub struct LongformerClassificationHead {
787 dense: nn::Linear,
788 dropout: Dropout,
789 out_proj: nn::Linear,
790}
791
792impl LongformerClassificationHead {
793 pub fn new<'p, P>(
794 p: P,
795 config: &LongformerConfig,
796 ) -> Result<LongformerClassificationHead, RustBertError>
797 where
798 P: Borrow<nn::Path<'p>>,
799 {
800 let p = p.borrow();
801
802 let dense = nn::linear(
803 p / "dense",
804 config.hidden_size,
805 config.hidden_size,
806 Default::default(),
807 );
808 let dropout = Dropout::new(config.hidden_dropout_prob);
809
810 let num_labels = config
811 .id2label
812 .as_ref()
813 .ok_or_else(|| {
814 RustBertError::InvalidConfigurationError(
815 "num_labels not provided in configuration".to_string(),
816 )
817 })?
818 .len() as i64;
819 let out_proj = nn::linear(
820 p / "out_proj",
821 config.hidden_size,
822 num_labels,
823 Default::default(),
824 );
825
826 Ok(LongformerClassificationHead {
827 dense,
828 dropout,
829 out_proj,
830 })
831 }
832
833 pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
834 hidden_states
835 .select(1, 0)
836 .apply_t(&self.dropout, train)
837 .apply(&self.dense)
838 .tanh()
839 .apply_t(&self.dropout, train)
840 .apply(&self.out_proj)
841 }
842}
843
844/// # Longformer for sequence classification
845/// Base Longformer model with a classifier head to perform sentence or document-level classification
846/// It is made of the following blocks:
847/// - `longformer`: Base Longformer
848/// - `classifier`: Longformer classification head
849pub struct LongformerForSequenceClassification {
850 longformer: LongformerModel,
851 classifier: LongformerClassificationHead,
852}
853
854impl LongformerForSequenceClassification {
855 /// Build a new `LongformerForSequenceClassification`
856 ///
857 /// # Arguments
858 ///
859 /// * `p` - Variable store path for the root of the Longformer model
860 /// * `config` - `LongformerConfig` object defining the model architecture
861 ///
862 /// # Example
863 ///
864 /// ```no_run
865 /// use rust_bert::longformer::{LongformerConfig, LongformerForSequenceClassification};
866 /// use rust_bert::Config;
867 /// use std::path::Path;
868 /// use tch::{nn, Device};
869 ///
870 /// let config_path = Path::new("path/to/config.json");
871 /// let device = Device::Cpu;
872 /// let p = nn::VarStore::new(device);
873 /// let config = LongformerConfig::from_file(config_path);
874 /// let longformer_model = LongformerForSequenceClassification::new(&p.root(), &config).unwrap();
875 /// ```
876 pub fn new<'p, P>(
877 p: P,
878 config: &LongformerConfig,
879 ) -> Result<LongformerForSequenceClassification, RustBertError>
880 where
881 P: Borrow<nn::Path<'p>>,
882 {
883 let p = p.borrow();
884
885 let longformer = LongformerModel::new(p / "longformer", config, false);
886 let classifier = LongformerClassificationHead::new(p / "classifier", config)?;
887
888 Ok(LongformerForSequenceClassification {
889 longformer,
890 classifier,
891 })
892 }
893
894 /// Forward pass through the model
895 ///
896 /// # Arguments
897 ///
898 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
899 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
900 /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
901 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
902 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
903 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
904 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
905 ///
906 /// # Returns
907 ///
908 /// * `LongformerSequenceClassificationOutput` containing:
909 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_classes*)
910 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
911 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
912 /// - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*) where x is the number of tokens with global attention
913 ///
914 /// # Example
915 ///
916 /// ```no_run
917 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
918 /// # use rust_bert::Config;
919 /// # use std::path::Path;
920 /// # use tch::kind::Kind::{Int64, Double};
921 /// use rust_bert::longformer::{LongformerConfig, LongformerForSequenceClassification};
922 /// # let config_path = Path::new("path/to/config.json");
923 /// # let vocab_path = Path::new("path/to/vocab.txt");
924 /// # let device = Device::Cpu;
925 /// # let vs = nn::VarStore::new(device);
926 /// # let config = LongformerConfig::from_file(config_path);
927 /// let longformer_model = LongformerForSequenceClassification::new(&vs.root(), &config).unwrap();
928 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
929 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
930 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
931 /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
932 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
933 ///
934 /// let model_output = no_grad(|| {
935 /// longformer_model
936 /// .forward_t(
937 /// Some(&input_tensor),
938 /// Some(&attention_mask),
939 /// Some(&global_attention_mask),
940 /// None,
941 /// None,
942 /// None,
943 /// false,
944 /// )
945 /// .unwrap()
946 /// });
947 /// ```
948 pub fn forward_t(
949 &self,
950 input_ids: Option<&Tensor>,
951 attention_mask: Option<&Tensor>,
952 global_attention_mask: Option<&Tensor>,
953 token_type_ids: Option<&Tensor>,
954 position_ids: Option<&Tensor>,
955 input_embeds: Option<&Tensor>,
956 train: bool,
957 ) -> Result<LongformerSequenceClassificationOutput, RustBertError> {
958 let calc_global_attention_mask = if global_attention_mask.is_none() {
959 let (input_shape, device) = if let Some(input_ids) = input_ids {
960 if input_embeds.is_none() {
961 (input_ids.size(), input_ids.device())
962 } else {
963 return Err(RustBertError::ValueError(
964 "Only one of input ids or input embeddings may be set".into(),
965 ));
966 }
967 } else if let Some(input_embeds) = input_embeds {
968 (input_embeds.size()[..2].to_vec(), input_embeds.device())
969 } else {
970 return Err(RustBertError::ValueError(
971 "At least one of input ids or input embeddings must be set".into(),
972 ));
973 };
974
975 let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
976 let global_attention_mask =
977 Tensor::zeros([batch_size, sequence_length], (Kind::Int, device));
978 let _ = global_attention_mask.select(1, 0).fill_(1);
979 Some(global_attention_mask)
980 } else {
981 None
982 };
983
984 let global_attention_mask = if global_attention_mask.is_some() {
985 global_attention_mask
986 } else {
987 calc_global_attention_mask.as_ref()
988 };
989
990 let base_model_output = self.longformer.forward_t(
991 input_ids,
992 attention_mask,
993 global_attention_mask,
994 token_type_ids,
995 position_ids,
996 input_embeds,
997 train,
998 )?;
999
1000 let logits = self
1001 .classifier
1002 .forward_t(&base_model_output.hidden_state, train);
1003 Ok(LongformerSequenceClassificationOutput {
1004 logits,
1005 all_hidden_states: base_model_output.all_hidden_states,
1006 all_attentions: base_model_output.all_attentions,
1007 all_global_attentions: base_model_output.all_global_attentions,
1008 })
1009 }
1010}
1011
1012/// # Longformer for question answering
1013/// Extractive question-answering model based on a Longformer language model. Identifies the segment of a context that answers a provided question.
1014/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
1015/// See the question answering pipeline (also provided in this crate) for more details.
1016/// It is made of the following blocks:
1017/// - `longformer`: Base Longformer
1018/// - `qa_outputs`: Linear layer for question answering
1019pub struct LongformerForQuestionAnswering {
1020 longformer: LongformerModel,
1021 qa_outputs: nn::Linear,
1022 sep_token_id: i64,
1023}
1024
1025impl LongformerForQuestionAnswering {
1026 /// Build a new `LongformerForQuestionAnswering`
1027 ///
1028 /// # Arguments
1029 ///
1030 /// * `p` - Variable store path for the root of the Longformer model
1031 /// * `config` - `LongformerConfig` object defining the model architecture
1032 ///
1033 /// # Example
1034 ///
1035 /// ```no_run
1036 /// use rust_bert::longformer::{LongformerConfig, LongformerForQuestionAnswering};
1037 /// use rust_bert::Config;
1038 /// use std::path::Path;
1039 /// use tch::{nn, Device};
1040 ///
1041 /// let config_path = Path::new("path/to/config.json");
1042 /// let device = Device::Cpu;
1043 /// let p = nn::VarStore::new(device);
1044 /// let config = LongformerConfig::from_file(config_path);
1045 /// let longformer_model = LongformerForQuestionAnswering::new(&p.root(), &config);
1046 /// ```
1047 pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForQuestionAnswering
1048 where
1049 P: Borrow<nn::Path<'p>>,
1050 {
1051 let p = p.borrow();
1052
1053 let longformer = LongformerModel::new(p / "longformer", config, false);
1054 let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, 2, Default::default());
1055 let sep_token_id = config.sep_token_id;
1056
1057 LongformerForQuestionAnswering {
1058 longformer,
1059 qa_outputs,
1060 sep_token_id,
1061 }
1062 }
1063
1064 /// Forward pass through the model
1065 ///
1066 /// # Arguments
1067 ///
1068 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
1069 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
1070 /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
1071 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1072 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1073 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1074 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1075 ///
1076 /// # Returns
1077 ///
1078 /// * `LongformerForQuestionAnsweringOutput` containing:
1079 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
1080 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
1081 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1082 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
1083 /// - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*) where x is the number of tokens with global attention
1084 ///
1085 /// # Example
1086 ///
1087 /// ```no_run
1088 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
1089 /// # use rust_bert::Config;
1090 /// # use std::path::Path;
1091 /// # use tch::kind::Kind::{Int64, Double};
1092 /// use rust_bert::longformer::{LongformerConfig, LongformerForQuestionAnswering};
1093 /// # let config_path = Path::new("path/to/config.json");
1094 /// # let vocab_path = Path::new("path/to/vocab.txt");
1095 /// # let device = Device::Cpu;
1096 /// # let vs = nn::VarStore::new(device);
1097 /// # let config = LongformerConfig::from_file(config_path);
1098 /// let longformer_model = LongformerForQuestionAnswering::new(&vs.root(), &config);
1099 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
1100 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1101 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1102 /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1103 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1104 ///
1105 /// let model_output = no_grad(|| {
1106 /// longformer_model
1107 /// .forward_t(
1108 /// Some(&input_tensor),
1109 /// Some(&attention_mask),
1110 /// Some(&global_attention_mask),
1111 /// None,
1112 /// None,
1113 /// None,
1114 /// false,
1115 /// )
1116 /// .unwrap()
1117 /// });
1118 /// ```
1119 pub fn forward_t(
1120 &self,
1121 input_ids: Option<&Tensor>,
1122 attention_mask: Option<&Tensor>,
1123 global_attention_mask: Option<&Tensor>,
1124 token_type_ids: Option<&Tensor>,
1125 position_ids: Option<&Tensor>,
1126 input_embeds: Option<&Tensor>,
1127 train: bool,
1128 ) -> Result<LongformerQuestionAnsweringOutput, RustBertError> {
1129 let calc_global_attention_mask = if global_attention_mask.is_none() {
1130 if let Some(input_ids) = input_ids {
1131 Some(compute_global_attention_mask(
1132 input_ids,
1133 self.sep_token_id,
1134 true,
1135 ))
1136 } else {
1137 return Err(RustBertError::ValueError(
1138 "Inputs ids must be provided to LongformerQuestionAnsweringOutput if the global_attention_mask is not given".into(),
1139 ));
1140 }
1141 } else {
1142 None
1143 };
1144
1145 let global_attention_mask = if global_attention_mask.is_some() {
1146 global_attention_mask
1147 } else {
1148 calc_global_attention_mask.as_ref()
1149 };
1150
1151 let base_model_output = self.longformer.forward_t(
1152 input_ids,
1153 attention_mask,
1154 global_attention_mask,
1155 token_type_ids,
1156 position_ids,
1157 input_embeds,
1158 train,
1159 )?;
1160
1161 let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1162 let logits = sequence_output.split(1, -1);
1163 let (start_logits, end_logits) = (&logits[0], &logits[1]);
1164 let start_logits = start_logits.squeeze_dim(-1);
1165 let end_logits = end_logits.squeeze_dim(-1);
1166
1167 Ok(LongformerQuestionAnsweringOutput {
1168 start_logits,
1169 end_logits,
1170 all_hidden_states: base_model_output.all_hidden_states,
1171 all_attentions: base_model_output.all_attentions,
1172 all_global_attentions: base_model_output.all_global_attentions,
1173 })
1174 }
1175}
1176
1177/// # Longformer for token classification (e.g. NER, POS)
1178/// Token-level classifier predicting a label for each token provided.
1179/// It is made of the following blocks:
1180/// - `longformer`: Base Longformer model
1181/// - `classifier`: Linear layer for token classification
1182pub struct LongformerForTokenClassification {
1183 longformer: LongformerModel,
1184 dropout: Dropout,
1185 classifier: nn::Linear,
1186}
1187
1188impl LongformerForTokenClassification {
1189 /// Build a new `LongformerForTokenClassification`
1190 ///
1191 /// # Arguments
1192 ///
1193 /// * `p` - Variable store path for the root of the Longformer model
1194 /// * `config` - `LongformerConfig` object defining the model architecture
1195 ///
1196 /// # Example
1197 ///
1198 /// ```no_run
1199 /// use rust_bert::longformer::{LongformerConfig, LongformerForTokenClassification};
1200 /// use rust_bert::Config;
1201 /// use std::path::Path;
1202 /// use tch::{nn, Device};
1203 ///
1204 /// let config_path = Path::new("path/to/config.json");
1205 /// let device = Device::Cpu;
1206 /// let p = nn::VarStore::new(device);
1207 /// let config = LongformerConfig::from_file(config_path);
1208 /// let longformer_model = LongformerForTokenClassification::new(&p.root(), &config).unwrap();
1209 /// ```
1210 pub fn new<'p, P>(
1211 p: P,
1212 config: &LongformerConfig,
1213 ) -> Result<LongformerForTokenClassification, RustBertError>
1214 where
1215 P: Borrow<nn::Path<'p>>,
1216 {
1217 let p = p.borrow();
1218
1219 let longformer = LongformerModel::new(p / "longformer", config, false);
1220 let dropout = Dropout::new(config.hidden_dropout_prob);
1221
1222 let num_labels = config
1223 .id2label
1224 .as_ref()
1225 .ok_or_else(|| {
1226 RustBertError::InvalidConfigurationError(
1227 "num_labels not provided in configuration".to_string(),
1228 )
1229 })?
1230 .len() as i64;
1231
1232 let classifier = nn::linear(
1233 p / "classifier",
1234 config.hidden_size,
1235 num_labels,
1236 Default::default(),
1237 );
1238
1239 Ok(LongformerForTokenClassification {
1240 longformer,
1241 dropout,
1242 classifier,
1243 })
1244 }
1245
1246 /// Forward pass through the model
1247 ///
1248 /// # Arguments
1249 ///
1250 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
1251 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
1252 /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
1253 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1254 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1255 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1256 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1257 ///
1258 /// # Returns
1259 ///
1260 /// * `LongformerTokenClassificationOutput` containing:
1261 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
1262 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1263 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
1264 /// - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*) where x is the number of tokens with global attention
1265 ///
1266 /// # Example
1267 ///
1268 /// ```no_run
1269 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
1270 /// # use rust_bert::Config;
1271 /// # use std::path::Path;
1272 /// # use tch::kind::Kind::{Int64, Double};
1273 /// use rust_bert::longformer::{LongformerConfig, LongformerForTokenClassification};
1274 /// # let config_path = Path::new("path/to/config.json");
1275 /// # let vocab_path = Path::new("path/to/vocab.txt");
1276 /// # let device = Device::Cpu;
1277 /// # let vs = nn::VarStore::new(device);
1278 /// # let config = LongformerConfig::from_file(config_path);
1279 /// let longformer_model = LongformerForTokenClassification::new(&vs.root(), &config).unwrap();
1280 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
1281 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1282 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1283 /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1284 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1285 ///
1286 /// let model_output = no_grad(|| {
1287 /// longformer_model
1288 /// .forward_t(
1289 /// Some(&input_tensor),
1290 /// Some(&attention_mask),
1291 /// Some(&global_attention_mask),
1292 /// None,
1293 /// None,
1294 /// None,
1295 /// false,
1296 /// )
1297 /// .unwrap()
1298 /// });
1299 /// ```
1300 pub fn forward_t(
1301 &self,
1302 input_ids: Option<&Tensor>,
1303 attention_mask: Option<&Tensor>,
1304 global_attention_mask: Option<&Tensor>,
1305 token_type_ids: Option<&Tensor>,
1306 position_ids: Option<&Tensor>,
1307 input_embeds: Option<&Tensor>,
1308 train: bool,
1309 ) -> Result<LongformerTokenClassificationOutput, RustBertError> {
1310 let base_model_output = self.longformer.forward_t(
1311 input_ids,
1312 attention_mask,
1313 global_attention_mask,
1314 token_type_ids,
1315 position_ids,
1316 input_embeds,
1317 train,
1318 )?;
1319
1320 let logits = base_model_output
1321 .hidden_state
1322 .apply_t(&self.dropout, train)
1323 .apply(&self.classifier);
1324
1325 Ok(LongformerTokenClassificationOutput {
1326 logits,
1327 all_hidden_states: base_model_output.all_hidden_states,
1328 all_attentions: base_model_output.all_attentions,
1329 all_global_attentions: base_model_output.all_global_attentions,
1330 })
1331 }
1332}
1333
1334/// # Longformer for multiple choices
1335/// Multiple choices model using a Longformer base model and a linear classifier.
1336/// Input should be in the form `<cls> Context <sep><sep> Possible choice <sep>`. The choice is made along the batch axis,
1337/// assuming all elements of the batch are alternatives to be chosen from for a given context.
1338/// It is made of the following blocks:
1339/// - `longformer`: Base LongformerModel model
1340/// - `classifier`: Linear layer for multiple choices
1341pub struct LongformerForMultipleChoice {
1342 longformer: LongformerModel,
1343 dropout: Dropout,
1344 classifier: nn::Linear,
1345 sep_token_id: i64,
1346}
1347
1348impl LongformerForMultipleChoice {
1349 /// Build a new `LongformerForMultipleChoice`
1350 ///
1351 /// # Arguments
1352 ///
1353 /// * `p` - Variable store path for the root of the Longformer model
1354 /// * `config` - `LongformerConfig` object defining the model architecture
1355 ///
1356 /// # Example
1357 ///
1358 /// ```no_run
1359 /// use rust_bert::longformer::{LongformerConfig, LongformerForMultipleChoice};
1360 /// use rust_bert::Config;
1361 /// use std::path::Path;
1362 /// use tch::{nn, Device};
1363 ///
1364 /// let config_path = Path::new("path/to/config.json");
1365 /// let device = Device::Cpu;
1366 /// let p = nn::VarStore::new(device);
1367 /// let config = LongformerConfig::from_file(config_path);
1368 /// let longformer_model = LongformerForMultipleChoice::new(&p.root(), &config);
1369 /// ```
1370 pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerForMultipleChoice
1371 where
1372 P: Borrow<nn::Path<'p>>,
1373 {
1374 let p = p.borrow();
1375
1376 let longformer = LongformerModel::new(p / "longformer", config, true);
1377 let dropout = Dropout::new(config.hidden_dropout_prob);
1378 let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
1379 let sep_token_id = config.sep_token_id;
1380
1381 LongformerForMultipleChoice {
1382 longformer,
1383 dropout,
1384 classifier,
1385 sep_token_id,
1386 }
1387 }
1388
1389 /// Forward pass through the model
1390 ///
1391 /// # Arguments
1392 ///
1393 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
1394 /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 0 will be masked.
1395 /// * `global_attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*). Positions with a mask with value 1 will attend all other positions in the sequence.
1396 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1397 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1398 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1399 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1400 ///
1401 /// # Returns
1402 ///
1403 /// * `LongformerSequenceClassificationOutput` containing:
1404 /// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
1405 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1406 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, * attention_window_size*, *x + attention_window_size + 1*) where x is the number of tokens with global attention
1407 /// - `all_global_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *num_heads*, *sequence_length*, *attention_window_size*, *x*) where x is the number of tokens with global attention
1408 ///
1409 /// # Example
1410 ///
1411 /// ```no_run
1412 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
1413 /// # use rust_bert::Config;
1414 /// # use std::path::Path;
1415 /// # use tch::kind::Kind::{Int64, Double};
1416 /// use rust_bert::longformer::{LongformerConfig, LongformerForMultipleChoice};
1417 /// # let config_path = Path::new("path/to/config.json");
1418 /// # let vocab_path = Path::new("path/to/vocab.txt");
1419 /// # let device = Device::Cpu;
1420 /// # let vs = nn::VarStore::new(device);
1421 /// # let config = LongformerConfig::from_file(config_path);
1422 /// let longformer_model = LongformerForMultipleChoice::new(&vs.root(), &config);
1423 /// let (batch_size, sequence_length, target_sequence_length) = (64, 128, 32);
1424 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1425 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1426 /// let global_attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1427 /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
1428 ///
1429 /// let model_output = no_grad(|| {
1430 /// longformer_model
1431 /// .forward_t(
1432 /// Some(&input_tensor),
1433 /// Some(&attention_mask),
1434 /// Some(&global_attention_mask),
1435 /// None,
1436 /// None,
1437 /// None,
1438 /// false,
1439 /// )
1440 /// .unwrap()
1441 /// });
1442 /// ```
1443 pub fn forward_t(
1444 &self,
1445 input_ids: Option<&Tensor>,
1446 attention_mask: Option<&Tensor>,
1447 global_attention_mask: Option<&Tensor>,
1448 token_type_ids: Option<&Tensor>,
1449 position_ids: Option<&Tensor>,
1450 input_embeds: Option<&Tensor>,
1451 train: bool,
1452 ) -> Result<LongformerSequenceClassificationOutput, RustBertError> {
1453 let num_choices = match (input_ids, input_embeds) {
1454 (Some(input_ids_value), None) => input_ids_value.size()[1],
1455 (None, Some(input_embeds_value)) => input_embeds_value.size()[1],
1456 (Some(_), Some(_)) => {
1457 return Err(RustBertError::ValueError(
1458 "Only one of input ids or input embeddings may be set".into(),
1459 ));
1460 }
1461 (None, None) => {
1462 return Err(RustBertError::ValueError(
1463 "At least one of input ids or input embeddings must be set".into(),
1464 ));
1465 }
1466 };
1467
1468 let calc_global_attention_mask = if global_attention_mask.is_none() {
1469 if let Some(input_ids) = input_ids {
1470 let mut masks = Vec::with_capacity(num_choices as usize);
1471 for i in 0..num_choices {
1472 masks.push(compute_global_attention_mask(
1473 &input_ids.select(1, i),
1474 self.sep_token_id,
1475 false,
1476 ));
1477 }
1478 Some(Tensor::stack(masks.as_slice(), 1))
1479 } else {
1480 return Err(RustBertError::ValueError(
1481 "Inputs ids must be provided to LongformerQuestionAnsweringOutput if the global_attention_mask is not given".into(),
1482 ));
1483 }
1484 } else {
1485 None
1486 };
1487
1488 let flat_input_ids =
1489 input_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1490 let flat_attention_mask =
1491 attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1492 let flat_token_type_ids =
1493 token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1494 let flat_position_ids =
1495 position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1496 let flat_input_embeds =
1497 input_embeds.map(|tensor| tensor.view((-1, tensor.size()[1], tensor.size()[2])));
1498
1499 let global_attention_mask = if global_attention_mask.is_some() {
1500 global_attention_mask
1501 } else {
1502 calc_global_attention_mask.as_ref()
1503 };
1504 let flat_global_attention_mask =
1505 global_attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1506
1507 let base_model_output = self.longformer.forward_t(
1508 flat_input_ids.as_ref(),
1509 flat_attention_mask.as_ref(),
1510 flat_global_attention_mask.as_ref(),
1511 flat_token_type_ids.as_ref(),
1512 flat_position_ids.as_ref(),
1513 flat_input_embeds.as_ref(),
1514 train,
1515 )?;
1516
1517 let logits = base_model_output
1518 .pooled_output
1519 .unwrap()
1520 .apply_t(&self.dropout, train)
1521 .apply(&self.classifier)
1522 .view((-1, num_choices));
1523
1524 Ok(LongformerSequenceClassificationOutput {
1525 logits,
1526 all_hidden_states: base_model_output.all_hidden_states,
1527 all_attentions: base_model_output.all_attentions,
1528 all_global_attentions: base_model_output.all_global_attentions,
1529 })
1530 }
1531}
1532
1533/// Container for the Longformer model output.
1534pub struct LongformerModelOutput {
1535 /// Last hidden states from the model
1536 pub hidden_state: Tensor,
1537 /// Pooled output (hidden state for the first token)
1538 pub pooled_output: Option<Tensor>,
1539 /// Hidden states for all intermediate layers
1540 pub all_hidden_states: Option<Vec<Tensor>>,
1541 /// Attention weights for all intermediate layers
1542 pub all_attentions: Option<Vec<Tensor>>,
1543 /// Global attention weights for all intermediate layers
1544 pub all_global_attentions: Option<Vec<Tensor>>,
1545}
1546
1547/// Container for the Longformer masked LM model output.
1548pub struct LongformerMaskedLMOutput {
1549 /// Logits for the vocabulary items at each sequence position
1550 pub prediction_scores: Tensor,
1551 /// Hidden states for all intermediate layers
1552 pub all_hidden_states: Option<Vec<Tensor>>,
1553 /// Attention weights for all intermediate layers
1554 pub all_attentions: Option<Vec<Tensor>>,
1555 /// Global attention weights for all intermediate layers
1556 pub all_global_attentions: Option<Vec<Tensor>>,
1557}
1558
1559/// Container for the Longformer sequence classification model output.
1560pub struct LongformerSequenceClassificationOutput {
1561 /// Logits for each sequence item (token) for each target class
1562 pub logits: Tensor,
1563 /// Hidden states for all intermediate layers
1564 pub all_hidden_states: Option<Vec<Tensor>>,
1565 /// Attention weights for all intermediate layers
1566 pub all_attentions: Option<Vec<Tensor>>,
1567 /// Global attention weights for all intermediate layers
1568 pub all_global_attentions: Option<Vec<Tensor>>,
1569}
1570
1571/// Container for the Longformer token classification model output.
1572pub struct LongformerTokenClassificationOutput {
1573 /// Logits for each sequence item (token) for each target class
1574 pub logits: Tensor,
1575 /// Hidden states for all intermediate layers
1576 pub all_hidden_states: Option<Vec<Tensor>>,
1577 /// Attention weights for all intermediate layers
1578 pub all_attentions: Option<Vec<Tensor>>,
1579 /// Global attention weights for all intermediate layers
1580 pub all_global_attentions: Option<Vec<Tensor>>,
1581}
1582
1583/// Container for the Longformer question answering model output.
1584pub struct LongformerQuestionAnsweringOutput {
1585 /// Logits for the start position for token of each input sequence
1586 pub start_logits: Tensor,
1587 /// Logits for the end position for token of each input sequence
1588 pub end_logits: Tensor,
1589 /// Hidden states for all intermediate layers
1590 pub all_hidden_states: Option<Vec<Tensor>>,
1591 /// Attention weights for all intermediate layers
1592 pub all_attentions: Option<Vec<Tensor>>,
1593 /// Global attention weights for all intermediate layers
1594 pub all_global_attentions: Option<Vec<Tensor>>,
1595}