Skip to main content

yule_gpu/
cpu.rs

1use crate::{BackendKind, BufferHandle, ComputeBackend, DeviceInfo, buffer::next_buffer_handle};
2use yule_core::error::{Result, YuleError};
3use std::collections::HashMap;
4use std::sync::Mutex;
5
6pub struct CpuBackend {
7    buffers: Mutex<HashMap<u64, Vec<u8>>>,
8}
9
10impl CpuBackend {
11    pub fn new() -> Self {
12        Self {
13            buffers: Mutex::new(HashMap::new()),
14        }
15    }
16
17    /// Get a raw byte buffer, returning an error if not found.
18    fn get_buf<'a>(
19        buffers: &'a HashMap<u64, Vec<u8>>,
20        handle: &BufferHandle,
21    ) -> Result<&'a Vec<u8>> {
22        buffers
23            .get(&handle.0)
24            .ok_or_else(|| YuleError::Gpu(format!("buffer {} not found", handle.0)))
25    }
26
27    /// Get a mutable raw byte buffer.
28    fn get_buf_mut<'a>(
29        buffers: &'a mut HashMap<u64, Vec<u8>>,
30        handle: &BufferHandle,
31    ) -> Result<&'a mut Vec<u8>> {
32        buffers
33            .get_mut(&handle.0)
34            .ok_or_else(|| YuleError::Gpu(format!("buffer {} not found", handle.0)))
35    }
36}
37
38impl Default for CpuBackend {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44/// Reinterpret &[u8] as &[f32] (LE). Length must be divisible by 4.
45#[inline]
46fn as_f32_slice(data: &[u8]) -> &[f32] {
47    debug_assert!(data.len() % 4 == 0);
48    // SAFETY: f32 has alignment 4, but Vec<u8> may not be aligned.
49    // Use bytemuck for safe casting if available; for now, copy is safest.
50    // Actually, we'll use a safe transmute path via bytemuck.
51    bytemuck::cast_slice(data)
52}
53
54/// Reinterpret &mut [u8] as &mut [f32] (LE).
55#[inline]
56fn as_f32_slice_mut(data: &mut [u8]) -> &mut [f32] {
57    debug_assert!(data.len() % 4 == 0);
58    bytemuck::cast_slice_mut(data)
59}
60
61impl ComputeBackend for CpuBackend {
62    fn name(&self) -> &str {
63        "cpu"
64    }
65
66    fn device_info(&self) -> DeviceInfo {
67        DeviceInfo {
68            name: "CPU".into(),
69            backend: BackendKind::Cpu,
70            memory_bytes: 0, // TODO: detect system RAM
71            compute_units: std::thread::available_parallelism()
72                .map(|p| p.get() as u32)
73                .unwrap_or(1),
74        }
75    }
76
77    fn allocate(&self, size_bytes: usize) -> Result<BufferHandle> {
78        let handle = next_buffer_handle();
79        // Allocate aligned to 64 bytes for SIMD. Vec<u8> doesn't guarantee
80        // this, but on most allocators, allocations >= 64 bytes are aligned.
81        // For a production engine we'd use aligned_alloc or a custom allocator.
82        let buf = vec![0u8; size_bytes];
83        self.buffers.lock().unwrap().insert(handle.0, buf);
84        Ok(handle)
85    }
86
87    fn free(&self, handle: BufferHandle) -> Result<()> {
88        self.buffers.lock().unwrap().remove(&handle.0);
89        Ok(())
90    }
91
92    /// Matrix multiply: C[m,n] = A[m,k] * B[k,n]
93    /// For single-token decode (m=1), this is a GEMV.
94    /// Buffers hold row-major f32 data.
95    fn matmul(
96        &self,
97        a: &BufferHandle,
98        b: &BufferHandle,
99        out: &BufferHandle,
100        m: u32,
101        n: u32,
102        k: u32,
103    ) -> Result<()> {
104        let mut buffers = self.buffers.lock().unwrap();
105        // We need to borrow a, b immutably and out mutably.
106        // Since they share one HashMap, we extract raw pointers carefully.
107        let a_data = Self::get_buf(&buffers, a)?.as_ptr();
108        let b_data = Self::get_buf(&buffers, b)?.as_ptr();
109        let a_len = Self::get_buf(&buffers, a)?.len();
110        let b_len = Self::get_buf(&buffers, b)?.len();
111        let out_buf = Self::get_buf_mut(&mut buffers, out)?;
112        let out_f32 = as_f32_slice_mut(out_buf);
113
114        // SAFETY: pointers stay valid while lock is held, and out is distinct
115        let a_f32: &[f32] = bytemuck::cast_slice(unsafe {
116            std::slice::from_raw_parts(a_data, a_len)
117        });
118        let b_f32: &[f32] = bytemuck::cast_slice(unsafe {
119            std::slice::from_raw_parts(b_data, b_len)
120        });
121
122        let (m, n, k) = (m as usize, n as usize, k as usize);
123
124        // Naive GEMM: C[i,j] = sum_p A[i,p] * B[p,j]
125        // For m=1 this is a GEMV and dominates inference time.
126        for i in 0..m {
127            for j in 0..n {
128                let mut sum = 0.0f32;
129                for p in 0..k {
130                    sum += a_f32[i * k + p] * b_f32[p * n + j];
131                }
132                out_f32[i * n + j] = sum;
133            }
134        }
135        Ok(())
136    }
137
138    /// Softmax: out[i] = exp(input[i] - max) / sum(exp(input - max))
139    /// Numerically stable via max subtraction.
140    fn softmax(
141        &self,
142        input: &BufferHandle,
143        output: &BufferHandle,
144        size: u32,
145    ) -> Result<()> {
146        let mut buffers = self.buffers.lock().unwrap();
147        let inp_data = Self::get_buf(&buffers, input)?.as_ptr();
148        let inp_len = Self::get_buf(&buffers, input)?.len();
149        let out_buf = Self::get_buf_mut(&mut buffers, output)?;
150        let out_f32 = as_f32_slice_mut(out_buf);
151
152        let inp_f32: &[f32] = bytemuck::cast_slice(unsafe {
153            std::slice::from_raw_parts(inp_data, inp_len)
154        });
155
156        let n = size as usize;
157        let mut max_val = f32::NEG_INFINITY;
158        for i in 0..n {
159            if inp_f32[i] > max_val {
160                max_val = inp_f32[i];
161            }
162        }
163
164        let mut sum = 0.0f32;
165        for i in 0..n {
166            let e = (inp_f32[i] - max_val).exp();
167            out_f32[i] = e;
168            sum += e;
169        }
170
171        let inv_sum = 1.0 / sum;
172        for i in 0..n {
173            out_f32[i] *= inv_sum;
174        }
175        Ok(())
176    }
177
178    /// RMSNorm: out = (input / rms) * weight
179    /// where rms = sqrt(mean(input^2) + eps)
180    fn rms_norm(
181        &self,
182        input: &BufferHandle,
183        weight: &BufferHandle,
184        output: &BufferHandle,
185        size: u32,
186        eps: f32,
187    ) -> Result<()> {
188        let mut buffers = self.buffers.lock().unwrap();
189        let inp_data = Self::get_buf(&buffers, input)?.as_ptr();
190        let inp_len = Self::get_buf(&buffers, input)?.len();
191        let wt_data = Self::get_buf(&buffers, weight)?.as_ptr();
192        let wt_len = Self::get_buf(&buffers, weight)?.len();
193        let out_buf = Self::get_buf_mut(&mut buffers, output)?;
194        let out_f32 = as_f32_slice_mut(out_buf);
195
196        let inp_f32: &[f32] = bytemuck::cast_slice(unsafe {
197            std::slice::from_raw_parts(inp_data, inp_len)
198        });
199        let wt_f32: &[f32] = bytemuck::cast_slice(unsafe {
200            std::slice::from_raw_parts(wt_data, wt_len)
201        });
202
203        let n = size as usize;
204        let mut ss = 0.0f32;
205        for i in 0..n {
206            ss += inp_f32[i] * inp_f32[i];
207        }
208        let rms = (ss / n as f32 + eps).sqrt();
209        let inv_rms = 1.0 / rms;
210
211        for i in 0..n {
212            out_f32[i] = inp_f32[i] * inv_rms * wt_f32[i];
213        }
214        Ok(())
215    }
216
217    /// RoPE (Rotary Position Embedding) applied in-place to Q and K buffers.
218    /// head_dim: dimension per attention head (typically 128).
219    /// Applies rotation in pairs: (q[2i], q[2i+1]) rotated by pos * freq.
220    fn rope(
221        &self,
222        q: &BufferHandle,
223        k: &BufferHandle,
224        pos: u32,
225        head_dim: u32,
226        freq_base: f32,
227        _n_heads_q: u32,
228        _n_heads_k: u32,
229    ) -> Result<()> {
230        let mut buffers = self.buffers.lock().unwrap();
231
232        // Apply RoPE to both Q and K in sequence
233        for handle in [q, k] {
234            let buf = Self::get_buf_mut(&mut buffers, handle)?;
235            let f32_data = as_f32_slice_mut(buf);
236            let hd = head_dim as usize;
237            let n_heads = f32_data.len() / hd;
238
239            for h in 0..n_heads {
240                let base = h * hd;
241                for i in 0..(hd / 2) {
242                    let freq = 1.0 / freq_base.powf(2.0 * i as f32 / hd as f32);
243                    let theta = pos as f32 * freq;
244                    let cos_t = theta.cos();
245                    let sin_t = theta.sin();
246
247                    let x0 = f32_data[base + 2 * i];
248                    let x1 = f32_data[base + 2 * i + 1];
249                    f32_data[base + 2 * i] = x0 * cos_t - x1 * sin_t;
250                    f32_data[base + 2 * i + 1] = x0 * sin_t + x1 * cos_t;
251                }
252            }
253        }
254        Ok(())
255    }
256
257    /// SiLU (Sigmoid Linear Unit): out[i] = input[i] * sigmoid(input[i])
258    /// Also known as swish. Used in SwiGLU FFN.
259    fn silu(
260        &self,
261        input: &BufferHandle,
262        output: &BufferHandle,
263        size: u32,
264    ) -> Result<()> {
265        let mut buffers = self.buffers.lock().unwrap();
266        let inp_data = Self::get_buf(&buffers, input)?.as_ptr();
267        let inp_len = Self::get_buf(&buffers, input)?.len();
268        let out_buf = Self::get_buf_mut(&mut buffers, output)?;
269        let out_f32 = as_f32_slice_mut(out_buf);
270
271        let inp_f32: &[f32] = bytemuck::cast_slice(unsafe {
272            std::slice::from_raw_parts(inp_data, inp_len)
273        });
274
275        let n = size as usize;
276        for i in 0..n {
277            let x = inp_f32[i];
278            let sigmoid = 1.0 / (1.0 + (-x).exp());
279            out_f32[i] = x * sigmoid;
280        }
281        Ok(())
282    }
283
284    /// Element-wise multiply: out[i] = a[i] * b[i]
285    fn element_mul(
286        &self,
287        a: &BufferHandle,
288        b: &BufferHandle,
289        output: &BufferHandle,
290        size: u32,
291    ) -> Result<()> {
292        let mut buffers = self.buffers.lock().unwrap();
293        let a_data = Self::get_buf(&buffers, a)?.as_ptr();
294        let a_len = Self::get_buf(&buffers, a)?.len();
295        let b_data = Self::get_buf(&buffers, b)?.as_ptr();
296        let b_len = Self::get_buf(&buffers, b)?.len();
297        let out_buf = Self::get_buf_mut(&mut buffers, output)?;
298        let out_f32 = as_f32_slice_mut(out_buf);
299
300        let a_f32: &[f32] = bytemuck::cast_slice(unsafe {
301            std::slice::from_raw_parts(a_data, a_len)
302        });
303        let b_f32: &[f32] = bytemuck::cast_slice(unsafe {
304            std::slice::from_raw_parts(b_data, b_len)
305        });
306
307        let n = size as usize;
308        for i in 0..n {
309            out_f32[i] = a_f32[i] * b_f32[i];
310        }
311        Ok(())
312    }
313
314    /// Element-wise add: out[i] = a[i] + b[i]
315    fn add(
316        &self,
317        a: &BufferHandle,
318        b: &BufferHandle,
319        output: &BufferHandle,
320        size: u32,
321    ) -> Result<()> {
322        let mut buffers = self.buffers.lock().unwrap();
323        let a_data = Self::get_buf(&buffers, a)?.as_ptr();
324        let a_len = Self::get_buf(&buffers, a)?.len();
325        let b_data = Self::get_buf(&buffers, b)?.as_ptr();
326        let b_len = Self::get_buf(&buffers, b)?.len();
327        let out_buf = Self::get_buf_mut(&mut buffers, output)?;
328        let out_f32 = as_f32_slice_mut(out_buf);
329
330        let a_f32: &[f32] = bytemuck::cast_slice(unsafe {
331            std::slice::from_raw_parts(a_data, a_len)
332        });
333        let b_f32: &[f32] = bytemuck::cast_slice(unsafe {
334            std::slice::from_raw_parts(b_data, b_len)
335        });
336
337        let n = size as usize;
338        for i in 0..n {
339            out_f32[i] = a_f32[i] + b_f32[i];
340        }
341        Ok(())
342    }
343
344    fn copy_to_device(&self, data: &[u8], handle: &BufferHandle) -> Result<()> {
345        let mut buffers = self.buffers.lock().unwrap();
346        let buf = buffers.get_mut(&handle.0)
347            .ok_or_else(|| YuleError::Gpu("buffer not found".into()))?;
348        buf[..data.len()].copy_from_slice(data);
349        Ok(())
350    }
351
352    fn copy_from_device(&self, handle: &BufferHandle, data: &mut [u8]) -> Result<()> {
353        let buffers = self.buffers.lock().unwrap();
354        let buf = buffers.get(&handle.0)
355            .ok_or_else(|| YuleError::Gpu("buffer not found".into()))?;
356        data.copy_from_slice(&buf[..data.len()]);
357        Ok(())
358    }
359
360    fn copy_buffer(&self, src: &BufferHandle, dst: &BufferHandle, size: usize) -> Result<()> {
361        let mut buffers = self.buffers.lock().unwrap();
362        let src_ptr = Self::get_buf(&buffers, src)?.as_ptr();
363        let src_len = Self::get_buf(&buffers, src)?.len();
364        let dst_buf = Self::get_buf_mut(&mut buffers, dst)?;
365        let n = size.min(src_len).min(dst_buf.len());
366        let src_slice = unsafe { std::slice::from_raw_parts(src_ptr, n) };
367        dst_buf[..n].copy_from_slice(src_slice);
368        Ok(())
369    }
370
371    fn copy_buffer_offset(
372        &self, src: &BufferHandle, dst: &BufferHandle,
373        src_offset: usize, dst_offset: usize, size: usize,
374    ) -> Result<()> {
375        let mut buffers = self.buffers.lock().unwrap();
376        let src_ptr = Self::get_buf(&buffers, src)?.as_ptr();
377        let src_len = Self::get_buf(&buffers, src)?.len();
378        let dst_buf = Self::get_buf_mut(&mut buffers, dst)?;
379        let src_slice = unsafe { std::slice::from_raw_parts(src_ptr.add(src_offset), size.min(src_len - src_offset)) };
380        dst_buf[dst_offset..dst_offset + size].copy_from_slice(src_slice);
381        Ok(())
382    }
383
384    fn synchronize(&self) -> Result<()> {
385        Ok(()) // CPU is synchronous
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    fn write_f32(backend: &CpuBackend, handle: &BufferHandle, data: &[f32]) {
394        let bytes: &[u8] = bytemuck::cast_slice(data);
395        backend.copy_to_device(bytes, handle).unwrap();
396    }
397
398    fn read_f32(backend: &CpuBackend, handle: &BufferHandle, n: usize) -> Vec<f32> {
399        let mut bytes = vec![0u8; n * 4];
400        backend.copy_from_device(handle, &mut bytes).unwrap();
401        bytemuck::cast_slice(&bytes).to_vec()
402    }
403
404    #[test]
405    fn test_softmax() {
406        let b = CpuBackend::new();
407        let inp = b.allocate(16).unwrap(); // 4 floats
408        let out = b.allocate(16).unwrap();
409        write_f32(&b, &inp, &[1.0, 2.0, 3.0, 4.0]);
410
411        b.softmax(&inp, &out, 4).unwrap();
412        let result = read_f32(&b, &out, 4);
413
414        // Check sums to 1.0
415        let sum: f32 = result.iter().sum();
416        assert!((sum - 1.0).abs() < 1e-5);
417        // Check monotonicity
418        assert!(result[3] > result[2]);
419        assert!(result[2] > result[1]);
420        assert!(result[1] > result[0]);
421    }
422
423    #[test]
424    fn test_rms_norm() {
425        let b = CpuBackend::new();
426        let inp = b.allocate(16).unwrap();
427        let wt = b.allocate(16).unwrap();
428        let out = b.allocate(16).unwrap();
429
430        write_f32(&b, &inp, &[1.0, 2.0, 3.0, 4.0]);
431        write_f32(&b, &wt, &[1.0, 1.0, 1.0, 1.0]);
432
433        b.rms_norm(&inp, &wt, &out, 4, 1e-6).unwrap();
434        let result = read_f32(&b, &out, 4);
435
436        // RMS of [1,2,3,4] = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.7386
437        let rms = (7.5f32 + 1e-6).sqrt();
438        assert!((result[0] - 1.0 / rms).abs() < 1e-4);
439        assert!((result[3] - 4.0 / rms).abs() < 1e-4);
440    }
441
442    #[test]
443    fn test_silu() {
444        let b = CpuBackend::new();
445        let inp = b.allocate(12).unwrap();
446        let out = b.allocate(12).unwrap();
447
448        write_f32(&b, &inp, &[0.0, 1.0, -1.0]);
449        b.silu(&inp, &out, 3).unwrap();
450        let result = read_f32(&b, &out, 3);
451
452        // silu(0) = 0 * 0.5 = 0
453        assert!((result[0] - 0.0).abs() < 1e-5);
454        // silu(1) = 1 * sigmoid(1) ≈ 0.7311
455        assert!((result[1] - 0.7311).abs() < 1e-3);
456        // silu(-1) = -1 * sigmoid(-1) ≈ -0.2689
457        assert!((result[2] - (-0.2689)).abs() < 1e-3);
458    }
459
460    #[test]
461    fn test_element_mul() {
462        let b = CpuBackend::new();
463        let a = b.allocate(12).unwrap();
464        let bh = b.allocate(12).unwrap();
465        let out = b.allocate(12).unwrap();
466
467        write_f32(&b, &a, &[2.0, 3.0, 4.0]);
468        write_f32(&b, &bh, &[5.0, 6.0, 7.0]);
469        b.element_mul(&a, &bh, &out, 3).unwrap();
470        let result = read_f32(&b, &out, 3);
471
472        assert!((result[0] - 10.0).abs() < 1e-5);
473        assert!((result[1] - 18.0).abs() < 1e-5);
474        assert!((result[2] - 28.0).abs() < 1e-5);
475    }
476
477    #[test]
478    fn test_add() {
479        let b = CpuBackend::new();
480        let a = b.allocate(12).unwrap();
481        let bh = b.allocate(12).unwrap();
482        let out = b.allocate(12).unwrap();
483
484        write_f32(&b, &a, &[1.0, 2.0, 3.0]);
485        write_f32(&b, &bh, &[4.0, 5.0, 6.0]);
486        b.add(&a, &bh, &out, 3).unwrap();
487        let result = read_f32(&b, &out, 3);
488
489        assert!((result[0] - 5.0).abs() < 1e-5);
490        assert!((result[1] - 7.0).abs() < 1e-5);
491        assert!((result[2] - 9.0).abs() < 1e-5);
492    }
493
494    #[test]
495    fn test_matmul_gemv() {
496        // GEMV: 1×4 times 4×3 = 1×3
497        let b = CpuBackend::new();
498        let a = b.allocate(16).unwrap();  // 1×4
499        let bh = b.allocate(48).unwrap(); // 4×3
500        let out = b.allocate(12).unwrap(); // 1×3
501
502        write_f32(&b, &a, &[1.0, 2.0, 3.0, 4.0]);
503        // B row-major: [[1,0,0],[0,1,0],[0,0,1],[1,1,1]]
504        write_f32(&b, &bh, &[
505            1.0, 0.0, 0.0,
506            0.0, 1.0, 0.0,
507            0.0, 0.0, 1.0,
508            1.0, 1.0, 1.0,
509        ]);
510
511        b.matmul(&a, &bh, &out, 1, 3, 4).unwrap();
512        let result = read_f32(&b, &out, 3);
513
514        // C[0,0] = 1*1 + 2*0 + 3*0 + 4*1 = 5
515        // C[0,1] = 1*0 + 2*1 + 3*0 + 4*1 = 6
516        // C[0,2] = 1*0 + 2*0 + 3*1 + 4*1 = 7
517        assert!((result[0] - 5.0).abs() < 1e-5);
518        assert!((result[1] - 6.0).abs() < 1e-5);
519        assert!((result[2] - 7.0).abs() < 1e-5);
520    }
521
522    #[test]
523    fn test_rope_single_head_pos0() {
524        let b = CpuBackend::new();
525        let q = b.allocate(16).unwrap(); // 1 head, head_dim=4
526        let k = b.allocate(16).unwrap();
527
528        write_f32(&b, &q, &[1.0, 0.0, 1.0, 0.0]);
529        write_f32(&b, &k, &[0.0, 1.0, 0.0, 1.0]);
530
531        b.rope(&q, &k, 0, 4, 10000.0, 1, 1).unwrap();
532        let q_result = read_f32(&b, &q, 4);
533        let k_result = read_f32(&b, &k, 4);
534
535        // At pos=0, theta=0 for all freqs, so cos=1, sin=0 → no change
536        assert!((q_result[0] - 1.0).abs() < 1e-5);
537        assert!((q_result[1] - 0.0).abs() < 1e-5);
538        assert!((k_result[1] - 1.0).abs() < 1e-5);
539    }
540
541    #[test]
542    fn test_rope_multi_head() {
543        let b = CpuBackend::new();
544        // 2 Q heads, 1 KV head, head_dim=4
545        let q = b.allocate(32).unwrap(); // 2 * 4 floats
546        let k = b.allocate(16).unwrap(); // 1 * 4 floats
547
548        write_f32(&b, &q, &[1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0]);
549        write_f32(&b, &k, &[1.0, 0.0, 1.0, 0.0]);
550
551        // pos=0 → no rotation, but verify both heads are processed
552        b.rope(&q, &k, 0, 4, 10000.0, 2, 1).unwrap();
553        let q_result = read_f32(&b, &q, 8);
554        let k_result = read_f32(&b, &k, 4);
555
556        // Both Q heads should be unchanged at pos=0
557        assert!((q_result[0] - 1.0).abs() < 1e-5); // head 0
558        assert!((q_result[4] - 0.0).abs() < 1e-5); // head 1
559        assert!((q_result[5] - 1.0).abs() < 1e-5); // head 1
560        assert!((k_result[0] - 1.0).abs() < 1e-5);
561    }
562
563    #[test]
564    fn test_rope_nonzero_pos() {
565        let b = CpuBackend::new();
566        let q = b.allocate(16).unwrap(); // 1 head, head_dim=4
567        let k = b.allocate(16).unwrap();
568
569        write_f32(&b, &q, &[1.0, 0.0, 1.0, 0.0]);
570        write_f32(&b, &k, &[1.0, 0.0, 1.0, 0.0]);
571
572        b.rope(&q, &k, 5, 4, 10000.0, 1, 1).unwrap();
573        let q_result = read_f32(&b, &q, 4);
574
575        // At pos=5, pair (q[0], q[1]) should be rotated
576        // freq = 1/10000^(0/4) = 1.0, theta = 5.0
577        let cos5 = 5.0f32.cos();
578        let sin5 = 5.0f32.sin();
579        // q[0] = 1.0 * cos5 - 0.0 * sin5 = cos5
580        assert!((q_result[0] - cos5).abs() < 1e-4);
581        // q[1] = 1.0 * sin5 + 0.0 * cos5 = sin5
582        assert!((q_result[1] - sin5).abs() < 1e-4);
583    }
584
585    #[test]
586    fn test_copy_buffer() {
587        let b = CpuBackend::new();
588        let src = b.allocate(16).unwrap();
589        let dst = b.allocate(16).unwrap();
590
591        write_f32(&b, &src, &[1.0, 2.0, 3.0, 4.0]);
592        b.copy_buffer(&src, &dst, 16).unwrap();
593        let result = read_f32(&b, &dst, 4);
594
595        assert!((result[0] - 1.0).abs() < 1e-5);
596        assert!((result[3] - 4.0).abs() < 1e-5);
597    }
598
599    #[test]
600    fn test_copy_buffer_offset() {
601        let b = CpuBackend::new();
602        let src = b.allocate(16).unwrap(); // 4 floats
603        let dst = b.allocate(32).unwrap(); // 8 floats, initially zeros
604
605        write_f32(&b, &src, &[10.0, 20.0, 30.0, 40.0]);
606
607        // Copy 2 floats from src offset 4 (starting at float[1]) to dst offset 8 (float[2])
608        b.copy_buffer_offset(&src, &dst, 4, 8, 8).unwrap();
609        let result = read_f32(&b, &dst, 8);
610
611        assert!((result[0] - 0.0).abs() < 1e-5); // untouched
612        assert!((result[1] - 0.0).abs() < 1e-5); // untouched
613        assert!((result[2] - 20.0).abs() < 1e-5); // copied from src[1]
614        assert!((result[3] - 30.0).abs() < 1e-5); // copied from src[2]
615        assert!((result[4] - 0.0).abs() < 1e-5); // untouched
616    }
617
618    #[test]
619    fn test_copy_buffer_offset_kv_cache_pattern() {
620        // Simulate KV cache write: copy one position's worth of data to a specific offset
621        let b = CpuBackend::new();
622        let n_kv_heads = 2;
623        let head_dim = 4;
624        let kv_stride = n_kv_heads * head_dim; // 8 floats per position
625        let max_seq_len = 4;
626
627        let k_tmp = b.allocate(kv_stride * 4).unwrap();
628        let k_cache = b.allocate(max_seq_len * kv_stride * 4).unwrap();
629
630        // Write position data
631        let k_data: Vec<f32> = (0..kv_stride).map(|i| (i + 1) as f32).collect();
632        write_f32(&b, &k_tmp, &k_data);
633
634        // Write to position 2
635        let pos = 2;
636        let cache_byte_offset = pos * kv_stride * 4;
637        b.copy_buffer_offset(&k_tmp, &k_cache, 0, cache_byte_offset, kv_stride * 4).unwrap();
638
639        let cache = read_f32(&b, &k_cache, max_seq_len * kv_stride);
640
641        // Position 0 and 1 should be zeros
642        assert!((cache[0] - 0.0).abs() < 1e-5);
643        assert!((cache[kv_stride - 1] - 0.0).abs() < 1e-5);
644        // Position 2 should have our data
645        assert!((cache[pos * kv_stride] - 1.0).abs() < 1e-5);
646        assert!((cache[pos * kv_stride + kv_stride - 1] - kv_stride as f32).abs() < 1e-5);
647        // Position 3 should be zeros
648        assert!((cache[3 * kv_stride] - 0.0).abs() < 1e-5);
649    }
650
651    #[test]
652    fn test_attention_manual() {
653        // Test the full attention pipeline: score → softmax → value aggregation
654        // Single head, head_dim=2, seq_len=3
655        let b = CpuBackend::new();
656        let hd = 2;
657        let seq_len = 3;
658
659        let q = b.allocate(hd * 4).unwrap();
660        let k_cache = b.allocate(seq_len * hd * 4).unwrap();
661        let v_cache = b.allocate(seq_len * hd * 4).unwrap();
662        let scores = b.allocate(seq_len * 4).unwrap();
663        let out = b.allocate(hd * 4).unwrap();
664
665        // Q = [1, 0]
666        write_f32(&b, &q, &[1.0, 0.0]);
667
668        // K cache: 3 positions, each 2 dims
669        // K[0] = [1, 0], K[1] = [0, 1], K[2] = [1, 1]
670        write_f32(&b, &k_cache, &[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
671
672        // V cache: V[0] = [10, 20], V[1] = [30, 40], V[2] = [50, 60]
673        write_f32(&b, &v_cache, &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
674
675        // Compute scores manually: Q · K[t] / sqrt(hd)
676        // score[0] = (1*1 + 0*0) / sqrt(2) = 1/sqrt(2) ≈ 0.7071
677        // score[1] = (1*0 + 0*1) / sqrt(2) = 0
678        // score[2] = (1*1 + 0*1) / sqrt(2) = 1/sqrt(2) ≈ 0.7071
679        let scale = 1.0 / (hd as f32).sqrt();
680        let s0 = 1.0 * scale;
681        let s1 = 0.0 * scale;
682        let s2 = 1.0 * scale;
683
684        // Softmax
685        let max_s = s0.max(s1).max(s2);
686        let e0 = (s0 - max_s).exp();
687        let e1 = (s1 - max_s).exp();
688        let e2 = (s2 - max_s).exp();
689        let sum = e0 + e1 + e2;
690        let w0 = e0 / sum;
691        let w1 = e1 / sum;
692        let w2 = e2 / sum;
693
694        // Weighted V: out = w0*V[0] + w1*V[1] + w2*V[2]
695        let expected_0 = w0 * 10.0 + w1 * 30.0 + w2 * 50.0;
696        let expected_1 = w0 * 20.0 + w1 * 40.0 + w2 * 60.0;
697
698        // Now use the actual backend ops
699        // Step 1: compute scores
700        let q_f32 = read_f32(&b, &q, hd);
701        let k_f32 = read_f32(&b, &k_cache, seq_len * hd);
702        let mut scores_f32 = vec![0.0f32; seq_len];
703        for t in 0..seq_len {
704            let mut dot = 0.0f32;
705            for d in 0..hd {
706                dot += q_f32[d] * k_f32[t * hd + d];
707            }
708            scores_f32[t] = dot * scale;
709        }
710        write_f32(&b, &scores, &scores_f32);
711
712        // Step 2: softmax
713        b.softmax(&scores, &scores, seq_len as u32).unwrap();
714        let weights = read_f32(&b, &scores, seq_len);
715
716        // Verify softmax weights match
717        assert!((weights[0] - w0).abs() < 1e-4);
718        assert!((weights[1] - w1).abs() < 1e-4);
719        assert!((weights[2] - w2).abs() < 1e-4);
720
721        // Step 3: weighted value sum
722        let v_f32 = read_f32(&b, &v_cache, seq_len * hd);
723        let mut out_f32 = vec![0.0f32; hd];
724        for t in 0..seq_len {
725            for d in 0..hd {
726                out_f32[d] += weights[t] * v_f32[t * hd + d];
727            }
728        }
729        write_f32(&b, &out, &out_f32);
730        let result = read_f32(&b, &out, hd);
731
732        assert!((result[0] - expected_0).abs() < 1e-3);
733        assert!((result[1] - expected_1).abs() < 1e-3);
734    }
735
736    #[test]
737    fn test_attention_gqa() {
738        // 2 Q heads sharing 1 KV head (GQA ratio 2:1)
739        // head_dim=2, seq_len=2
740        let b = CpuBackend::new();
741        let hd = 2;
742        let n_heads = 2;
743        let n_kv_heads = 1;
744        let kv_stride = n_kv_heads * hd;
745        let seq_len = 2;
746
747        // Q: 2 heads × 2 dims = 4 floats
748        let q = b.allocate(n_heads * hd * 4).unwrap();
749        write_f32(&b, &q, &[1.0, 0.0, 0.0, 1.0]); // head0=[1,0], head1=[0,1]
750
751        // KV cache: 2 positions × 1 kv head × 2 dims
752        let k_cache = b.allocate(seq_len * kv_stride * 4).unwrap();
753        let v_cache = b.allocate(seq_len * kv_stride * 4).unwrap();
754        write_f32(&b, &k_cache, &[1.0, 0.0, 0.0, 1.0]); // K[0]=[1,0], K[1]=[0,1]
755        write_f32(&b, &v_cache, &[10.0, 20.0, 30.0, 40.0]); // V[0]=[10,20], V[1]=[30,40]
756
757        let scores_buf = b.allocate(seq_len * 4).unwrap();
758        let attn_out = b.allocate(n_heads * hd * 4).unwrap();
759
760        let scale = 1.0 / (hd as f32).sqrt();
761        let kv_group = n_heads / n_kv_heads;
762
763        // Process each Q head, both share the same KV head
764        for h in 0..n_heads {
765            let kv_h = h / kv_group;
766            let head_offset = h * hd;
767            let kv_off = kv_h * hd;
768
769            // Compute scores for this head
770            let q_f32 = read_f32(&b, &q, n_heads * hd);
771            let k_f32 = read_f32(&b, &k_cache, seq_len * kv_stride);
772            let mut scores = vec![0.0f32; seq_len];
773            for t in 0..seq_len {
774                let mut dot = 0.0f32;
775                for d in 0..hd {
776                    dot += q_f32[head_offset + d] * k_f32[t * kv_stride + kv_off + d];
777                }
778                scores[t] = dot * scale;
779            }
780            write_f32(&b, &scores_buf, &scores);
781
782            b.softmax(&scores_buf, &scores_buf, seq_len as u32).unwrap();
783            let weights = read_f32(&b, &scores_buf, seq_len);
784
785            // Weighted value
786            let v_f32 = read_f32(&b, &v_cache, seq_len * kv_stride);
787            let mut head_out = vec![0.0f32; hd];
788            for t in 0..seq_len {
789                for d in 0..hd {
790                    head_out[d] += weights[t] * v_f32[t * kv_stride + kv_off + d];
791                }
792            }
793
794            // Write to attn_out at head offset
795            let mut full_out = read_f32(&b, &attn_out, n_heads * hd);
796            full_out[head_offset..head_offset + hd].copy_from_slice(&head_out);
797            write_f32(&b, &attn_out, &full_out);
798        }
799
800        let result = read_f32(&b, &attn_out, n_heads * hd);
801
802        // Head 0: Q=[1,0], K[0]=[1,0]→dot=1, K[1]=[0,1]→dot=0
803        // Scores: [1/√2, 0], softmax gives more weight to position 0
804        // Head 1: Q=[0,1], K[0]=[1,0]→dot=0, K[1]=[0,1]→dot=1
805        // Scores: [0, 1/√2], softmax gives more weight to position 1
806
807        // Head 0 should weight V[0]=[10,20] more heavily
808        assert!(result[0] < 25.0); // closer to 10 than 30
809        assert!(result[1] < 35.0); // closer to 20 than 40
810
811        // Head 1 should weight V[1]=[30,40] more heavily
812        assert!(result[2] > 15.0); // closer to 30 than 10
813        assert!(result[3] > 25.0); // closer to 40 than 20
814    }
815}