rust_bert/models/deberta_v2/
mod.rs

1//! # DeBERTa V2 (He et al.)
2//!
3//! Implementation of the DeBERTa V2/V3 language model ([DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing](https://arxiv.org/abs/2111.09543) He, Gao, Chen, 2021).
4//! The base model is implemented in the `deberta_v2_model::DebertaV2Model` struct. Several language model heads have also been implemented, including:
5//! - Question answering: `deberta_v2_model::DebertaV2ForQuestionAnswering`
6//! - Sequence classification: `deberta_v2_model::DebertaV2ForSequenceClassification`
7//! - Token classification (e.g. NER, POS tagging): `deberta_v2_model::DebertaV2ForTokenClassification`.
8//!
9//! # Model set-up and pre-trained weights loading
10//!
11//! All models expect the following resources:
12//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
13//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
14//! - `DebertaV2Tokenizer` using a `spiece.model` SentencePiece model file
15//!
16//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
17//!
18//! ```no_run
19//! # fn main() -> anyhow::Result<()> {
20//! #
21//! use tch::{nn, Device};
22//! # use std::path::PathBuf;
23//! use rust_bert::deberta_v2::{
24//!     DebertaV2Config, DebertaV2ConfigResources, DebertaV2ForSequenceClassification,
25//!     DebertaV2ModelResources, DebertaV2VocabResources,
26//! };
27//! use rust_bert::resources::{RemoteResource, ResourceProvider};
28//! use rust_bert::Config;
29//! use rust_tokenizers::tokenizer::DeBERTaV2Tokenizer;
30//!
31//! let config_resource =
32//!     RemoteResource::from_pretrained(DebertaV2ConfigResources::DEBERTA_V3_BASE);
33//! let vocab_resource = RemoteResource::from_pretrained(DebertaV2VocabResources::DEBERTA_V3_BASE);
34//! let weights_resource =
35//!     RemoteResource::from_pretrained(DebertaV2ModelResources::DEBERTA_V3_BASE);
36//! let config_path = config_resource.get_local_path()?;
37//! let vocab_path = vocab_resource.get_local_path()?;
38//! let weights_path = weights_resource.get_local_path()?;
39//! let device = Device::cuda_if_available();
40//! let mut vs = nn::VarStore::new(device);
41//! let tokenizer =
42//!     DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?;
43//! let config = DebertaV2Config::from_file(config_path);
44//! let deberta_model = DebertaV2ForSequenceClassification::new(&vs.root(), &config);
45//! vs.load(weights_path)?;
46//!
47//! # Ok(())
48//! # }
49//! ```
50
51mod attention;
52mod deberta_v2_model;
53mod embeddings;
54mod encoder;
55
56pub use deberta_v2_model::{
57    DebertaV2Config, DebertaV2ConfigResources, DebertaV2ForMaskedLM, DebertaV2ForQuestionAnswering,
58    DebertaV2ForSequenceClassification, DebertaV2ForTokenClassification, DebertaV2Model,
59    DebertaV2ModelResources, DebertaV2QuestionAnsweringOutput,
60    DebertaV2SequenceClassificationOutput, DebertaV2TokenClassificationOutput,
61    DebertaV2VocabResources,
62};