rust_bert/models/deberta/deberta_model.rs
1// Copyright 2020, Microsoft and the HuggingFace Inc. team.
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::bert::{
14 BertQuestionAnsweringOutput, BertSequenceClassificationOutput, BertTokenClassificationOutput,
15};
16use crate::common::activations::TensorFunction;
17use crate::common::dropout::{Dropout, XDropout};
18use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
19use crate::common::kind::get_min;
20use crate::deberta::embeddings::DebertaEmbeddings;
21use crate::deberta::encoder::{DebertaEncoder, DebertaEncoderOutput};
22use crate::{Activation, Config, RustBertError};
23use serde::de::{SeqAccess, Visitor};
24use serde::{de, Deserialize, Deserializer, Serialize};
25use std::borrow::Borrow;
26use std::collections::HashMap;
27use std::fmt;
28use std::str::FromStr;
29use tch::nn::{Init, Module, ModuleT};
30use tch::{nn, Kind, Tensor};
31
32/// # DeBERTa Pretrained model weight files
33pub struct DebertaModelResources;
34
35/// # DeBERTa Pretrained model config files
36pub struct DebertaConfigResources;
37
38/// # DeBERTa Pretrained model vocab files
39pub struct DebertaVocabResources;
40
41/// # DeBERTa Pretrained model merges files
42pub struct DebertaMergesResources;
43
44impl DebertaModelResources {
45 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
46 pub const DEBERTA_BASE: (&'static str, &'static str) = (
47 "deberta-base/model",
48 "https://huggingface.co/microsoft/deberta-base/resolve/main/rust_model.ot",
49 );
50 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
51 pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
52 "deberta-base-mnli/model",
53 "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/rust_model.ot",
54 );
55}
56
57impl DebertaConfigResources {
58 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
59 pub const DEBERTA_BASE: (&'static str, &'static str) = (
60 "deberta-base/config",
61 "https://huggingface.co/microsoft/deberta-base/resolve/main/config.json",
62 );
63 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
64 pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
65 "deberta-base-mnli/config",
66 "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/config.json",
67 );
68}
69
70impl DebertaVocabResources {
71 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
72 pub const DEBERTA_BASE: (&'static str, &'static str) = (
73 "deberta-base/vocab",
74 "https://huggingface.co/microsoft/deberta-base/resolve/main/vocab.json",
75 );
76 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
77 pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
78 "deberta-base-mnli/vocab",
79 "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/vocab.json",
80 );
81}
82
83impl DebertaMergesResources {
84 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
85 pub const DEBERTA_BASE: (&'static str, &'static str) = (
86 "deberta-base/merges",
87 "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
88 );
89 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
90 pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
91 "deberta-base-mnli/merges",
92 "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/merges.txt",
93 );
94}
95
96#[allow(non_camel_case_types)]
97#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq, Eq)]
98/// # Position attention type to use for the DeBERTa model.
99pub enum PositionAttentionType {
100 p2c,
101 c2p,
102 p2p,
103}
104
105impl FromStr for PositionAttentionType {
106 type Err = RustBertError;
107
108 fn from_str(s: &str) -> Result<Self, Self::Err> {
109 match s {
110 "p2c" => Ok(PositionAttentionType::p2c),
111 "c2p" => Ok(PositionAttentionType::c2p),
112 "p2p" => Ok(PositionAttentionType::p2p),
113 _ => Err(RustBertError::InvalidConfigurationError(format!(
114 "Position attention type `{s}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
115 ))),
116 }
117 }
118}
119
120#[allow(non_camel_case_types)]
121#[derive(Clone, Debug, Serialize, Deserialize, Default)]
122pub struct PositionAttentionTypes {
123 types: Vec<PositionAttentionType>,
124}
125
126impl FromStr for PositionAttentionTypes {
127 type Err = RustBertError;
128
129 fn from_str(s: &str) -> Result<Self, Self::Err> {
130 let types = s
131 .to_lowercase()
132 .split('|')
133 .map(PositionAttentionType::from_str)
134 .collect::<Result<Vec<_>, _>>()?;
135 Ok(PositionAttentionTypes { types })
136 }
137}
138
139impl PositionAttentionTypes {
140 pub fn has_type(&self, attention_type: PositionAttentionType) -> bool {
141 self.types
142 .iter()
143 .any(|self_type| *self_type == attention_type)
144 }
145
146 pub fn len(&self) -> usize {
147 self.types.len()
148 }
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152/// # DeBERTa model configuration
153/// Defines the DeBERTa model architecture (e.g. number of layers, hidden layer size, label mapping...)
154pub struct DebertaConfig {
155 pub hidden_act: Activation,
156 pub attention_probs_dropout_prob: f64,
157 pub hidden_dropout_prob: f64,
158 pub hidden_size: i64,
159 pub initializer_range: f64,
160 pub intermediate_size: i64,
161 pub max_position_embeddings: i64,
162 pub num_attention_heads: i64,
163 pub num_hidden_layers: i64,
164 pub type_vocab_size: i64,
165 pub vocab_size: i64,
166 pub position_biased_input: Option<bool>,
167 #[serde(default, deserialize_with = "deserialize_attention_type")]
168 pub pos_att_type: Option<PositionAttentionTypes>,
169 pub pooler_dropout: Option<f64>,
170 pub pooler_hidden_act: Option<Activation>,
171 pub pooler_hidden_size: Option<i64>,
172 pub layer_norm_eps: Option<f64>,
173 pub pad_token_id: Option<i64>,
174 pub relative_attention: Option<bool>,
175 pub max_relative_positions: Option<i64>,
176 pub embedding_size: Option<i64>,
177 pub talking_head: Option<bool>,
178 pub output_hidden_states: Option<bool>,
179 pub output_attentions: Option<bool>,
180 pub classifier_dropout: Option<f64>,
181 pub is_decoder: Option<bool>,
182 pub id2label: Option<HashMap<i64, String>>,
183 pub label2id: Option<HashMap<String, i64>>,
184 pub share_att_key: Option<bool>,
185 pub position_buckets: Option<i64>,
186}
187
188impl Default for DebertaConfig {
189 fn default() -> Self {
190 DebertaConfig {
191 hidden_act: Activation::gelu,
192 attention_probs_dropout_prob: 0.1,
193 hidden_dropout_prob: 0.1,
194 hidden_size: 768,
195 initializer_range: 0.02,
196 intermediate_size: 3072,
197 max_position_embeddings: 512,
198 num_attention_heads: 12,
199 num_hidden_layers: 12,
200 type_vocab_size: 0,
201 vocab_size: 50265,
202 position_biased_input: Some(true),
203 pos_att_type: None,
204 pooler_dropout: Some(0.0),
205 pooler_hidden_act: Some(Activation::gelu),
206 pooler_hidden_size: Some(768),
207 layer_norm_eps: Some(1e-7),
208 pad_token_id: Some(0),
209 relative_attention: Some(false),
210 max_relative_positions: Some(-1),
211 embedding_size: None,
212 talking_head: None,
213 output_hidden_states: None,
214 output_attentions: None,
215 classifier_dropout: None,
216 is_decoder: None,
217 id2label: None,
218 label2id: None,
219 share_att_key: None,
220 position_buckets: None,
221 }
222 }
223}
224
225pub fn deserialize_attention_type<'de, D>(
226 deserializer: D,
227) -> Result<Option<PositionAttentionTypes>, D::Error>
228where
229 D: Deserializer<'de>,
230{
231 struct AttentionTypeVisitor;
232
233 impl<'de> Visitor<'de> for AttentionTypeVisitor {
234 type Value = PositionAttentionTypes;
235
236 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
237 formatter.write_str("null, string or sequence")
238 }
239
240 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
241 where
242 E: de::Error,
243 {
244 Ok(FromStr::from_str(value).unwrap())
245 }
246
247 fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
248 where
249 S: SeqAccess<'de>,
250 {
251 let mut types = vec![];
252 while let Some(attention_type) = seq.next_element::<String>()? {
253 types.push(FromStr::from_str(attention_type.as_str()).unwrap())
254 }
255 Ok(PositionAttentionTypes { types })
256 }
257 }
258
259 deserializer.deserialize_any(AttentionTypeVisitor).map(Some)
260}
261
262impl Config for DebertaConfig {}
263
264pub fn x_softmax(input: &Tensor, mask: &Tensor, dim: i64) -> Tensor {
265 let inverse_mask = ((1 - mask) as Tensor).to_kind(Kind::Bool);
266 input
267 .masked_fill(&inverse_mask, get_min(input.kind()).unwrap())
268 .softmax(dim, input.kind())
269 .masked_fill(&inverse_mask, 0.0)
270}
271
272pub trait BaseDebertaLayerNorm {
273 fn new<'p, P>(p: P, size: i64, variance_epsilon: f64) -> Self
274 where
275 P: Borrow<nn::Path<'p>>;
276}
277
278#[derive(Debug)]
279pub struct DebertaLayerNorm {
280 weight: Tensor,
281 bias: Tensor,
282 variance_epsilon: f64,
283}
284
285impl BaseDebertaLayerNorm for DebertaLayerNorm {
286 fn new<'p, P>(p: P, size: i64, variance_epsilon: f64) -> DebertaLayerNorm
287 where
288 P: Borrow<nn::Path<'p>>,
289 {
290 let p = p.borrow();
291 let weight = p.var("weight", &[size], Init::Const(1.0));
292 let bias = p.var("bias", &[size], Init::Const(0.0));
293 DebertaLayerNorm {
294 weight,
295 bias,
296 variance_epsilon,
297 }
298 }
299}
300
301impl Module for DebertaLayerNorm {
302 fn forward(&self, hidden_states: &Tensor) -> Tensor {
303 let input_type = hidden_states.kind();
304 let hidden_states = hidden_states.to_kind(Kind::Float);
305 let mean = hidden_states.mean_dim([-1].as_slice(), true, hidden_states.kind());
306 let variance = (&hidden_states - &mean).pow_tensor_scalar(2.0).mean_dim(
307 [-1].as_slice(),
308 true,
309 hidden_states.kind(),
310 );
311 let hidden_states = (hidden_states - mean)
312 / (variance + self.variance_epsilon)
313 .sqrt()
314 .to_kind(input_type);
315 &self.weight * hidden_states + &self.bias
316 }
317}
318
319/// # DeBERTa Base model
320/// Base architecture for DeBERTa models. Task-specific models will be built from this common base model
321/// It is made of the following blocks:
322/// - `embeddings`: `DeBERTa` embeddings
323/// - `encoder`: `DeBERTaEncoder` (transformer) made of a vector of layers.
324pub struct DebertaModel {
325 embeddings: DebertaEmbeddings,
326 encoder: DebertaEncoder,
327}
328
329impl DebertaModel {
330 /// Build a new `DebertaModel`
331 ///
332 /// # Arguments
333 ///
334 /// * `p` - Variable store path for the root of the BERT model
335 /// * `config` - `DebertaConfig` object defining the model architecture and decoder status
336 ///
337 /// # Example
338 ///
339 /// ```no_run
340 /// use rust_bert::deberta::{DebertaConfig, DebertaModel};
341 /// use rust_bert::Config;
342 /// use std::path::Path;
343 /// use tch::{nn, Device};
344 ///
345 /// let config_path = Path::new("path/to/config.json");
346 /// let device = Device::Cpu;
347 /// let p = nn::VarStore::new(device);
348 /// let config = DebertaConfig::from_file(config_path);
349 /// let model: DebertaModel = DebertaModel::new(&p.root() / "deberta", &config);
350 /// ```
351 pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaModel
352 where
353 P: Borrow<nn::Path<'p>>,
354 {
355 let p = p.borrow();
356
357 let embeddings = DebertaEmbeddings::new(p / "embeddings", config);
358 let encoder = DebertaEncoder::new(p / "encoder", config);
359
360 DebertaModel {
361 embeddings,
362 encoder,
363 }
364 }
365
366 /// Forward pass through the model
367 ///
368 /// # Arguments
369 ///
370 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
371 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
372 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
373 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
374 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
375 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
376 ///
377 /// # Returns
378 ///
379 /// * `DebertaOutput` containing:
380 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
381 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
382 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
383 ///
384 /// # Example
385 ///
386 /// ```no_run
387 /// # use rust_bert::deberta::{DebertaModel, DebertaConfig};
388 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
389 /// # use rust_bert::Config;
390 /// # use std::path::Path;
391 /// # let config_path = Path::new("path/to/config.json");
392 /// # let device = Device::Cpu;
393 /// # let vs = nn::VarStore::new(device);
394 /// # let config = DebertaConfig::from_file(config_path);
395 /// # let model = DebertaModel::new(&vs.root(), &config);
396 /// let (batch_size, sequence_length) = (64, 128);
397 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
398 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Kind::Int64, device));
399 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
400 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
401 /// .expand(&[batch_size, sequence_length], true);
402 ///
403 /// let model_output = no_grad(|| {
404 /// model
405 /// .forward_t(
406 /// Some(&input_tensor),
407 /// Some(&attention_mask),
408 /// Some(&token_type_ids),
409 /// Some(&position_ids),
410 /// None,
411 /// false,
412 /// )
413 /// .unwrap()
414 /// });
415 /// ```
416 pub fn forward_t(
417 &self,
418 input_ids: Option<&Tensor>,
419 attention_mask: Option<&Tensor>,
420 token_type_ids: Option<&Tensor>,
421 position_ids: Option<&Tensor>,
422 input_embeds: Option<&Tensor>,
423 train: bool,
424 ) -> Result<DebertaModelOutput, RustBertError> {
425 let (input_shape, device) =
426 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
427
428 let calc_attention_mask = if attention_mask.is_none() {
429 Some(Tensor::ones(input_shape.as_slice(), (Kind::Bool, device)))
430 } else {
431 None
432 };
433
434 let attention_mask =
435 attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
436
437 let embedding_output = self.embeddings.forward_t(
438 input_ids,
439 token_type_ids,
440 position_ids,
441 attention_mask,
442 input_embeds,
443 train,
444 )?;
445
446 let encoder_output =
447 self.encoder
448 .forward_t(&embedding_output, attention_mask, None, None, train)?;
449
450 Ok(encoder_output)
451 }
452}
453
454#[derive(Debug)]
455struct DebertaPredictionHeadTransform {
456 dense: nn::Linear,
457 activation: TensorFunction,
458 layer_norm: nn::LayerNorm,
459}
460
461impl DebertaPredictionHeadTransform {
462 pub fn new<'p, P>(
463 p: P,
464 config: &DebertaConfig,
465 transform_bias: bool,
466 ) -> DebertaPredictionHeadTransform
467 where
468 P: Borrow<nn::Path<'p>>,
469 {
470 let p = p.borrow();
471
472 let dense = nn::linear(
473 p / "dense",
474 config.hidden_size,
475 config.hidden_size,
476 nn::LinearConfig {
477 bias: transform_bias,
478 ..Default::default()
479 },
480 );
481 let activation = config.hidden_act.get_function();
482 let layer_norm_config = nn::LayerNormConfig {
483 eps: 1e-7,
484 ..Default::default()
485 };
486 let layer_norm =
487 nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
488
489 DebertaPredictionHeadTransform {
490 dense,
491 activation,
492 layer_norm,
493 }
494 }
495}
496
497impl Module for DebertaPredictionHeadTransform {
498 fn forward(&self, hidden_states: &Tensor) -> Tensor {
499 self.activation.get_fn()(&hidden_states.apply(&self.dense)).apply(&self.layer_norm)
500 }
501}
502
503#[derive(Debug)]
504pub(crate) struct DebertaLMPredictionHead {
505 transform: DebertaPredictionHeadTransform,
506 decoder: nn::Linear,
507}
508
509impl DebertaLMPredictionHead {
510 pub fn new<'p, P>(p: P, config: &DebertaConfig, transform_bias: bool) -> DebertaLMPredictionHead
511 where
512 P: Borrow<nn::Path<'p>>,
513 {
514 let p = p.borrow();
515
516 let transform =
517 DebertaPredictionHeadTransform::new(p / "transform", config, transform_bias);
518 let decoder = nn::linear(
519 p / "decoder",
520 config.hidden_size,
521 config.vocab_size,
522 Default::default(),
523 );
524
525 DebertaLMPredictionHead { transform, decoder }
526 }
527}
528
529impl Module for DebertaLMPredictionHead {
530 fn forward(&self, hidden_states: &Tensor) -> Tensor {
531 hidden_states.apply(&self.transform).apply(&self.decoder)
532 }
533}
534
535/// # DeBERTa for masked language model
536/// Base DeBERTa model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
537/// It is made of the following blocks:
538/// - `deberta`: Base DeBERTa model
539/// - `cls`: LM prediction head
540pub struct DebertaForMaskedLM {
541 deberta: DebertaModel,
542 cls: DebertaLMPredictionHead,
543}
544
545impl DebertaForMaskedLM {
546 /// Build a new `DebertaForMaskedLM`
547 ///
548 /// # Arguments
549 ///
550 /// * `p` - Variable store path for the root of the BertForMaskedLM model
551 /// * `config` - `DebertaConfig` object defining the model architecture and vocab size
552 ///
553 /// # Example
554 ///
555 /// ```no_run
556 /// use rust_bert::deberta::{DebertaConfig, DebertaForMaskedLM};
557 /// use rust_bert::Config;
558 /// use std::path::Path;
559 /// use tch::{nn, Device};
560 ///
561 /// let config_path = Path::new("path/to/config.json");
562 /// let device = Device::Cpu;
563 /// let p = nn::VarStore::new(device);
564 /// let config = DebertaConfig::from_file(config_path);
565 /// let model = DebertaForMaskedLM::new(&p.root(), &config);
566 /// ```
567 pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForMaskedLM
568 where
569 P: Borrow<nn::Path<'p>>,
570 {
571 let p = p.borrow();
572
573 let deberta = DebertaModel::new(p / "deberta", config);
574 let cls = DebertaLMPredictionHead::new(p.sub("cls").sub("predictions"), config, false);
575
576 DebertaForMaskedLM { deberta, cls }
577 }
578
579 /// Forward pass through the model
580 ///
581 /// # Arguments
582 ///
583 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
584 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
585 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
586 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
587 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
588 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
589 ///
590 /// # Returns
591 ///
592 /// * `DebertaMaskedLMOutput` containing:
593 /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
594 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
595 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
596 ///
597 /// # Example
598 ///
599 /// ```no_run
600 /// # use rust_bert::deberta::{DebertaForMaskedLM, DebertaConfig};
601 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
602 /// # use rust_bert::Config;
603 /// # use std::path::Path;
604 /// # let config_path = Path::new("path/to/config.json");
605 /// # let device = Device::Cpu;
606 /// # let vs = nn::VarStore::new(device);
607 /// # let config = DebertaConfig::from_file(config_path);
608 /// # let model = DebertaForMaskedLM::new(&vs.root(), &config);
609 /// let (batch_size, sequence_length) = (64, 128);
610 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
611 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
612 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
613 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
614 /// .expand(&[batch_size, sequence_length], true);
615 ///
616 /// let model_output = no_grad(|| {
617 /// model.forward_t(
618 /// Some(&input_tensor),
619 /// Some(&mask),
620 /// Some(&token_type_ids),
621 /// Some(&position_ids),
622 /// None,
623 /// false,
624 /// )
625 /// });
626 /// ```
627 pub fn forward_t(
628 &self,
629 input_ids: Option<&Tensor>,
630 attention_mask: Option<&Tensor>,
631 token_type_ids: Option<&Tensor>,
632 position_ids: Option<&Tensor>,
633 input_embeds: Option<&Tensor>,
634 train: bool,
635 ) -> Result<DebertaMaskedLMOutput, RustBertError> {
636 let model_outputs = self.deberta.forward_t(
637 input_ids,
638 attention_mask,
639 token_type_ids,
640 position_ids,
641 input_embeds,
642 train,
643 )?;
644
645 let logits = model_outputs.hidden_state.apply(&self.cls);
646 Ok(DebertaMaskedLMOutput {
647 logits,
648 all_hidden_states: model_outputs.all_hidden_states,
649 all_attentions: model_outputs.all_attentions,
650 })
651 }
652}
653
654#[derive(Debug)]
655pub struct ContextPooler {
656 dense: nn::Linear,
657 dropout: XDropout,
658 activation: TensorFunction,
659 pub output_dim: i64,
660}
661
662impl ContextPooler {
663 pub fn new<'p, P>(p: P, config: &DebertaConfig) -> ContextPooler
664 where
665 P: Borrow<nn::Path<'p>>,
666 {
667 let p = p.borrow();
668 let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
669
670 let dense = nn::linear(
671 p / "dense",
672 pooler_hidden_size,
673 pooler_hidden_size,
674 Default::default(),
675 );
676 let dropout = XDropout::new(config.pooler_dropout.unwrap_or(0.0));
677 let activation = config
678 .pooler_hidden_act
679 .unwrap_or(Activation::gelu)
680 .get_function();
681
682 ContextPooler {
683 dense,
684 dropout,
685 activation,
686 output_dim: pooler_hidden_size,
687 }
688 }
689}
690
691impl ModuleT for ContextPooler {
692 fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
693 self.activation.get_fn()(
694 &hidden_states
695 .select(1, 0)
696 .apply_t(&self.dropout, train)
697 .apply(&self.dense),
698 )
699 }
700}
701
702/// # DeBERTa for sequence classification
703/// Base DeBERTa model with a classifier head to perform sentence or document-level classification
704/// It is made of the following blocks:
705/// - `deberta`: Base BertModel
706/// - `classifier`: BERT linear layer for classification
707pub struct DebertaForSequenceClassification {
708 deberta: DebertaModel,
709 pooler: ContextPooler,
710 classifier: nn::Linear,
711 dropout: XDropout,
712}
713
714impl DebertaForSequenceClassification {
715 /// Build a new `DebertaForSequenceClassification`
716 ///
717 /// # Arguments
718 ///
719 /// * `p` - Variable store path for the root of the DebertaForSequenceClassification model
720 /// * `config` - `DebertaConfig` object defining the model architecture and number of classes
721 ///
722 /// # Example
723 ///
724 /// ```no_run
725 /// use rust_bert::deberta::{DebertaConfig, DebertaForSequenceClassification};
726 /// use rust_bert::Config;
727 /// use std::path::Path;
728 /// use tch::{nn, Device};
729 ///
730 /// let config_path = Path::new("path/to/config.json");
731 /// let device = Device::Cpu;
732 /// let p = nn::VarStore::new(device);
733 /// let config = DebertaConfig::from_file(config_path);
734 /// let model = DebertaForSequenceClassification::new(&p.root(), &config).unwrap();
735 /// ```
736 pub fn new<'p, P>(
737 p: P,
738 config: &DebertaConfig,
739 ) -> Result<DebertaForSequenceClassification, RustBertError>
740 where
741 P: Borrow<nn::Path<'p>>,
742 {
743 let p = p.borrow();
744
745 let deberta = DebertaModel::new(p / "deberta", config);
746 let pooler = ContextPooler::new(p / "pooler", config);
747 let dropout = XDropout::new(
748 config
749 .classifier_dropout
750 .unwrap_or(config.hidden_dropout_prob),
751 );
752
753 let num_labels = config
754 .id2label
755 .as_ref()
756 .ok_or_else(|| {
757 RustBertError::InvalidConfigurationError(
758 "num_labels not provided in configuration".to_string(),
759 )
760 })?
761 .len() as i64;
762
763 let classifier = nn::linear(
764 p / "classifier",
765 pooler.output_dim,
766 num_labels,
767 Default::default(),
768 );
769
770 Ok(DebertaForSequenceClassification {
771 deberta,
772 pooler,
773 classifier,
774 dropout,
775 })
776 }
777
778 /// Forward pass through the model
779 ///
780 /// # Arguments
781 ///
782 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
783 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
784 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
785 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
786 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
787 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
788 ///
789 /// # Returns
790 ///
791 /// * `DebertaSequenceClassificationOutput` containing:
792 /// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
793 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
794 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
795 ///
796 /// # Example
797 ///
798 /// ```no_run
799 /// # use rust_bert::deberta::{DebertaForSequenceClassification, DebertaConfig};
800 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
801 /// # use rust_bert::Config;
802 /// # use std::path::Path;
803 /// # let config_path = Path::new("path/to/config.json");
804 /// # let device = Device::Cpu;
805 /// # let vs = nn::VarStore::new(device);
806 /// # let config = DebertaConfig::from_file(config_path);
807 /// # let model = DebertaForSequenceClassification::new(&vs.root(), &config).unwrap();;
808 /// let (batch_size, sequence_length) = (64, 128);
809 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
810 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
811 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
812 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
813 /// .expand(&[batch_size, sequence_length], true);
814 ///
815 /// let model_output = no_grad(|| {
816 /// model.forward_t(
817 /// Some(&input_tensor),
818 /// Some(&mask),
819 /// Some(&token_type_ids),
820 /// Some(&position_ids),
821 /// None,
822 /// false,
823 /// )
824 /// });
825 /// ```
826 pub fn forward_t(
827 &self,
828 input_ids: Option<&Tensor>,
829 attention_mask: Option<&Tensor>,
830 token_type_ids: Option<&Tensor>,
831 position_ids: Option<&Tensor>,
832 input_embeds: Option<&Tensor>,
833 train: bool,
834 ) -> Result<DebertaSequenceClassificationOutput, RustBertError> {
835 let base_model_output = self.deberta.forward_t(
836 input_ids,
837 attention_mask,
838 token_type_ids,
839 position_ids,
840 input_embeds,
841 train,
842 )?;
843
844 let logits = base_model_output
845 .hidden_state
846 .apply_t(&self.pooler, train)
847 .apply_t(&self.dropout, train)
848 .apply(&self.classifier);
849
850 Ok(DebertaSequenceClassificationOutput {
851 logits,
852 all_hidden_states: base_model_output.all_hidden_states,
853 all_attentions: base_model_output.all_attentions,
854 })
855 }
856}
857
858/// # DeBERTa for token classification (e.g. NER, POS)
859/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
860/// not necessarily aligned with words in the sentence.
861/// It is made of the following blocks:
862/// - `deberta`: Base DeBERTa
863/// - `dropout`: Dropout layer before the last token-level predictions layer
864/// - `classifier`: Linear layer for token classification
865pub struct DebertaForTokenClassification {
866 deberta: DebertaModel,
867 dropout: Dropout,
868 classifier: nn::Linear,
869}
870
871impl DebertaForTokenClassification {
872 /// Build a new `DebertaForTokenClassification`
873 ///
874 /// # Arguments
875 ///
876 /// * `p` - Variable store path for the root of the Deberta model
877 /// * `config` - `DebertaConfig` object defining the model architecture
878 ///
879 /// # Example
880 ///
881 /// ```no_run
882 /// use rust_bert::deberta::{DebertaConfig, DebertaForTokenClassification};
883 /// use rust_bert::Config;
884 /// use std::path::Path;
885 /// use tch::{nn, Device};
886 ///
887 /// let config_path = Path::new("path/to/config.json");
888 /// let device = Device::Cpu;
889 /// let p = nn::VarStore::new(device);
890 /// let config = DebertaConfig::from_file(config_path);
891 /// let model = DebertaForTokenClassification::new(&p.root(), &config).unwrap();
892 /// ```
893 pub fn new<'p, P>(
894 p: P,
895 config: &DebertaConfig,
896 ) -> Result<DebertaForTokenClassification, RustBertError>
897 where
898 P: Borrow<nn::Path<'p>>,
899 {
900 let p = p.borrow();
901
902 let deberta = DebertaModel::new(p / "deberta", config);
903 let dropout = Dropout::new(config.hidden_dropout_prob);
904 let num_labels = config
905 .id2label
906 .as_ref()
907 .ok_or_else(|| {
908 RustBertError::InvalidConfigurationError(
909 "num_labels not provided in configuration".to_string(),
910 )
911 })?
912 .len() as i64;
913 let classifier = nn::linear(
914 p / "classifier",
915 config.hidden_size,
916 num_labels,
917 Default::default(),
918 );
919
920 Ok(DebertaForTokenClassification {
921 deberta,
922 dropout,
923 classifier,
924 })
925 }
926
927 /// Forward pass through the model
928 ///
929 /// # Arguments
930 ///
931 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
932 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
933 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
934 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
935 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
936 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
937 ///
938 /// # Returns
939 ///
940 /// * `DebertaTokenClassificationOutput` containing:
941 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*)
942 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
943 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
944 ///
945 /// # Example
946 ///
947 /// ```no_run
948 /// # use rust_bert::deberta::{DebertaForTokenClassification, DebertaConfig};
949 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
950 /// # use rust_bert::Config;
951 /// # use std::path::Path;
952 /// # let config_path = Path::new("path/to/config.json");
953 /// # let device = Device::Cpu;
954 /// # let vs = nn::VarStore::new(device);
955 /// # let config = DebertaConfig::from_file(config_path);
956 /// # let model = DebertaForTokenClassification::new(&vs.root(), &config).unwrap();
957 /// let (batch_size, sequence_length) = (64, 128);
958 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
959 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
960 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
961 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
962 /// .expand(&[batch_size, sequence_length], true);
963 ///
964 /// let model_output = no_grad(|| {
965 /// model.forward_t(
966 /// Some(&input_tensor),
967 /// Some(&mask),
968 /// Some(&token_type_ids),
969 /// Some(&position_ids),
970 /// None,
971 /// false,
972 /// )
973 /// });
974 /// ```
975 pub fn forward_t(
976 &self,
977 input_ids: Option<&Tensor>,
978 attention_mask: Option<&Tensor>,
979 token_type_ids: Option<&Tensor>,
980 position_ids: Option<&Tensor>,
981 input_embeds: Option<&Tensor>,
982 train: bool,
983 ) -> Result<DebertaTokenClassificationOutput, RustBertError> {
984 let base_model_output = self.deberta.forward_t(
985 input_ids,
986 attention_mask,
987 token_type_ids,
988 position_ids,
989 input_embeds,
990 train,
991 )?;
992
993 let logits = base_model_output
994 .hidden_state
995 .apply_t(&self.dropout, train)
996 .apply(&self.classifier);
997
998 Ok(DebertaTokenClassificationOutput {
999 logits,
1000 all_hidden_states: base_model_output.all_hidden_states,
1001 all_attentions: base_model_output.all_attentions,
1002 })
1003 }
1004}
1005
1006/// # DeBERTa for question answering
1007/// Extractive question-answering model based on a DeBERTa language model. Identifies the segment of a context that answers a provided question.
1008/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
1009/// See the question answering pipeline (also provided in this crate) for more details.
1010/// It is made of the following blocks:
1011/// - `deberta`: Base DeBERTa model
1012/// - `qa_outputs`: Linear layer for question answering
1013pub struct DebertaForQuestionAnswering {
1014 deberta: DebertaModel,
1015 qa_outputs: nn::Linear,
1016}
1017
1018impl DebertaForQuestionAnswering {
1019 /// Build a new `DebertaForQuestionAnswering`
1020 ///
1021 /// # Arguments
1022 ///
1023 /// * `p` - Variable store path for the root of the BertForQuestionAnswering model
1024 /// * `config` - `DebertaConfig` object defining the model architecture
1025 ///
1026 /// # Example
1027 ///
1028 /// ```no_run
1029 /// use rust_bert::deberta::{DebertaConfig, DebertaForQuestionAnswering};
1030 /// use rust_bert::Config;
1031 /// use std::path::Path;
1032 /// use tch::{nn, Device};
1033 ///
1034 /// let config_path = Path::new("path/to/config.json");
1035 /// let device = Device::Cpu;
1036 /// let p = nn::VarStore::new(device);
1037 /// let config = DebertaConfig::from_file(config_path);
1038 /// let model = DebertaForQuestionAnswering::new(&p.root(), &config);
1039 /// ```
1040 pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForQuestionAnswering
1041 where
1042 P: Borrow<nn::Path<'p>>,
1043 {
1044 let p = p.borrow();
1045
1046 let deberta = DebertaModel::new(p / "deberta", config);
1047 let num_labels = 2;
1048 let qa_outputs = nn::linear(
1049 p / "qa_outputs",
1050 config.hidden_size,
1051 num_labels,
1052 Default::default(),
1053 );
1054
1055 DebertaForQuestionAnswering {
1056 deberta,
1057 qa_outputs,
1058 }
1059 }
1060
1061 /// Forward pass through the model
1062 ///
1063 /// # Arguments
1064 ///
1065 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1066 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1067 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1068 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1069 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1070 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1071 ///
1072 /// # Returns
1073 ///
1074 /// * `DebertaQuestionAnsweringOutput` containing:
1075 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
1076 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
1077 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1078 /// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1079 ///
1080 /// # Example
1081 ///
1082 /// ```no_run
1083 /// # use rust_bert::deberta::{DebertaForQuestionAnswering, DebertaConfig};
1084 /// # use tch::{nn, Device, Tensor, no_grad};
1085 /// # use rust_bert::Config;
1086 /// # use std::path::Path;
1087 /// # use tch::kind::Kind::Int64;
1088 /// # let config_path = Path::new("path/to/config.json");
1089 /// # let device = Device::Cpu;
1090 /// # let vs = nn::VarStore::new(device);
1091 /// # let config = DebertaConfig::from_file(config_path);
1092 /// # let model = DebertaForQuestionAnswering::new(&vs.root(), &config);
1093 /// let (batch_size, sequence_length) = (64, 128);
1094 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1095 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1096 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1097 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1098 /// .expand(&[batch_size, sequence_length], true);
1099 ///
1100 /// let model_output = no_grad(|| {
1101 /// model.forward_t(
1102 /// Some(&input_tensor),
1103 /// Some(&mask),
1104 /// Some(&token_type_ids),
1105 /// Some(&position_ids),
1106 /// None,
1107 /// false,
1108 /// )
1109 /// });
1110 /// ```
1111 pub fn forward_t(
1112 &self,
1113 input_ids: Option<&Tensor>,
1114 attention_mask: Option<&Tensor>,
1115 token_type_ids: Option<&Tensor>,
1116 position_ids: Option<&Tensor>,
1117 input_embeds: Option<&Tensor>,
1118 train: bool,
1119 ) -> Result<DebertaQuestionAnsweringOutput, RustBertError> {
1120 let base_model_output = self.deberta.forward_t(
1121 input_ids,
1122 attention_mask,
1123 token_type_ids,
1124 position_ids,
1125 input_embeds,
1126 train,
1127 )?;
1128
1129 let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1130 let logits = sequence_output.split(1, -1);
1131 let (start_logits, end_logits) = (&logits[0], &logits[1]);
1132 let start_logits = start_logits.squeeze_dim(-1);
1133 let end_logits = end_logits.squeeze_dim(-1);
1134
1135 Ok(DebertaQuestionAnsweringOutput {
1136 start_logits,
1137 end_logits,
1138 all_hidden_states: base_model_output.all_hidden_states,
1139 all_attentions: base_model_output.all_attentions,
1140 })
1141 }
1142}
1143
1144/// Container for the DeBERTa model output.
1145pub type DebertaModelOutput = DebertaEncoderOutput;
1146
1147/// Container for the DeBERTa masked LM model output.
1148pub struct DebertaMaskedLMOutput {
1149 /// Logits for the vocabulary items at each sequence position
1150 pub logits: Tensor,
1151 /// Hidden states for all intermediate layers
1152 pub all_hidden_states: Option<Vec<Tensor>>,
1153 /// Attention weights for all intermediate layers
1154 pub all_attentions: Option<Vec<Tensor>>,
1155}
1156
1157/// Container for the DeBERTa sequence classification model output.
1158pub type DebertaSequenceClassificationOutput = BertSequenceClassificationOutput;
1159
1160/// Container for the DeBERTa token classification model output.
1161pub type DebertaTokenClassificationOutput = BertTokenClassificationOutput;
1162
1163/// Container for the DeBERTa question answering model output.
1164pub type DebertaQuestionAnsweringOutput = BertQuestionAnsweringOutput;