rust_bert/models/xlnet/mod.rs
1//! # XLNet (Generalized Autoregressive Pretraining for Language Understanding)
2//!
3//! Implementation of the XLNet language model ([Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) Yang, Dai, Yang, Carbonell, Salakhutdinov, Le, 2019).
4//! The base model is implemented in the `xlnet_model::XLNetModel` struct. Several language model heads have also been implemented, including:
5//! - Language generation: `xlnet_model::XLNetLMHeadModel` implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information)
6//! - Multiple choices: `xlnet_model:XLNetForMultipleChoice`
7//! - Question answering: `xlnet_model::XLNetForQuestionAnswering`
8//! - Sequence classification: `xlnet_model::XLNetForSequenceClassification`
9//! - Token classification (e.g. NER, POS tagging): `xlnet::XLNetForTokenClassification`.
10//!
11//! # Model set-up and pre-trained weights loading
12//!
13//! A full working example (generation) is provided in `examples/generation_xlnet`, run with `cargo run --example generation_xlnet`.
14//! All models expect the following resources:
15//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
16//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
17//! - `XLNetTokenizer` using a `spiece.model` sentence piece model
18//!
19//!
20//! ```no_run
21//! # fn main() -> anyhow::Result<()> {
22//! use rust_bert::pipelines::common::{ModelResource, ModelType};
23//! use rust_bert::pipelines::generation_utils::LanguageGenerator;
24//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
25//! use rust_bert::resources::RemoteResource;
26//! use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
27//! let config_resource = Box::new(RemoteResource::from_pretrained(
28//! XLNetConfigResources::XLNET_BASE_CASED,
29//! ));
30//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
31//! XLNetVocabResources::XLNET_BASE_CASED,
32//! ));
33//! let model_resource = Box::new(RemoteResource::from_pretrained(
34//! XLNetModelResources::XLNET_BASE_CASED,
35//! ));
36//! let generate_config = TextGenerationConfig {
37//! model_type: ModelType::XLNet,
38//! model_resource: ModelResource::Torch(model_resource),
39//! config_resource,
40//! vocab_resource,
41//! merges_resource: None,
42//! max_length: Some(56),
43//! do_sample: true,
44//! num_beams: 3,
45//! temperature: 1.0,
46//! num_return_sequences: 1,
47//! ..Default::default()
48//! };
49//! let model = TextGenerationModel::new(generate_config)?;
50//! let input_context = "Once upon a time,";
51//! let output = model.generate(&[input_context], None);
52//!
53//! # Ok(())
54//! # }
55//! ```
56
57mod attention;
58mod encoder;
59mod xlnet_model;
60
61pub use attention::LayerState;
62pub use xlnet_model::{
63 XLNetConfig, XLNetConfigResources, XLNetForMultipleChoice, XLNetForQuestionAnswering,
64 XLNetForSequenceClassification, XLNetForTokenClassification, XLNetGenerator, XLNetLMHeadModel,
65 XLNetModel, XLNetModelResources, XLNetVocabResources,
66};