rust_bert/models/gpt_neo/mod.rs
1//! # GPT-Neo
2//!
3//! Implementation of the GPT-Neo language model ([The Pile: An 800GB Dataset of Diverse Text for Language Modeling](https://arxiv.org/abs/2101.00027) Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others, 2020).
4//! The base model is implemented in the `gpt_neo_model::GptNeoModel` struct. A causal language modeling head is implemented in `gpt_neo_model::GptNeoForCausalLM`
5//!
6//! # Model set-up and pre-trained weights loading
7//!
8//! A full working example is provided in `examples/generation_gpt_neo`, run with `cargo run --example generation_gpt_neo`.
9//! All models expect the following resources:
10//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
11//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
12//! - `GPT2Tokenizer` using a `vocab.json` vocabulary and a `merges.txt` merges file
13//!
14//! The following pre-trained checkpoints are readily available:
15//! - 125M parameters model (GptNeoModelResources::GPT_NEO_125M)
16//! - 1.3B parameters model (GptNeoModelResources::GPT_NEO_1_3B)
17//! - 2.7B parameters model (GptNeoModelResources::GPT_NEO_2_7B)
18//!
19//! ```no_run
20//! use rust_bert::gpt_neo::{
21//! GptNeoConfigResources, GptNeoMergesResources, GptNeoModelResources, GptNeoVocabResources,
22//! };
23//! use rust_bert::pipelines::common::ModelType;
24//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
25//! use rust_bert::resources::RemoteResource;
26//! use tch::Device;
27//!
28//! fn main() -> anyhow::Result<()> {
29//! use rust_bert::pipelines::common::ModelResource;
30//! let config_resource = Box::new(RemoteResource::from_pretrained(
31//! GptNeoConfigResources::GPT_NEO_1_3B,
32//! ));
33//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
34//! GptNeoVocabResources::GPT_NEO_1_3B,
35//! ));
36//! let merges_resource = Box::new(RemoteResource::from_pretrained(
37//! GptNeoMergesResources::GPT_NEO_1_3B,
38//! ));
39//! let model_resource = Box::new(RemoteResource::from_pretrained(
40//! GptNeoModelResources::GPT_NEO_1_3B,
41//! ));
42//!
43//! let text_generation_config = TextGenerationConfig {
44//! model_type: ModelType::GPTNeo,
45//! model_resource: ModelResource::Torch(model_resource),
46//! config_resource,
47//! vocab_resource,
48//! merges_resource: Some(merges_resource),
49//! num_beams: 4,
50//! no_repeat_ngram_size: 3,
51//! device: Device::cuda_if_available(),
52//! ..Default::default()
53//! };
54//! let model = TextGenerationModel::new(text_generation_config)?;
55//!
56//! let input_context_1 = "It was a very nice and sunny";
57//! let input_context_2 = "It was a gloom winter night, and";
58//! let output = model.generate(&[input_context_1, input_context_2], None)?;
59//!
60//! for sentence in output {
61//! println!("{}", sentence);
62//! }
63//!
64//! Ok(())
65//! }
66//! ```
67
68mod attention;
69mod decoder;
70mod gpt_neo_model;
71
72pub use gpt_neo_model::{
73 GptNeoConfig, GptNeoConfigResources, GptNeoForCausalLM, GptNeoGenerator, GptNeoMergesResources,
74 GptNeoModel, GptNeoModelResources, GptNeoVocabResources,
75};
76
77pub use attention::LayerState;