Trait tch::nn::RNN[][src]

pub trait RNN {
    type State;
    fn zero_state(&self, batch_dim: i64) -> Self::State;
fn step(&self, input: &Tensor, state: &Self::State) -> Self::State;
fn seq_init(
        &self,
        input: &Tensor,
        state: &Self::State
    ) -> (Tensor, Self::State); fn seq(&self, input: &Tensor) -> (Tensor, Self::State) { ... } }

Trait for Recurrent Neural Networks.

Associated Types

Loading content...

Required methods

fn zero_state(&self, batch_dim: i64) -> Self::State[src]

A zero state from which the recurrent network is usually initialized.

fn step(&self, input: &Tensor, state: &Self::State) -> Self::State[src]

Applies a single step of the recurrent network.

The input should have dimensions [batch_size, features].

fn seq_init(&self, input: &Tensor, state: &Self::State) -> (Tensor, Self::State)[src]

Applies multiple steps of the recurrent network.

The input should have dimensions [batch_size, seq_len, features].

Loading content...

Provided methods

fn seq(&self, input: &Tensor) -> (Tensor, Self::State)[src]

Applies multiple steps of the recurrent network.

The input should have dimensions [batch_size, seq_len, features]. The initial state is the result of applying zero_state.

Loading content...

Implementors

impl RNN for GRU[src]

type State = GRUState

impl RNN for LSTM[src]

type State = LSTMState

Loading content...