rust_bert/models/fnet/mod.rs
1//! # FNet, Mixing Tokens with Fourier Transforms (Lee-Thorp et al.)
2//!
3//! Implementation of the FNet language model ([https://arxiv.org/abs/2105.03824](https://arxiv.org/abs/2105.03824) Lee-Thorp, Ainslie, Eckstein, Ontanon, 2021).
4//! The base model is implemented in the `fnet_model::FNetModel` struct. Several language model heads have also been implemented, including:
5//! - Masked language model: `fnet_model::FNetForMaskedLM`
6//! - Question answering: `fnet_model::FNetForQuestionAnswering`
7//! - Sequence classification: `fnet_model::FNetForSequenceClassification`
8//! - Token classification (e.g. NER, POS tagging): `fnet_model::FNetForTokenClassification`
9//!
10//! # Model set-up and pre-trained weights loading
11//!
12//! The example below illustrate a FNet Masked language model example, the structure is similar for other models.
13//! All models expect the following resources:
14//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
15//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
16//! - `FNetTokenizer` using a `spiece.model` SentencePiece (BPE) model file
17//!
18//! Pretrained models are available and can be downloaded using RemoteResources.
19//!
20//! ```no_run
21//! # fn main() -> anyhow::Result<()> {
22//! #
23//! use tch::{nn, Device};
24//! # use std::path::PathBuf;
25//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
26//! use rust_bert::resources::{LocalResource, ResourceProvider};
27//! use rust_bert::Config;
28//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
29//!
30//! let config_resource = LocalResource {
31//! local_path: PathBuf::from("path/to/config.json"),
32//! };
33//! let vocab_resource = LocalResource {
34//! local_path: PathBuf::from("path/to/spiece.model"),
35//! };
36//! let weights_resource = LocalResource {
37//! local_path: PathBuf::from("path/to/model.ot"),
38//! };
39//! let config_path = config_resource.get_local_path()?;
40//! let vocab_path = vocab_resource.get_local_path()?;
41//! let weights_path = weights_resource.get_local_path()?;
42//! let device = Device::cuda_if_available();
43//! let mut vs = nn::VarStore::new(device);
44//! let tokenizer: FNetTokenizer =
45//! FNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
46//! let config = FNetConfig::from_file(config_path);
47//! let bert_model = FNetForMaskedLM::new(&vs.root(), &config);
48//! vs.load(weights_path)?;
49//!
50//! # Ok(())
51//! # }
52//! ```
53
54mod attention;
55mod embeddings;
56mod encoder;
57mod fnet_model;
58
59pub use fnet_model::{
60 FNetConfig, FNetConfigResources, FNetForMaskedLM, FNetForMultipleChoice,
61 FNetForQuestionAnswering, FNetForSequenceClassification, FNetForTokenClassification,
62 FNetMaskedLMOutput, FNetModel, FNetModelOutput, FNetModelResources,
63 FNetQuestionAnsweringOutput, FNetSequenceClassificationOutput, FNetTokenClassificationOutput,
64 FNetVocabResources,
65};