1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
use tch::Tensor;

use crate::models::layer_output::LayerOutput;
use crate::TransformerError;

/// Encoder networks.
pub trait Encoder {
    /// Apply the encoder.
    ///
    /// Returns the output and attention per layer. The (optional)
    /// attention mask of shape `[batch_size, time_steps]` indicates
    /// which tokens should be included (`true`) and excluded (`false`) from
    /// attention. This can be used to mask inactive timesteps.
    fn encode(
        &self,
        input: &Tensor,
        attention_mask: Option<&Tensor>,
        train: bool,
    ) -> Result<Vec<LayerOutput>, TransformerError>;

    /// Get the number of layers that is returned by the encoder.
    fn n_layers(&self) -> i64;
}