Module deberta_v2

Source
Expand description

§DeBERTa V2 (He et al.)

Implementation of the DeBERTa V2/V3 language model (DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing He, Gao, Chen, 2021). The base model is implemented in the deberta_v2_model::DebertaV2Model struct. Several language model heads have also been implemented, including:

  • Question answering: deberta_v2_model::DebertaV2ForQuestionAnswering
  • Sequence classification: deberta_v2_model::DebertaV2ForSequenceClassification
  • Token classification (e.g. NER, POS tagging): deberta_v2_model::DebertaV2ForTokenClassification.

§Model set-up and pre-trained weights loading

All models expect the following resources:

  • Configuration file expected to have a structure following the Transformers library
  • Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the .bin weights to the .ot format.
  • DebertaV2Tokenizer using a spiece.model SentencePiece model file

Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.

use tch::{nn, Device};
use rust_bert::deberta_v2::{
    DebertaV2Config, DebertaV2ConfigResources, DebertaV2ForSequenceClassification,
    DebertaV2ModelResources, DebertaV2VocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::DeBERTaV2Tokenizer;

let config_resource =
    RemoteResource::from_pretrained(DebertaV2ConfigResources::DEBERTA_V3_BASE);
let vocab_resource = RemoteResource::from_pretrained(DebertaV2VocabResources::DEBERTA_V3_BASE);
let weights_resource =
    RemoteResource::from_pretrained(DebertaV2ModelResources::DEBERTA_V3_BASE);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer =
    DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?;
let config = DebertaV2Config::from_file(config_path);
let deberta_model = DebertaV2ForSequenceClassification::new(&vs.root(), &config);
vs.load(weights_path)?;

Structs§

DebertaV2Config
DeBERTa (v2) model configuration
DebertaV2ConfigResources
DeBERTaV2 Pretrained model config files
DebertaV2ForMaskedLM
DeBERTa V2 for masked language model
DebertaV2ForQuestionAnswering
DeBERTa V2 for question answering
DebertaV2ForSequenceClassification
DeBERTa V2 for sequence classification
DebertaV2ForTokenClassification
DeBERTa V2 for token classification (e.g. NER, POS)
DebertaV2Model
DeBERTa V2 Base model
DebertaV2ModelResources
DeBERTaV2 Pretrained model weight files
DebertaV2VocabResources
DeBERTaV2 Pretrained model vocab files

Type Aliases§

DebertaV2QuestionAnsweringOutput
Container for the DeBERTa question answering model output.
DebertaV2SequenceClassificationOutput
Container for the DeBERTa sequence classification model output.
DebertaV2TokenClassificationOutput
Container for the DeBERTa token classification model output.