Expand description
§DeBERTa V2 (He et al.)
Implementation of the DeBERTa V2/V3 language model (DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing He, Gao, Chen, 2021).
The base model is implemented in the deberta_v2_model::DebertaV2Model
struct. Several language model heads have also been implemented, including:
- Question answering:
deberta_v2_model::DebertaV2ForQuestionAnswering
- Sequence classification:
deberta_v2_model::DebertaV2ForSequenceClassification
- Token classification (e.g. NER, POS tagging):
deberta_v2_model::DebertaV2ForTokenClassification
.
§Model set-up and pre-trained weights loading
All models expect the following resources:
- Configuration file expected to have a structure following the Transformers library
- Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the
.bin
weights to the.ot
format. DebertaV2Tokenizer
using aspiece.model
SentencePiece model file
Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
use tch::{nn, Device};
use rust_bert::deberta_v2::{
DebertaV2Config, DebertaV2ConfigResources, DebertaV2ForSequenceClassification,
DebertaV2ModelResources, DebertaV2VocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::DeBERTaV2Tokenizer;
let config_resource =
RemoteResource::from_pretrained(DebertaV2ConfigResources::DEBERTA_V3_BASE);
let vocab_resource = RemoteResource::from_pretrained(DebertaV2VocabResources::DEBERTA_V3_BASE);
let weights_resource =
RemoteResource::from_pretrained(DebertaV2ModelResources::DEBERTA_V3_BASE);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer =
DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?;
let config = DebertaV2Config::from_file(config_path);
let deberta_model = DebertaV2ForSequenceClassification::new(&vs.root(), &config);
vs.load(weights_path)?;
Structs§
- Deberta
V2Config - DeBERTa (v2) model configuration
- Deberta
V2Config Resources - DeBERTaV2 Pretrained model config files
- Deberta
V2For MaskedLM - DeBERTa V2 for masked language model
- Deberta
V2For Question Answering - DeBERTa V2 for question answering
- Deberta
V2For Sequence Classification - DeBERTa V2 for sequence classification
- Deberta
V2For Token Classification - DeBERTa V2 for token classification (e.g. NER, POS)
- Deberta
V2Model - DeBERTa V2 Base model
- Deberta
V2Model Resources - DeBERTaV2 Pretrained model weight files
- Deberta
V2Vocab Resources - DeBERTaV2 Pretrained model vocab files
Type Aliases§
- Deberta
V2Question Answering Output - Container for the DeBERTa question answering model output.
- Deberta
V2Sequence Classification Output - Container for the DeBERTa sequence classification model output.
- Deberta
V2Token Classification Output - Container for the DeBERTa token classification model output.