syntaxdot_transformers/models/encoder.rs
1use tch::Tensor;
2
3use crate::models::layer_output::LayerOutput;
4use crate::TransformerError;
5
6/// Encoder networks.
7pub trait Encoder {
8 /// Apply the encoder.
9 ///
10 /// Returns the output and attention per layer. The (optional)
11 /// attention mask of shape `[batch_size, time_steps]` indicates
12 /// which tokens should be included (`true`) and excluded (`false`) from
13 /// attention. This can be used to mask inactive timesteps.
14 fn encode(
15 &self,
16 input: &Tensor,
17 attention_mask: Option<&Tensor>,
18 train: bool,
19 ) -> Result<Vec<LayerOutput>, TransformerError>;
20
21 /// Get the number of layers that is returned by the encoder.
22 fn n_layers(&self) -> i64;
23}