syntaxdot_transformers/models/
layer_output.rs1use crate::TransformerError;
2use tch::Tensor;
3
4#[derive(Debug)]
6pub struct HiddenLayer {
7 pub output: Tensor,
9
10 pub attention: Tensor,
12}
13
14#[derive(Debug)]
16pub enum LayerOutput {
17 Embedding(Tensor),
19
20 EncoderWithAttention(HiddenLayer),
22}
23
24impl LayerOutput {
25 pub fn attention(&self) -> Option<&Tensor> {
30 match self {
31 LayerOutput::Embedding(_) => None,
32 LayerOutput::EncoderWithAttention(hidden) => Some(&hidden.attention),
33 }
34 }
35
36 pub fn embedding(&self) -> Option<&Tensor> {
41 match self {
42 LayerOutput::Embedding(embedding) => Some(embedding),
43 LayerOutput::EncoderWithAttention(_) => None,
44 }
45 }
46
47 pub fn map_output<F>(&self, f: F) -> Result<Self, TransformerError>
49 where
50 F: Fn(&Tensor) -> Result<Tensor, TransformerError>,
51 {
52 let layer = match self {
53 LayerOutput::Embedding(embedding) => LayerOutput::Embedding(f(embedding)?),
54 LayerOutput::EncoderWithAttention(HiddenLayer { output, attention }) => {
55 LayerOutput::EncoderWithAttention(HiddenLayer {
56 output: f(output)?,
57 attention: attention.shallow_clone(),
58 })
59 }
60 };
61
62 Ok(layer)
63 }
64
65 pub fn output(&self) -> &Tensor {
67 match self {
68 LayerOutput::Embedding(embedding) => embedding,
69 LayerOutput::EncoderWithAttention(hidden) => &hidden.output,
70 }
71 }
72
73 pub fn output_mut(&mut self) -> &mut Tensor {
75 match self {
76 LayerOutput::Embedding(embedding) => embedding,
77 LayerOutput::EncoderWithAttention(hidden) => &mut hidden.output,
78 }
79 }
80
81 pub fn encoder_with_attention(&self) -> Option<&HiddenLayer> {
86 match self {
87 LayerOutput::Embedding(_) => None,
88 LayerOutput::EncoderWithAttention(hidden) => Some(hidden),
89 }
90 }
91}