ruvector_sparse_inference/backend/
cpu.rs1use super::Backend;
4use crate::config::ActivationType;
5use ndarray::Array2;
6use std::sync::OnceLock;
7
8#[cfg(target_arch = "x86_64")]
9use std::arch::x86_64::*;
10
11#[cfg(target_arch = "aarch64")]
12use std::arch::aarch64::*;
13
14#[cfg(target_arch = "x86_64")]
16static SIMD_FEATURES: OnceLock<SimdFeatures> = OnceLock::new();
17
18#[cfg(target_arch = "x86_64")]
19#[derive(Debug, Clone, Copy)]
20struct SimdFeatures {
21 has_avx2: bool,
22 has_sse41: bool,
23 has_fma: bool,
24}
25
26#[cfg(target_arch = "x86_64")]
27fn get_simd_features() -> SimdFeatures {
28 *SIMD_FEATURES.get_or_init(|| SimdFeatures {
29 has_avx2: is_x86_feature_detected!("avx2"),
30 has_sse41: is_x86_feature_detected!("sse4.1"),
31 has_fma: is_x86_feature_detected!("fma"),
32 })
33}
34
35pub struct CpuBackend;
37
38impl Backend for CpuBackend {
39 fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
40 debug_assert_eq!(a.len(), b.len());
41
42 #[cfg(target_arch = "x86_64")]
43 {
44 let features = get_simd_features();
45 if features.has_avx2 {
46 return unsafe { dot_product_avx2(a, b) };
47 } else if features.has_sse41 {
48 return unsafe { dot_product_sse(a, b) };
49 }
50 return dot_product_scalar(a, b);
51 }
52
53 #[cfg(target_arch = "aarch64")]
54 return unsafe { dot_product_neon(a, b) };
55
56 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
58 dot_product_scalar(a, b)
59 }
60
61 fn sparse_matmul(&self, matrix: &Array2<f32>, input: &[f32], rows: &[usize]) -> Vec<f32> {
62 let mut output = Vec::with_capacity(rows.len());
63
64 for &row_idx in rows {
65 let row = matrix.row(row_idx);
66 let dot = self.dot_product(row.as_slice().unwrap(), input);
67 output.push(dot);
68 }
69
70 output
71 }
72
73 fn sparse_matmul_accumulate(
74 &self,
75 matrix: &Array2<f32>,
76 input: &[f32],
77 cols: &[usize],
78 output: &mut [f32],
79 ) {
80 for (i, &col_idx) in cols.iter().enumerate() {
81 let col = matrix.column(col_idx);
82 let scalar = input[i];
83 for (j, &val) in col.iter().enumerate() {
85 output[j] += val * scalar;
86 }
87 }
88 }
89
90 fn activation(&self, data: &mut [f32], activation_type: ActivationType) {
91 #[cfg(target_arch = "x86_64")]
92 let features = get_simd_features();
93
94 match activation_type {
95 ActivationType::Relu => {
96 #[cfg(target_arch = "x86_64")]
97 if features.has_avx2 {
98 return unsafe { relu_avx2(data) };
99 }
100 relu_scalar(data);
101 }
102 ActivationType::Gelu => {
103 #[cfg(target_arch = "x86_64")]
104 if features.has_avx2 {
105 return unsafe { gelu_avx2(data) };
106 }
107 gelu_scalar(data);
108 }
109 ActivationType::Silu | ActivationType::Swish => {
110 #[cfg(target_arch = "x86_64")]
111 if features.has_avx2 {
112 return unsafe { silu_avx2(data) };
113 }
114 silu_scalar(data);
115 }
116 ActivationType::Identity => { }
117 }
118 }
119
120 fn add(&self, a: &mut [f32], b: &[f32]) {
121 debug_assert_eq!(a.len(), b.len());
122
123 #[cfg(target_arch = "x86_64")]
124 if get_simd_features().has_avx2 {
125 return unsafe { add_avx2(a, b) };
126 }
127
128 for (x, y) in a.iter_mut().zip(b.iter()) {
129 *x += y;
130 }
131 }
132
133 fn axpy(&self, a: &mut [f32], b: &[f32], scalar: f32) {
134 debug_assert_eq!(a.len(), b.len());
135
136 #[cfg(target_arch = "x86_64")]
137 if get_simd_features().has_avx2 {
138 return unsafe { axpy_avx2(a, b, scalar) };
139 }
140
141 for (x, y) in a.iter_mut().zip(b.iter()) {
142 *x += y * scalar;
143 }
144 }
145
146 fn name(&self) -> &'static str {
147 #[cfg(target_arch = "x86_64")]
148 {
149 let features = get_simd_features();
150 if features.has_avx2 {
151 return "CPU-AVX2";
152 } else if features.has_sse41 {
153 return "CPU-SSE4.1";
154 }
155 return "CPU-Scalar";
156 }
157 #[cfg(target_arch = "aarch64")]
158 return "CPU-NEON";
159
160 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
161 "CPU-Scalar"
162 }
163
164 fn simd_width(&self) -> usize {
165 #[cfg(target_arch = "x86_64")]
166 {
167 let features = get_simd_features();
168 if features.has_avx2 { return 8; }
169 if features.has_sse41 { return 4; }
170 return 1;
171 }
172 #[cfg(target_arch = "aarch64")]
173 return 4;
174
175 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
176 1
177 }
178}
179
180#[cfg(target_arch = "x86_64")]
183#[target_feature(enable = "avx2")]
184unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
185 let n = a.len();
186 let chunks = n / 8;
187
188 let mut sum = _mm256_setzero_ps();
189
190 for i in 0..chunks {
191 let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
192 let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
193 sum = _mm256_fmadd_ps(va, vb, sum);
194 }
195
196 let sum128 = _mm_add_ps(
198 _mm256_extractf128_ps(sum, 0),
199 _mm256_extractf128_ps(sum, 1),
200 );
201 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
202 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
203 let mut result = _mm_cvtss_f32(sum32);
204
205 for i in (chunks * 8)..n {
207 result += a[i] * b[i];
208 }
209
210 result
211}
212
213#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "avx2")]
215unsafe fn relu_avx2(data: &mut [f32]) {
216 let zero = _mm256_setzero_ps();
217 let chunks = data.len() / 8;
218
219 for i in 0..chunks {
220 let ptr = data.as_mut_ptr().add(i * 8);
221 let v = _mm256_loadu_ps(ptr);
222 let result = _mm256_max_ps(v, zero);
223 _mm256_storeu_ps(ptr, result);
224 }
225
226 for i in (chunks * 8)..data.len() {
228 data[i] = data[i].max(0.0);
229 }
230}
231
232#[cfg(target_arch = "x86_64")]
236#[target_feature(enable = "avx2", enable = "fma")]
237unsafe fn gelu_avx2(data: &mut [f32]) {
238 let chunks = data.len() / 8;
239
240 let half = _mm256_set1_ps(0.5);
242 let one = _mm256_set1_ps(1.0);
243 let sqrt_2_over_pi = _mm256_set1_ps(0.7978845608); let coef = _mm256_set1_ps(0.044715);
245
246 let c27 = _mm256_set1_ps(27.0);
248 let c9 = _mm256_set1_ps(9.0);
249
250 for i in 0..chunks {
251 let ptr = data.as_mut_ptr().add(i * 8);
252 let x = _mm256_loadu_ps(ptr);
253
254 let x2 = _mm256_mul_ps(x, x);
256 let x3 = _mm256_mul_ps(x2, x);
257
258 let inner = _mm256_mul_ps(sqrt_2_over_pi, _mm256_fmadd_ps(coef, x3, x));
260
261 let inner2 = _mm256_mul_ps(inner, inner);
263 let num = _mm256_fmadd_ps(inner2, one, c27); let den = _mm256_fmadd_ps(inner2, c9, c27); let tanh_approx = _mm256_mul_ps(inner, _mm256_div_ps(num, den));
266
267 let result = _mm256_mul_ps(half, _mm256_mul_ps(x, _mm256_add_ps(one, tanh_approx)));
269 _mm256_storeu_ps(ptr, result);
270 }
271
272 for i in (chunks * 8)..data.len() {
274 let x = data[i];
275 let x3 = x * x * x;
276 let inner = 0.7978845608 * (x + 0.044715 * x3);
277 data[i] = 0.5 * x * (1.0 + inner.tanh());
278 }
279}
280
281#[cfg(target_arch = "x86_64")]
284#[target_feature(enable = "avx2", enable = "fma")]
285unsafe fn silu_avx2(data: &mut [f32]) {
286 let chunks = data.len() / 8;
287
288 let half = _mm256_set1_ps(0.5);
290 let c27 = _mm256_set1_ps(27.0);
291 let c9 = _mm256_set1_ps(9.0);
292 let one = _mm256_set1_ps(1.0);
293
294 for i in 0..chunks {
295 let ptr = data.as_mut_ptr().add(i * 8);
296 let x = _mm256_loadu_ps(ptr);
297
298 let x_half = _mm256_mul_ps(x, half);
300
301 let xh2 = _mm256_mul_ps(x_half, x_half);
303 let num = _mm256_fmadd_ps(xh2, one, c27);
304 let den = _mm256_fmadd_ps(xh2, c9, c27);
305 let tanh_approx = _mm256_mul_ps(x_half, _mm256_div_ps(num, den));
306
307 let sigmoid = _mm256_fmadd_ps(half, tanh_approx, half);
309
310 let result = _mm256_mul_ps(x, sigmoid);
312 _mm256_storeu_ps(ptr, result);
313 }
314
315 for i in (chunks * 8)..data.len() {
317 let x = data[i];
318 data[i] = x / (1.0 + (-x).exp());
319 }
320}
321
322#[cfg(target_arch = "x86_64")]
323#[target_feature(enable = "avx2")]
324unsafe fn add_avx2(a: &mut [f32], b: &[f32]) {
325 let chunks = a.len() / 8;
326
327 for i in 0..chunks {
328 let pa = a.as_mut_ptr().add(i * 8);
329 let pb = b.as_ptr().add(i * 8);
330 let va = _mm256_loadu_ps(pa);
331 let vb = _mm256_loadu_ps(pb);
332 _mm256_storeu_ps(pa, _mm256_add_ps(va, vb));
333 }
334
335 for i in (chunks * 8)..a.len() {
336 a[i] += b[i];
337 }
338}
339
340#[cfg(target_arch = "x86_64")]
341#[target_feature(enable = "avx2")]
342unsafe fn axpy_avx2(a: &mut [f32], b: &[f32], scalar: f32) {
343 let vs = _mm256_set1_ps(scalar);
344 let chunks = a.len() / 8;
345
346 for i in 0..chunks {
347 let pa = a.as_mut_ptr().add(i * 8);
348 let pb = b.as_ptr().add(i * 8);
349 let va = _mm256_loadu_ps(pa);
350 let vb = _mm256_loadu_ps(pb);
351 let result = _mm256_fmadd_ps(vb, vs, va);
352 _mm256_storeu_ps(pa, result);
353 }
354
355 for i in (chunks * 8)..a.len() {
356 a[i] += b[i] * scalar;
357 }
358}
359
360#[cfg(target_arch = "x86_64")]
363#[target_feature(enable = "sse4.1")]
364unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
365 let n = a.len();
366 let chunks = n / 4;
367
368 let mut sum = _mm_setzero_ps();
369
370 for i in 0..chunks {
371 let va = _mm_loadu_ps(a.as_ptr().add(i * 4));
372 let vb = _mm_loadu_ps(b.as_ptr().add(i * 4));
373 sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
374 }
375
376 let sum2 = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
378 let sum1 = _mm_add_ss(sum2, _mm_shuffle_ps(sum2, sum2, 1));
379 let mut result = _mm_cvtss_f32(sum1);
380
381 for i in (chunks * 4)..n {
382 result += a[i] * b[i];
383 }
384
385 result
386}
387
388#[cfg(target_arch = "aarch64")]
391unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
392 let n = a.len();
393 let chunks = n / 4;
394
395 let mut sum = vdupq_n_f32(0.0);
396
397 for i in 0..chunks {
398 let va = vld1q_f32(a.as_ptr().add(i * 4));
399 let vb = vld1q_f32(b.as_ptr().add(i * 4));
400 sum = vfmaq_f32(sum, va, vb);
401 }
402
403 let mut result = vaddvq_f32(sum);
405
406 for i in (chunks * 4)..n {
407 result += a[i] * b[i];
408 }
409
410 result
411}
412
413fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
416 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
417}
418
419fn relu_scalar(data: &mut [f32]) {
420 for x in data.iter_mut() {
421 *x = x.max(0.0);
422 }
423}
424
425fn gelu_scalar(data: &mut [f32]) {
426 const SQRT_2_OVER_PI: f32 = 0.7978845608;
427 const GELU_COEF: f32 = 0.044715;
428
429 for x in data.iter_mut() {
430 let x3 = *x * *x * *x;
431 let inner = SQRT_2_OVER_PI * (*x + GELU_COEF * x3);
432 *x = 0.5 * *x * (1.0 + inner.tanh());
433 }
434}
435
436fn silu_scalar(data: &mut [f32]) {
437 for x in data.iter_mut() {
438 *x = *x / (1.0 + (-*x).exp());
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_dot_product() {
448 let backend = CpuBackend;
449 let a = vec![1.0, 2.0, 3.0, 4.0];
450 let b = vec![2.0, 3.0, 4.0, 5.0];
451 let result = backend.dot_product(&a, &b);
452 assert!((result - 40.0).abs() < 1e-5);
453 }
454
455 #[test]
456 fn test_relu() {
457 let backend = CpuBackend;
458 let mut data = vec![-1.0, 0.0, 1.0, 2.0];
459 backend.activation(&mut data, ActivationType::Relu);
460 assert_eq!(data, vec![0.0, 0.0, 1.0, 2.0]);
461 }
462
463 #[test]
464 fn test_add() {
465 let backend = CpuBackend;
466 let mut a = vec![1.0, 2.0, 3.0, 4.0];
467 let b = vec![5.0, 6.0, 7.0, 8.0];
468 backend.add(&mut a, &b);
469 assert_eq!(a, vec![6.0, 8.0, 10.0, 12.0]);
470 }
471
472 #[test]
473 fn test_axpy() {
474 let backend = CpuBackend;
475 let mut a = vec![1.0, 2.0, 3.0, 4.0];
476 let b = vec![1.0, 1.0, 1.0, 1.0];
477 backend.axpy(&mut a, &b, 2.0);
478 assert_eq!(a, vec![3.0, 4.0, 5.0, 6.0]);
479 }
480}