basic_transformer/
multi_head_attention.rs

1//! Multi-Head Attention (MHA) - Minimal example using public API only
2//!
3//! This example implements a small, self-contained multi-head attention module
4//! using public `train_station` APIs and reuses the `LinearLayer` from
5//! `basic_linear_layer.rs`. It demonstrates shape-safe forward passes plus a
6//! tiny optimization step to verify gradients flow.
7
8use train_station::{
9    optimizers::{Adam, Optimizer},
10    Tensor,
11};
12
13// Reuse the LinearLayer example implementation without duplicating it.
14// This pulls in the module locally (its `main` stays namespaced and is unused).
15#[path = "basic_linear_layer.rs"]
16mod basic_linear_layer;
17pub use basic_linear_layer::LinearLayer;
18
19/// Minimal multi-head attention implemented with public API
20pub struct MultiHeadAttention {
21    pub embed_dim: usize,
22    pub num_heads: usize,
23    head_dim: usize,
24    // Learnable projections
25    q_proj: LinearLayer,
26    k_proj: LinearLayer,
27    v_proj: LinearLayer,
28    out_proj: LinearLayer,
29}
30
31impl MultiHeadAttention {
32    pub fn new(embed_dim: usize, num_heads: usize, seed: Option<u64>) -> Self {
33        assert!(
34            embed_dim.is_multiple_of(num_heads),
35            "embed_dim must be divisible by num_heads"
36        );
37        let head_dim = embed_dim / num_heads;
38        let s0 = seed;
39        let s1 = s0.map(|s| s + 1);
40        let s2 = s0.map(|s| s + 2);
41        let s3 = s0.map(|s| s + 3);
42        Self {
43            embed_dim,
44            num_heads,
45            head_dim,
46            q_proj: LinearLayer::new(embed_dim, embed_dim, s0),
47            k_proj: LinearLayer::new(embed_dim, embed_dim, s1),
48            v_proj: LinearLayer::new(embed_dim, embed_dim, s2),
49            out_proj: LinearLayer::new(embed_dim, embed_dim, s3),
50        }
51    }
52
53    /// Collect mutable parameter references for optimization
54    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
55        let mut params = Vec::new();
56        params.extend(self.q_proj.parameters());
57        params.extend(self.k_proj.parameters());
58        params.extend(self.v_proj.parameters());
59        params.extend(self.out_proj.parameters());
60        params
61    }
62
63    /// Forward pass
64    ///
65    /// - query: [batch, tgt_len, embed_dim]
66    /// - key:   [batch, src_len, embed_dim]
67    /// - value: [batch, src_len, embed_dim]
68    /// - attn_mask: Optional mask broadcastable to [batch, heads, tgt_len, src_len]
69    ///   If provided as a boolean mask (true = keep, false = mask), it will be
70    ///   applied via masked_fill with -1e9 before softmax. If provided as tensor
71    ///   with other shapes, it is added as is (additive mask).
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
124
125    fn project_qkv(
126        query: &Tensor,
127        key: &Tensor,
128        value: &Tensor,
129        q_proj: &LinearLayer,
130        k_proj: &LinearLayer,
131        v_proj: &LinearLayer,
132    ) -> (Tensor, Tensor, Tensor) {
133        let (bq, tq, eq) = Self::triple(query);
134        let (bk, tk, ek) = Self::triple(key);
135        let (_bv, tv, ev) = Self::triple(value);
136        assert!(eq == ek && ek == ev, "Q,K,V embed dims must match");
137        let q2d = query.view(vec![(bq * tq) as i32, eq as i32]);
138        let k2d = key.view(vec![(bk * tk) as i32, ek as i32]);
139        let v2d = value.view(vec![(_bv * tv) as i32, ev as i32]);
140        let q = q_proj
141            .forward(&q2d)
142            .view(vec![bq as i32, tq as i32, eq as i32]);
143        let k = k_proj
144            .forward(&k2d)
145            .view(vec![bk as i32, tk as i32, ek as i32]);
146        let v = v_proj
147            .forward(&v2d)
148            .view(vec![bk as i32, tv as i32, ev as i32]);
149        (q, k, v)
150    }
151
152    fn split_heads(x: &Tensor, b: usize, t: usize, h: usize, d: usize) -> Tensor {
153        x.view(vec![b as i32, t as i32, h as i32, d as i32])
154            .permute(vec![0, 2, 1, 3])
155    }
156
157    fn triple(t: &Tensor) -> (usize, usize, usize) {
158        let dims = t.shape().dims();
159        assert!(dims.len() == 3, "expected 3D tensor [batch, seq, embed]");
160        (dims[0], dims[1], dims[2])
161    }
162}
163
164#[allow(unused)]
165fn main() -> Result<(), Box<dyn std::error::Error>> {
166    println!("=== Multi-Head Attention Example ===");
167
168    let batch = 2usize;
169    let src_len = 5usize;
170    let tgt_len = 4usize;
171    let embed = 16usize;
172    let heads = 4usize;
173
174    let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175    let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176    let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178    let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180    // Simple causal mask for target self-attention shape [b, h, tq, tk]
181    let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182    // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183    // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184    if src_len >= tgt_len {
185        // set upper triangle to a large negative value for head 0
186        for i in 0..tgt_len {
187            for j in (i + 1)..src_len {
188                let idx = [0usize, 0usize, i, j];
189                // Quick set via data_mut using a slice view
190                let offset = mask.memory_offset(&idx);
191                let data = mask.data_mut();
192                data[offset] = -1e9;
193            }
194        }
195    }
196
197    let out = mha.forward(&query, &key, &value, Some(&mask));
198    println!("Output shape: {:?}", out.shape().dims());
199
200    // Tiny training step to confirm gradients are wired
201    let mut optimizer = Adam::with_learning_rate(0.01);
202    let mut params = mha.parameters();
203    for p in &params {
204        optimizer.add_parameter(p);
205    }
206
207    // Dummy loss = mean of output
208    let mut loss = out.mean();
209    loss.backward(None);
210    optimizer.step(&mut params);
211    optimizer.zero_grad(&mut params);
212
213    println!("Loss: {:.6}", loss.value());
214    println!("=== Done ===");
215    Ok(())
216}