Expand description
§BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al.)
Implementation of the BERT language model (https://arxiv.org/abs/1810.04805 Devlin, Chang, Lee, Toutanova, 2018).
The base model is implemented in the bert_model::BertModel struct. Several language model heads have also been implemented, including:
- Masked language model:
bert_model::BertForMaskedLM - Multiple choices:
bert_model:BertForMultipleChoice - Question answering:
bert_model::BertForQuestionAnswering - Sequence classification:
bert_model::BertForSequenceClassification - Token classification (e.g. NER, POS tagging):
bert_model::BertForTokenClassification
§Model set-up and pre-trained weights loading
A full working example is provided in examples/masked_language_model_bert, run with cargo run --example masked_language_model_bert.
The example below illustrate a Masked language model example, the structure is similar for other models.
All models expect the following resources:
- Configuration file expected to have a structure following the Transformers library
- Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the
.binweights to the.otformat. BertTokenizerusing avocab.txtvocabulary
Pretrained models are available and can be downloaded using RemoteResources.
use tch::{nn, Device};
use rust_bert::bert::{BertConfig, BertForMaskedLM};
use rust_bert::resources::{LocalResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::BertTokenizer;
let config_resource = LocalResource {
local_path: PathBuf::from("path/to/config.json"),
};
let vocab_resource = LocalResource {
local_path: PathBuf::from("path/to/vocab.txt"),
};
let weights_resource = LocalResource {
local_path: PathBuf::from("path/to/model.ot"),
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = BertConfig::from_file(config_path);
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
Structs§
- Bert
Config - BERT model configuration
- Bert
Config Resources - BERT Pretrained model config files
- Bert
Embeddings - BertEmbeddings implementation for BERT model
- Bert
Encoder - BERT Encoder
- Bert
Encoder Output - Container for the BERT encoder output.
- Bert
ForMaskedLM - BERT for masked language model
- Bert
ForMultiple Choice - BERT for multiple choices
- Bert
ForQuestion Answering - BERT for question answering
- Bert
ForSequence Classification - BERT for sequence classification
- Bert
ForToken Classification - BERT for token classification (e.g. NER, POS)
- Bert
Layer - BERT Layer
- Bert
Layer Output - Container for the BERT layer output.
- Bert
MaskedLM Output - Container for the BERT masked LM model output.
- Bert
Model - BERT Base model
- Bert
Model Output - Container for the BERT model output.
- Bert
Model Resources - BERT Pretrained model weight files
- Bert
Pooler - BERT Pooler
- Bert
Question Answering Output - Container for the BERT question answering model output.
- Bert
Sequence Classification Output - Container for the BERT sequence classification model output.
- Bert
Token Classification Output - Container for the BERT token classification model output.
- Bert
Vocab Resources - BERT Pretrained model vocab files
Traits§
- Bert
Embedding - BertEmbedding trait (for use in BertModel or RoBERTaModel)
Type Aliases§
- Bert
ForSentence Embeddings - BERT for sentence embeddings