rust_bert/models/deberta_v2/deberta_v2_model.rs
1// Copyright 2020, Microsoft and the HuggingFace Inc. team.
2// Copyright 2022 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::dropout::{Dropout, XDropout};
14use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
15use crate::deberta::{
16 deserialize_attention_type, ContextPooler, DebertaConfig, DebertaLMPredictionHead,
17 DebertaMaskedLMOutput, DebertaModelOutput, DebertaQuestionAnsweringOutput,
18 DebertaSequenceClassificationOutput, DebertaTokenClassificationOutput, PositionAttentionTypes,
19};
20use crate::deberta_v2::embeddings::DebertaV2Embeddings;
21use crate::deberta_v2::encoder::DebertaV2Encoder;
22use crate::{Activation, Config, RustBertError};
23use serde::de::{SeqAccess, Visitor};
24use serde::{de, Deserialize, Deserializer, Serialize};
25use std::borrow::Borrow;
26use std::collections::HashMap;
27use std::fmt;
28use std::str::FromStr;
29use tch::{nn, Kind, Tensor};
30
31/// # DeBERTaV2 Pretrained model weight files
32pub struct DebertaV2ModelResources;
33
34/// # DeBERTaV2 Pretrained model config files
35pub struct DebertaV2ConfigResources;
36
37/// # DeBERTaV2 Pretrained model vocab files
38pub struct DebertaV2VocabResources;
39
40impl DebertaV2ModelResources {
41 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-v3-base>. Modified with conversion to C-array format.
42 pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
43 "deberta-v3-base/model",
44 "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/rust_model.ot",
45 );
46}
47
48impl DebertaV2ConfigResources {
49 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-v3-base>. Modified with conversion to C-array format.
50 pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
51 "deberta-v3-base/config",
52 "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/config.json",
53 );
54}
55
56impl DebertaV2VocabResources {
57 /// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/deberta-v3-base>. Modified with conversion to C-array format.
58 pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
59 "deberta-v3-base/vocab",
60 "https://huggingface.co/microsoft/deberta-v3-base/resolve/main/spm.model",
61 );
62}
63
64#[derive(Debug, Serialize, Deserialize)]
65/// # DeBERTa (v2) model configuration
66/// Defines the DeBERTa (v2) model architecture (e.g. number of layers, hidden layer size, label mapping...)
67pub struct DebertaV2Config {
68 pub vocab_size: i64,
69 pub hidden_size: i64,
70 pub num_hidden_layers: i64,
71 pub hidden_act: Activation,
72 pub attention_probs_dropout_prob: f64,
73 pub hidden_dropout_prob: f64,
74 pub initializer_range: f64,
75 pub intermediate_size: i64,
76 pub max_position_embeddings: i64,
77 pub position_buckets: Option<i64>,
78 pub num_attention_heads: i64,
79 pub type_vocab_size: i64,
80 pub position_biased_input: Option<bool>,
81 #[serde(default, deserialize_with = "deserialize_attention_type")]
82 pub pos_att_type: Option<PositionAttentionTypes>,
83 #[serde(default, deserialize_with = "deserialize_norm_type")]
84 pub norm_rel_ebd: Option<NormRelEmbedTypes>,
85 pub share_att_key: Option<bool>,
86 pub conv_kernel_size: Option<i64>,
87 pub conv_groups: Option<i64>,
88 pub conv_act: Option<Activation>,
89 pub pooler_dropout: Option<f64>,
90 pub pooler_hidden_act: Option<Activation>,
91 pub pooler_hidden_size: Option<i64>,
92 pub layer_norm_eps: Option<f64>,
93 pub pad_token_id: Option<i64>,
94 pub relative_attention: Option<bool>,
95 pub max_relative_positions: Option<i64>,
96 pub embedding_size: Option<i64>,
97 pub talking_head: Option<bool>,
98 pub output_hidden_states: Option<bool>,
99 pub output_attentions: Option<bool>,
100 pub classifier_activation: Option<bool>,
101 pub classifier_dropout: Option<f64>,
102 pub is_decoder: Option<bool>,
103 pub id2label: Option<HashMap<i64, String>>,
104 pub label2id: Option<HashMap<String, i64>>,
105}
106
107#[allow(non_camel_case_types)]
108#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq, Eq)]
109/// # Layer normalization layer for the DeBERTa model's relative embeddings.
110pub enum NormRelEmbedType {
111 layer_norm,
112}
113
114impl FromStr for NormRelEmbedType {
115 type Err = RustBertError;
116
117 fn from_str(s: &str) -> Result<Self, Self::Err> {
118 match s {
119 "layer_norm" => Ok(NormRelEmbedType::layer_norm),
120 _ => Err(RustBertError::InvalidConfigurationError(format!(
121 "Layer normalization type `{s}` not in accepted variants (`layer_norm`)",
122 ))),
123 }
124 }
125}
126
127#[allow(non_camel_case_types)]
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct NormRelEmbedTypes {
130 types: Vec<NormRelEmbedType>,
131}
132
133impl FromStr for NormRelEmbedTypes {
134 type Err = RustBertError;
135
136 fn from_str(s: &str) -> Result<Self, Self::Err> {
137 let types = s
138 .to_lowercase()
139 .split('|')
140 .map(NormRelEmbedType::from_str)
141 .collect::<Result<Vec<_>, _>>()?;
142 Ok(NormRelEmbedTypes { types })
143 }
144}
145
146impl NormRelEmbedTypes {
147 pub fn has_type(&self, norm_type: NormRelEmbedType) -> bool {
148 self.types.iter().any(|self_type| *self_type == norm_type)
149 }
150
151 pub fn len(&self) -> usize {
152 self.types.len()
153 }
154}
155
156pub fn deserialize_norm_type<'de, D>(deserializer: D) -> Result<Option<NormRelEmbedTypes>, D::Error>
157where
158 D: Deserializer<'de>,
159{
160 struct NormTypeVisitor;
161
162 impl<'de> Visitor<'de> for NormTypeVisitor {
163 type Value = NormRelEmbedTypes;
164
165 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
166 formatter.write_str("null, string or sequence")
167 }
168
169 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
170 where
171 E: de::Error,
172 {
173 Ok(FromStr::from_str(value).unwrap())
174 }
175
176 fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
177 where
178 S: SeqAccess<'de>,
179 {
180 let mut types = vec![];
181 while let Some(norm_type) = seq.next_element::<String>()? {
182 types.push(FromStr::from_str(norm_type.as_str()).unwrap())
183 }
184 Ok(NormRelEmbedTypes { types })
185 }
186 }
187
188 deserializer.deserialize_any(NormTypeVisitor).map(Some)
189}
190
191impl Config for DebertaV2Config {}
192
193impl Default for DebertaV2Config {
194 fn default() -> Self {
195 DebertaV2Config {
196 vocab_size: 128100,
197 hidden_size: 1536,
198 num_hidden_layers: 24,
199 hidden_act: Activation::gelu,
200 attention_probs_dropout_prob: 0.1,
201 hidden_dropout_prob: 0.1,
202 initializer_range: 0.02,
203 intermediate_size: 6144,
204 max_position_embeddings: 512,
205 position_buckets: None,
206 num_attention_heads: 24,
207 type_vocab_size: 0,
208 position_biased_input: Some(true),
209 pos_att_type: None,
210 norm_rel_ebd: None,
211 share_att_key: None,
212 conv_kernel_size: None,
213 conv_groups: None,
214 conv_act: None,
215 pooler_dropout: Some(0.0),
216 pooler_hidden_act: Some(Activation::gelu),
217 pooler_hidden_size: None,
218 layer_norm_eps: Some(1e-7),
219 pad_token_id: Some(0),
220 relative_attention: None,
221 max_relative_positions: None,
222 embedding_size: None,
223 talking_head: None,
224 output_hidden_states: None,
225 output_attentions: None,
226 classifier_activation: None,
227 classifier_dropout: None,
228 is_decoder: None,
229 id2label: None,
230 label2id: None,
231 }
232 }
233}
234
235impl From<DebertaV2Config> for DebertaConfig {
236 fn from(v2_config: DebertaV2Config) -> Self {
237 DebertaConfig {
238 hidden_act: v2_config.hidden_act,
239 attention_probs_dropout_prob: v2_config.attention_probs_dropout_prob,
240 hidden_dropout_prob: v2_config.hidden_dropout_prob,
241 hidden_size: v2_config.hidden_size,
242 initializer_range: v2_config.initializer_range,
243 intermediate_size: v2_config.intermediate_size,
244 max_position_embeddings: v2_config.max_position_embeddings,
245 num_attention_heads: v2_config.num_attention_heads,
246 num_hidden_layers: v2_config.num_hidden_layers,
247 type_vocab_size: v2_config.type_vocab_size,
248 vocab_size: v2_config.vocab_size,
249 position_biased_input: v2_config.position_biased_input,
250 pos_att_type: v2_config.pos_att_type,
251 pooler_dropout: v2_config.pooler_dropout,
252 pooler_hidden_act: v2_config.pooler_hidden_act,
253 pooler_hidden_size: v2_config.pooler_hidden_size,
254 layer_norm_eps: v2_config.layer_norm_eps,
255 pad_token_id: v2_config.pad_token_id,
256 relative_attention: v2_config.relative_attention,
257 max_relative_positions: v2_config.max_relative_positions,
258 embedding_size: v2_config.embedding_size,
259 talking_head: v2_config.talking_head,
260 output_hidden_states: v2_config.output_hidden_states,
261 output_attentions: v2_config.output_attentions,
262 classifier_dropout: v2_config.classifier_dropout,
263 is_decoder: v2_config.is_decoder,
264 id2label: v2_config.id2label,
265 label2id: v2_config.label2id,
266 share_att_key: v2_config.share_att_key,
267 position_buckets: v2_config.position_buckets,
268 }
269 }
270}
271
272impl From<&DebertaV2Config> for DebertaConfig {
273 fn from(v2_config: &DebertaV2Config) -> Self {
274 DebertaConfig {
275 hidden_act: v2_config.hidden_act,
276 attention_probs_dropout_prob: v2_config.attention_probs_dropout_prob,
277 hidden_dropout_prob: v2_config.hidden_dropout_prob,
278 hidden_size: v2_config.hidden_size,
279 initializer_range: v2_config.initializer_range,
280 intermediate_size: v2_config.intermediate_size,
281 max_position_embeddings: v2_config.max_position_embeddings,
282 num_attention_heads: v2_config.num_attention_heads,
283 num_hidden_layers: v2_config.num_hidden_layers,
284 type_vocab_size: v2_config.type_vocab_size,
285 vocab_size: v2_config.vocab_size,
286 position_biased_input: v2_config.position_biased_input,
287 pos_att_type: v2_config.pos_att_type.clone(),
288 pooler_dropout: v2_config.pooler_dropout,
289 pooler_hidden_act: v2_config.pooler_hidden_act,
290 pooler_hidden_size: v2_config.pooler_hidden_size,
291 layer_norm_eps: v2_config.layer_norm_eps,
292 pad_token_id: v2_config.pad_token_id,
293 relative_attention: v2_config.relative_attention,
294 max_relative_positions: v2_config.max_relative_positions,
295 embedding_size: v2_config.embedding_size,
296 talking_head: v2_config.talking_head,
297 output_hidden_states: v2_config.output_hidden_states,
298 output_attentions: v2_config.output_attentions,
299 classifier_dropout: v2_config.classifier_dropout,
300 is_decoder: v2_config.is_decoder,
301 id2label: v2_config.id2label.clone(),
302 label2id: v2_config.label2id.clone(),
303 share_att_key: v2_config.share_att_key,
304 position_buckets: v2_config.position_buckets,
305 }
306 }
307}
308
309/// # DeBERTa V2 Base model
310/// Base architecture for DeBERTa V2 models. Task-specific models will be built from this common base model
311/// It is made of the following blocks:
312/// - `embeddings`: `DeBERTa` V2 embeddings
313/// - `encoder`: `DeBERTaV2Encoder` (transformer) made of a vector of layers.
314pub struct DebertaV2Model {
315 embeddings: DebertaV2Embeddings,
316 encoder: DebertaV2Encoder,
317}
318
319impl DebertaV2Model {
320 /// Build a new `DebertaV2Model`
321 ///
322 /// # Arguments
323 ///
324 /// * `p` - Variable store path for the root of the BERT model
325 /// * `config` - `DebertaV2Config` object defining the model architecture and decoder status
326 ///
327 /// # Example
328 ///
329 /// ```no_run
330 /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2Model};
331 /// use rust_bert::Config;
332 /// use std::path::Path;
333 /// use tch::{nn, Device};
334 ///
335 /// let config_path = Path::new("path/to/config.json");
336 /// let device = Device::Cpu;
337 /// let p = nn::VarStore::new(device);
338 /// let config = DebertaV2Config::from_file(config_path);
339 /// let model: DebertaV2Model = DebertaV2Model::new(&p.root() / "deberta", &config);
340 /// ```
341 pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2Model
342 where
343 P: Borrow<nn::Path<'p>>,
344 {
345 let p = p.borrow();
346
347 let embeddings = DebertaV2Embeddings::new(p / "embeddings", &config.into());
348 let encoder = DebertaV2Encoder::new(p / "encoder", config);
349
350 DebertaV2Model {
351 embeddings,
352 encoder,
353 }
354 }
355
356 /// Forward pass through the model
357 ///
358 /// # Arguments
359 ///
360 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
361 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
362 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
363 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
364 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
365 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
366 ///
367 /// # Returns
368 ///
369 /// * `DebertaV2Output` containing:
370 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
371 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
372 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
373 ///
374 /// # Example
375 ///
376 /// ```no_run
377 /// # use rust_bert::deberta_v2::{DebertaV2Model, DebertaV2Config};
378 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
379 /// # use rust_bert::Config;
380 /// # use std::path::Path;
381 /// # let config_path = Path::new("path/to/config.json");
382 /// # let device = Device::Cpu;
383 /// # let vs = nn::VarStore::new(device);
384 /// # let config = DebertaV2Config::from_file(config_path);
385 /// # let model = DebertaV2Model::new(&vs.root(), &config);
386 /// let (batch_size, sequence_length) = (64, 128);
387 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
388 /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Kind::Int64, device));
389 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
390 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
391 /// .expand(&[batch_size, sequence_length], true);
392 ///
393 /// let model_output = no_grad(|| {
394 /// model
395 /// .forward_t(
396 /// Some(&input_tensor),
397 /// Some(&attention_mask),
398 /// Some(&token_type_ids),
399 /// Some(&position_ids),
400 /// None,
401 /// false,
402 /// )
403 /// .unwrap()
404 /// });
405 /// ```
406 pub fn forward_t(
407 &self,
408 input_ids: Option<&Tensor>,
409 attention_mask: Option<&Tensor>,
410 token_type_ids: Option<&Tensor>,
411 position_ids: Option<&Tensor>,
412 input_embeds: Option<&Tensor>,
413 train: bool,
414 ) -> Result<DebertaV2ModelOutput, RustBertError> {
415 let (input_shape, device) =
416 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
417
418 let calc_attention_mask = if attention_mask.is_none() {
419 Some(Tensor::ones(input_shape.as_slice(), (Kind::Bool, device)))
420 } else {
421 None
422 };
423
424 let attention_mask =
425 attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
426
427 let embedding_output = self.embeddings.forward_t(
428 input_ids,
429 token_type_ids,
430 position_ids,
431 attention_mask,
432 input_embeds,
433 train,
434 )?;
435
436 let encoder_output =
437 self.encoder
438 .forward_t(&embedding_output, attention_mask, None, None, train)?;
439
440 Ok(encoder_output)
441 }
442}
443
444/// # DeBERTa V2 for masked language model
445/// Base DeBERTa V2 model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
446/// It is made of the following blocks:
447/// - `deberta`: Base DeBERTa V2 model
448/// - `cls`: LM prediction head
449pub struct DebertaV2ForMaskedLM {
450 deberta: DebertaV2Model,
451 cls: DebertaLMPredictionHead,
452}
453
454impl DebertaV2ForMaskedLM {
455 /// Build a new `DebertaV2ForMaskedLM`
456 ///
457 /// # Arguments
458 ///
459 /// * `p` - Variable store path for the root of the BertForMaskedLM model
460 /// * `config` - `DebertaConfig` object defining the model architecture and vocab size
461 ///
462 /// # Example
463 ///
464 /// ```no_run
465 /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForMaskedLM};
466 /// use rust_bert::Config;
467 /// use std::path::Path;
468 /// use tch::{nn, Device};
469 ///
470 /// let config_path = Path::new("path/to/config.json");
471 /// let device = Device::Cpu;
472 /// let p = nn::VarStore::new(device);
473 /// let config = DebertaV2Config::from_file(config_path);
474 /// let model = DebertaV2ForMaskedLM::new(&p.root(), &config);
475 /// ```
476 pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForMaskedLM
477 where
478 P: Borrow<nn::Path<'p>>,
479 {
480 let p = p.borrow();
481
482 let deberta = DebertaV2Model::new(p / "deberta", config);
483 let cls =
484 DebertaLMPredictionHead::new(p.sub("cls").sub("predictions"), &config.into(), false);
485
486 DebertaV2ForMaskedLM { deberta, cls }
487 }
488
489 /// Forward pass through the model
490 ///
491 /// # Arguments
492 ///
493 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
494 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
495 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
496 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
497 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
498 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
499 ///
500 /// # Returns
501 ///
502 /// * `DebertaMaskedLMOutput` containing:
503 /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
504 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
505 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
506 ///
507 /// # Example
508 ///
509 /// ```no_run
510 /// # use rust_bert::deberta_v2::{DebertaV2ForMaskedLM, DebertaV2Config};
511 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
512 /// # use rust_bert::Config;
513 /// # use std::path::Path;
514 /// # let config_path = Path::new("path/to/config.json");
515 /// # let device = Device::Cpu;
516 /// # let vs = nn::VarStore::new(device);
517 /// # let config = DebertaV2Config::from_file(config_path);
518 /// # let model = DebertaV2ForMaskedLM::new(&vs.root(), &config);
519 /// let (batch_size, sequence_length) = (64, 128);
520 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
521 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
522 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
523 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
524 /// .expand(&[batch_size, sequence_length], true);
525 ///
526 /// let model_output = no_grad(|| {
527 /// model.forward_t(
528 /// Some(&input_tensor),
529 /// Some(&mask),
530 /// Some(&token_type_ids),
531 /// Some(&position_ids),
532 /// None,
533 /// false,
534 /// )
535 /// });
536 /// ```
537 pub fn forward_t(
538 &self,
539 input_ids: Option<&Tensor>,
540 attention_mask: Option<&Tensor>,
541 token_type_ids: Option<&Tensor>,
542 position_ids: Option<&Tensor>,
543 input_embeds: Option<&Tensor>,
544 train: bool,
545 ) -> Result<DebertaV2MaskedLMOutput, RustBertError> {
546 let model_outputs = self.deberta.forward_t(
547 input_ids,
548 attention_mask,
549 token_type_ids,
550 position_ids,
551 input_embeds,
552 train,
553 )?;
554
555 let logits = model_outputs.hidden_state.apply(&self.cls);
556 Ok(DebertaV2MaskedLMOutput {
557 logits,
558 all_hidden_states: model_outputs.all_hidden_states,
559 all_attentions: model_outputs.all_attentions,
560 })
561 }
562}
563
564/// # DeBERTa V2 for sequence classification
565/// Base DeBERTa V2 model with a classifier head to perform sentence or document-level classification
566/// It is made of the following blocks:
567/// - `deberta`: Base Deberta (V2) Model
568/// - `classifier`: BERT linear layer for classification
569pub struct DebertaV2ForSequenceClassification {
570 deberta: DebertaV2Model,
571 pooler: ContextPooler,
572 classifier: nn::Linear,
573 dropout: XDropout,
574}
575
576impl DebertaV2ForSequenceClassification {
577 /// Build a new `DebertaV2ForSequenceClassification`
578 ///
579 /// # Arguments
580 ///
581 /// * `p` - Variable store path for the root of the DebertaForSequenceClassification model
582 /// * `config` - `DebertaV2Config` object defining the model architecture and number of classes
583 ///
584 /// # Example
585 ///
586 /// ```no_run
587 /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForSequenceClassification};
588 /// use rust_bert::Config;
589 /// use std::path::Path;
590 /// use tch::{nn, Device};
591 ///
592 /// let config_path = Path::new("path/to/config.json");
593 /// let device = Device::Cpu;
594 /// let p = nn::VarStore::new(device);
595 /// let config = DebertaV2Config::from_file(config_path);
596 /// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config).unwrap();
597 /// ```
598 pub fn new<'p, P>(
599 p: P,
600 config: &DebertaV2Config,
601 ) -> Result<DebertaV2ForSequenceClassification, RustBertError>
602 where
603 P: Borrow<nn::Path<'p>>,
604 {
605 let p = p.borrow();
606
607 let deberta = DebertaV2Model::new(p / "deberta", config);
608 let pooler = ContextPooler::new(p / "pooler", &config.into());
609 let dropout = XDropout::new(
610 config
611 .classifier_dropout
612 .unwrap_or(config.hidden_dropout_prob),
613 );
614
615 let num_labels = config
616 .id2label
617 .as_ref()
618 .ok_or_else(|| {
619 RustBertError::InvalidConfigurationError(
620 "num_labels not provided in configuration".to_string(),
621 )
622 })?
623 .len() as i64;
624
625 let classifier = nn::linear(
626 p / "classifier",
627 pooler.output_dim,
628 num_labels,
629 Default::default(),
630 );
631
632 Ok(DebertaV2ForSequenceClassification {
633 deberta,
634 pooler,
635 classifier,
636 dropout,
637 })
638 }
639
640 /// Forward pass through the model
641 ///
642 /// # Arguments
643 ///
644 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
645 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
646 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
647 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
648 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
649 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
650 ///
651 /// # Returns
652 ///
653 /// * `DebertaV2SequenceClassificationOutput` containing:
654 /// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
655 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
656 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
657 ///
658 /// # Example
659 ///
660 /// ```no_run
661 /// # use rust_bert::deberta_v2::{DebertaV2ForSequenceClassification, DebertaV2Config};
662 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
663 /// # use rust_bert::Config;
664 /// # use std::path::Path;
665 /// # let config_path = Path::new("path/to/config.json");
666 /// # let device = Device::Cpu;
667 /// # let vs = nn::VarStore::new(device);
668 /// # let config = DebertaV2Config::from_file(config_path);
669 /// # let model = DebertaV2ForSequenceClassification::new(&vs.root(), &config).unwrap();;
670 /// let (batch_size, sequence_length) = (64, 128);
671 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
672 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
673 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
674 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
675 /// .expand(&[batch_size, sequence_length], true);
676 ///
677 /// let model_output = no_grad(|| {
678 /// model.forward_t(
679 /// Some(&input_tensor),
680 /// Some(&mask),
681 /// Some(&token_type_ids),
682 /// Some(&position_ids),
683 /// None,
684 /// false,
685 /// )
686 /// });
687 /// ```
688 pub fn forward_t(
689 &self,
690 input_ids: Option<&Tensor>,
691 attention_mask: Option<&Tensor>,
692 token_type_ids: Option<&Tensor>,
693 position_ids: Option<&Tensor>,
694 input_embeds: Option<&Tensor>,
695 train: bool,
696 ) -> Result<DebertaV2SequenceClassificationOutput, RustBertError> {
697 let base_model_output = self.deberta.forward_t(
698 input_ids,
699 attention_mask,
700 token_type_ids,
701 position_ids,
702 input_embeds,
703 train,
704 )?;
705
706 let logits = base_model_output
707 .hidden_state
708 .apply_t(&self.pooler, train)
709 .apply_t(&self.dropout, train)
710 .apply(&self.classifier);
711
712 Ok(DebertaV2SequenceClassificationOutput {
713 logits,
714 all_hidden_states: base_model_output.all_hidden_states,
715 all_attentions: base_model_output.all_attentions,
716 })
717 }
718}
719
720/// # DeBERTa V2 for token classification (e.g. NER, POS)
721/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
722/// not necessarily aligned with words in the sentence.
723/// It is made of the following blocks:
724/// - `deberta`: Base DeBERTa (V2) model
725/// - `dropout`: Dropout layer before the last token-level predictions layer
726/// - `classifier`: Linear layer for token classification
727pub struct DebertaV2ForTokenClassification {
728 deberta: DebertaV2Model,
729 dropout: Dropout,
730 classifier: nn::Linear,
731}
732
733impl DebertaV2ForTokenClassification {
734 /// Build a new `DebertaV2ForTokenClassification`
735 ///
736 /// # Arguments
737 ///
738 /// * `p` - Variable store path for the root of the Deberta V2 model
739 /// * `config` - `DebertaV2Config` object defining the model architecture
740 ///
741 /// # Example
742 ///
743 /// ```no_run
744 /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForTokenClassification};
745 /// use rust_bert::Config;
746 /// use std::path::Path;
747 /// use tch::{nn, Device};
748 ///
749 /// let config_path = Path::new("path/to/config.json");
750 /// let device = Device::Cpu;
751 /// let p = nn::VarStore::new(device);
752 /// let config = DebertaV2Config::from_file(config_path);
753 /// let model = DebertaV2ForTokenClassification::new(&p.root(), &config);
754 /// ```
755 pub fn new<'p, P>(
756 p: P,
757 config: &DebertaV2Config,
758 ) -> Result<DebertaV2ForTokenClassification, RustBertError>
759 where
760 P: Borrow<nn::Path<'p>>,
761 {
762 let p = p.borrow();
763
764 let deberta = DebertaV2Model::new(p / "deberta", config);
765 let dropout = Dropout::new(config.hidden_dropout_prob);
766 let num_labels = config
767 .id2label
768 .as_ref()
769 .ok_or_else(|| {
770 RustBertError::InvalidConfigurationError(
771 "num_labels not provided in configuration".to_string(),
772 )
773 })?
774 .len() as i64;
775 let classifier = nn::linear(
776 p / "classifier",
777 config.hidden_size,
778 num_labels,
779 Default::default(),
780 );
781
782 Ok(DebertaV2ForTokenClassification {
783 deberta,
784 dropout,
785 classifier,
786 })
787 }
788
789 /// Forward pass through the model
790 ///
791 /// # Arguments
792 ///
793 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
794 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
795 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
796 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
797 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
798 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
799 ///
800 /// # Returns
801 ///
802 /// * `DebertaV2TokenClassificationOutput` containing:
803 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*)
804 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
805 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
806 ///
807 /// # Example
808 ///
809 /// ```no_run
810 /// # use rust_bert::deberta_v2::{DebertaV2ForTokenClassification, DebertaV2Config};
811 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
812 /// # use rust_bert::Config;
813 /// # use std::path::Path;
814 /// # let config_path = Path::new("path/to/config.json");
815 /// # let device = Device::Cpu;
816 /// # let vs = nn::VarStore::new(device);
817 /// # let config = DebertaV2Config::from_file(config_path);
818 /// # let model = DebertaV2ForTokenClassification::new(&vs.root(), &config).unwrap();
819 /// let (batch_size, sequence_length) = (64, 128);
820 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
821 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
822 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
823 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
824 /// .expand(&[batch_size, sequence_length], true);
825 ///
826 /// let model_output = no_grad(|| {
827 /// model.forward_t(
828 /// Some(&input_tensor),
829 /// Some(&mask),
830 /// Some(&token_type_ids),
831 /// Some(&position_ids),
832 /// None,
833 /// false,
834 /// )
835 /// });
836 /// ```
837 pub fn forward_t(
838 &self,
839 input_ids: Option<&Tensor>,
840 attention_mask: Option<&Tensor>,
841 token_type_ids: Option<&Tensor>,
842 position_ids: Option<&Tensor>,
843 input_embeds: Option<&Tensor>,
844 train: bool,
845 ) -> Result<DebertaV2TokenClassificationOutput, RustBertError> {
846 let base_model_output = self.deberta.forward_t(
847 input_ids,
848 attention_mask,
849 token_type_ids,
850 position_ids,
851 input_embeds,
852 train,
853 )?;
854
855 let logits = base_model_output
856 .hidden_state
857 .apply_t(&self.dropout, train)
858 .apply(&self.classifier);
859
860 Ok(DebertaV2TokenClassificationOutput {
861 logits,
862 all_hidden_states: base_model_output.all_hidden_states,
863 all_attentions: base_model_output.all_attentions,
864 })
865 }
866}
867
868/// # DeBERTa V2 for question answering
869/// Extractive question-answering model based on a DeBERTa V2 language model. Identifies the segment of a context that answers a provided question.
870/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
871/// See the question answering pipeline (also provided in this crate) for more details.
872/// It is made of the following blocks:
873/// - `deberta`: Base DeBERTa V2 model
874/// - `qa_outputs`: Linear layer for question answering
875pub struct DebertaV2ForQuestionAnswering {
876 deberta: DebertaV2Model,
877 qa_outputs: nn::Linear,
878}
879
880impl DebertaV2ForQuestionAnswering {
881 /// Build a new `DebertaV2ForQuestionAnswering`
882 ///
883 /// # Arguments
884 ///
885 /// * `p` - Variable store path for the root of the BertForQuestionAnswering model
886 /// * `config` - `DebertaV2Config` object defining the model architecture
887 ///
888 /// # Example
889 ///
890 /// ```no_run
891 /// use rust_bert::deberta_v2::{DebertaV2Config, DebertaV2ForQuestionAnswering};
892 /// use rust_bert::Config;
893 /// use std::path::Path;
894 /// use tch::{nn, Device};
895 ///
896 /// let config_path = Path::new("path/to/config.json");
897 /// let device = Device::Cpu;
898 /// let p = nn::VarStore::new(device);
899 /// let config = DebertaV2Config::from_file(config_path);
900 /// let model = DebertaV2ForQuestionAnswering::new(&p.root(), &config);
901 /// ```
902 pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForQuestionAnswering
903 where
904 P: Borrow<nn::Path<'p>>,
905 {
906 let p = p.borrow();
907
908 let deberta = DebertaV2Model::new(p / "deberta", config);
909 let num_labels = 2;
910 let qa_outputs = nn::linear(
911 p / "qa_outputs",
912 config.hidden_size,
913 num_labels,
914 Default::default(),
915 );
916
917 DebertaV2ForQuestionAnswering {
918 deberta,
919 qa_outputs,
920 }
921 }
922
923 /// Forward pass through the model
924 ///
925 /// # Arguments
926 ///
927 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
928 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
929 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
930 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
931 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
932 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
933 ///
934 /// # Returns
935 ///
936 /// * `DebertaQuestionAnsweringOutput` containing:
937 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
938 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
939 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
940 /// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
941 ///
942 /// # Example
943 ///
944 /// ```no_run
945 /// # use rust_bert::deberta_v2::{DebertaV2ForQuestionAnswering, DebertaV2Config};
946 /// # use tch::{nn, Device, Tensor, no_grad};
947 /// # use rust_bert::Config;
948 /// # use std::path::Path;
949 /// # use tch::kind::Kind::Int64;
950 /// # let config_path = Path::new("path/to/config.json");
951 /// # let device = Device::Cpu;
952 /// # let vs = nn::VarStore::new(device);
953 /// # let config = DebertaV2Config::from_file(config_path);
954 /// # let model = DebertaV2ForQuestionAnswering::new(&vs.root(), &config);
955 /// let (batch_size, sequence_length) = (64, 128);
956 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
957 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
958 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
959 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
960 /// .expand(&[batch_size, sequence_length], true);
961 ///
962 /// let model_output = no_grad(|| {
963 /// model.forward_t(
964 /// Some(&input_tensor),
965 /// Some(&mask),
966 /// Some(&token_type_ids),
967 /// Some(&position_ids),
968 /// None,
969 /// false,
970 /// )
971 /// });
972 /// ```
973 pub fn forward_t(
974 &self,
975 input_ids: Option<&Tensor>,
976 attention_mask: Option<&Tensor>,
977 token_type_ids: Option<&Tensor>,
978 position_ids: Option<&Tensor>,
979 input_embeds: Option<&Tensor>,
980 train: bool,
981 ) -> Result<DebertaV2QuestionAnsweringOutput, RustBertError> {
982 let base_model_output = self.deberta.forward_t(
983 input_ids,
984 attention_mask,
985 token_type_ids,
986 position_ids,
987 input_embeds,
988 train,
989 )?;
990
991 let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
992 let logits = sequence_output.split(1, -1);
993 let (start_logits, end_logits) = (&logits[0], &logits[1]);
994 let start_logits = start_logits.squeeze_dim(-1);
995 let end_logits = end_logits.squeeze_dim(-1);
996
997 Ok(DebertaV2QuestionAnsweringOutput {
998 start_logits,
999 end_logits,
1000 all_hidden_states: base_model_output.all_hidden_states,
1001 all_attentions: base_model_output.all_attentions,
1002 })
1003 }
1004}
1005
1006/// Container for the DeBERTa V2 model output.
1007pub type DebertaV2ModelOutput = DebertaModelOutput;
1008
1009/// Container for the DeBERTa V2masked LM model output.
1010pub type DebertaV2MaskedLMOutput = DebertaMaskedLMOutput;
1011
1012/// Container for the DeBERTa sequence classification model output.
1013pub type DebertaV2SequenceClassificationOutput = DebertaSequenceClassificationOutput;
1014
1015/// Container for the DeBERTa token classification model output.
1016pub type DebertaV2TokenClassificationOutput = DebertaTokenClassificationOutput;
1017
1018/// Container for the DeBERTa question answering model output.
1019pub type DebertaV2QuestionAnsweringOutput = DebertaQuestionAnsweringOutput;