1use candle_core::{Result, Tensor, DType, Device};
5use candle_nn::VarBuilder;
6use super::activations::CastedLinear;
7use super::positional::{apply_rotary_pos_emb, RotaryEmbedding};
8
9pub struct Attention {
17 hidden_size: usize,
18 head_dim: usize,
19 output_size: usize,
20 num_heads: usize,
21 num_key_value_heads: usize,
22 causal: bool,
23
24 qkv_proj: CastedLinear,
25 o_proj: CastedLinear,
26}
27
28impl Attention {
29 pub fn new(
39 hidden_size: usize,
40 head_dim: usize,
41 num_heads: usize,
42 num_key_value_heads: usize,
43 causal: bool,
44 vb: VarBuilder,
45 ) -> Result<Self> {
46 let output_size = head_dim * num_heads;
47
48 let qkv_size = (num_heads + 2 * num_key_value_heads) * head_dim;
50 let qkv_proj = CastedLinear::new(
51 hidden_size,
52 qkv_size,
53 false,
54 vb.pp("qkv_proj"),
55 )?;
56
57 let o_proj = CastedLinear::new(
59 output_size,
60 hidden_size,
61 false,
62 vb.pp("o_proj"),
63 )?;
64
65 Ok(Self {
66 hidden_size,
67 head_dim,
68 output_size,
69 num_heads,
70 num_key_value_heads,
71 causal,
72 qkv_proj,
73 o_proj,
74 })
75 }
76
77 pub fn forward(
86 &self,
87 hidden_states: &Tensor,
88 cos_sin: Option<(&Tensor, &Tensor)>,
89 ) -> Result<Tensor> {
90 let (batch_size, seq_len, _) = hidden_states.dims3()?;
91
92 let qkv = self.qkv_proj.forward(hidden_states)?;
94
95 let qkv = qkv.reshape((
99 batch_size,
100 seq_len,
101 self.num_heads + 2 * self.num_key_value_heads,
102 self.head_dim,
103 ))?;
104
105 let query = qkv.narrow(2, 0, self.num_heads)?; let key = qkv.narrow(2, self.num_heads, self.num_key_value_heads)?;
108 let value = qkv.narrow(2, self.num_heads + self.num_key_value_heads, self.num_key_value_heads)?;
109
110 let (query, key) = if let Some((cos, sin)) = cos_sin {
112 apply_rotary_pos_emb(&query, &key, cos, sin)?
113 } else {
114 (query, key)
115 };
116
117 let query = query.transpose(1, 2)?.contiguous()?;
119 let key = key.transpose(1, 2)?.contiguous()?;
120 let value = value.transpose(1, 2)?.contiguous()?;
121
122 let (key, value) = if self.num_key_value_heads < self.num_heads {
124 let repeat_factor = self.num_heads / self.num_key_value_heads;
125 (
126 repeat_kv(&key, repeat_factor)?,
127 repeat_kv(&value, repeat_factor)?,
128 )
129 } else {
130 (key, value)
131 };
132
133 let attn_output = scaled_dot_product_attention(
135 &query,
136 &key,
137 &value,
138 self.causal,
139 )?;
140
141 let attn_output = attn_output.transpose(1, 2)?;
143
144 let attn_output = attn_output.reshape((batch_size, seq_len, self.output_size))?;
146
147 self.o_proj.forward(&attn_output)
149 }
150}
151
152fn repeat_kv(x: &Tensor, n: usize) -> Result<Tensor> {
156 if n == 1 {
157 return Ok(x.clone());
158 }
159
160 let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?;
161
162 let x = x.unsqueeze(2)?;
165 let x = x.broadcast_as((batch, num_kv_heads, n, seq_len, head_dim))?;
166
167 x.reshape((batch, num_kv_heads * n, seq_len, head_dim))
169}
170
171fn scaled_dot_product_attention(
181 query: &Tensor,
182 key: &Tensor,
183 value: &Tensor,
184 causal: bool,
185) -> Result<Tensor> {
186 let (_batch, _num_heads, seq_len, head_dim) = query.dims4()?;
187 let scale = 1.0 / (head_dim as f64).sqrt();
188
189 let scores = query.matmul(&key.transpose(2, 3)?)?;
191 let scores = (scores * scale)?;
192
193 let scores = if causal {
195 let mask = create_causal_mask(seq_len, scores.device())?;
196 scores.broadcast_add(&mask)?
197 } else {
198 scores
199 };
200
201 let attn_weights = candle_nn::ops::softmax_last_dim(&scores)?;
203
204 attn_weights.matmul(value)
206}
207
208fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor> {
213 let mut mask_data = vec![0.0f32; seq_len * seq_len];
214
215 for i in 0..seq_len {
216 for j in (i + 1)..seq_len {
217 mask_data[i * seq_len + j] = f32::NEG_INFINITY;
218 }
219 }
220
221 Tensor::from_vec(mask_data, (seq_len, seq_len), device)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use candle_nn::VarMap;
228
229 #[test]
230 fn test_attention_shape() -> Result<()> {
231 let device = Device::Cpu;
232 let varmap = VarMap::new();
233 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
234
235 let attn = Attention::new(256, 32, 8, 8, false, vb)?;
236
237 let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
238 let out = attn.forward(&x, None)?;
239
240 assert_eq!(out.dims(), &[2, 16, 256]);
241
242 Ok(())
243 }
244
245 #[test]
246 fn test_attention_with_rope() -> Result<()> {
247 let device = Device::Cpu;
248 let varmap = VarMap::new();
249 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
250
251 let attn = Attention::new(256, 32, 8, 8, false, vb)?;
252
253 let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
254
255 let rope = RotaryEmbedding::new(32, 512, 10000.0, &device)?;
257 let (cos, sin) = rope.forward_with_len(16)?;
258
259 let out = attn.forward(&x, Some((&cos, &sin)))?;
260
261 assert_eq!(out.dims(), &[2, 16, 256]);
262
263 Ok(())
264 }
265
266 #[test]
267 fn test_grouped_query_attention() -> Result<()> {
268 let device = Device::Cpu;
269 let varmap = VarMap::new();
270 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
271
272 let attn = Attention::new(256, 32, 8, 2, false, vb)?;
274
275 let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
276 let out = attn.forward(&x, None)?;
277
278 assert_eq!(out.dims(), &[2, 16, 256]);
279
280 Ok(())
281 }
282
283 #[test]
284 fn test_causal_mask() -> Result<()> {
285 let device = Device::Cpu;
286 let mask = create_causal_mask(4, &device)?;
287
288 assert_eq!(mask.dims(), &[4, 4]);
290
291 let mask_vec = mask.flatten_all()?.to_vec1::<f32>()?;
293
294 assert_eq!(mask_vec[0], 0.0);
296 assert!(mask_vec[1].is_infinite() && mask_vec[1].is_sign_negative());
297
298 assert_eq!(mask_vec[4], 0.0);
300 assert_eq!(mask_vec[5], 0.0);
301 assert!(mask_vec[6].is_infinite() && mask_vec[6].is_sign_negative());
302
303 Ok(())
304 }
305
306 #[test]
307 fn test_repeat_kv() -> Result<()> {
308 let device = Device::Cpu;
309
310 let x = Tensor::randn(0f32, 1.0, (2, 2, 16, 32), &device)?;
311 let repeated = repeat_kv(&x, 4)?;
312
313 assert_eq!(repeated.dims(), &[2, 8, 16, 32]);
314
315 Ok(())
316 }
317}