Struct rust_bert::models::gpt_neo::GptNeoModel
source · pub struct GptNeoModel { /* private fields */ }
Expand description
GPT-Neo Base model
Base architecture for GPT-Neo models. Task-specific models will be built from this common base model It is made of the following blocks:
word_embeddings
: Word embeddingsposition_embeddings
: Position embeddingslayers
: Vector ofGptNeoBlock
(transformer part of the model)
Implementations§
source§impl GptNeoModel
impl GptNeoModel
sourcepub fn new<'p, P>(
p: P,
config: &GptNeoConfig
) -> Result<GptNeoModel, RustBertError>
pub fn new<'p, P>( p: P, config: &GptNeoConfig ) -> Result<GptNeoModel, RustBertError>
Build a new GptNeoModel
Arguments
p
- Variable store path for the root of the GPT-Neo modelconfig
-GptNeoConfig
object defining the model architecture
Example
use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
use rust_bert::Config;
use std::path::Path;
use tch::{nn, Device};
let config_path = Path::new("path/to/config.json");
let device = Device::Cpu;
let p = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let gpt_neo_model = GptNeoModel::new(&p.root(), &config).unwrap();
sourcepub fn forward_t(
&self,
input_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
layer_states: Option<Vec<Option<LayerState>>>,
attention_mask: Option<&Tensor>,
train: bool
) -> Result<GptNeoModelOutput, RustBertError>
pub fn forward_t( &self, input_ids: Option<&Tensor>, input_embeds: Option<&Tensor>, token_type_ids: Option<&Tensor>, position_ids: Option<&Tensor>, layer_states: Option<Vec<Option<LayerState>>>, attention_mask: Option<&Tensor>, train: bool ) -> Result<GptNeoModelOutput, RustBertError>
Forward pass through the model
Arguments
input_ids
- Optional input tensor of shape (batch size, sequence_length). This orinput_embeds
must be provided.input_embeds
- Optional input tensor of shape (batch size, sequence_length, embeddings dimension). This orinput_ids
must be provided.token_type_ids
- Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.position_ids
- Optional position ids of shape (batch size, sequence_length). If None, will be incremented starting from the length of the past input.layer_states
- Optional VectorOption<Vec<Option<&LayerState>>>
of length n_layer containing tuples with the past keys and values for both the self attention of each layer.attention_mask
- Optional attention mask of shape (batch size, sequence_length) for the encoder positions. Positions with a mask with value 0 will be masked.train
- boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
Returns
Result<GptNeoModelOutput, RustBertError>
containing:hidden_states
-Tensor
of shape (batch size, sequence_length, hidden_size) representing the activations of the last hidden statenext_cache
-Option<Vec<Option<LayerState>>>
of length n_layer containing the past content for the the attention layersall_hidden_states
-Option<Vec<Tensor>>
of length n_layer + 1 with shape (batch size, sequence_length, hidden_size)all_attentions
-Option<Vec<Tensor>>
of length n_layer containing the attention weights for each layer
Example
use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
let (batch_size, sequence_length) = (64, 128);
let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
let model_output = no_grad(|| {
gpt_neo_model.forward_t(
Some(&input_tensor),
Some(&attention_mask),
None,
None,
None,
None,
false,
)
});
Auto Trait Implementations§
impl RefUnwindSafe for GptNeoModel
impl Send for GptNeoModel
impl !Sync for GptNeoModel
impl Unpin for GptNeoModel
impl UnwindSafe for GptNeoModel
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more