rust_bert/models/deberta_v2/
deberta_v2_model.rs

1// Copyright 2020, Microsoft and the HuggingFace Inc. team.
2// Copyright 2022 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::dropout::{Dropout, XDropout};
14use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
15use crate::deberta::{
16    deserialize_attention_type, ContextPooler, DebertaConfig, DebertaLMPredictionHead,
17    DebertaMaskedLMOutput, DebertaModelOutput, DebertaQuestionAnsweringOutput,
18    DebertaSequenceClassificationOutput, DebertaTokenClassificationOutput, PositionAttentionTypes,
19};
20use crate::deberta_v2::embeddings::DebertaV2Embeddings;
21use crate::deberta_v2::encoder::DebertaV2Encoder;
22use crate::{Activation, Config, RustBertError};
23use serde::de::{SeqAccess, Visitor};
24use serde::{de, Deserialize, Deserializer, Serialize};
25use std::borrow::Borrow;
26use std::collections::HashMap;
27use std::fmt;
28use std::str::FromStr;
29use tch::{nn, Kind, Tensor};
30
31/// # DeBERTaV2 Pretrained model weight files
32pub struct DebertaV2ModelResources;
33
34/// # DeBERTaV2 Pretrained model config files
35pub struct DebertaV2ConfigResources;
36
37/// # DeBERTaV2 Pretrained model vocab files
38pub struct DebertaV2VocabResources;
39
40impl DebertaV2ModelResources {
41    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-v3-base>. Modified with conversion to C-array format.
42    pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
43        "deberta-v3-base/model",
44        "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/rust_model.ot",
45    );
46}
47
48impl DebertaV2ConfigResources {
49    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-v3-base>. Modified with conversion to C-array format.
50    pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
51        "deberta-v3-base/config",
52        "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/config.json",
53    );
54}
55
56impl DebertaV2VocabResources {
57    /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-v3-base>. Modified with conversion to C-array format.
58    pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
59        "deberta-v3-base/vocab",
60        "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/spm.model",
61    );
62}
63
64#[derive(Debug, Serialize, Deserialize)]
65/// # DeBERTa (v2) model configuration
66/// Defines the DeBERTa (v2) model architecture (e.g. number of layers, hidden layer size, label mapping...)
67pub struct DebertaV2Config {
68    pub vocab_size: i64,
69    pub hidden_size: i64,
70    pub num_hidden_layers: i64,
71    pub hidden_act: Activation,
72    pub attention_probs_dropout_prob: f64,
73    pub hidden_dropout_prob: f64,
74    pub initializer_range: f64,
75    pub intermediate_size: i64,
76    pub max_position_embeddings: i64,
77    pub position_buckets: Option<i64>,
78    pub num_attention_heads: i64,
79    pub type_vocab_size: i64,
80    pub position_biased_input: Option<bool>,
81    #[serde(default, deserialize_with = "deserialize_attention_type")]
82    pub pos_att_type: Option<PositionAttentionTypes>,
83    #[serde(default, deserialize_with = "deserialize_norm_type")]
84    pub norm_rel_ebd: Option<NormRelEmbedTypes>,
85    pub share_att_key: Option<bool>,
86    pub conv_kernel_size: Option<i64>,
87    pub conv_groups: Option<i64>,
88    pub conv_act: Option<Activation>,
89    pub pooler_dropout: Option<f64>,
90    pub pooler_hidden_act: Option<Activation>,
91    pub pooler_hidden_size: Option<i64>,
92    pub layer_norm_eps: Option<f64>,
93    pub pad_token_id: Option<i64>,
94    pub relative_attention: Option<bool>,
95    pub max_relative_positions: Option<i64>,
96    pub embedding_size: Option<i64>,
97    pub talking_head: Option<bool>,
98    pub output_hidden_states: Option<bool>,
99    pub output_attentions: Option<bool>,
100    pub classifier_activation: Option<bool>,
101    pub classifier_dropout: Option<f64>,
102    pub is_decoder: Option<bool>,
103    pub id2label: Option<HashMap<i64, String>>,
104    pub label2id: Option<HashMap<String, i64>>,
105}
106
107#[allow(non_camel_case_types)]
108#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq, Eq)]
109/// # Layer normalization layer for the DeBERTa model's relative embeddings.
110pub enum NormRelEmbedType {
111    layer_norm,
112}
113
114impl FromStr for NormRelEmbedType {
115    type Err = RustBertError;
116
117    fn from_str(s: &str) -> Result<Self, Self::Err> {
118        match s {
119            "layer_norm" => Ok(NormRelEmbedType::layer_norm),
120            _ => Err(RustBertError::InvalidConfigurationError(format!(
121                "Layer normalization type `{s}` not in accepted variants (`layer_norm`)",
122            ))),
123        }
124    }
125}
126
127#[allow(non_camel_case_types)]
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct NormRelEmbedTypes {
130    types: Vec<NormRelEmbedType>,
131}
132
133impl FromStr for NormRelEmbedTypes {
134    type Err = RustBertError;
135
136    fn from_str(s: &str) -> Result<Self, Self::Err> {
137        let types = s
138            .to_lowercase()
139            .split('|')
140            .map(NormRelEmbedType::from_str)
141            .collect::<Result<Vec<_>, _>>()?;
142        Ok(NormRelEmbedTypes { types })
143    }
144}
145
146impl NormRelEmbedTypes {
147    pub fn has_type(&self, norm_type: NormRelEmbedType) -> bool {
148        self.types.iter().any(|self_type| *self_type == norm_type)
149    }
150
151    pub fn len(&self) -> usize {
152        self.types.len()
153    }
154}
155
156pub fn deserialize_norm_type<'de, D>(deserializer: D) -> Result<Option<NormRelEmbedTypes>, D::Error>
157where
158    D: Deserializer<'de>,
159{
160    struct NormTypeVisitor;
161
162    impl<'de> Visitor<'de> for NormTypeVisitor {
163        type Value = NormRelEmbedTypes;
164
165        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
166            formatter.write_str("null, string or sequence")
167        }
168
169        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
170        where
171            E: de::Error,
172        {
173            Ok(FromStr::from_str(value).unwrap())
174        }
175
176        fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
177        where
178            S: SeqAccess<'de>,
179        {
180            let mut types = vec![];
181            while let Some(norm_type) = seq.next_element::<String>()? {
182                types.push(FromStr::from_str(norm_type.as_str()).unwrap())
183            }
184            Ok(NormRelEmbedTypes { types })
185        }
186    }
187
188    deserializer.deserialize_any(NormTypeVisitor).map(Some)
189}
190
191impl Config for DebertaV2Config {}
192
193impl Default for DebertaV2Config {
194    fn default() -> Self {
195        DebertaV2Config {
196            vocab_size: 128100,
197            hidden_size: 1536,
198            num_hidden_layers: 24,
199            hidden_act: Activation::gelu,
200            attention_probs_dropout_prob: 0.1,
201            hidden_dropout_prob: 0.1,
202            initializer_range: 0.02,
203            intermediate_size: 6144,
204            max_position_embeddings: 512,
205            position_buckets: None,
206            num_attention_heads: 24,
207            type_vocab_size: 0,
208            position_biased_input: Some(true),
209            pos_att_type: None,
210            norm_rel_ebd: None,
211            share_att_key: None,
212            conv_kernel_size: None,
213            conv_groups: None,
214            conv_act: None,
215            pooler_dropout: Some(0.0),
216            pooler_hidden_act: Some(Activation::gelu),
217            pooler_hidden_size: None,
218            layer_norm_eps: Some(1e-7),
219            pad_token_id: Some(0),
220            relative_attention: None,
221            max_relative_positions: None,
222            embedding_size: None,
223            talking_head: None,
224            output_hidden_states: None,
225            output_attentions: None,
226            classifier_activation: None,
227            classifier_dropout: None,
228            is_decoder: None,
229            id2label: None,
230            label2id: None,
231        }
232    }
233}
234
235impl From<DebertaV2Config> for DebertaConfig {
236    fn from(v2_config: DebertaV2Config) -> Self {
237        DebertaConfig {
238            hidden_act: v2_config.hidden_act,
239            attention_probs_dropout_prob: v2_config.attention_probs_dropout_prob,
240            hidden_dropout_prob: v2_config.hidden_dropout_prob,
241            hidden_size: v2_config.hidden_size,
242            initializer_range: v2_config.initializer_range,
243            intermediate_size: v2_config.intermediate_size,
244            max_position_embeddings: v2_config.max_position_embeddings,
245            num_attention_heads: v2_config.num_attention_heads,
246            num_hidden_layers: v2_config.num_hidden_layers,
247            type_vocab_size: v2_config.type_vocab_size,
248            vocab_size: v2_config.vocab_size,
249            position_biased_input: v2_config.position_biased_input,
250            pos_att_type: v2_config.pos_att_type,
251            pooler_dropout: v2_config.pooler_dropout,
252            pooler_hidden_act: v2_config.pooler_hidden_act,
253            pooler_hidden_size: v2_config.pooler_hidden_size,
254            layer_norm_eps: v2_config.layer_norm_eps,
255            pad_token_id: v2_config.pad_token_id,
256            relative_attention: v2_config.relative_attention,
257            max_relative_positions: v2_config.max_relative_positions,
258            embedding_size: v2_config.embedding_size,
259            talking_head: v2_config.talking_head,
260            output_hidden_states: v2_config.output_hidden_states,
261            output_attentions: v2_config.output_attentions,
262            classifier_dropout: v2_config.classifier_dropout,
263            is_decoder: v2_config.is_decoder,
264            id2label: v2_config.id2label,
265            label2id: v2_config.label2id,
266            share_att_key: v2_config.share_att_key,
267            position_buckets: v2_config.position_buckets,
268        }
269    }
270}
271
272impl From<&DebertaV2Config> for DebertaConfig {
273    fn from(v2_config: &DebertaV2Config) -> Self {
274        DebertaConfig {
275            hidden_act: v2_config.hidden_act,
276            attention_probs_dropout_prob: v2_config.attention_probs_dropout_prob,
277            hidden_dropout_prob: v2_config.hidden_dropout_prob,
278            hidden_size: v2_config.hidden_size,
279            initializer_range: v2_config.initializer_range,
280            intermediate_size: v2_config.intermediate_size,
281            max_position_embeddings: v2_config.max_position_embeddings,
282            num_attention_heads: v2_config.num_attention_heads,
283            num_hidden_layers: v2_config.num_hidden_layers,
284            type_vocab_size: v2_config.type_vocab_size,
285            vocab_size: v2_config.vocab_size,
286            position_biased_input: v2_config.position_biased_input,
287            pos_att_type: v2_config.pos_att_type.clone(),
288            pooler_dropout: v2_config.pooler_dropout,
289            pooler_hidden_act: v2_config.pooler_hidden_act,
290            pooler_hidden_size: v2_config.pooler_hidden_size,
291            layer_norm_eps: v2_config.layer_norm_eps,
292            pad_token_id: v2_config.pad_token_id,
293            relative_attention: v2_config.relative_attention,
294            max_relative_positions: v2_config.max_relative_positions,
295            embedding_size: v2_config.embedding_size,
296            talking_head: v2_config.talking_head,
297            output_hidden_states: v2_config.output_hidden_states,
298            output_attentions: v2_config.output_attentions,
299            classifier_dropout: v2_config.classifier_dropout,
300            is_decoder: v2_config.is_decoder,
301            id2label: v2_config.id2label.clone(),
302            label2id: v2_config.label2id.clone(),
303            share_att_key: v2_config.share_att_key,
304            position_buckets: v2_config.position_buckets,
305        }
306    }
307}
308
309/// # DeBERTa V2 Base model
310/// Base architecture for DeBERTa V2 models. Task-specific models will be built from this common base model
311/// It is made of the following blocks:
312/// - `embeddings`: `DeBERTa` V2 embeddings
313/// - `encoder`: `DeBERTaV2Encoder` (transformer) made of a vector of layers.
314pub struct DebertaV2Model {
315    embeddings: DebertaV2Embeddings,
316    encoder: DebertaV2Encoder,
317}
318
319impl DebertaV2Model {
320    /// Build a new `DebertaV2Model`
321    ///
322    /// # Arguments
323    ///
324    /// * `p` - Variable store path for the root of the BERT model
325    /// * `config` - `DebertaV2Config` object defining the model architecture and decoder status
326    ///
327    /// # Example
328    ///
329    /// ```no_run
330    /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2Model};
331    /// use rust_bert::Config;
332    /// use std::path::Path;
333    /// use tch::{nn, Device};
334    ///
335    /// let config_path = Path::new("path/to/config.json");
336    /// let device = Device::Cpu;
337    /// let p = nn::VarStore::new(device);
338    /// let config = DebertaV2Config::from_file(config_path);
339    /// let model: DebertaV2Model = DebertaV2Model::new(&p.root() / "deberta", &config);
340    /// ```
341    pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2Model
342    where
343        P: Borrow<nn::Path<'p>>,
344    {
345        let p = p.borrow();
346
347        let embeddings = DebertaV2Embeddings::new(p / "embeddings", &config.into());
348        let encoder = DebertaV2Encoder::new(p / "encoder", config);
349
350        DebertaV2Model {
351            embeddings,
352            encoder,
353        }
354    }
355
356    /// Forward pass through the model
357    ///
358    /// # Arguments
359    ///
360    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
361    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
362    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
363    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
364    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
365    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
366    ///
367    /// # Returns
368    ///
369    /// * `DebertaV2Output` containing:
370    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
371    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
372    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
373    ///
374    /// # Example
375    ///
376    /// ```no_run
377    /// # use rust_bert::deberta_v2::{DebertaV2Model, DebertaV2Config};
378    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
379    /// # use rust_bert::Config;
380    /// # use std::path::Path;
381    /// # let config_path = Path::new("path/to/config.json");
382    /// # let device = Device::Cpu;
383    /// # let vs = nn::VarStore::new(device);
384    /// # let config = DebertaV2Config::from_file(config_path);
385    /// # let model = DebertaV2Model::new(&vs.root(), &config);
386    /// let (batch_size, sequence_length) = (64, 128);
387    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
388    /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Kind::Int64, device));
389    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
390    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
391    ///     .expand(&[batch_size, sequence_length], true);
392    ///
393    /// let model_output = no_grad(|| {
394    ///     model
395    ///         .forward_t(
396    ///             Some(&input_tensor),
397    ///             Some(&attention_mask),
398    ///             Some(&token_type_ids),
399    ///             Some(&position_ids),
400    ///             None,
401    ///             false,
402    ///         )
403    ///         .unwrap()
404    /// });
405    /// ```
406    pub fn forward_t(
407        &self,
408        input_ids: Option<&Tensor>,
409        attention_mask: Option<&Tensor>,
410        token_type_ids: Option<&Tensor>,
411        position_ids: Option<&Tensor>,
412        input_embeds: Option<&Tensor>,
413        train: bool,
414    ) -> Result<DebertaV2ModelOutput, RustBertError> {
415        let (input_shape, device) =
416            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
417
418        let calc_attention_mask = if attention_mask.is_none() {
419            Some(Tensor::ones(input_shape.as_slice(), (Kind::Bool, device)))
420        } else {
421            None
422        };
423
424        let attention_mask =
425            attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
426
427        let embedding_output = self.embeddings.forward_t(
428            input_ids,
429            token_type_ids,
430            position_ids,
431            attention_mask,
432            input_embeds,
433            train,
434        )?;
435
436        let encoder_output =
437            self.encoder
438                .forward_t(&embedding_output, attention_mask, None, None, train)?;
439
440        Ok(encoder_output)
441    }
442}
443
444/// # DeBERTa V2 for masked language model
445/// Base DeBERTa V2 model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
446/// It is made of the following blocks:
447/// - `deberta`: Base DeBERTa V2 model
448/// - `cls`: LM prediction head
449pub struct DebertaV2ForMaskedLM {
450    deberta: DebertaV2Model,
451    cls: DebertaLMPredictionHead,
452}
453
454impl DebertaV2ForMaskedLM {
455    /// Build a new `DebertaV2ForMaskedLM`
456    ///
457    /// # Arguments
458    ///
459    /// * `p` - Variable store path for the root of the BertForMaskedLM model
460    /// * `config` - `DebertaConfig` object defining the model architecture and vocab size
461    ///
462    /// # Example
463    ///
464    /// ```no_run
465    /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForMaskedLM};
466    /// use rust_bert::Config;
467    /// use std::path::Path;
468    /// use tch::{nn, Device};
469    ///
470    /// let config_path = Path::new("path/to/config.json");
471    /// let device = Device::Cpu;
472    /// let p = nn::VarStore::new(device);
473    /// let config = DebertaV2Config::from_file(config_path);
474    /// let model = DebertaV2ForMaskedLM::new(&p.root(), &config);
475    /// ```
476    pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForMaskedLM
477    where
478        P: Borrow<nn::Path<'p>>,
479    {
480        let p = p.borrow();
481
482        let deberta = DebertaV2Model::new(p / "deberta", config);
483        let cls =
484            DebertaLMPredictionHead::new(p.sub("cls").sub("predictions"), &config.into(), false);
485
486        DebertaV2ForMaskedLM { deberta, cls }
487    }
488
489    /// Forward pass through the model
490    ///
491    /// # Arguments
492    ///
493    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
494    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
495    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
496    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
497    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
498    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
499    ///
500    /// # Returns
501    ///
502    /// * `DebertaMaskedLMOutput` containing:
503    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
504    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
505    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
506    ///
507    /// # Example
508    ///
509    /// ```no_run
510    /// # use rust_bert::deberta_v2::{DebertaV2ForMaskedLM, DebertaV2Config};
511    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
512    /// # use rust_bert::Config;
513    /// # use std::path::Path;
514    /// # let config_path = Path::new("path/to/config.json");
515    /// # let device = Device::Cpu;
516    /// # let vs = nn::VarStore::new(device);
517    /// # let config = DebertaV2Config::from_file(config_path);
518    /// # let model = DebertaV2ForMaskedLM::new(&vs.root(), &config);
519    /// let (batch_size, sequence_length) = (64, 128);
520    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
521    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
522    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
523    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
524    ///     .expand(&[batch_size, sequence_length], true);
525    ///
526    /// let model_output = no_grad(|| {
527    ///     model.forward_t(
528    ///         Some(&input_tensor),
529    ///         Some(&mask),
530    ///         Some(&token_type_ids),
531    ///         Some(&position_ids),
532    ///         None,
533    ///         false,
534    ///     )
535    /// });
536    /// ```
537    pub fn forward_t(
538        &self,
539        input_ids: Option<&Tensor>,
540        attention_mask: Option<&Tensor>,
541        token_type_ids: Option<&Tensor>,
542        position_ids: Option<&Tensor>,
543        input_embeds: Option<&Tensor>,
544        train: bool,
545    ) -> Result<DebertaV2MaskedLMOutput, RustBertError> {
546        let model_outputs = self.deberta.forward_t(
547            input_ids,
548            attention_mask,
549            token_type_ids,
550            position_ids,
551            input_embeds,
552            train,
553        )?;
554
555        let logits = model_outputs.hidden_state.apply(&self.cls);
556        Ok(DebertaV2MaskedLMOutput {
557            logits,
558            all_hidden_states: model_outputs.all_hidden_states,
559            all_attentions: model_outputs.all_attentions,
560        })
561    }
562}
563
564/// # DeBERTa V2 for sequence classification
565/// Base DeBERTa V2 model with a classifier head to perform sentence or document-level classification
566/// It is made of the following blocks:
567/// - `deberta`: Base Deberta (V2) Model
568/// - `classifier`: BERT linear layer for classification
569pub struct DebertaV2ForSequenceClassification {
570    deberta: DebertaV2Model,
571    pooler: ContextPooler,
572    classifier: nn::Linear,
573    dropout: XDropout,
574}
575
576impl DebertaV2ForSequenceClassification {
577    /// Build a new `DebertaV2ForSequenceClassification`
578    ///
579    /// # Arguments
580    ///
581    /// * `p` - Variable store path for the root of the DebertaForSequenceClassification model
582    /// * `config` - `DebertaV2Config` object defining the model architecture and number of classes
583    ///
584    /// # Example
585    ///
586    /// ```no_run
587    /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForSequenceClassification};
588    /// use rust_bert::Config;
589    /// use std::path::Path;
590    /// use tch::{nn, Device};
591    ///
592    /// let config_path = Path::new("path/to/config.json");
593    /// let device = Device::Cpu;
594    /// let p = nn::VarStore::new(device);
595    /// let config = DebertaV2Config::from_file(config_path);
596    /// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config).unwrap();
597    /// ```
598    pub fn new<'p, P>(
599        p: P,
600        config: &DebertaV2Config,
601    ) -> Result<DebertaV2ForSequenceClassification, RustBertError>
602    where
603        P: Borrow<nn::Path<'p>>,
604    {
605        let p = p.borrow();
606
607        let deberta = DebertaV2Model::new(p / "deberta", config);
608        let pooler = ContextPooler::new(p / "pooler", &config.into());
609        let dropout = XDropout::new(
610            config
611                .classifier_dropout
612                .unwrap_or(config.hidden_dropout_prob),
613        );
614
615        let num_labels = config
616            .id2label
617            .as_ref()
618            .ok_or_else(|| {
619                RustBertError::InvalidConfigurationError(
620                    "num_labels not provided in configuration".to_string(),
621                )
622            })?
623            .len() as i64;
624
625        let classifier = nn::linear(
626            p / "classifier",
627            pooler.output_dim,
628            num_labels,
629            Default::default(),
630        );
631
632        Ok(DebertaV2ForSequenceClassification {
633            deberta,
634            pooler,
635            classifier,
636            dropout,
637        })
638    }
639
640    /// Forward pass through the model
641    ///
642    /// # Arguments
643    ///
644    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
645    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
646    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
647    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
648    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
649    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
650    ///
651    /// # Returns
652    ///
653    /// * `DebertaV2SequenceClassificationOutput` containing:
654    ///   - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
655    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
656    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
657    ///
658    /// # Example
659    ///
660    /// ```no_run
661    /// # use rust_bert::deberta_v2::{DebertaV2ForSequenceClassification, DebertaV2Config};
662    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
663    /// # use rust_bert::Config;
664    /// # use std::path::Path;
665    /// # let config_path = Path::new("path/to/config.json");
666    /// # let device = Device::Cpu;
667    /// # let vs = nn::VarStore::new(device);
668    /// # let config = DebertaV2Config::from_file(config_path);
669    /// # let model = DebertaV2ForSequenceClassification::new(&vs.root(), &config).unwrap();;
670    /// let (batch_size, sequence_length) = (64, 128);
671    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
672    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
673    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
674    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
675    ///     .expand(&[batch_size, sequence_length], true);
676    ///
677    /// let model_output = no_grad(|| {
678    ///     model.forward_t(
679    ///         Some(&input_tensor),
680    ///         Some(&mask),
681    ///         Some(&token_type_ids),
682    ///         Some(&position_ids),
683    ///         None,
684    ///         false,
685    ///     )
686    /// });
687    /// ```
688    pub fn forward_t(
689        &self,
690        input_ids: Option<&Tensor>,
691        attention_mask: Option<&Tensor>,
692        token_type_ids: Option<&Tensor>,
693        position_ids: Option<&Tensor>,
694        input_embeds: Option<&Tensor>,
695        train: bool,
696    ) -> Result<DebertaV2SequenceClassificationOutput, RustBertError> {
697        let base_model_output = self.deberta.forward_t(
698            input_ids,
699            attention_mask,
700            token_type_ids,
701            position_ids,
702            input_embeds,
703            train,
704        )?;
705
706        let logits = base_model_output
707            .hidden_state
708            .apply_t(&self.pooler, train)
709            .apply_t(&self.dropout, train)
710            .apply(&self.classifier);
711
712        Ok(DebertaV2SequenceClassificationOutput {
713            logits,
714            all_hidden_states: base_model_output.all_hidden_states,
715            all_attentions: base_model_output.all_attentions,
716        })
717    }
718}
719
720/// # DeBERTa V2 for token classification (e.g. NER, POS)
721/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
722/// not necessarily aligned with words in the sentence.
723/// It is made of the following blocks:
724/// - `deberta`: Base DeBERTa (V2) model
725/// - `dropout`: Dropout layer before the last token-level predictions layer
726/// - `classifier`: Linear layer for token classification
727pub struct DebertaV2ForTokenClassification {
728    deberta: DebertaV2Model,
729    dropout: Dropout,
730    classifier: nn::Linear,
731}
732
733impl DebertaV2ForTokenClassification {
734    /// Build a new `DebertaV2ForTokenClassification`
735    ///
736    /// # Arguments
737    ///
738    /// * `p` - Variable store path for the root of the Deberta V2 model
739    /// * `config` - `DebertaV2Config` object defining the model architecture
740    ///
741    /// # Example
742    ///
743    /// ```no_run
744    /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForTokenClassification};
745    /// use rust_bert::Config;
746    /// use std::path::Path;
747    /// use tch::{nn, Device};
748    ///
749    /// let config_path = Path::new("path/to/config.json");
750    /// let device = Device::Cpu;
751    /// let p = nn::VarStore::new(device);
752    /// let config = DebertaV2Config::from_file(config_path);
753    /// let model = DebertaV2ForTokenClassification::new(&p.root(), &config);
754    /// ```
755    pub fn new<'p, P>(
756        p: P,
757        config: &DebertaV2Config,
758    ) -> Result<DebertaV2ForTokenClassification, RustBertError>
759    where
760        P: Borrow<nn::Path<'p>>,
761    {
762        let p = p.borrow();
763
764        let deberta = DebertaV2Model::new(p / "deberta", config);
765        let dropout = Dropout::new(config.hidden_dropout_prob);
766        let num_labels = config
767            .id2label
768            .as_ref()
769            .ok_or_else(|| {
770                RustBertError::InvalidConfigurationError(
771                    "num_labels not provided in configuration".to_string(),
772                )
773            })?
774            .len() as i64;
775        let classifier = nn::linear(
776            p / "classifier",
777            config.hidden_size,
778            num_labels,
779            Default::default(),
780        );
781
782        Ok(DebertaV2ForTokenClassification {
783            deberta,
784            dropout,
785            classifier,
786        })
787    }
788
789    /// Forward pass through the model
790    ///
791    /// # Arguments
792    ///
793    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
794    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
795    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
796    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
797    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
798    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
799    ///
800    /// # Returns
801    ///
802    /// * `DebertaV2TokenClassificationOutput` containing:
803    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*)
804    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
805    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
806    ///
807    /// # Example
808    ///
809    /// ```no_run
810    /// # use rust_bert::deberta_v2::{DebertaV2ForTokenClassification, DebertaV2Config};
811    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
812    /// # use rust_bert::Config;
813    /// # use std::path::Path;
814    /// # let config_path = Path::new("path/to/config.json");
815    /// # let device = Device::Cpu;
816    /// # let vs = nn::VarStore::new(device);
817    /// # let config = DebertaV2Config::from_file(config_path);
818    /// # let model = DebertaV2ForTokenClassification::new(&vs.root(), &config).unwrap();
819    /// let (batch_size, sequence_length) = (64, 128);
820    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
821    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
822    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
823    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
824    ///     .expand(&[batch_size, sequence_length], true);
825    ///
826    /// let model_output = no_grad(|| {
827    ///     model.forward_t(
828    ///         Some(&input_tensor),
829    ///         Some(&mask),
830    ///         Some(&token_type_ids),
831    ///         Some(&position_ids),
832    ///         None,
833    ///         false,
834    ///     )
835    /// });
836    /// ```
837    pub fn forward_t(
838        &self,
839        input_ids: Option<&Tensor>,
840        attention_mask: Option<&Tensor>,
841        token_type_ids: Option<&Tensor>,
842        position_ids: Option<&Tensor>,
843        input_embeds: Option<&Tensor>,
844        train: bool,
845    ) -> Result<DebertaV2TokenClassificationOutput, RustBertError> {
846        let base_model_output = self.deberta.forward_t(
847            input_ids,
848            attention_mask,
849            token_type_ids,
850            position_ids,
851            input_embeds,
852            train,
853        )?;
854
855        let logits = base_model_output
856            .hidden_state
857            .apply_t(&self.dropout, train)
858            .apply(&self.classifier);
859
860        Ok(DebertaV2TokenClassificationOutput {
861            logits,
862            all_hidden_states: base_model_output.all_hidden_states,
863            all_attentions: base_model_output.all_attentions,
864        })
865    }
866}
867
868/// # DeBERTa V2 for question answering
869/// Extractive question-answering model based on a DeBERTa V2 language model. Identifies the segment of a context that answers a provided question.
870/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
871/// See the question answering pipeline (also provided in this crate) for more details.
872/// It is made of the following blocks:
873/// - `deberta`: Base DeBERTa V2 model
874/// - `qa_outputs`: Linear layer for question answering
875pub struct DebertaV2ForQuestionAnswering {
876    deberta: DebertaV2Model,
877    qa_outputs: nn::Linear,
878}
879
880impl DebertaV2ForQuestionAnswering {
881    /// Build a new `DebertaV2ForQuestionAnswering`
882    ///
883    /// # Arguments
884    ///
885    /// * `p` - Variable store path for the root of the BertForQuestionAnswering model
886    /// * `config` - `DebertaV2Config` object defining the model architecture
887    ///
888    /// # Example
889    ///
890    /// ```no_run
891    /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForQuestionAnswering};
892    /// use rust_bert::Config;
893    /// use std::path::Path;
894    /// use tch::{nn, Device};
895    ///
896    /// let config_path = Path::new("path/to/config.json");
897    /// let device = Device::Cpu;
898    /// let p = nn::VarStore::new(device);
899    /// let config = DebertaV2Config::from_file(config_path);
900    /// let model = DebertaV2ForQuestionAnswering::new(&p.root(), &config);
901    /// ```
902    pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForQuestionAnswering
903    where
904        P: Borrow<nn::Path<'p>>,
905    {
906        let p = p.borrow();
907
908        let deberta = DebertaV2Model::new(p / "deberta", config);
909        let num_labels = 2;
910        let qa_outputs = nn::linear(
911            p / "qa_outputs",
912            config.hidden_size,
913            num_labels,
914            Default::default(),
915        );
916
917        DebertaV2ForQuestionAnswering {
918            deberta,
919            qa_outputs,
920        }
921    }
922
923    /// Forward pass through the model
924    ///
925    /// # Arguments
926    ///
927    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
928    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
929    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
930    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
931    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
932    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
933    ///
934    /// # Returns
935    ///
936    /// * `DebertaQuestionAnsweringOutput` containing:
937    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
938    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
939    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
940    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
941    ///
942    /// # Example
943    ///
944    /// ```no_run
945    /// # use rust_bert::deberta_v2::{DebertaV2ForQuestionAnswering, DebertaV2Config};
946    /// # use tch::{nn, Device, Tensor, no_grad};
947    /// # use rust_bert::Config;
948    /// # use std::path::Path;
949    /// # use tch::kind::Kind::Int64;
950    /// # let config_path = Path::new("path/to/config.json");
951    /// # let device = Device::Cpu;
952    /// # let vs = nn::VarStore::new(device);
953    /// # let config = DebertaV2Config::from_file(config_path);
954    /// # let model = DebertaV2ForQuestionAnswering::new(&vs.root(), &config);
955    /// let (batch_size, sequence_length) = (64, 128);
956    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
957    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
958    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
959    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
960    ///     .expand(&[batch_size, sequence_length], true);
961    ///
962    /// let model_output = no_grad(|| {
963    ///     model.forward_t(
964    ///         Some(&input_tensor),
965    ///         Some(&mask),
966    ///         Some(&token_type_ids),
967    ///         Some(&position_ids),
968    ///         None,
969    ///         false,
970    ///     )
971    /// });
972    /// ```
973    pub fn forward_t(
974        &self,
975        input_ids: Option<&Tensor>,
976        attention_mask: Option<&Tensor>,
977        token_type_ids: Option<&Tensor>,
978        position_ids: Option<&Tensor>,
979        input_embeds: Option<&Tensor>,
980        train: bool,
981    ) -> Result<DebertaV2QuestionAnsweringOutput, RustBertError> {
982        let base_model_output = self.deberta.forward_t(
983            input_ids,
984            attention_mask,
985            token_type_ids,
986            position_ids,
987            input_embeds,
988            train,
989        )?;
990
991        let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
992        let logits = sequence_output.split(1, -1);
993        let (start_logits, end_logits) = (&logits[0], &logits[1]);
994        let start_logits = start_logits.squeeze_dim(-1);
995        let end_logits = end_logits.squeeze_dim(-1);
996
997        Ok(DebertaV2QuestionAnsweringOutput {
998            start_logits,
999            end_logits,
1000            all_hidden_states: base_model_output.all_hidden_states,
1001            all_attentions: base_model_output.all_attentions,
1002        })
1003    }
1004}
1005
1006/// Container for the DeBERTa V2 model output.
1007pub type DebertaV2ModelOutput = DebertaModelOutput;
1008
1009/// Container for the DeBERTa V2masked LM model output.
1010pub type DebertaV2MaskedLMOutput = DebertaMaskedLMOutput;
1011
1012/// Container for the DeBERTa sequence classification model output.
1013pub type DebertaV2SequenceClassificationOutput = DebertaSequenceClassificationOutput;
1014
1015/// Container for the DeBERTa token classification model output.
1016pub type DebertaV2TokenClassificationOutput = DebertaTokenClassificationOutput;
1017
1018/// Container for the DeBERTa question answering model output.
1019pub type DebertaV2QuestionAnsweringOutput = DebertaQuestionAnsweringOutput;