syntaxdot_transformers/models/
layer_output.rs

1use crate::TransformerError;
2use tch::Tensor;
3
4/// Hidden layer output and attention.
5#[derive(Debug)]
6pub struct HiddenLayer {
7    /// The output of the layer.
8    pub output: Tensor,
9
10    /// The layer attention scores (unnormalized).
11    pub attention: Tensor,
12}
13
14/// Output of a BERT layer.
15#[derive(Debug)]
16pub enum LayerOutput {
17    /// Embedding layer output.
18    Embedding(Tensor),
19
20    /// Encoder layer output.
21    EncoderWithAttention(HiddenLayer),
22}
23
24impl LayerOutput {
25    /// Get the layer attention.
26    ///
27    /// Return a `Some` value if the layer output is from an encoder layer,
28    /// or `None` otherwise.
29    pub fn attention(&self) -> Option<&Tensor> {
30        match self {
31            LayerOutput::Embedding(_) => None,
32            LayerOutput::EncoderWithAttention(hidden) => Some(&hidden.attention),
33        }
34    }
35
36    /// Get the embedding.
37    ///
38    /// Returns `Some` if the layer output is an embedding or `None`
39    /// otherwise.
40    pub fn embedding(&self) -> Option<&Tensor> {
41        match self {
42            LayerOutput::Embedding(embedding) => Some(embedding),
43            LayerOutput::EncoderWithAttention(_) => None,
44        }
45    }
46
47    /// Map the output representation of this layer.
48    pub fn map_output<F>(&self, f: F) -> Result<Self, TransformerError>
49    where
50        F: Fn(&Tensor) -> Result<Tensor, TransformerError>,
51    {
52        let layer = match self {
53            LayerOutput::Embedding(embedding) => LayerOutput::Embedding(f(embedding)?),
54            LayerOutput::EncoderWithAttention(HiddenLayer { output, attention }) => {
55                LayerOutput::EncoderWithAttention(HiddenLayer {
56                    output: f(output)?,
57                    attention: attention.shallow_clone(),
58                })
59            }
60        };
61
62        Ok(layer)
63    }
64
65    /// Get the layer output.
66    pub fn output(&self) -> &Tensor {
67        match self {
68            LayerOutput::Embedding(embedding) => embedding,
69            LayerOutput::EncoderWithAttention(hidden) => &hidden.output,
70        }
71    }
72
73    /// Get the layer output mutably.
74    pub fn output_mut(&mut self) -> &mut Tensor {
75        match self {
76            LayerOutput::Embedding(embedding) => embedding,
77            LayerOutput::EncoderWithAttention(hidden) => &mut hidden.output,
78        }
79    }
80
81    /// Get the output of an encoder layer.
82    ///
83    /// Return a `Some` value if the layer output is from an encoder layer,
84    /// or `None` otherwise.
85    pub fn encoder_with_attention(&self) -> Option<&HiddenLayer> {
86        match self {
87            LayerOutput::Embedding(_) => None,
88            LayerOutput::EncoderWithAttention(hidden) => Some(hidden),
89        }
90    }
91}