basic_transformer/
basic_encoder.rs

1//! Basic Transformer Encoder block using public API and example LinearLayer
2
3use train_station::{
4    optimizers::{Adam, Optimizer},
5    Tensor,
6};
7
8// Reuse LinearLayer from the local example file
9#[allow(clippy::duplicate_mod)]
10#[path = "basic_linear_layer.rs"]
11mod basic_linear_layer;
12use basic_linear_layer::LinearLayer;
13
14// Reuse the MHA example module (no duplication)
15#[allow(clippy::duplicate_mod)]
16#[path = "multi_head_attention.rs"]
17mod multi_head_attention;
18use multi_head_attention::MultiHeadAttention;
19
20pub struct EncoderBlock {
21    pub _embed_dim: usize,
22    pub _num_heads: usize,
23    mha: MultiHeadAttention,
24    ffn_in: LinearLayer,
25    ffn_out: LinearLayer,
26}
27
28impl EncoderBlock {
29    pub fn new(embed_dim: usize, num_heads: usize, seed: Option<u64>) -> Self {
30        let s0 = seed;
31        let s1 = s0.map(|s| s + 1);
32        let s2 = s0.map(|s| s + 2);
33        Self {
34            _embed_dim: embed_dim,
35            _num_heads: num_heads,
36            mha: MultiHeadAttention::new(embed_dim, num_heads, s0),
37            ffn_in: LinearLayer::new(embed_dim, embed_dim * 2, s1),
38            ffn_out: LinearLayer::new(embed_dim * 2, embed_dim, s2),
39        }
40    }
41
42    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
43        let mut params = Vec::new();
44        params.extend(self.mha.parameters());
45        params.extend(self.ffn_in.parameters());
46        params.extend(self.ffn_out.parameters());
47        params
48    }
49
50    /// Forward pass
51    /// input: [batch, seq, embed]
52    /// attn_mask: optional mask broadcastable to [batch, heads, seq, seq]
53    pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54        let attn = self.mha.forward(input, input, input, attn_mask);
55        let res1 = attn.add_tensor(input);
56
57        // Feed-forward network with ReLU and residual
58        let (b, t, e) = Self::triple(input);
59        let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60        let hidden = self.ffn_in.forward(&x2d).relu();
61        let out2d = self.ffn_out.forward(&hidden);
62        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63        out.add_tensor(&res1)
64    }
65
66    fn triple(t: &Tensor) -> (usize, usize, usize) {
67        let d = t.shape().dims();
68        (d[0], d[1], d[2])
69    }
70}
71
72#[allow(unused)]
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74    println!("=== Basic Encoder Example ===");
75
76    let batch = 2usize;
77    let seq = 6usize;
78    let embed = 32usize;
79    let heads = 4usize;
80
81    let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82    let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84    // Example: no mask (set Some(mask) to use masking)
85    let out = enc.forward(&input, None);
86    println!("Output shape: {:?}", out.shape().dims());
87
88    // Verify gradients/optimization
89    let mut opt = Adam::with_learning_rate(0.01);
90    let mut params = enc.parameters();
91    for p in &params {
92        opt.add_parameter(p);
93    }
94    let mut loss = out.mean();
95    loss.backward(None);
96    opt.step(&mut params);
97    opt.zero_grad(&mut params);
98    println!("Loss: {:.6}", loss.value());
99    println!("=== Done ===");
100    Ok(())
101}