syntaxdot_transformers/models/
encoder.rs

1use tch::Tensor;
2
3use crate::models::layer_output::LayerOutput;
4use crate::TransformerError;
5
6/// Encoder networks.
7pub trait Encoder {
8    /// Apply the encoder.
9    ///
10    /// Returns the output and attention per layer. The (optional)
11    /// attention mask of shape `[batch_size, time_steps]` indicates
12    /// which tokens should be included (`true`) and excluded (`false`) from
13    /// attention. This can be used to mask inactive timesteps.
14    fn encode(
15        &self,
16        input: &Tensor,
17        attention_mask: Option<&Tensor>,
18        train: bool,
19    ) -> Result<Vec<LayerOutput>, TransformerError>;
20
21    /// Get the number of layers that is returned by the encoder.
22    fn n_layers(&self) -> i64;
23}