basic_transformer/
basic_decoder.rs

1//! Basic Transformer Decoder block using public API and example modules
2
3use train_station::{
4    optimizers::{Adam, Optimizer},
5    Tensor,
6};
7
8#[path = "basic_linear_layer.rs"]
9mod basic_linear_layer;
10use basic_linear_layer::LinearLayer;
11
12#[allow(clippy::duplicate_mod)]
13#[path = "multi_head_attention.rs"]
14mod multi_head_attention;
15use multi_head_attention::MultiHeadAttention;
16
17pub struct DecoderBlock {
18    pub _embed_dim: usize,
19    pub _num_heads: usize,
20    self_attn: MultiHeadAttention,
21    cross_attn: MultiHeadAttention,
22    ffn_in: LinearLayer,
23    ffn_out: LinearLayer,
24}
25
26impl DecoderBlock {
27    pub fn new(embed_dim: usize, num_heads: usize, seed: Option<u64>) -> Self {
28        let s0 = seed;
29        let s1 = s0.map(|s| s + 1);
30        let s2 = s0.map(|s| s + 2);
31        let s3 = s0.map(|s| s + 3);
32        Self {
33            _embed_dim: embed_dim,
34            _num_heads: num_heads,
35            self_attn: MultiHeadAttention::new(embed_dim, num_heads, s0),
36            cross_attn: MultiHeadAttention::new(embed_dim, num_heads, s1),
37            ffn_in: LinearLayer::new(embed_dim, embed_dim * 2, s2),
38            ffn_out: LinearLayer::new(embed_dim * 2, embed_dim, s3),
39        }
40    }
41
42    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
43        let mut params = Vec::new();
44        params.extend(self.self_attn.parameters());
45        params.extend(self.cross_attn.parameters());
46        params.extend(self.ffn_in.parameters());
47        params.extend(self.ffn_out.parameters());
48        params
49    }
50
51    /// Forward pass
52    /// tgt: [batch, tgt_len, embed]
53    /// memory: [batch, src_len, embed] (encoder outputs)
54    /// causal_mask: mask broadcastable to [batch, heads, tgt_len, tgt_len] (true=keep, false=masked)
55    /// cross_mask:  optional mask broadcastable to [batch, heads, tgt_len, src_len]
56    pub fn forward(
57        &self,
58        tgt: &Tensor,
59        memory: &Tensor,
60        causal_mask: Option<&Tensor>,
61        cross_mask: Option<&Tensor>,
62    ) -> Tensor {
63        let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64        let res1 = self_attn.add_tensor(tgt);
65
66        let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67        let res2 = cross.add_tensor(&res1);
68
69        let (b, t, e) = Self::triple(tgt);
70        let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71        let hidden = self.ffn_in.forward(&x2d).relu();
72        let out2d = self.ffn_out.forward(&hidden);
73        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74        out.add_tensor(&res2)
75    }
76
77    fn triple(t: &Tensor) -> (usize, usize, usize) {
78        let d = t.shape().dims();
79        (d[0], d[1], d[2])
80    }
81}
82
83#[allow(unused)]
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85    println!("=== Basic Decoder Example ===");
86
87    let batch = 2usize;
88    let src = 7usize;
89    let tgt = 5usize;
90    let embed = 32usize;
91    let heads = 4usize;
92
93    let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94    let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96    let mut dec = DecoderBlock::new(embed, heads, Some(456));
97    let out = dec.forward(&tgt_in, &memory, None, None);
98    println!("Output shape: {:?}", out.shape().dims());
99
100    let mut opt = Adam::with_learning_rate(0.01);
101    let mut params = dec.parameters();
102    for p in &params {
103        opt.add_parameter(p);
104    }
105    let mut loss = out.mean();
106    loss.backward(None);
107    opt.step(&mut params);
108    opt.zero_grad(&mut params);
109    println!("Loss: {:.6}", loss.value());
110    println!("=== Done ===");
111    Ok(())
112}