Module deberta

Source
Expand description

§DeBERTa :Decoding-enhanced BERT with Disentangled Attention (He et al.)

Implementation of the DeBERTa language model (DeBERTa :Decoding-enhanced BERT with Disentangled Attention He, Liu ,Gao, Chen, 2021). The base model is implemented in the deberta_model::DebertaModel struct. Several language model heads have also been implemented, including:

  • Question answering: deberta_model::DebertaForQuestionAnswering
  • Sequence classification: deberta_model::DebertaForSequenceClassification
  • Token classification (e.g. NER, POS tagging): deberta_model::DebertaForTokenClassification.

§Model set-up and pre-trained weights loading

All models expect the following resources:

  • Configuration file expected to have a structure following the Transformers library
  • Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the .bin weights to the .ot format.
  • DebertaTokenizer using a vocab.json vocabulary and merges.txt merges file

Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.

use tch::{nn, Device};
use rust_bert::deberta::{
    DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
    DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::DeBERTaTokenizer;

let config_resource =
    RemoteResource::from_pretrained(DebertaConfigResources::DEBERTA_BASE_MNLI);
let vocab_resource = RemoteResource::from_pretrained(DebertaVocabResources::DEBERTA_BASE_MNLI);
let merges_resource =
    RemoteResource::from_pretrained(DebertaMergesResources::DEBERTA_BASE_MNLI);
let weights_resource =
    RemoteResource::from_pretrained(DebertaModelResources::DEBERTA_BASE_MNLI);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = DeBERTaTokenizer::from_file(
    vocab_path.to_str().unwrap(),
    merges_path.to_str().unwrap(),
    true,
)?;
let config = DebertaConfig::from_file(config_path);
let deberta_model = DebertaForSequenceClassification::new(&vs.root(), &config);
vs.load(weights_path)?;

Structs§

DebertaConfig
DeBERTa model configuration
DebertaConfigResources
DeBERTa Pretrained model config files
DebertaForMaskedLM
DeBERTa for masked language model
DebertaForQuestionAnswering
DeBERTa for question answering
DebertaForSequenceClassification
DeBERTa for sequence classification
DebertaForTokenClassification
DeBERTa for token classification (e.g. NER, POS)
DebertaMaskedLMOutput
Container for the DeBERTa masked LM model output.
DebertaMergesResources
DeBERTa Pretrained model merges files
DebertaModel
DeBERTa Base model
DebertaModelResources
DeBERTa Pretrained model weight files
DebertaVocabResources
DeBERTa Pretrained model vocab files

Type Aliases§

DebertaQuestionAnsweringOutput
Container for the DeBERTa question answering model output.
DebertaSequenceClassificationOutput
Container for the DeBERTa sequence classification model output.
DebertaTokenClassificationOutput
Container for the DeBERTa token classification model output.