1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
//! # XLNet (Generalized Autoregressive Pretraining for Language Understanding)
//!
//! Implementation of the XLNet language model ([Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) Yang, Dai, Yang, Carbonell, Salakhutdinov, Le, 2019).
//! The base model is implemented in the `xlnet_model::XLNetModel` struct. Several language model heads have also been implemented, including:
//! - Language generation: `xlnet_model::XLNetLMHeadModel` implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information)
//! - Multiple choices: `xlnet_model:XLNetForMultipleChoice`
//! - Question answering: `xlnet_model::XLNetForQuestionAnswering`
//! - Sequence classification: `xlnet_model::XLNetForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `xlnet::XLNetForTokenClassification`.
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (generation) is provided in `examples/generation_xlnet`, run with `cargo run --example generation_xlnet`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `XLNetTokenizer` using a `spiece.model` sentence piece model
//!
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::{ModelResource, ModelType};
//! use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::RemoteResource;
//! use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
//! let config_resource = Box::new(RemoteResource::from_pretrained(
//!     XLNetConfigResources::XLNET_BASE_CASED,
//! ));
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//!     XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let model_resource = Box::new(RemoteResource::from_pretrained(
//!     XLNetModelResources::XLNET_BASE_CASED,
//! ));
//! let generate_config = TextGenerationConfig {
//!     model_type: ModelType::XLNet,
//!     model_resource: ModelResource::Torch(model_resource),
//!     config_resource,
//!     vocab_resource,
//!     merges_resource: None,
//!     max_length: Some(56),
//!     do_sample: true,
//!     num_beams: 3,
//!     temperature: 1.0,
//!     num_return_sequences: 1,
//!     ..Default::default()
//! };
//! let model = TextGenerationModel::new(generate_config)?;
//! let input_context = "Once upon a time,";
//! let output = model.generate(&[input_context], None);
//!
//! # Ok(())
//! # }
//! ```

mod attention;
mod encoder;
mod xlnet_model;

pub use attention::LayerState;
pub use xlnet_model::{
    XLNetConfig, XLNetConfigResources, XLNetForMultipleChoice, XLNetForQuestionAnswering,
    XLNetForSequenceClassification, XLNetForTokenClassification, XLNetGenerator, XLNetLMHeadModel,
    XLNetModel, XLNetModelResources, XLNetVocabResources,
};