Module rust_bert::mobilebert[][src]

MobileBERT (A Compact Task-agnostic BERT for Resource-Limited Devices)

Implementation of the MobileBERT language model (MobileBERT: A Compact Task-agnostic BERT for Resource-Limited Devices Sun, Yu, Song, Liu, Yang, Zhou, 2020). The base model is implemented in the mobilebert_model::MobileBertModel struct. Several language model heads have also been implemented, including:

  • Multiple choices: mobilebert_model:MobileBertForMultipleChoice
  • Question answering: mobilebert_model::MobileBertForQuestionAnswering
  • Sequence classification: mobilebert_model::MobileBertForSequenceClassification
  • Token classification (e.g. NER, POS tagging): mobilebert_model::MobileBertForTokenClassification.

Model set-up and pre-trained weights loading

A full working example (generation) is provided in examples/mobilebert_masked_lm, run with cargo run --example mobilebert_masked_lm. All models expect the following resources:

  • Configuration file expected to have a structure following the Transformers library
  • Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the .bin weights to the .ot format.
  • BertTokenizer using a vocab.txt vocabulary Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
use tch::{nn, Device};
use rust_bert::mobilebert::{
    MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM,
    MobileBertModelResources, MobileBertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::BertTokenizer;

let config_resource = Resource::Remote(RemoteResource::from_pretrained(
    MobileBertConfigResources::MOBILEBERT_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
    MobileBertVocabResources::MOBILEBERT_UNCASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
    MobileBertModelResources::MOBILEBERT_UNCASED,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
    BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = MobileBertConfig::from_file(config_path);
let bert_model = MobileBertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;

Structs

MobileBertConfig

MobileBERT model configuration

MobileBertConfigResources

MobileBERT Pretrained model config files

MobileBertForMaskedLM

MobileBERT for masked language model

MobileBertForMultipleChoice

MobileBERT for multiple choices

MobileBertForQuestionAnswering

MobileBERT for question answering

MobileBertForSequenceClassification

MobileBERT for sequence classification

MobileBertForTokenClassification

MobileBERT for token classification (e.g. NER, POS)

MobileBertModel

MobileBertModel Base model

MobileBertModelResources

MobileBERT Pretrained model weight files

MobileBertVocabResources

MobileBERT Pretrained model vocab files

NoNorm

No-normalization option for MobileBERT

Enums

NormalizationType

Normalization type to use for the MobileBERT model.