1use crate::{BackendKind, BufferHandle, ComputeBackend, DeviceInfo, buffer::next_buffer_handle};
2use yule_core::error::{Result, YuleError};
3use std::collections::HashMap;
4use std::sync::Mutex;
5
6pub struct CpuBackend {
7 buffers: Mutex<HashMap<u64, Vec<u8>>>,
8}
9
10impl CpuBackend {
11 pub fn new() -> Self {
12 Self {
13 buffers: Mutex::new(HashMap::new()),
14 }
15 }
16
17 fn get_buf<'a>(
19 buffers: &'a HashMap<u64, Vec<u8>>,
20 handle: &BufferHandle,
21 ) -> Result<&'a Vec<u8>> {
22 buffers
23 .get(&handle.0)
24 .ok_or_else(|| YuleError::Gpu(format!("buffer {} not found", handle.0)))
25 }
26
27 fn get_buf_mut<'a>(
29 buffers: &'a mut HashMap<u64, Vec<u8>>,
30 handle: &BufferHandle,
31 ) -> Result<&'a mut Vec<u8>> {
32 buffers
33 .get_mut(&handle.0)
34 .ok_or_else(|| YuleError::Gpu(format!("buffer {} not found", handle.0)))
35 }
36}
37
38impl Default for CpuBackend {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44#[inline]
46fn as_f32_slice(data: &[u8]) -> &[f32] {
47 debug_assert!(data.len() % 4 == 0);
48 bytemuck::cast_slice(data)
52}
53
54#[inline]
56fn as_f32_slice_mut(data: &mut [u8]) -> &mut [f32] {
57 debug_assert!(data.len() % 4 == 0);
58 bytemuck::cast_slice_mut(data)
59}
60
61impl ComputeBackend for CpuBackend {
62 fn name(&self) -> &str {
63 "cpu"
64 }
65
66 fn device_info(&self) -> DeviceInfo {
67 DeviceInfo {
68 name: "CPU".into(),
69 backend: BackendKind::Cpu,
70 memory_bytes: 0, compute_units: std::thread::available_parallelism()
72 .map(|p| p.get() as u32)
73 .unwrap_or(1),
74 }
75 }
76
77 fn allocate(&self, size_bytes: usize) -> Result<BufferHandle> {
78 let handle = next_buffer_handle();
79 let buf = vec![0u8; size_bytes];
83 self.buffers.lock().unwrap().insert(handle.0, buf);
84 Ok(handle)
85 }
86
87 fn free(&self, handle: BufferHandle) -> Result<()> {
88 self.buffers.lock().unwrap().remove(&handle.0);
89 Ok(())
90 }
91
92 fn matmul(
96 &self,
97 a: &BufferHandle,
98 b: &BufferHandle,
99 out: &BufferHandle,
100 m: u32,
101 n: u32,
102 k: u32,
103 ) -> Result<()> {
104 let mut buffers = self.buffers.lock().unwrap();
105 let a_data = Self::get_buf(&buffers, a)?.as_ptr();
108 let b_data = Self::get_buf(&buffers, b)?.as_ptr();
109 let a_len = Self::get_buf(&buffers, a)?.len();
110 let b_len = Self::get_buf(&buffers, b)?.len();
111 let out_buf = Self::get_buf_mut(&mut buffers, out)?;
112 let out_f32 = as_f32_slice_mut(out_buf);
113
114 let a_f32: &[f32] = bytemuck::cast_slice(unsafe {
116 std::slice::from_raw_parts(a_data, a_len)
117 });
118 let b_f32: &[f32] = bytemuck::cast_slice(unsafe {
119 std::slice::from_raw_parts(b_data, b_len)
120 });
121
122 let (m, n, k) = (m as usize, n as usize, k as usize);
123
124 for i in 0..m {
127 for j in 0..n {
128 let mut sum = 0.0f32;
129 for p in 0..k {
130 sum += a_f32[i * k + p] * b_f32[p * n + j];
131 }
132 out_f32[i * n + j] = sum;
133 }
134 }
135 Ok(())
136 }
137
138 fn softmax(
141 &self,
142 input: &BufferHandle,
143 output: &BufferHandle,
144 size: u32,
145 ) -> Result<()> {
146 let mut buffers = self.buffers.lock().unwrap();
147 let inp_data = Self::get_buf(&buffers, input)?.as_ptr();
148 let inp_len = Self::get_buf(&buffers, input)?.len();
149 let out_buf = Self::get_buf_mut(&mut buffers, output)?;
150 let out_f32 = as_f32_slice_mut(out_buf);
151
152 let inp_f32: &[f32] = bytemuck::cast_slice(unsafe {
153 std::slice::from_raw_parts(inp_data, inp_len)
154 });
155
156 let n = size as usize;
157 let mut max_val = f32::NEG_INFINITY;
158 for i in 0..n {
159 if inp_f32[i] > max_val {
160 max_val = inp_f32[i];
161 }
162 }
163
164 let mut sum = 0.0f32;
165 for i in 0..n {
166 let e = (inp_f32[i] - max_val).exp();
167 out_f32[i] = e;
168 sum += e;
169 }
170
171 let inv_sum = 1.0 / sum;
172 for i in 0..n {
173 out_f32[i] *= inv_sum;
174 }
175 Ok(())
176 }
177
178 fn rms_norm(
181 &self,
182 input: &BufferHandle,
183 weight: &BufferHandle,
184 output: &BufferHandle,
185 size: u32,
186 eps: f32,
187 ) -> Result<()> {
188 let mut buffers = self.buffers.lock().unwrap();
189 let inp_data = Self::get_buf(&buffers, input)?.as_ptr();
190 let inp_len = Self::get_buf(&buffers, input)?.len();
191 let wt_data = Self::get_buf(&buffers, weight)?.as_ptr();
192 let wt_len = Self::get_buf(&buffers, weight)?.len();
193 let out_buf = Self::get_buf_mut(&mut buffers, output)?;
194 let out_f32 = as_f32_slice_mut(out_buf);
195
196 let inp_f32: &[f32] = bytemuck::cast_slice(unsafe {
197 std::slice::from_raw_parts(inp_data, inp_len)
198 });
199 let wt_f32: &[f32] = bytemuck::cast_slice(unsafe {
200 std::slice::from_raw_parts(wt_data, wt_len)
201 });
202
203 let n = size as usize;
204 let mut ss = 0.0f32;
205 for i in 0..n {
206 ss += inp_f32[i] * inp_f32[i];
207 }
208 let rms = (ss / n as f32 + eps).sqrt();
209 let inv_rms = 1.0 / rms;
210
211 for i in 0..n {
212 out_f32[i] = inp_f32[i] * inv_rms * wt_f32[i];
213 }
214 Ok(())
215 }
216
217 fn rope(
221 &self,
222 q: &BufferHandle,
223 k: &BufferHandle,
224 pos: u32,
225 head_dim: u32,
226 freq_base: f32,
227 _n_heads_q: u32,
228 _n_heads_k: u32,
229 ) -> Result<()> {
230 let mut buffers = self.buffers.lock().unwrap();
231
232 for handle in [q, k] {
234 let buf = Self::get_buf_mut(&mut buffers, handle)?;
235 let f32_data = as_f32_slice_mut(buf);
236 let hd = head_dim as usize;
237 let n_heads = f32_data.len() / hd;
238
239 for h in 0..n_heads {
240 let base = h * hd;
241 for i in 0..(hd / 2) {
242 let freq = 1.0 / freq_base.powf(2.0 * i as f32 / hd as f32);
243 let theta = pos as f32 * freq;
244 let cos_t = theta.cos();
245 let sin_t = theta.sin();
246
247 let x0 = f32_data[base + 2 * i];
248 let x1 = f32_data[base + 2 * i + 1];
249 f32_data[base + 2 * i] = x0 * cos_t - x1 * sin_t;
250 f32_data[base + 2 * i + 1] = x0 * sin_t + x1 * cos_t;
251 }
252 }
253 }
254 Ok(())
255 }
256
257 fn silu(
260 &self,
261 input: &BufferHandle,
262 output: &BufferHandle,
263 size: u32,
264 ) -> Result<()> {
265 let mut buffers = self.buffers.lock().unwrap();
266 let inp_data = Self::get_buf(&buffers, input)?.as_ptr();
267 let inp_len = Self::get_buf(&buffers, input)?.len();
268 let out_buf = Self::get_buf_mut(&mut buffers, output)?;
269 let out_f32 = as_f32_slice_mut(out_buf);
270
271 let inp_f32: &[f32] = bytemuck::cast_slice(unsafe {
272 std::slice::from_raw_parts(inp_data, inp_len)
273 });
274
275 let n = size as usize;
276 for i in 0..n {
277 let x = inp_f32[i];
278 let sigmoid = 1.0 / (1.0 + (-x).exp());
279 out_f32[i] = x * sigmoid;
280 }
281 Ok(())
282 }
283
284 fn element_mul(
286 &self,
287 a: &BufferHandle,
288 b: &BufferHandle,
289 output: &BufferHandle,
290 size: u32,
291 ) -> Result<()> {
292 let mut buffers = self.buffers.lock().unwrap();
293 let a_data = Self::get_buf(&buffers, a)?.as_ptr();
294 let a_len = Self::get_buf(&buffers, a)?.len();
295 let b_data = Self::get_buf(&buffers, b)?.as_ptr();
296 let b_len = Self::get_buf(&buffers, b)?.len();
297 let out_buf = Self::get_buf_mut(&mut buffers, output)?;
298 let out_f32 = as_f32_slice_mut(out_buf);
299
300 let a_f32: &[f32] = bytemuck::cast_slice(unsafe {
301 std::slice::from_raw_parts(a_data, a_len)
302 });
303 let b_f32: &[f32] = bytemuck::cast_slice(unsafe {
304 std::slice::from_raw_parts(b_data, b_len)
305 });
306
307 let n = size as usize;
308 for i in 0..n {
309 out_f32[i] = a_f32[i] * b_f32[i];
310 }
311 Ok(())
312 }
313
314 fn add(
316 &self,
317 a: &BufferHandle,
318 b: &BufferHandle,
319 output: &BufferHandle,
320 size: u32,
321 ) -> Result<()> {
322 let mut buffers = self.buffers.lock().unwrap();
323 let a_data = Self::get_buf(&buffers, a)?.as_ptr();
324 let a_len = Self::get_buf(&buffers, a)?.len();
325 let b_data = Self::get_buf(&buffers, b)?.as_ptr();
326 let b_len = Self::get_buf(&buffers, b)?.len();
327 let out_buf = Self::get_buf_mut(&mut buffers, output)?;
328 let out_f32 = as_f32_slice_mut(out_buf);
329
330 let a_f32: &[f32] = bytemuck::cast_slice(unsafe {
331 std::slice::from_raw_parts(a_data, a_len)
332 });
333 let b_f32: &[f32] = bytemuck::cast_slice(unsafe {
334 std::slice::from_raw_parts(b_data, b_len)
335 });
336
337 let n = size as usize;
338 for i in 0..n {
339 out_f32[i] = a_f32[i] + b_f32[i];
340 }
341 Ok(())
342 }
343
344 fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()> {
345 let mut buffers = self.buffers.lock().unwrap();
346 let buf = buffers.get_mut(&handle.0)
347 .ok_or_else(|| YuleError::Gpu("buffer not found".into()))?;
348 buf[..data.len()].copy_from_slice(data);
349 Ok(())
350 }
351
352 fn copy_from_device(&self, handle: &BufferHandle, data: &mut [u8]) -> Result<()> {
353 let buffers = self.buffers.lock().unwrap();
354 let buf = buffers.get(&handle.0)
355 .ok_or_else(|| YuleError::Gpu("buffer not found".into()))?;
356 data.copy_from_slice(&buf[..data.len()]);
357 Ok(())
358 }
359
360 fn copy_buffer(&self, src: &BufferHandle, dst: &BufferHandle, size: usize) -> Result<()> {
361 let mut buffers = self.buffers.lock().unwrap();
362 let src_ptr = Self::get_buf(&buffers, src)?.as_ptr();
363 let src_len = Self::get_buf(&buffers, src)?.len();
364 let dst_buf = Self::get_buf_mut(&mut buffers, dst)?;
365 let n = size.min(src_len).min(dst_buf.len());
366 let src_slice = unsafe { std::slice::from_raw_parts(src_ptr, n) };
367 dst_buf[..n].copy_from_slice(src_slice);
368 Ok(())
369 }
370
371 fn copy_buffer_offset(
372 &self, src: &BufferHandle, dst: &BufferHandle,
373 src_offset: usize, dst_offset: usize, size: usize,
374 ) -> Result<()> {
375 let mut buffers = self.buffers.lock().unwrap();
376 let src_ptr = Self::get_buf(&buffers, src)?.as_ptr();
377 let src_len = Self::get_buf(&buffers, src)?.len();
378 let dst_buf = Self::get_buf_mut(&mut buffers, dst)?;
379 let src_slice = unsafe { std::slice::from_raw_parts(src_ptr.add(src_offset), size.min(src_len - src_offset)) };
380 dst_buf[dst_offset..dst_offset + size].copy_from_slice(src_slice);
381 Ok(())
382 }
383
384 fn synchronize(&self) -> Result<()> {
385 Ok(()) }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 fn write_f32(backend: &CpuBackend, handle: &BufferHandle, data: &[f32]) {
394 let bytes: &[u8] = bytemuck::cast_slice(data);
395 backend.copy_to_device(bytes, handle).unwrap();
396 }
397
398 fn read_f32(backend: &CpuBackend, handle: &BufferHandle, n: usize) -> Vec<f32> {
399 let mut bytes = vec![0u8; n * 4];
400 backend.copy_from_device(handle, &mut bytes).unwrap();
401 bytemuck::cast_slice(&bytes).to_vec()
402 }
403
404 #[test]
405 fn test_softmax() {
406 let b = CpuBackend::new();
407 let inp = b.allocate(16).unwrap(); let out = b.allocate(16).unwrap();
409 write_f32(&b, &inp, &[1.0, 2.0, 3.0, 4.0]);
410
411 b.softmax(&inp, &out, 4).unwrap();
412 let result = read_f32(&b, &out, 4);
413
414 let sum: f32 = result.iter().sum();
416 assert!((sum - 1.0).abs() < 1e-5);
417 assert!(result[3] > result[2]);
419 assert!(result[2] > result[1]);
420 assert!(result[1] > result[0]);
421 }
422
423 #[test]
424 fn test_rms_norm() {
425 let b = CpuBackend::new();
426 let inp = b.allocate(16).unwrap();
427 let wt = b.allocate(16).unwrap();
428 let out = b.allocate(16).unwrap();
429
430 write_f32(&b, &inp, &[1.0, 2.0, 3.0, 4.0]);
431 write_f32(&b, &wt, &[1.0, 1.0, 1.0, 1.0]);
432
433 b.rms_norm(&inp, &wt, &out, 4, 1e-6).unwrap();
434 let result = read_f32(&b, &out, 4);
435
436 let rms = (7.5f32 + 1e-6).sqrt();
438 assert!((result[0] - 1.0 / rms).abs() < 1e-4);
439 assert!((result[3] - 4.0 / rms).abs() < 1e-4);
440 }
441
442 #[test]
443 fn test_silu() {
444 let b = CpuBackend::new();
445 let inp = b.allocate(12).unwrap();
446 let out = b.allocate(12).unwrap();
447
448 write_f32(&b, &inp, &[0.0, 1.0, -1.0]);
449 b.silu(&inp, &out, 3).unwrap();
450 let result = read_f32(&b, &out, 3);
451
452 assert!((result[0] - 0.0).abs() < 1e-5);
454 assert!((result[1] - 0.7311).abs() < 1e-3);
456 assert!((result[2] - (-0.2689)).abs() < 1e-3);
458 }
459
460 #[test]
461 fn test_element_mul() {
462 let b = CpuBackend::new();
463 let a = b.allocate(12).unwrap();
464 let bh = b.allocate(12).unwrap();
465 let out = b.allocate(12).unwrap();
466
467 write_f32(&b, &a, &[2.0, 3.0, 4.0]);
468 write_f32(&b, &bh, &[5.0, 6.0, 7.0]);
469 b.element_mul(&a, &bh, &out, 3).unwrap();
470 let result = read_f32(&b, &out, 3);
471
472 assert!((result[0] - 10.0).abs() < 1e-5);
473 assert!((result[1] - 18.0).abs() < 1e-5);
474 assert!((result[2] - 28.0).abs() < 1e-5);
475 }
476
477 #[test]
478 fn test_add() {
479 let b = CpuBackend::new();
480 let a = b.allocate(12).unwrap();
481 let bh = b.allocate(12).unwrap();
482 let out = b.allocate(12).unwrap();
483
484 write_f32(&b, &a, &[1.0, 2.0, 3.0]);
485 write_f32(&b, &bh, &[4.0, 5.0, 6.0]);
486 b.add(&a, &bh, &out, 3).unwrap();
487 let result = read_f32(&b, &out, 3);
488
489 assert!((result[0] - 5.0).abs() < 1e-5);
490 assert!((result[1] - 7.0).abs() < 1e-5);
491 assert!((result[2] - 9.0).abs() < 1e-5);
492 }
493
494 #[test]
495 fn test_matmul_gemv() {
496 let b = CpuBackend::new();
498 let a = b.allocate(16).unwrap(); let bh = b.allocate(48).unwrap(); let out = b.allocate(12).unwrap(); write_f32(&b, &a, &[1.0, 2.0, 3.0, 4.0]);
503 write_f32(&b, &bh, &[
505 1.0, 0.0, 0.0,
506 0.0, 1.0, 0.0,
507 0.0, 0.0, 1.0,
508 1.0, 1.0, 1.0,
509 ]);
510
511 b.matmul(&a, &bh, &out, 1, 3, 4).unwrap();
512 let result = read_f32(&b, &out, 3);
513
514 assert!((result[0] - 5.0).abs() < 1e-5);
518 assert!((result[1] - 6.0).abs() < 1e-5);
519 assert!((result[2] - 7.0).abs() < 1e-5);
520 }
521
522 #[test]
523 fn test_rope_single_head_pos0() {
524 let b = CpuBackend::new();
525 let q = b.allocate(16).unwrap(); let k = b.allocate(16).unwrap();
527
528 write_f32(&b, &q, &[1.0, 0.0, 1.0, 0.0]);
529 write_f32(&b, &k, &[0.0, 1.0, 0.0, 1.0]);
530
531 b.rope(&q, &k, 0, 4, 10000.0, 1, 1).unwrap();
532 let q_result = read_f32(&b, &q, 4);
533 let k_result = read_f32(&b, &k, 4);
534
535 assert!((q_result[0] - 1.0).abs() < 1e-5);
537 assert!((q_result[1] - 0.0).abs() < 1e-5);
538 assert!((k_result[1] - 1.0).abs() < 1e-5);
539 }
540
541 #[test]
542 fn test_rope_multi_head() {
543 let b = CpuBackend::new();
544 let q = b.allocate(32).unwrap(); let k = b.allocate(16).unwrap(); write_f32(&b, &q, &[1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0]);
549 write_f32(&b, &k, &[1.0, 0.0, 1.0, 0.0]);
550
551 b.rope(&q, &k, 0, 4, 10000.0, 2, 1).unwrap();
553 let q_result = read_f32(&b, &q, 8);
554 let k_result = read_f32(&b, &k, 4);
555
556 assert!((q_result[0] - 1.0).abs() < 1e-5); assert!((q_result[4] - 0.0).abs() < 1e-5); assert!((q_result[5] - 1.0).abs() < 1e-5); assert!((k_result[0] - 1.0).abs() < 1e-5);
561 }
562
563 #[test]
564 fn test_rope_nonzero_pos() {
565 let b = CpuBackend::new();
566 let q = b.allocate(16).unwrap(); let k = b.allocate(16).unwrap();
568
569 write_f32(&b, &q, &[1.0, 0.0, 1.0, 0.0]);
570 write_f32(&b, &k, &[1.0, 0.0, 1.0, 0.0]);
571
572 b.rope(&q, &k, 5, 4, 10000.0, 1, 1).unwrap();
573 let q_result = read_f32(&b, &q, 4);
574
575 let cos5 = 5.0f32.cos();
578 let sin5 = 5.0f32.sin();
579 assert!((q_result[0] - cos5).abs() < 1e-4);
581 assert!((q_result[1] - sin5).abs() < 1e-4);
583 }
584
585 #[test]
586 fn test_copy_buffer() {
587 let b = CpuBackend::new();
588 let src = b.allocate(16).unwrap();
589 let dst = b.allocate(16).unwrap();
590
591 write_f32(&b, &src, &[1.0, 2.0, 3.0, 4.0]);
592 b.copy_buffer(&src, &dst, 16).unwrap();
593 let result = read_f32(&b, &dst, 4);
594
595 assert!((result[0] - 1.0).abs() < 1e-5);
596 assert!((result[3] - 4.0).abs() < 1e-5);
597 }
598
599 #[test]
600 fn test_copy_buffer_offset() {
601 let b = CpuBackend::new();
602 let src = b.allocate(16).unwrap(); let dst = b.allocate(32).unwrap(); write_f32(&b, &src, &[10.0, 20.0, 30.0, 40.0]);
606
607 b.copy_buffer_offset(&src, &dst, 4, 8, 8).unwrap();
609 let result = read_f32(&b, &dst, 8);
610
611 assert!((result[0] - 0.0).abs() < 1e-5); assert!((result[1] - 0.0).abs() < 1e-5); assert!((result[2] - 20.0).abs() < 1e-5); assert!((result[3] - 30.0).abs() < 1e-5); assert!((result[4] - 0.0).abs() < 1e-5); }
617
618 #[test]
619 fn test_copy_buffer_offset_kv_cache_pattern() {
620 let b = CpuBackend::new();
622 let n_kv_heads = 2;
623 let head_dim = 4;
624 let kv_stride = n_kv_heads * head_dim; let max_seq_len = 4;
626
627 let k_tmp = b.allocate(kv_stride * 4).unwrap();
628 let k_cache = b.allocate(max_seq_len * kv_stride * 4).unwrap();
629
630 let k_data: Vec<f32> = (0..kv_stride).map(|i| (i + 1) as f32).collect();
632 write_f32(&b, &k_tmp, &k_data);
633
634 let pos = 2;
636 let cache_byte_offset = pos * kv_stride * 4;
637 b.copy_buffer_offset(&k_tmp, &k_cache, 0, cache_byte_offset, kv_stride * 4).unwrap();
638
639 let cache = read_f32(&b, &k_cache, max_seq_len * kv_stride);
640
641 assert!((cache[0] - 0.0).abs() < 1e-5);
643 assert!((cache[kv_stride - 1] - 0.0).abs() < 1e-5);
644 assert!((cache[pos * kv_stride] - 1.0).abs() < 1e-5);
646 assert!((cache[pos * kv_stride + kv_stride - 1] - kv_stride as f32).abs() < 1e-5);
647 assert!((cache[3 * kv_stride] - 0.0).abs() < 1e-5);
649 }
650
651 #[test]
652 fn test_attention_manual() {
653 let b = CpuBackend::new();
656 let hd = 2;
657 let seq_len = 3;
658
659 let q = b.allocate(hd * 4).unwrap();
660 let k_cache = b.allocate(seq_len * hd * 4).unwrap();
661 let v_cache = b.allocate(seq_len * hd * 4).unwrap();
662 let scores = b.allocate(seq_len * 4).unwrap();
663 let out = b.allocate(hd * 4).unwrap();
664
665 write_f32(&b, &q, &[1.0, 0.0]);
667
668 write_f32(&b, &k_cache, &[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
671
672 write_f32(&b, &v_cache, &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
674
675 let scale = 1.0 / (hd as f32).sqrt();
680 let s0 = 1.0 * scale;
681 let s1 = 0.0 * scale;
682 let s2 = 1.0 * scale;
683
684 let max_s = s0.max(s1).max(s2);
686 let e0 = (s0 - max_s).exp();
687 let e1 = (s1 - max_s).exp();
688 let e2 = (s2 - max_s).exp();
689 let sum = e0 + e1 + e2;
690 let w0 = e0 / sum;
691 let w1 = e1 / sum;
692 let w2 = e2 / sum;
693
694 let expected_0 = w0 * 10.0 + w1 * 30.0 + w2 * 50.0;
696 let expected_1 = w0 * 20.0 + w1 * 40.0 + w2 * 60.0;
697
698 let q_f32 = read_f32(&b, &q, hd);
701 let k_f32 = read_f32(&b, &k_cache, seq_len * hd);
702 let mut scores_f32 = vec![0.0f32; seq_len];
703 for t in 0..seq_len {
704 let mut dot = 0.0f32;
705 for d in 0..hd {
706 dot += q_f32[d] * k_f32[t * hd + d];
707 }
708 scores_f32[t] = dot * scale;
709 }
710 write_f32(&b, &scores, &scores_f32);
711
712 b.softmax(&scores, &scores, seq_len as u32).unwrap();
714 let weights = read_f32(&b, &scores, seq_len);
715
716 assert!((weights[0] - w0).abs() < 1e-4);
718 assert!((weights[1] - w1).abs() < 1e-4);
719 assert!((weights[2] - w2).abs() < 1e-4);
720
721 let v_f32 = read_f32(&b, &v_cache, seq_len * hd);
723 let mut out_f32 = vec![0.0f32; hd];
724 for t in 0..seq_len {
725 for d in 0..hd {
726 out_f32[d] += weights[t] * v_f32[t * hd + d];
727 }
728 }
729 write_f32(&b, &out, &out_f32);
730 let result = read_f32(&b, &out, hd);
731
732 assert!((result[0] - expected_0).abs() < 1e-3);
733 assert!((result[1] - expected_1).abs() < 1e-3);
734 }
735
736 #[test]
737 fn test_attention_gqa() {
738 let b = CpuBackend::new();
741 let hd = 2;
742 let n_heads = 2;
743 let n_kv_heads = 1;
744 let kv_stride = n_kv_heads * hd;
745 let seq_len = 2;
746
747 let q = b.allocate(n_heads * hd * 4).unwrap();
749 write_f32(&b, &q, &[1.0, 0.0, 0.0, 1.0]); let k_cache = b.allocate(seq_len * kv_stride * 4).unwrap();
753 let v_cache = b.allocate(seq_len * kv_stride * 4).unwrap();
754 write_f32(&b, &k_cache, &[1.0, 0.0, 0.0, 1.0]); write_f32(&b, &v_cache, &[10.0, 20.0, 30.0, 40.0]); let scores_buf = b.allocate(seq_len * 4).unwrap();
758 let attn_out = b.allocate(n_heads * hd * 4).unwrap();
759
760 let scale = 1.0 / (hd as f32).sqrt();
761 let kv_group = n_heads / n_kv_heads;
762
763 for h in 0..n_heads {
765 let kv_h = h / kv_group;
766 let head_offset = h * hd;
767 let kv_off = kv_h * hd;
768
769 let q_f32 = read_f32(&b, &q, n_heads * hd);
771 let k_f32 = read_f32(&b, &k_cache, seq_len * kv_stride);
772 let mut scores = vec![0.0f32; seq_len];
773 for t in 0..seq_len {
774 let mut dot = 0.0f32;
775 for d in 0..hd {
776 dot += q_f32[head_offset + d] * k_f32[t * kv_stride + kv_off + d];
777 }
778 scores[t] = dot * scale;
779 }
780 write_f32(&b, &scores_buf, &scores);
781
782 b.softmax(&scores_buf, &scores_buf, seq_len as u32).unwrap();
783 let weights = read_f32(&b, &scores_buf, seq_len);
784
785 let v_f32 = read_f32(&b, &v_cache, seq_len * kv_stride);
787 let mut head_out = vec![0.0f32; hd];
788 for t in 0..seq_len {
789 for d in 0..hd {
790 head_out[d] += weights[t] * v_f32[t * kv_stride + kv_off + d];
791 }
792 }
793
794 let mut full_out = read_f32(&b, &attn_out, n_heads * hd);
796 full_out[head_offset..head_offset + hd].copy_from_slice(&head_out);
797 write_f32(&b, &attn_out, &full_out);
798 }
799
800 let result = read_f32(&b, &attn_out, n_heads * hd);
801
802 assert!(result[0] < 25.0); assert!(result[1] < 35.0); assert!(result[2] > 15.0); assert!(result[3] > 25.0); }
815}