rust_bert/models/gpt_j/
mod.rs

1//! # GPT-J
2//!
3//! Implementation of the GPT-J language model
4//!
5//! # Model set-up and pre-trained weights loading
6//!
7//! ```no_run
8//! # fn main() -> anyhow::Result<()> {
9//! #
10//! use tch::{nn, Device};
11//! # use std::path::PathBuf;
12//! use rust_bert::gpt_j::{GptJConfig, GptJLMHeadModel};
13//! use rust_bert::resources::{LocalResource, ResourceProvider};
14//! use rust_bert::Config;
15//! use rust_tokenizers::tokenizer::Gpt2Tokenizer;
16//!
17//! let config_resource = LocalResource {
18//!     local_path: PathBuf::from("path/to/config.json"),
19//! };
20//! let vocab_resource = LocalResource {
21//!     local_path: PathBuf::from("path/to/vocab.txt"),
22//! };
23//! let merges_resource = LocalResource {
24//!     local_path: PathBuf::from("path/to/vocab.txt"),
25//! };
26//! let weights_resource = LocalResource {
27//!     local_path: PathBuf::from("path/to/model.ot"),
28//! };
29//! let config_path = config_resource.get_local_path()?;
30//! let vocab_path = vocab_resource.get_local_path()?;
31//! let merges_path = merges_resource.get_local_path()?;
32//! let weights_path = weights_resource.get_local_path()?;
33//!
34//! let device = Device::cuda_if_available();
35//! let mut vs = nn::VarStore::new(device);
36//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
37//!     vocab_path.to_str().unwrap(),
38//!     merges_path.to_str().unwrap(),
39//!     true,
40//! )?;
41//! let config = GptJConfig::from_file(config_path);
42//! let gpt_j_model = GptJLMHeadModel::new(&vs.root(), &config);
43//! vs.load(weights_path)?;
44//!
45//! # Ok(())
46//! # }
47//! ```
48
49mod attention;
50mod gpt_j_model;
51mod transformer;
52
53pub use gpt_j_model::{
54    GptJConfig, GptJConfigResources, GptJGenerator, GptJLMHeadModel, GptJMergesResources,
55    GptJModel, GptJModelOutput, GptJModelResources, GptJVocabResources,
56};
57
58pub use attention::LayerState;