Module rust_bert::deberta_v2
source · Expand description
DeBERTa V2 (He et al.)
Implementation of the DeBERTa V2/V3 language model (DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing He, Gao, Chen, 2021).
The base model is implemented in the deberta_v2_model::DebertaV2Model struct. Several language model heads have also been implemented, including:
- Question answering:
deberta_v2_model::DebertaV2ForQuestionAnswering - Sequence classification:
deberta_v2_model::DebertaV2ForSequenceClassification - Token classification (e.g. NER, POS tagging):
deberta_v2_model::DebertaV2ForTokenClassification.
Model set-up and pre-trained weights loading
All models expect the following resources:
- Configuration file expected to have a structure following the Transformers library
- Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the
.binweights to the.otformat. DebertaV2Tokenizerusing aspiece.modelSentencePiece model file Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
use tch::{nn, Device};
use rust_bert::deberta_v2::{
DebertaV2Config, DebertaV2ConfigResources, DebertaV2ForSequenceClassification,
DebertaV2ModelResources, DebertaV2VocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::DeBERTaV2Tokenizer;
let config_resource =
RemoteResource::from_pretrained(DebertaV2ConfigResources::DEBERTA_V3_BASE);
let vocab_resource = RemoteResource::from_pretrained(DebertaV2VocabResources::DEBERTA_V3_BASE);
let weights_resource =
RemoteResource::from_pretrained(DebertaV2ModelResources::DEBERTA_V3_BASE);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer =
DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?;
let config = DebertaV2Config::from_file(config_path);
let deberta_model = DebertaV2ForSequenceClassification::new(&vs.root(), &config);
vs.load(weights_path)?;
Structs
DeBERTa (v2) model configuration
DeBERTaV2 Pretrained model config files
DeBERTa V2 for masked language model
DeBERTa V2 for question answering
DeBERTa V2 for sequence classification
DeBERTa V2 for token classification (e.g. NER, POS)
DeBERTa V2 Base model
DeBERTaV2 Pretrained model weight files
DeBERTaV2 Pretrained model vocab files
Type Definitions
Container for the DeBERTa question answering model output.
Container for the DeBERTa sequence classification model output.
Container for the DeBERTa token classification model output.