Expand description
§GPT-Neo
Implementation of the GPT-Neo language model (The Pile: An 800GB Dataset of Diverse Text for Language Modeling Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others, 2020).
The base model is implemented in the gpt_neo_model::GptNeoModel
struct. A causal language modeling head is implemented in gpt_neo_model::GptNeoForCausalLM
§Model set-up and pre-trained weights loading
A full working example is provided in examples/generation_gpt_neo
, run with cargo run --example generation_gpt_neo
.
All models expect the following resources:
- Configuration file expected to have a structure following the Transformers library
- Model weights are expected to have a structure and parameter names following the Transformers library. A conversion using the Python utility scripts is required to convert the
.bin
weights to the.ot
format. GPT2Tokenizer
using avocab.json
vocabulary and amerges.txt
merges file
The following pre-trained checkpoints are readily available:
- 125M parameters model (GptNeoModelResources::GPT_NEO_125M)
- 1.3B parameters model (GptNeoModelResources::GPT_NEO_1_3B)
- 2.7B parameters model (GptNeoModelResources::GPT_NEO_2_7B)
use rust_bert::gpt_neo::{
GptNeoConfigResources, GptNeoMergesResources, GptNeoModelResources, GptNeoVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
use rust_bert::pipelines::common::ModelResource;
let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_1_3B,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_1_3B,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_1_3B,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_1_3B,
));
let text_generation_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
num_beams: 4,
no_repeat_ngram_size: 3,
device: Device::cuda_if_available(),
..Default::default()
};
let model = TextGenerationModel::new(text_generation_config)?;
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None)?;
for sentence in output {
println!("{}", sentence);
}
Ok(())
}
Structs§
- GptNeo
Config - GPT-Neo model configuration
- GptNeo
Config Resources - GPT-Neo Pretrained model config files
- GptNeo
ForCausalLM - GPT-Neo Model for causal language modeling
- GptNeo
Generator - Language generation model based on the GPT-Neo architecture
- GptNeo
Merges Resources - GPT-Neo Pretrained model merges files
- GptNeo
Model - GPT-Neo Base model
- GptNeo
Model Resources - GPT-Neo Pretrained model weight files
- GptNeo
Vocab Resources - GPT-Neo Pretrained model vocab files
- Layer
State - Cache for GPT-Neo attention layers