[−][src]Crate rust_bert
Ready-to-use NLP pipelines and Transformer-based models
Rust native Transformer-based models implementation. Port of the Transformers library, using the tch-rs crate and pre-processing from rust-tokenizers. Supports multithreaded tokenization and GPU inference. This repository exposes the model base architecture, task-specific heads (see below) and ready-to-use pipelines.
Quick Start
This crate can be used in two different ways:
- Ready-to-use NLP pipelines for:
- Translation
- Summarization
- Sentiment Analysis
- Named Entity Recognition
- Question-Answering
- Language Generation.
More information on these can be found in the pipelines
module
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel}; let qa_model = QuestionAnsweringModel::new(Default::default())?; let question = String::from("Where does Amy live ?"); let context = String::from("Amy lives in Amsterdam"); let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
- Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
DistilBERT | BERT | RoBERTa | GPT | GPT2 | BART | Electra | Marian | ALBERT | |
---|---|---|---|---|---|---|---|---|---|
Masked LM | ✅ | ✅ | ✅ | ✅ | ✅ | ||||
Sequence classification | ✅ | ✅ | ✅ | ✅ | |||||
Token classification | ✅ | ✅ | ✅ | ✅ | ✅ | ||||
Question answering | ✅ | ✅ | ✅ | ✅ | |||||
Multiple choices | ✅ | ✅ | ✅ | ||||||
Next token prediction | ✅ | ✅ | |||||||
Natural Language Generation | ✅ | ✅ | |||||||
Summarization | ✅ | ||||||||
Translation | ✅ |
Loading pre-trained models
A number of pretrained model configuration, weights and vocabulary are downloaded directly from Huggingface's model repository. The list of models available with Rust-compatible weights is available in the example ./examples/download_all_dependencies.rs. Additional models can be added if of interest, please raise an issue.
In order to load custom weights to the library, these need to be converter to a binary format that can be read by Libtorch (the original .bin
files are pickles and cannot be used directly).
Several Python scripts to load Pytorch weights and convert them to the appropriate format are provided and can be adapted based on the model needs.
The procedure for building custom weights or re-building pretrained weights is as follows:
- Compile the package: cargo build --release
- Download the model files & perform necessary conversions
- Set-up a virtual environment and install dependencies
- run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
- Run the example cargo run --release
Modules
albert | ALBERT: A Lite BERT for Self-supervised Learning of Language Representations (Lan et al.) |
bart | BART (Lewis et al.) |
bert | BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al.) |
distilbert | DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter (Sanh et al.) |
electra | Electra: Pre-training Text Encoders as Discriminators Rather Than Generators (Clark et al.) |
gpt2 | GPT2 (Radford et al.) |
marian | Marian |
openai_gpt | GPT (Radford et al.) |
pipelines | Ready-to-use NLP pipelines and models |
resources | Resource definitions for model weights, vocabularies and configuration files |
roberta | RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al.) |
Traits
Config | Utility to deserialize JSON config files |