Struct rust_bert::bert::BertEncoder [−][src]
BERT Encoder
Encoder used in BERT models.
It is made of a Vector of BertLayer through which hidden states will be passed. The encoder can also be
used as a decoder (with cross-attention) if encoder_hidden_states are provided.
Implementations
impl BertEncoder[src]
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertEncoder where
P: Borrow<Path<'p>>, [src]
P: Borrow<Path<'p>>,
Build a new BertEncoder
Arguments
p- Variable store path for the root of the BERT modelconfig-BertConfigobject defining the model architecture
Example
use rust_bert::bert::{BertConfig, BertEncoder}; use rust_bert::Config; use std::path::Path; use tch::{nn, Device}; let config_path = Path::new("path/to/config.json"); let device = Device::Cpu; let p = nn::VarStore::new(device); let config = BertConfig::from_file(config_path); let encoder: BertEncoder = BertEncoder::new(&p.root(), &config);
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool
) -> BertEncoderOutput[src]
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool
) -> BertEncoderOutput
Forward pass through the encoder
Arguments
hidden_states- input tensor of shape (batch size, sequence_length, hidden_size).mask- Optional mask of shape (batch size, sequence_length). Masked position have value 0, non-masked value 1. If None set to 1encoder_hidden_states- Optional encoder hidden state of shape (batch size, encoder_sequence_length, hidden_size). If the model is defined as a decoder and theencoder_hidden_statesis not None, used in the cross-attention layer as keys and values (query from the decoder).encoder_mask- Optional encoder attention mask of shape (batch size, encoder_sequence_length). If the model is defined as a decoder and theencoder_hidden_statesis not None, used to mask encoder values. Positions with value 0 will be masked.train- boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
Returns
BertEncoderOutputcontaining:hidden_state-Tensorof shape (batch size, sequence_length, hidden_size)all_hidden_states-Option<Vec<Tensor>>of length num_hidden_layers with shape (batch size, sequence_length, hidden_size)all_attentions-Option<Vec<Tensor>>of length num_hidden_layers with shape (batch size, sequence_length, hidden_size)
Example
let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config); let (batch_size, sequence_length, hidden_size) = (64, 128, 512); let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Float, device)); let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); let encoder_output = no_grad(|| encoder.forward_t(&input_tensor, &Some(mask), &None, &None, false));
Auto Trait Implementations
impl RefUnwindSafe for BertEncoder
impl Send for BertEncoder
impl !Sync for BertEncoder
impl Unpin for BertEncoder
impl UnwindSafe for BertEncoder
Blanket Implementations
impl<T> Any for T where
T: 'static + ?Sized, [src]
T: 'static + ?Sized,
impl<T> Borrow<T> for T where
T: ?Sized, [src]
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized, [src]
T: ?Sized,
pub fn borrow_mut(&mut self) -> &mut T[src]
impl<T> From<T> for T[src]
impl<T> Instrument for T[src]
pub fn instrument(self, span: Span) -> Instrumented<Self>[src]
pub fn in_current_span(self) -> Instrumented<Self>[src]
impl<T, U> Into<U> for T where
U: From<T>, [src]
U: From<T>,
impl<T> Pointable for T
pub const ALIGN: usize
type Init = T
The type for initializers.
pub unsafe fn init(init: <T as Pointable>::Init) -> usize
pub unsafe fn deref<'a>(ptr: usize) -> &'a T
pub unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T
pub unsafe fn drop(ptr: usize)
impl<T> Same<T> for T
type Output = T
Should always be Self
impl<T, U> TryFrom<U> for T where
U: Into<T>, [src]
U: Into<T>,
type Error = Infallible
The type returned in the event of a conversion error.
pub fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>[src]
impl<T, U> TryInto<U> for T where
U: TryFrom<T>, [src]
U: TryFrom<T>,
type Error = <U as TryFrom<T>>::Error
The type returned in the event of a conversion error.
pub fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>[src]
impl<V, T> VZip<V> for T where
V: MultiLane<T>,
V: MultiLane<T>,