ruvector_sparse_inference/backend/
cpu.rs

1//! CPU backend with portable SIMD optimizations
2
3use super::Backend;
4use crate::config::ActivationType;
5use ndarray::Array2;
6use std::sync::OnceLock;
7
8#[cfg(target_arch = "x86_64")]
9use std::arch::x86_64::*;
10
11#[cfg(target_arch = "aarch64")]
12use std::arch::aarch64::*;
13
14/// Cached SIMD feature detection for x86_64
15#[cfg(target_arch = "x86_64")]
16static SIMD_FEATURES: OnceLock<SimdFeatures> = OnceLock::new();
17
18#[cfg(target_arch = "x86_64")]
19#[derive(Debug, Clone, Copy)]
20struct SimdFeatures {
21    has_avx2: bool,
22    has_sse41: bool,
23    has_fma: bool,
24}
25
26#[cfg(target_arch = "x86_64")]
27fn get_simd_features() -> SimdFeatures {
28    *SIMD_FEATURES.get_or_init(|| SimdFeatures {
29        has_avx2: is_x86_feature_detected!("avx2"),
30        has_sse41: is_x86_feature_detected!("sse4.1"),
31        has_fma: is_x86_feature_detected!("fma"),
32    })
33}
34
35/// CPU backend using portable SIMD
36pub struct CpuBackend;
37
38impl Backend for CpuBackend {
39    fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
40        debug_assert_eq!(a.len(), b.len());
41
42        #[cfg(target_arch = "x86_64")]
43        {
44            let features = get_simd_features();
45            if features.has_avx2 {
46                return unsafe { dot_product_avx2(a, b) };
47            } else if features.has_sse41 {
48                return unsafe { dot_product_sse(a, b) };
49            }
50            return dot_product_scalar(a, b);
51        }
52
53        #[cfg(target_arch = "aarch64")]
54        return unsafe { dot_product_neon(a, b) };
55
56        // Fallback scalar
57        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
58        dot_product_scalar(a, b)
59    }
60
61    fn sparse_matmul(&self, matrix: &Array2<f32>, input: &[f32], rows: &[usize]) -> Vec<f32> {
62        let mut output = Vec::with_capacity(rows.len());
63
64        for &row_idx in rows {
65            let row = matrix.row(row_idx);
66            let dot = self.dot_product(row.as_slice().unwrap(), input);
67            output.push(dot);
68        }
69
70        output
71    }
72
73    fn sparse_matmul_accumulate(
74        &self,
75        matrix: &Array2<f32>,
76        input: &[f32],
77        cols: &[usize],
78        output: &mut [f32],
79    ) {
80        for (i, &col_idx) in cols.iter().enumerate() {
81            let col = matrix.column(col_idx);
82            let scalar = input[i];
83            // Column view may not be contiguous, iterate element-by-element
84            for (j, &val) in col.iter().enumerate() {
85                output[j] += val * scalar;
86            }
87        }
88    }
89
90    fn activation(&self, data: &mut [f32], activation_type: ActivationType) {
91        #[cfg(target_arch = "x86_64")]
92        let features = get_simd_features();
93
94        match activation_type {
95            ActivationType::Relu => {
96                #[cfg(target_arch = "x86_64")]
97                if features.has_avx2 {
98                    return unsafe { relu_avx2(data) };
99                }
100                relu_scalar(data);
101            }
102            ActivationType::Gelu => {
103                #[cfg(target_arch = "x86_64")]
104                if features.has_avx2 {
105                    return unsafe { gelu_avx2(data) };
106                }
107                gelu_scalar(data);
108            }
109            ActivationType::Silu | ActivationType::Swish => {
110                #[cfg(target_arch = "x86_64")]
111                if features.has_avx2 {
112                    return unsafe { silu_avx2(data) };
113                }
114                silu_scalar(data);
115            }
116            ActivationType::Identity => { /* no-op */ }
117        }
118    }
119
120    fn add(&self, a: &mut [f32], b: &[f32]) {
121        debug_assert_eq!(a.len(), b.len());
122
123        #[cfg(target_arch = "x86_64")]
124        if get_simd_features().has_avx2 {
125            return unsafe { add_avx2(a, b) };
126        }
127
128        for (x, y) in a.iter_mut().zip(b.iter()) {
129            *x += y;
130        }
131    }
132
133    fn axpy(&self, a: &mut [f32], b: &[f32], scalar: f32) {
134        debug_assert_eq!(a.len(), b.len());
135
136        #[cfg(target_arch = "x86_64")]
137        if get_simd_features().has_avx2 {
138            return unsafe { axpy_avx2(a, b, scalar) };
139        }
140
141        for (x, y) in a.iter_mut().zip(b.iter()) {
142            *x += y * scalar;
143        }
144    }
145
146    fn name(&self) -> &'static str {
147        #[cfg(target_arch = "x86_64")]
148        {
149            let features = get_simd_features();
150            if features.has_avx2 {
151                return "CPU-AVX2";
152            } else if features.has_sse41 {
153                return "CPU-SSE4.1";
154            }
155            return "CPU-Scalar";
156        }
157        #[cfg(target_arch = "aarch64")]
158        return "CPU-NEON";
159
160        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
161        "CPU-Scalar"
162    }
163
164    fn simd_width(&self) -> usize {
165        #[cfg(target_arch = "x86_64")]
166        {
167            let features = get_simd_features();
168            if features.has_avx2 { return 8; }
169            if features.has_sse41 { return 4; }
170            return 1;
171        }
172        #[cfg(target_arch = "aarch64")]
173        return 4;
174
175        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
176        1
177    }
178}
179
180// ============ AVX2 Implementations ============
181
182#[cfg(target_arch = "x86_64")]
183#[target_feature(enable = "avx2")]
184unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
185    let n = a.len();
186    let chunks = n / 8;
187
188    let mut sum = _mm256_setzero_ps();
189
190    for i in 0..chunks {
191        let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
192        let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
193        sum = _mm256_fmadd_ps(va, vb, sum);
194    }
195
196    // Horizontal sum
197    let sum128 = _mm_add_ps(
198        _mm256_extractf128_ps(sum, 0),
199        _mm256_extractf128_ps(sum, 1),
200    );
201    let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
202    let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
203    let mut result = _mm_cvtss_f32(sum32);
204
205    // Handle remainder
206    for i in (chunks * 8)..n {
207        result += a[i] * b[i];
208    }
209
210    result
211}
212
213#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "avx2")]
215unsafe fn relu_avx2(data: &mut [f32]) {
216    let zero = _mm256_setzero_ps();
217    let chunks = data.len() / 8;
218
219    for i in 0..chunks {
220        let ptr = data.as_mut_ptr().add(i * 8);
221        let v = _mm256_loadu_ps(ptr);
222        let result = _mm256_max_ps(v, zero);
223        _mm256_storeu_ps(ptr, result);
224    }
225
226    // Handle remainder
227    for i in (chunks * 8)..data.len() {
228        data[i] = data[i].max(0.0);
229    }
230}
231
232/// SIMD GELU using polynomial approximation
233/// GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
234/// Using fast tanh approximation for SIMD
235#[cfg(target_arch = "x86_64")]
236#[target_feature(enable = "avx2", enable = "fma")]
237unsafe fn gelu_avx2(data: &mut [f32]) {
238    let chunks = data.len() / 8;
239
240    // Constants for GELU approximation
241    let half = _mm256_set1_ps(0.5);
242    let one = _mm256_set1_ps(1.0);
243    let sqrt_2_over_pi = _mm256_set1_ps(0.7978845608); // sqrt(2/π)
244    let coef = _mm256_set1_ps(0.044715);
245
246    // Constants for fast tanh approximation: tanh(x) ≈ x * (27 + x²) / (27 + 9x²)
247    let c27 = _mm256_set1_ps(27.0);
248    let c9 = _mm256_set1_ps(9.0);
249
250    for i in 0..chunks {
251        let ptr = data.as_mut_ptr().add(i * 8);
252        let x = _mm256_loadu_ps(ptr);
253
254        // x³
255        let x2 = _mm256_mul_ps(x, x);
256        let x3 = _mm256_mul_ps(x2, x);
257
258        // inner = sqrt(2/π) * (x + 0.044715 * x³)
259        let inner = _mm256_mul_ps(sqrt_2_over_pi, _mm256_fmadd_ps(coef, x3, x));
260
261        // Fast tanh approximation
262        let inner2 = _mm256_mul_ps(inner, inner);
263        let num = _mm256_fmadd_ps(inner2, one, c27); // 27 + inner²
264        let den = _mm256_fmadd_ps(inner2, c9, c27); // 27 + 9*inner²
265        let tanh_approx = _mm256_mul_ps(inner, _mm256_div_ps(num, den));
266
267        // 0.5 * x * (1 + tanh)
268        let result = _mm256_mul_ps(half, _mm256_mul_ps(x, _mm256_add_ps(one, tanh_approx)));
269        _mm256_storeu_ps(ptr, result);
270    }
271
272    // Handle remainder with scalar
273    for i in (chunks * 8)..data.len() {
274        let x = data[i];
275        let x3 = x * x * x;
276        let inner = 0.7978845608 * (x + 0.044715 * x3);
277        data[i] = 0.5 * x * (1.0 + inner.tanh());
278    }
279}
280
281/// SIMD SiLU (Swish) using fast sigmoid approximation
282/// SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
283#[cfg(target_arch = "x86_64")]
284#[target_feature(enable = "avx2", enable = "fma")]
285unsafe fn silu_avx2(data: &mut [f32]) {
286    let chunks = data.len() / 8;
287
288    // For sigmoid, use: 1/(1+e^-x) ≈ 0.5 + 0.5*tanh(x/2)
289    let half = _mm256_set1_ps(0.5);
290    let c27 = _mm256_set1_ps(27.0);
291    let c9 = _mm256_set1_ps(9.0);
292    let one = _mm256_set1_ps(1.0);
293
294    for i in 0..chunks {
295        let ptr = data.as_mut_ptr().add(i * 8);
296        let x = _mm256_loadu_ps(ptr);
297
298        // Use sigmoid(x) = 0.5 + 0.5 * tanh(x/2)
299        let x_half = _mm256_mul_ps(x, half);
300
301        // Fast tanh(x/2)
302        let xh2 = _mm256_mul_ps(x_half, x_half);
303        let num = _mm256_fmadd_ps(xh2, one, c27);
304        let den = _mm256_fmadd_ps(xh2, c9, c27);
305        let tanh_approx = _mm256_mul_ps(x_half, _mm256_div_ps(num, den));
306
307        // sigmoid = 0.5 + 0.5 * tanh
308        let sigmoid = _mm256_fmadd_ps(half, tanh_approx, half);
309
310        // silu = x * sigmoid
311        let result = _mm256_mul_ps(x, sigmoid);
312        _mm256_storeu_ps(ptr, result);
313    }
314
315    // Handle remainder with scalar
316    for i in (chunks * 8)..data.len() {
317        let x = data[i];
318        data[i] = x / (1.0 + (-x).exp());
319    }
320}
321
322#[cfg(target_arch = "x86_64")]
323#[target_feature(enable = "avx2")]
324unsafe fn add_avx2(a: &mut [f32], b: &[f32]) {
325    let chunks = a.len() / 8;
326
327    for i in 0..chunks {
328        let pa = a.as_mut_ptr().add(i * 8);
329        let pb = b.as_ptr().add(i * 8);
330        let va = _mm256_loadu_ps(pa);
331        let vb = _mm256_loadu_ps(pb);
332        _mm256_storeu_ps(pa, _mm256_add_ps(va, vb));
333    }
334
335    for i in (chunks * 8)..a.len() {
336        a[i] += b[i];
337    }
338}
339
340#[cfg(target_arch = "x86_64")]
341#[target_feature(enable = "avx2")]
342unsafe fn axpy_avx2(a: &mut [f32], b: &[f32], scalar: f32) {
343    let vs = _mm256_set1_ps(scalar);
344    let chunks = a.len() / 8;
345
346    for i in 0..chunks {
347        let pa = a.as_mut_ptr().add(i * 8);
348        let pb = b.as_ptr().add(i * 8);
349        let va = _mm256_loadu_ps(pa);
350        let vb = _mm256_loadu_ps(pb);
351        let result = _mm256_fmadd_ps(vb, vs, va);
352        _mm256_storeu_ps(pa, result);
353    }
354
355    for i in (chunks * 8)..a.len() {
356        a[i] += b[i] * scalar;
357    }
358}
359
360// ============ SSE4.1 Implementations ============
361
362#[cfg(target_arch = "x86_64")]
363#[target_feature(enable = "sse4.1")]
364unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
365    let n = a.len();
366    let chunks = n / 4;
367
368    let mut sum = _mm_setzero_ps();
369
370    for i in 0..chunks {
371        let va = _mm_loadu_ps(a.as_ptr().add(i * 4));
372        let vb = _mm_loadu_ps(b.as_ptr().add(i * 4));
373        sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
374    }
375
376    // Horizontal sum
377    let sum2 = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
378    let sum1 = _mm_add_ss(sum2, _mm_shuffle_ps(sum2, sum2, 1));
379    let mut result = _mm_cvtss_f32(sum1);
380
381    for i in (chunks * 4)..n {
382        result += a[i] * b[i];
383    }
384
385    result
386}
387
388// ============ NEON Implementations (ARM) ============
389
390#[cfg(target_arch = "aarch64")]
391unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
392    let n = a.len();
393    let chunks = n / 4;
394
395    let mut sum = vdupq_n_f32(0.0);
396
397    for i in 0..chunks {
398        let va = vld1q_f32(a.as_ptr().add(i * 4));
399        let vb = vld1q_f32(b.as_ptr().add(i * 4));
400        sum = vfmaq_f32(sum, va, vb);
401    }
402
403    // Horizontal sum
404    let mut result = vaddvq_f32(sum);
405
406    for i in (chunks * 4)..n {
407        result += a[i] * b[i];
408    }
409
410    result
411}
412
413// ============ Scalar Fallbacks ============
414
415fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
416    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
417}
418
419fn relu_scalar(data: &mut [f32]) {
420    for x in data.iter_mut() {
421        *x = x.max(0.0);
422    }
423}
424
425fn gelu_scalar(data: &mut [f32]) {
426    const SQRT_2_OVER_PI: f32 = 0.7978845608;
427    const GELU_COEF: f32 = 0.044715;
428
429    for x in data.iter_mut() {
430        let x3 = *x * *x * *x;
431        let inner = SQRT_2_OVER_PI * (*x + GELU_COEF * x3);
432        *x = 0.5 * *x * (1.0 + inner.tanh());
433    }
434}
435
436fn silu_scalar(data: &mut [f32]) {
437    for x in data.iter_mut() {
438        *x = *x / (1.0 + (-*x).exp());
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_dot_product() {
448        let backend = CpuBackend;
449        let a = vec![1.0, 2.0, 3.0, 4.0];
450        let b = vec![2.0, 3.0, 4.0, 5.0];
451        let result = backend.dot_product(&a, &b);
452        assert!((result - 40.0).abs() < 1e-5);
453    }
454
455    #[test]
456    fn test_relu() {
457        let backend = CpuBackend;
458        let mut data = vec![-1.0, 0.0, 1.0, 2.0];
459        backend.activation(&mut data, ActivationType::Relu);
460        assert_eq!(data, vec![0.0, 0.0, 1.0, 2.0]);
461    }
462
463    #[test]
464    fn test_add() {
465        let backend = CpuBackend;
466        let mut a = vec![1.0, 2.0, 3.0, 4.0];
467        let b = vec![5.0, 6.0, 7.0, 8.0];
468        backend.add(&mut a, &b);
469        assert_eq!(a, vec![6.0, 8.0, 10.0, 12.0]);
470    }
471
472    #[test]
473    fn test_axpy() {
474        let backend = CpuBackend;
475        let mut a = vec![1.0, 2.0, 3.0, 4.0];
476        let b = vec![1.0, 1.0, 1.0, 1.0];
477        backend.axpy(&mut a, &b, 2.0);
478        assert_eq!(a, vec![3.0, 4.0, 5.0, 6.0]);
479    }
480}