Module rust_bert::mobilebert

source ·
Expand description

MobileBERT (A Compact Task-agnostic BERT for Resource-Limited Devices)

Implementation of the MobileBERT language model (MobileBERT: A Compact Task-agnostic BERT for Resource-Limited Devices Sun, Yu, Song, Liu, Yang, Zhou, 2020). The base model is implemented in the mobilebert_model::MobileBertModel struct. Several language model heads have also been implemented, including:

  • Multiple choices: mobilebert_model:MobileBertForMultipleChoice
  • Question answering: mobilebert_model::MobileBertForQuestionAnswering
  • Sequence classification: mobilebert_model::MobileBertForSequenceClassification
  • Token classification (e.g. NER, POS tagging): mobilebert_model::MobileBertForTokenClassification.

Model set-up and pre-trained weights loading

All models expect the following resources:

  • Configuration file expected to have a structure following the Transformers library
  • Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the .bin weights to the .ot format.
  • BertTokenizer using a vocab.txt vocabulary Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
use tch::{nn, Device};
use rust_bert::mobilebert::{
    MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM,
    MobileBertModelResources, MobileBertVocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::BertTokenizer;

let config_resource =
    RemoteResource::from_pretrained(MobileBertConfigResources::MOBILEBERT_UNCASED);
let vocab_resource =
    RemoteResource::from_pretrained(MobileBertVocabResources::MOBILEBERT_UNCASED);
let weights_resource =
    RemoteResource::from_pretrained(MobileBertModelResources::MOBILEBERT_UNCASED);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
    BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = MobileBertConfig::from_file(config_path);
let bert_model = MobileBertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;

Structs

MobileBERT model configuration
MobileBERT Pretrained model config files
MobileBERT for masked language model
MobileBERT for multiple choices
MobileBERT for question answering
MobileBERT for sequence classification
MobileBERT for token classification (e.g. NER, POS)
MobileBertModel Base model
MobileBERT Pretrained model weight files
MobileBERT Pretrained model vocab files
No-normalization option for MobileBERT

Enums

Normalization type to use for the MobileBERT model.