rust_bert/models/deberta/
deberta_model.rs

1// Copyright 2020, Microsoft and the HuggingFace Inc. team.
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::bert::{
14    BertQuestionAnsweringOutput, BertSequenceClassificationOutput, BertTokenClassificationOutput,
15};
16use crate::common::activations::TensorFunction;
17use crate::common::dropout::{Dropout, XDropout};
18use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
19use crate::common::kind::get_min;
20use crate::deberta::embeddings::DebertaEmbeddings;
21use crate::deberta::encoder::{DebertaEncoder, DebertaEncoderOutput};
22use crate::{Activation, Config, RustBertError};
23use serde::de::{SeqAccess, Visitor};
24use serde::{de, Deserialize, Deserializer, Serialize};
25use std::borrow::Borrow;
26use std::collections::HashMap;
27use std::fmt;
28use std::str::FromStr;
29use tch::nn::{Init, Module, ModuleT};
30use tch::{nn, Kind, Tensor};
31
32/// # DeBERTa Pretrained model weight files
33pub struct DebertaModelResources;
34
35/// # DeBERTa Pretrained model config files
36pub struct DebertaConfigResources;
37
38/// # DeBERTa Pretrained model vocab files
39pub struct DebertaVocabResources;
40
41/// # DeBERTa Pretrained model merges files
42pub struct DebertaMergesResources;
43
44impl DebertaModelResources {
45    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
46    pub const DEBERTA_BASE: (&'static str, &'static str) = (
47        "deberta-base/model",
48        "https://huggingface.co/microsoft/deberta-base/resolve/main/rust_model.ot",
49    );
50    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
51    pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
52        "deberta-base-mnli/model",
53        "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/rust_model.ot",
54    );
55}
56
57impl DebertaConfigResources {
58    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
59    pub const DEBERTA_BASE: (&'static str, &'static str) = (
60        "deberta-base/config",
61        "https://huggingface.co/microsoft/deberta-base/resolve/main/config.json",
62    );
63    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
64    pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
65        "deberta-base-mnli/config",
66        "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/config.json",
67    );
68}
69
70impl DebertaVocabResources {
71    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
72    pub const DEBERTA_BASE: (&'static str, &'static str) = (
73        "deberta-base/vocab",
74        "https://huggingface.co/microsoft/deberta-base/resolve/main/vocab.json",
75    );
76    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
77    pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
78        "deberta-base-mnli/vocab",
79        "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/vocab.json",
80    );
81}
82
83impl DebertaMergesResources {
84    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base>. Modified with conversion to C-array format.
85    pub const DEBERTA_BASE: (&'static str, &'static str) = (
86        "deberta-base/merges",
87        "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
88    );
89    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-base-mnli>. Modified with conversion to C-array format.
90    pub const DEBERTA_BASE_MNLI: (&'static str, &'static str) = (
91        "deberta-base-mnli/merges",
92        "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/merges.txt",
93    );
94}
95
96#[allow(non_camel_case_types)]
97#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq, Eq)]
98/// # Position attention type to use for the DeBERTa model.
99pub enum PositionAttentionType {
100    p2c,
101    c2p,
102    p2p,
103}
104
105impl FromStr for PositionAttentionType {
106    type Err = RustBertError;
107
108    fn from_str(s: &str) -> Result<Self, Self::Err> {
109        match s {
110            "p2c" => Ok(PositionAttentionType::p2c),
111            "c2p" => Ok(PositionAttentionType::c2p),
112            "p2p" => Ok(PositionAttentionType::p2p),
113            _ => Err(RustBertError::InvalidConfigurationError(format!(
114                "Position attention type `{s}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
115            ))),
116        }
117    }
118}
119
120#[allow(non_camel_case_types)]
121#[derive(Clone, Debug, Serialize, Deserialize, Default)]
122pub struct PositionAttentionTypes {
123    types: Vec<PositionAttentionType>,
124}
125
126impl FromStr for PositionAttentionTypes {
127    type Err = RustBertError;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        let types = s
131            .to_lowercase()
132            .split('|')
133            .map(PositionAttentionType::from_str)
134            .collect::<Result<Vec<_>, _>>()?;
135        Ok(PositionAttentionTypes { types })
136    }
137}
138
139impl PositionAttentionTypes {
140    pub fn has_type(&self, attention_type: PositionAttentionType) -> bool {
141        self.types
142            .iter()
143            .any(|self_type| *self_type == attention_type)
144    }
145
146    pub fn len(&self) -> usize {
147        self.types.len()
148    }
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152/// # DeBERTa model configuration
153/// Defines the DeBERTa model architecture (e.g. number of layers, hidden layer size, label mapping...)
154pub struct DebertaConfig {
155    pub hidden_act: Activation,
156    pub attention_probs_dropout_prob: f64,
157    pub hidden_dropout_prob: f64,
158    pub hidden_size: i64,
159    pub initializer_range: f64,
160    pub intermediate_size: i64,
161    pub max_position_embeddings: i64,
162    pub num_attention_heads: i64,
163    pub num_hidden_layers: i64,
164    pub type_vocab_size: i64,
165    pub vocab_size: i64,
166    pub position_biased_input: Option<bool>,
167    #[serde(default, deserialize_with = "deserialize_attention_type")]
168    pub pos_att_type: Option<PositionAttentionTypes>,
169    pub pooler_dropout: Option<f64>,
170    pub pooler_hidden_act: Option<Activation>,
171    pub pooler_hidden_size: Option<i64>,
172    pub layer_norm_eps: Option<f64>,
173    pub pad_token_id: Option<i64>,
174    pub relative_attention: Option<bool>,
175    pub max_relative_positions: Option<i64>,
176    pub embedding_size: Option<i64>,
177    pub talking_head: Option<bool>,
178    pub output_hidden_states: Option<bool>,
179    pub output_attentions: Option<bool>,
180    pub classifier_dropout: Option<f64>,
181    pub is_decoder: Option<bool>,
182    pub id2label: Option<HashMap<i64, String>>,
183    pub label2id: Option<HashMap<String, i64>>,
184    pub share_att_key: Option<bool>,
185    pub position_buckets: Option<i64>,
186}
187
188impl Default for DebertaConfig {
189    fn default() -> Self {
190        DebertaConfig {
191            hidden_act: Activation::gelu,
192            attention_probs_dropout_prob: 0.1,
193            hidden_dropout_prob: 0.1,
194            hidden_size: 768,
195            initializer_range: 0.02,
196            intermediate_size: 3072,
197            max_position_embeddings: 512,
198            num_attention_heads: 12,
199            num_hidden_layers: 12,
200            type_vocab_size: 0,
201            vocab_size: 50265,
202            position_biased_input: Some(true),
203            pos_att_type: None,
204            pooler_dropout: Some(0.0),
205            pooler_hidden_act: Some(Activation::gelu),
206            pooler_hidden_size: Some(768),
207            layer_norm_eps: Some(1e-7),
208            pad_token_id: Some(0),
209            relative_attention: Some(false),
210            max_relative_positions: Some(-1),
211            embedding_size: None,
212            talking_head: None,
213            output_hidden_states: None,
214            output_attentions: None,
215            classifier_dropout: None,
216            is_decoder: None,
217            id2label: None,
218            label2id: None,
219            share_att_key: None,
220            position_buckets: None,
221        }
222    }
223}
224
225pub fn deserialize_attention_type<'de, D>(
226    deserializer: D,
227) -> Result<Option<PositionAttentionTypes>, D::Error>
228where
229    D: Deserializer<'de>,
230{
231    struct AttentionTypeVisitor;
232
233    impl<'de> Visitor<'de> for AttentionTypeVisitor {
234        type Value = PositionAttentionTypes;
235
236        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
237            formatter.write_str("null, string or sequence")
238        }
239
240        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
241        where
242            E: de::Error,
243        {
244            Ok(FromStr::from_str(value).unwrap())
245        }
246
247        fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
248        where
249            S: SeqAccess<'de>,
250        {
251            let mut types = vec![];
252            while let Some(attention_type) = seq.next_element::<String>()? {
253                types.push(FromStr::from_str(attention_type.as_str()).unwrap())
254            }
255            Ok(PositionAttentionTypes { types })
256        }
257    }
258
259    deserializer.deserialize_any(AttentionTypeVisitor).map(Some)
260}
261
262impl Config for DebertaConfig {}
263
264pub fn x_softmax(input: &Tensor, mask: &Tensor, dim: i64) -> Tensor {
265    let inverse_mask = ((1 - mask) as Tensor).to_kind(Kind::Bool);
266    input
267        .masked_fill(&inverse_mask, get_min(input.kind()).unwrap())
268        .softmax(dim, input.kind())
269        .masked_fill(&inverse_mask, 0.0)
270}
271
272pub trait BaseDebertaLayerNorm {
273    fn new<'p, P>(p: P, size: i64, variance_epsilon: f64) -> Self
274    where
275        P: Borrow<nn::Path<'p>>;
276}
277
278#[derive(Debug)]
279pub struct DebertaLayerNorm {
280    weight: Tensor,
281    bias: Tensor,
282    variance_epsilon: f64,
283}
284
285impl BaseDebertaLayerNorm for DebertaLayerNorm {
286    fn new<'p, P>(p: P, size: i64, variance_epsilon: f64) -> DebertaLayerNorm
287    where
288        P: Borrow<nn::Path<'p>>,
289    {
290        let p = p.borrow();
291        let weight = p.var("weight", &[size], Init::Const(1.0));
292        let bias = p.var("bias", &[size], Init::Const(0.0));
293        DebertaLayerNorm {
294            weight,
295            bias,
296            variance_epsilon,
297        }
298    }
299}
300
301impl Module for DebertaLayerNorm {
302    fn forward(&self, hidden_states: &Tensor) -> Tensor {
303        let input_type = hidden_states.kind();
304        let hidden_states = hidden_states.to_kind(Kind::Float);
305        let mean = hidden_states.mean_dim([-1].as_slice(), true, hidden_states.kind());
306        let variance = (&hidden_states - &mean).pow_tensor_scalar(2.0).mean_dim(
307            [-1].as_slice(),
308            true,
309            hidden_states.kind(),
310        );
311        let hidden_states = (hidden_states - mean)
312            / (variance + self.variance_epsilon)
313                .sqrt()
314                .to_kind(input_type);
315        &self.weight * hidden_states + &self.bias
316    }
317}
318
319/// # DeBERTa Base model
320/// Base architecture for DeBERTa models. Task-specific models will be built from this common base model
321/// It is made of the following blocks:
322/// - `embeddings`: `DeBERTa` embeddings
323/// - `encoder`: `DeBERTaEncoder` (transformer) made of a vector of layers.
324pub struct DebertaModel {
325    embeddings: DebertaEmbeddings,
326    encoder: DebertaEncoder,
327}
328
329impl DebertaModel {
330    /// Build a new `DebertaModel`
331    ///
332    /// # Arguments
333    ///
334    /// * `p` - Variable store path for the root of the BERT model
335    /// * `config` - `DebertaConfig` object defining the model architecture and decoder status
336    ///
337    /// # Example
338    ///
339    /// ```no_run
340    /// use rust_bert::deberta::{DebertaConfig, DebertaModel};
341    /// use rust_bert::Config;
342    /// use std::path::Path;
343    /// use tch::{nn, Device};
344    ///
345    /// let config_path = Path::new("path/to/config.json");
346    /// let device = Device::Cpu;
347    /// let p = nn::VarStore::new(device);
348    /// let config = DebertaConfig::from_file(config_path);
349    /// let model: DebertaModel = DebertaModel::new(&p.root() / "deberta", &config);
350    /// ```
351    pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaModel
352    where
353        P: Borrow<nn::Path<'p>>,
354    {
355        let p = p.borrow();
356
357        let embeddings = DebertaEmbeddings::new(p / "embeddings", config);
358        let encoder = DebertaEncoder::new(p / "encoder", config);
359
360        DebertaModel {
361            embeddings,
362            encoder,
363        }
364    }
365
366    /// Forward pass through the model
367    ///
368    /// # Arguments
369    ///
370    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
371    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
372    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
373    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
374    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
375    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
376    ///
377    /// # Returns
378    ///
379    /// * `DebertaOutput` containing:
380    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
381    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
382    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
383    ///
384    /// # Example
385    ///
386    /// ```no_run
387    /// # use rust_bert::deberta::{DebertaModel, DebertaConfig};
388    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
389    /// # use rust_bert::Config;
390    /// # use std::path::Path;
391    /// # let config_path = Path::new("path/to/config.json");
392    /// # let device = Device::Cpu;
393    /// # let vs = nn::VarStore::new(device);
394    /// # let config = DebertaConfig::from_file(config_path);
395    /// # let model = DebertaModel::new(&vs.root(), &config);
396    /// let (batch_size, sequence_length) = (64, 128);
397    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
398    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Kind::Int64, device));
399    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
400    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
401    ///     .expand(&[batch_size, sequence_length], true);
402    ///
403    /// let model_output = no_grad(|| {
404    ///     model
405    ///         .forward_t(
406    ///             Some(&input_tensor),
407    ///             Some(&attention_mask),
408    ///             Some(&token_type_ids),
409    ///             Some(&position_ids),
410    ///             None,
411    ///             false,
412    ///         )
413    ///         .unwrap()
414    /// });
415    /// ```
416    pub fn forward_t(
417        &self,
418        input_ids: Option<&Tensor>,
419        attention_mask: Option<&Tensor>,
420        token_type_ids: Option<&Tensor>,
421        position_ids: Option<&Tensor>,
422        input_embeds: Option<&Tensor>,
423        train: bool,
424    ) -> Result<DebertaModelOutput, RustBertError> {
425        let (input_shape, device) =
426            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
427
428        let calc_attention_mask = if attention_mask.is_none() {
429            Some(Tensor::ones(input_shape.as_slice(), (Kind::Bool, device)))
430        } else {
431            None
432        };
433
434        let attention_mask =
435            attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
436
437        let embedding_output = self.embeddings.forward_t(
438            input_ids,
439            token_type_ids,
440            position_ids,
441            attention_mask,
442            input_embeds,
443            train,
444        )?;
445
446        let encoder_output =
447            self.encoder
448                .forward_t(&embedding_output, attention_mask, None, None, train)?;
449
450        Ok(encoder_output)
451    }
452}
453
454#[derive(Debug)]
455struct DebertaPredictionHeadTransform {
456    dense: nn::Linear,
457    activation: TensorFunction,
458    layer_norm: nn::LayerNorm,
459}
460
461impl DebertaPredictionHeadTransform {
462    pub fn new<'p, P>(
463        p: P,
464        config: &DebertaConfig,
465        transform_bias: bool,
466    ) -> DebertaPredictionHeadTransform
467    where
468        P: Borrow<nn::Path<'p>>,
469    {
470        let p = p.borrow();
471
472        let dense = nn::linear(
473            p / "dense",
474            config.hidden_size,
475            config.hidden_size,
476            nn::LinearConfig {
477                bias: transform_bias,
478                ..Default::default()
479            },
480        );
481        let activation = config.hidden_act.get_function();
482        let layer_norm_config = nn::LayerNormConfig {
483            eps: 1e-7,
484            ..Default::default()
485        };
486        let layer_norm =
487            nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
488
489        DebertaPredictionHeadTransform {
490            dense,
491            activation,
492            layer_norm,
493        }
494    }
495}
496
497impl Module for DebertaPredictionHeadTransform {
498    fn forward(&self, hidden_states: &Tensor) -> Tensor {
499        self.activation.get_fn()(&hidden_states.apply(&self.dense)).apply(&self.layer_norm)
500    }
501}
502
503#[derive(Debug)]
504pub(crate) struct DebertaLMPredictionHead {
505    transform: DebertaPredictionHeadTransform,
506    decoder: nn::Linear,
507}
508
509impl DebertaLMPredictionHead {
510    pub fn new<'p, P>(p: P, config: &DebertaConfig, transform_bias: bool) -> DebertaLMPredictionHead
511    where
512        P: Borrow<nn::Path<'p>>,
513    {
514        let p = p.borrow();
515
516        let transform =
517            DebertaPredictionHeadTransform::new(p / "transform", config, transform_bias);
518        let decoder = nn::linear(
519            p / "decoder",
520            config.hidden_size,
521            config.vocab_size,
522            Default::default(),
523        );
524
525        DebertaLMPredictionHead { transform, decoder }
526    }
527}
528
529impl Module for DebertaLMPredictionHead {
530    fn forward(&self, hidden_states: &Tensor) -> Tensor {
531        hidden_states.apply(&self.transform).apply(&self.decoder)
532    }
533}
534
535/// # DeBERTa for masked language model
536/// Base DeBERTa model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
537/// It is made of the following blocks:
538/// - `deberta`: Base DeBERTa model
539/// - `cls`: LM prediction head
540pub struct DebertaForMaskedLM {
541    deberta: DebertaModel,
542    cls: DebertaLMPredictionHead,
543}
544
545impl DebertaForMaskedLM {
546    /// Build a new `DebertaForMaskedLM`
547    ///
548    /// # Arguments
549    ///
550    /// * `p` - Variable store path for the root of the BertForMaskedLM model
551    /// * `config` - `DebertaConfig` object defining the model architecture and vocab size
552    ///
553    /// # Example
554    ///
555    /// ```no_run
556    /// use rust_bert::deberta::{DebertaConfig, DebertaForMaskedLM};
557    /// use rust_bert::Config;
558    /// use std::path::Path;
559    /// use tch::{nn, Device};
560    ///
561    /// let config_path = Path::new("path/to/config.json");
562    /// let device = Device::Cpu;
563    /// let p = nn::VarStore::new(device);
564    /// let config = DebertaConfig::from_file(config_path);
565    /// let model = DebertaForMaskedLM::new(&p.root(), &config);
566    /// ```
567    pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForMaskedLM
568    where
569        P: Borrow<nn::Path<'p>>,
570    {
571        let p = p.borrow();
572
573        let deberta = DebertaModel::new(p / "deberta", config);
574        let cls = DebertaLMPredictionHead::new(p.sub("cls").sub("predictions"), config, false);
575
576        DebertaForMaskedLM { deberta, cls }
577    }
578
579    /// Forward pass through the model
580    ///
581    /// # Arguments
582    ///
583    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
584    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
585    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
586    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
587    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
588    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
589    ///
590    /// # Returns
591    ///
592    /// * `DebertaMaskedLMOutput` containing:
593    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
594    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
595    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
596    ///
597    /// # Example
598    ///
599    /// ```no_run
600    /// # use rust_bert::deberta::{DebertaForMaskedLM, DebertaConfig};
601    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
602    /// # use rust_bert::Config;
603    /// # use std::path::Path;
604    /// # let config_path = Path::new("path/to/config.json");
605    /// # let device = Device::Cpu;
606    /// # let vs = nn::VarStore::new(device);
607    /// # let config = DebertaConfig::from_file(config_path);
608    /// # let model = DebertaForMaskedLM::new(&vs.root(), &config);
609    /// let (batch_size, sequence_length) = (64, 128);
610    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
611    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
612    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
613    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
614    ///     .expand(&[batch_size, sequence_length], true);
615    ///
616    /// let model_output = no_grad(|| {
617    ///     model.forward_t(
618    ///         Some(&input_tensor),
619    ///         Some(&mask),
620    ///         Some(&token_type_ids),
621    ///         Some(&position_ids),
622    ///         None,
623    ///         false,
624    ///     )
625    /// });
626    /// ```
627    pub fn forward_t(
628        &self,
629        input_ids: Option<&Tensor>,
630        attention_mask: Option<&Tensor>,
631        token_type_ids: Option<&Tensor>,
632        position_ids: Option<&Tensor>,
633        input_embeds: Option<&Tensor>,
634        train: bool,
635    ) -> Result<DebertaMaskedLMOutput, RustBertError> {
636        let model_outputs = self.deberta.forward_t(
637            input_ids,
638            attention_mask,
639            token_type_ids,
640            position_ids,
641            input_embeds,
642            train,
643        )?;
644
645        let logits = model_outputs.hidden_state.apply(&self.cls);
646        Ok(DebertaMaskedLMOutput {
647            logits,
648            all_hidden_states: model_outputs.all_hidden_states,
649            all_attentions: model_outputs.all_attentions,
650        })
651    }
652}
653
654#[derive(Debug)]
655pub struct ContextPooler {
656    dense: nn::Linear,
657    dropout: XDropout,
658    activation: TensorFunction,
659    pub output_dim: i64,
660}
661
662impl ContextPooler {
663    pub fn new<'p, P>(p: P, config: &DebertaConfig) -> ContextPooler
664    where
665        P: Borrow<nn::Path<'p>>,
666    {
667        let p = p.borrow();
668        let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
669
670        let dense = nn::linear(
671            p / "dense",
672            pooler_hidden_size,
673            pooler_hidden_size,
674            Default::default(),
675        );
676        let dropout = XDropout::new(config.pooler_dropout.unwrap_or(0.0));
677        let activation = config
678            .pooler_hidden_act
679            .unwrap_or(Activation::gelu)
680            .get_function();
681
682        ContextPooler {
683            dense,
684            dropout,
685            activation,
686            output_dim: pooler_hidden_size,
687        }
688    }
689}
690
691impl ModuleT for ContextPooler {
692    fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
693        self.activation.get_fn()(
694            &hidden_states
695                .select(1, 0)
696                .apply_t(&self.dropout, train)
697                .apply(&self.dense),
698        )
699    }
700}
701
702/// # DeBERTa for sequence classification
703/// Base DeBERTa model with a classifier head to perform sentence or document-level classification
704/// It is made of the following blocks:
705/// - `deberta`: Base BertModel
706/// - `classifier`: BERT linear layer for classification
707pub struct DebertaForSequenceClassification {
708    deberta: DebertaModel,
709    pooler: ContextPooler,
710    classifier: nn::Linear,
711    dropout: XDropout,
712}
713
714impl DebertaForSequenceClassification {
715    /// Build a new `DebertaForSequenceClassification`
716    ///
717    /// # Arguments
718    ///
719    /// * `p` - Variable store path for the root of the DebertaForSequenceClassification model
720    /// * `config` - `DebertaConfig` object defining the model architecture and number of classes
721    ///
722    /// # Example
723    ///
724    /// ```no_run
725    /// use rust_bert::deberta::{DebertaConfig, DebertaForSequenceClassification};
726    /// use rust_bert::Config;
727    /// use std::path::Path;
728    /// use tch::{nn, Device};
729    ///
730    /// let config_path = Path::new("path/to/config.json");
731    /// let device = Device::Cpu;
732    /// let p = nn::VarStore::new(device);
733    /// let config = DebertaConfig::from_file(config_path);
734    /// let model = DebertaForSequenceClassification::new(&p.root(), &config).unwrap();
735    /// ```
736    pub fn new<'p, P>(
737        p: P,
738        config: &DebertaConfig,
739    ) -> Result<DebertaForSequenceClassification, RustBertError>
740    where
741        P: Borrow<nn::Path<'p>>,
742    {
743        let p = p.borrow();
744
745        let deberta = DebertaModel::new(p / "deberta", config);
746        let pooler = ContextPooler::new(p / "pooler", config);
747        let dropout = XDropout::new(
748            config
749                .classifier_dropout
750                .unwrap_or(config.hidden_dropout_prob),
751        );
752
753        let num_labels = config
754            .id2label
755            .as_ref()
756            .ok_or_else(|| {
757                RustBertError::InvalidConfigurationError(
758                    "num_labels not provided in configuration".to_string(),
759                )
760            })?
761            .len() as i64;
762
763        let classifier = nn::linear(
764            p / "classifier",
765            pooler.output_dim,
766            num_labels,
767            Default::default(),
768        );
769
770        Ok(DebertaForSequenceClassification {
771            deberta,
772            pooler,
773            classifier,
774            dropout,
775        })
776    }
777
778    /// Forward pass through the model
779    ///
780    /// # Arguments
781    ///
782    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
783    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
784    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
785    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
786    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
787    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
788    ///
789    /// # Returns
790    ///
791    /// * `DebertaSequenceClassificationOutput` containing:
792    ///   - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
793    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
794    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
795    ///
796    /// # Example
797    ///
798    /// ```no_run
799    /// # use rust_bert::deberta::{DebertaForSequenceClassification, DebertaConfig};
800    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
801    /// # use rust_bert::Config;
802    /// # use std::path::Path;
803    /// # let config_path = Path::new("path/to/config.json");
804    /// # let device = Device::Cpu;
805    /// # let vs = nn::VarStore::new(device);
806    /// # let config = DebertaConfig::from_file(config_path);
807    /// # let model = DebertaForSequenceClassification::new(&vs.root(), &config).unwrap();;
808    /// let (batch_size, sequence_length) = (64, 128);
809    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
810    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
811    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
812    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
813    ///     .expand(&[batch_size, sequence_length], true);
814    ///
815    /// let model_output = no_grad(|| {
816    ///     model.forward_t(
817    ///         Some(&input_tensor),
818    ///         Some(&mask),
819    ///         Some(&token_type_ids),
820    ///         Some(&position_ids),
821    ///         None,
822    ///         false,
823    ///     )
824    /// });
825    /// ```
826    pub fn forward_t(
827        &self,
828        input_ids: Option<&Tensor>,
829        attention_mask: Option<&Tensor>,
830        token_type_ids: Option<&Tensor>,
831        position_ids: Option<&Tensor>,
832        input_embeds: Option<&Tensor>,
833        train: bool,
834    ) -> Result<DebertaSequenceClassificationOutput, RustBertError> {
835        let base_model_output = self.deberta.forward_t(
836            input_ids,
837            attention_mask,
838            token_type_ids,
839            position_ids,
840            input_embeds,
841            train,
842        )?;
843
844        let logits = base_model_output
845            .hidden_state
846            .apply_t(&self.pooler, train)
847            .apply_t(&self.dropout, train)
848            .apply(&self.classifier);
849
850        Ok(DebertaSequenceClassificationOutput {
851            logits,
852            all_hidden_states: base_model_output.all_hidden_states,
853            all_attentions: base_model_output.all_attentions,
854        })
855    }
856}
857
858/// # DeBERTa for token classification (e.g. NER, POS)
859/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
860/// not necessarily aligned with words in the sentence.
861/// It is made of the following blocks:
862/// - `deberta`: Base DeBERTa
863/// - `dropout`: Dropout layer before the last token-level predictions layer
864/// - `classifier`: Linear layer for token classification
865pub struct DebertaForTokenClassification {
866    deberta: DebertaModel,
867    dropout: Dropout,
868    classifier: nn::Linear,
869}
870
871impl DebertaForTokenClassification {
872    /// Build a new `DebertaForTokenClassification`
873    ///
874    /// # Arguments
875    ///
876    /// * `p` - Variable store path for the root of the Deberta model
877    /// * `config` - `DebertaConfig` object defining the model architecture
878    ///
879    /// # Example
880    ///
881    /// ```no_run
882    /// use rust_bert::deberta::{DebertaConfig, DebertaForTokenClassification};
883    /// use rust_bert::Config;
884    /// use std::path::Path;
885    /// use tch::{nn, Device};
886    ///
887    /// let config_path = Path::new("path/to/config.json");
888    /// let device = Device::Cpu;
889    /// let p = nn::VarStore::new(device);
890    /// let config = DebertaConfig::from_file(config_path);
891    /// let model = DebertaForTokenClassification::new(&p.root(), &config).unwrap();
892    /// ```
893    pub fn new<'p, P>(
894        p: P,
895        config: &DebertaConfig,
896    ) -> Result<DebertaForTokenClassification, RustBertError>
897    where
898        P: Borrow<nn::Path<'p>>,
899    {
900        let p = p.borrow();
901
902        let deberta = DebertaModel::new(p / "deberta", config);
903        let dropout = Dropout::new(config.hidden_dropout_prob);
904        let num_labels = config
905            .id2label
906            .as_ref()
907            .ok_or_else(|| {
908                RustBertError::InvalidConfigurationError(
909                    "num_labels not provided in configuration".to_string(),
910                )
911            })?
912            .len() as i64;
913        let classifier = nn::linear(
914            p / "classifier",
915            config.hidden_size,
916            num_labels,
917            Default::default(),
918        );
919
920        Ok(DebertaForTokenClassification {
921            deberta,
922            dropout,
923            classifier,
924        })
925    }
926
927    /// Forward pass through the model
928    ///
929    /// # Arguments
930    ///
931    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
932    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
933    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
934    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
935    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
936    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
937    ///
938    /// # Returns
939    ///
940    /// * `DebertaTokenClassificationOutput` containing:
941    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*)
942    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
943    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
944    ///
945    /// # Example
946    ///
947    /// ```no_run
948    /// # use rust_bert::deberta::{DebertaForTokenClassification, DebertaConfig};
949    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
950    /// # use rust_bert::Config;
951    /// # use std::path::Path;
952    /// # let config_path = Path::new("path/to/config.json");
953    /// # let device = Device::Cpu;
954    /// # let vs = nn::VarStore::new(device);
955    /// # let config = DebertaConfig::from_file(config_path);
956    /// # let model = DebertaForTokenClassification::new(&vs.root(), &config).unwrap();
957    /// let (batch_size, sequence_length) = (64, 128);
958    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
959    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
960    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
961    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
962    ///     .expand(&[batch_size, sequence_length], true);
963    ///
964    /// let model_output = no_grad(|| {
965    ///     model.forward_t(
966    ///         Some(&input_tensor),
967    ///         Some(&mask),
968    ///         Some(&token_type_ids),
969    ///         Some(&position_ids),
970    ///         None,
971    ///         false,
972    ///     )
973    /// });
974    /// ```
975    pub fn forward_t(
976        &self,
977        input_ids: Option<&Tensor>,
978        attention_mask: Option<&Tensor>,
979        token_type_ids: Option<&Tensor>,
980        position_ids: Option<&Tensor>,
981        input_embeds: Option<&Tensor>,
982        train: bool,
983    ) -> Result<DebertaTokenClassificationOutput, RustBertError> {
984        let base_model_output = self.deberta.forward_t(
985            input_ids,
986            attention_mask,
987            token_type_ids,
988            position_ids,
989            input_embeds,
990            train,
991        )?;
992
993        let logits = base_model_output
994            .hidden_state
995            .apply_t(&self.dropout, train)
996            .apply(&self.classifier);
997
998        Ok(DebertaTokenClassificationOutput {
999            logits,
1000            all_hidden_states: base_model_output.all_hidden_states,
1001            all_attentions: base_model_output.all_attentions,
1002        })
1003    }
1004}
1005
1006/// # DeBERTa for question answering
1007/// Extractive question-answering model based on a DeBERTa language model. Identifies the segment of a context that answers a provided question.
1008/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
1009/// See the question answering pipeline (also provided in this crate) for more details.
1010/// It is made of the following blocks:
1011/// - `deberta`: Base DeBERTa model
1012/// - `qa_outputs`: Linear layer for question answering
1013pub struct DebertaForQuestionAnswering {
1014    deberta: DebertaModel,
1015    qa_outputs: nn::Linear,
1016}
1017
1018impl DebertaForQuestionAnswering {
1019    /// Build a new `DebertaForQuestionAnswering`
1020    ///
1021    /// # Arguments
1022    ///
1023    /// * `p` - Variable store path for the root of the BertForQuestionAnswering model
1024    /// * `config` - `DebertaConfig` object defining the model architecture
1025    ///
1026    /// # Example
1027    ///
1028    /// ```no_run
1029    /// use rust_bert::deberta::{DebertaConfig, DebertaForQuestionAnswering};
1030    /// use rust_bert::Config;
1031    /// use std::path::Path;
1032    /// use tch::{nn, Device};
1033    ///
1034    /// let config_path = Path::new("path/to/config.json");
1035    /// let device = Device::Cpu;
1036    /// let p = nn::VarStore::new(device);
1037    /// let config = DebertaConfig::from_file(config_path);
1038    /// let model = DebertaForQuestionAnswering::new(&p.root(), &config);
1039    /// ```
1040    pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForQuestionAnswering
1041    where
1042        P: Borrow<nn::Path<'p>>,
1043    {
1044        let p = p.borrow();
1045
1046        let deberta = DebertaModel::new(p / "deberta", config);
1047        let num_labels = 2;
1048        let qa_outputs = nn::linear(
1049            p / "qa_outputs",
1050            config.hidden_size,
1051            num_labels,
1052            Default::default(),
1053        );
1054
1055        DebertaForQuestionAnswering {
1056            deberta,
1057            qa_outputs,
1058        }
1059    }
1060
1061    /// Forward pass through the model
1062    ///
1063    /// # Arguments
1064    ///
1065    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1066    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1067    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1068    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1069    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1070    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1071    ///
1072    /// # Returns
1073    ///
1074    /// * `DebertaQuestionAnsweringOutput` containing:
1075    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
1076    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
1077    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1078    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1079    ///
1080    /// # Example
1081    ///
1082    /// ```no_run
1083    /// # use rust_bert::deberta::{DebertaForQuestionAnswering, DebertaConfig};
1084    /// # use tch::{nn, Device, Tensor, no_grad};
1085    /// # use rust_bert::Config;
1086    /// # use std::path::Path;
1087    /// # use tch::kind::Kind::Int64;
1088    /// # let config_path = Path::new("path/to/config.json");
1089    /// # let device = Device::Cpu;
1090    /// # let vs = nn::VarStore::new(device);
1091    /// # let config = DebertaConfig::from_file(config_path);
1092    /// # let model = DebertaForQuestionAnswering::new(&vs.root(), &config);
1093    /// let (batch_size, sequence_length) = (64, 128);
1094    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1095    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1096    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1097    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1098    ///     .expand(&[batch_size, sequence_length], true);
1099    ///
1100    /// let model_output = no_grad(|| {
1101    ///     model.forward_t(
1102    ///         Some(&input_tensor),
1103    ///         Some(&mask),
1104    ///         Some(&token_type_ids),
1105    ///         Some(&position_ids),
1106    ///         None,
1107    ///         false,
1108    ///     )
1109    /// });
1110    /// ```
1111    pub fn forward_t(
1112        &self,
1113        input_ids: Option<&Tensor>,
1114        attention_mask: Option<&Tensor>,
1115        token_type_ids: Option<&Tensor>,
1116        position_ids: Option<&Tensor>,
1117        input_embeds: Option<&Tensor>,
1118        train: bool,
1119    ) -> Result<DebertaQuestionAnsweringOutput, RustBertError> {
1120        let base_model_output = self.deberta.forward_t(
1121            input_ids,
1122            attention_mask,
1123            token_type_ids,
1124            position_ids,
1125            input_embeds,
1126            train,
1127        )?;
1128
1129        let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1130        let logits = sequence_output.split(1, -1);
1131        let (start_logits, end_logits) = (&logits[0], &logits[1]);
1132        let start_logits = start_logits.squeeze_dim(-1);
1133        let end_logits = end_logits.squeeze_dim(-1);
1134
1135        Ok(DebertaQuestionAnsweringOutput {
1136            start_logits,
1137            end_logits,
1138            all_hidden_states: base_model_output.all_hidden_states,
1139            all_attentions: base_model_output.all_attentions,
1140        })
1141    }
1142}
1143
1144/// Container for the DeBERTa model output.
1145pub type DebertaModelOutput = DebertaEncoderOutput;
1146
1147/// Container for the DeBERTa masked LM model output.
1148pub struct DebertaMaskedLMOutput {
1149    /// Logits for the vocabulary items at each sequence position
1150    pub logits: Tensor,
1151    /// Hidden states for all intermediate layers
1152    pub all_hidden_states: Option<Vec<Tensor>>,
1153    /// Attention weights for all intermediate layers
1154    pub all_attentions: Option<Vec<Tensor>>,
1155}
1156
1157/// Container for the DeBERTa sequence classification model output.
1158pub type DebertaSequenceClassificationOutput = BertSequenceClassificationOutput;
1159
1160/// Container for the DeBERTa token classification model output.
1161pub type DebertaTokenClassificationOutput = BertTokenClassificationOutput;
1162
1163/// Container for the DeBERTa question answering model output.
1164pub type DebertaQuestionAnsweringOutput = BertQuestionAnsweringOutput;