1use train_station::{
9 optimizers::{Adam, Optimizer},
10 Tensor,
11};
12
13#[path = "basic_linear_layer.rs"]
16mod basic_linear_layer;
17pub use basic_linear_layer::LinearLayer;
18
19pub struct MultiHeadAttention {
21 pub embed_dim: usize,
22 pub num_heads: usize,
23 head_dim: usize,
24 q_proj: LinearLayer,
26 k_proj: LinearLayer,
27 v_proj: LinearLayer,
28 out_proj: LinearLayer,
29}
30
31impl MultiHeadAttention {
32 pub fn new(embed_dim: usize, num_heads: usize, seed: Option<u64>) -> Self {
33 assert!(
34 embed_dim.is_multiple_of(num_heads),
35 "embed_dim must be divisible by num_heads"
36 );
37 let head_dim = embed_dim / num_heads;
38 let s0 = seed;
39 let s1 = s0.map(|s| s + 1);
40 let s2 = s0.map(|s| s + 2);
41 let s3 = s0.map(|s| s + 3);
42 Self {
43 embed_dim,
44 num_heads,
45 head_dim,
46 q_proj: LinearLayer::new(embed_dim, embed_dim, s0),
47 k_proj: LinearLayer::new(embed_dim, embed_dim, s1),
48 v_proj: LinearLayer::new(embed_dim, embed_dim, s2),
49 out_proj: LinearLayer::new(embed_dim, embed_dim, s3),
50 }
51 }
52
53 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
55 let mut params = Vec::new();
56 params.extend(self.q_proj.parameters());
57 params.extend(self.k_proj.parameters());
58 params.extend(self.v_proj.parameters());
59 params.extend(self.out_proj.parameters());
60 params
61 }
62
63 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }
124
125 fn project_qkv(
126 query: &Tensor,
127 key: &Tensor,
128 value: &Tensor,
129 q_proj: &LinearLayer,
130 k_proj: &LinearLayer,
131 v_proj: &LinearLayer,
132 ) -> (Tensor, Tensor, Tensor) {
133 let (bq, tq, eq) = Self::triple(query);
134 let (bk, tk, ek) = Self::triple(key);
135 let (_bv, tv, ev) = Self::triple(value);
136 assert!(eq == ek && ek == ev, "Q,K,V embed dims must match");
137 let q2d = query.view(vec![(bq * tq) as i32, eq as i32]);
138 let k2d = key.view(vec![(bk * tk) as i32, ek as i32]);
139 let v2d = value.view(vec![(_bv * tv) as i32, ev as i32]);
140 let q = q_proj
141 .forward(&q2d)
142 .view(vec![bq as i32, tq as i32, eq as i32]);
143 let k = k_proj
144 .forward(&k2d)
145 .view(vec![bk as i32, tk as i32, ek as i32]);
146 let v = v_proj
147 .forward(&v2d)
148 .view(vec![bk as i32, tv as i32, ev as i32]);
149 (q, k, v)
150 }
151
152 fn split_heads(x: &Tensor, b: usize, t: usize, h: usize, d: usize) -> Tensor {
153 x.view(vec![b as i32, t as i32, h as i32, d as i32])
154 .permute(vec![0, 2, 1, 3])
155 }
156
157 fn triple(t: &Tensor) -> (usize, usize, usize) {
158 let dims = t.shape().dims();
159 assert!(dims.len() == 3, "expected 3D tensor [batch, seq, embed]");
160 (dims[0], dims[1], dims[2])
161 }
162}
163
164#[allow(unused)]
165fn main() -> Result<(), Box<dyn std::error::Error>> {
166 println!("=== Multi-Head Attention Example ===");
167
168 let batch = 2usize;
169 let src_len = 5usize;
170 let tgt_len = 4usize;
171 let embed = 16usize;
172 let heads = 4usize;
173
174 let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175 let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176 let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178 let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180 let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182 if src_len >= tgt_len {
185 for i in 0..tgt_len {
187 for j in (i + 1)..src_len {
188 let idx = [0usize, 0usize, i, j];
189 let offset = mask.memory_offset(&idx);
191 let data = mask.data_mut();
192 data[offset] = -1e9;
193 }
194 }
195 }
196
197 let out = mha.forward(&query, &key, &value, Some(&mask));
198 println!("Output shape: {:?}", out.shape().dims());
199
200 let mut optimizer = Adam::with_learning_rate(0.01);
202 let mut params = mha.parameters();
203 for p in ¶ms {
204 optimizer.add_parameter(p);
205 }
206
207 let mut loss = out.mean();
209 loss.backward(None);
210 optimizer.step(&mut params);
211 optimizer.zero_grad(&mut params);
212
213 println!("Loss: {:.6}", loss.value());
214 println!("=== Done ===");
215 Ok(())
216}