pub struct GptNeoModel { /* private fields */ }
Expand description
§GPT-Neo Base model
Base architecture for GPT-Neo models. Task-specific models will be built from this common base model It is made of the following blocks:
word_embeddings
: Word embeddingsposition_embeddings
: Position embeddingslayers
: Vector ofGptNeoBlock
(transformer part of the model)
Implementations§
Source§impl GptNeoModel
impl GptNeoModel
Sourcepub fn new<'p, P>(
p: P,
config: &GptNeoConfig,
) -> Result<GptNeoModel, RustBertError>
pub fn new<'p, P>( p: P, config: &GptNeoConfig, ) -> Result<GptNeoModel, RustBertError>
Build a new GptNeoModel
§Arguments
p
- Variable store path for the root of the GPT-Neo modelconfig
-GptNeoConfig
object defining the model architecture
§Example
use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
use rust_bert::Config;
use std::path::Path;
use tch::{nn, Device};
let config_path = Path::new("path/to/config.json");
let device = Device::Cpu;
let p = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let gpt_neo_model = GptNeoModel::new(&p.root(), &config).unwrap();
Sourcepub fn forward_t(
&self,
input_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
layer_states: Option<Vec<Option<LayerState>>>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<GptNeoModelOutput, RustBertError>
pub fn forward_t( &self, input_ids: Option<&Tensor>, input_embeds: Option<&Tensor>, token_type_ids: Option<&Tensor>, position_ids: Option<&Tensor>, layer_states: Option<Vec<Option<LayerState>>>, attention_mask: Option<&Tensor>, train: bool, ) -> Result<GptNeoModelOutput, RustBertError>
Forward pass through the model
§Arguments
input_ids
- Optional input tensor of shape (batch size, sequence_length). This orinput_embeds
must be provided.input_embeds
- Optional input tensor of shape (batch size, sequence_length, embeddings dimension). This orinput_ids
must be provided.token_type_ids
- Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.position_ids
- Optional position ids of shape (batch size, sequence_length). If None, will be incremented starting from the length of the past input.layer_states
- Optional VectorOption<Vec<Option<&LayerState>>>
of length n_layer containing tuples with the past keys and values for both the self attention of each layer.attention_mask
- Optional attention mask of shape (batch size, sequence_length) for the encoder positions. Positions with a mask with value 0 will be masked.train
- boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
§Returns
Result<GptNeoModelOutput, RustBertError>
containing:hidden_states
-Tensor
of shape (batch size, sequence_length, hidden_size) representing the activations of the last hidden statenext_cache
-Option<Vec<Option<LayerState>>>
of length n_layer containing the past content for the the attention layersall_hidden_states
-Option<Vec<Tensor>>
of length n_layer + 1 with shape (batch size, sequence_length, hidden_size)all_attentions
-Option<Vec<Tensor>>
of length n_layer containing the attention weights for each layer
§Example
use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
let (batch_size, sequence_length) = (64, 128);
let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
let model_output = no_grad(|| {
gpt_neo_model.forward_t(
Some(&input_tensor),
Some(&attention_mask),
None,
None,
None,
None,
false,
)
});
Auto Trait Implementations§
impl Freeze for GptNeoModel
impl RefUnwindSafe for GptNeoModel
impl Send for GptNeoModel
impl !Sync for GptNeoModel
impl Unpin for GptNeoModel
impl UnwindSafe for GptNeoModel
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more