basic_transformer/
basic_encoder.rs1use train_station::{
4 optimizers::{Adam, Optimizer},
5 Tensor,
6};
7
8#[allow(clippy::duplicate_mod)]
10#[path = "basic_linear_layer.rs"]
11mod basic_linear_layer;
12use basic_linear_layer::LinearLayer;
13
14#[allow(clippy::duplicate_mod)]
16#[path = "multi_head_attention.rs"]
17mod multi_head_attention;
18use multi_head_attention::MultiHeadAttention;
19
20pub struct EncoderBlock {
21 pub _embed_dim: usize,
22 pub _num_heads: usize,
23 mha: MultiHeadAttention,
24 ffn_in: LinearLayer,
25 ffn_out: LinearLayer,
26}
27
28impl EncoderBlock {
29 pub fn new(embed_dim: usize, num_heads: usize, seed: Option<u64>) -> Self {
30 let s0 = seed;
31 let s1 = s0.map(|s| s + 1);
32 let s2 = s0.map(|s| s + 2);
33 Self {
34 _embed_dim: embed_dim,
35 _num_heads: num_heads,
36 mha: MultiHeadAttention::new(embed_dim, num_heads, s0),
37 ffn_in: LinearLayer::new(embed_dim, embed_dim * 2, s1),
38 ffn_out: LinearLayer::new(embed_dim * 2, embed_dim, s2),
39 }
40 }
41
42 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
43 let mut params = Vec::new();
44 params.extend(self.mha.parameters());
45 params.extend(self.ffn_in.parameters());
46 params.extend(self.ffn_out.parameters());
47 params
48 }
49
50 pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54 let attn = self.mha.forward(input, input, input, attn_mask);
55 let res1 = attn.add_tensor(input);
56
57 let (b, t, e) = Self::triple(input);
59 let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60 let hidden = self.ffn_in.forward(&x2d).relu();
61 let out2d = self.ffn_out.forward(&hidden);
62 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63 out.add_tensor(&res1)
64 }
65
66 fn triple(t: &Tensor) -> (usize, usize, usize) {
67 let d = t.shape().dims();
68 (d[0], d[1], d[2])
69 }
70}
71
72#[allow(unused)]
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}