ruvector_sparse_inference/backend/
wasm.rs

1//! WebAssembly backend with portable SIMD
2
3use super::Backend;
4use crate::config::ActivationType;
5use ndarray::Array2;
6
7#[cfg(target_arch = "wasm32")]
8use std::arch::wasm32::*;
9
10/// WASM backend using wasm32 SIMD instructions
11pub struct WasmBackend;
12
13impl Backend for WasmBackend {
14    fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
15        debug_assert_eq!(a.len(), b.len());
16
17        #[cfg(target_arch = "wasm32")]
18        return dot_product_wasm_simd(a, b);
19
20        #[cfg(not(target_arch = "wasm32"))]
21        dot_product_scalar(a, b)
22    }
23
24    fn sparse_matmul(&self, matrix: &Array2<f32>, input: &[f32], rows: &[usize]) -> Vec<f32> {
25        rows.iter()
26            .map(|&row_idx| {
27                let row = matrix.row(row_idx);
28                self.dot_product(row.as_slice().unwrap(), input)
29            })
30            .collect()
31    }
32
33    fn sparse_matmul_accumulate(
34        &self,
35        matrix: &Array2<f32>,
36        input: &[f32],
37        cols: &[usize],
38        output: &mut [f32],
39    ) {
40        for (i, &col_idx) in cols.iter().enumerate() {
41            let col = matrix.column(col_idx);
42            self.axpy(output, col.as_slice().unwrap(), input[i]);
43        }
44    }
45
46    fn activation(&self, data: &mut [f32], activation_type: ActivationType) {
47        match activation_type {
48            ActivationType::Relu => {
49                #[cfg(target_arch = "wasm32")]
50                relu_wasm_simd(data);
51                #[cfg(not(target_arch = "wasm32"))]
52                relu_scalar(data);
53            }
54            ActivationType::Gelu => gelu_scalar(data),
55            ActivationType::Silu | ActivationType::Swish => silu_scalar(data),
56            ActivationType::Identity => { /* no-op */ }
57        }
58    }
59
60    fn add(&self, a: &mut [f32], b: &[f32]) {
61        #[cfg(target_arch = "wasm32")]
62        add_wasm_simd(a, b);
63
64        #[cfg(not(target_arch = "wasm32"))]
65        for (x, y) in a.iter_mut().zip(b.iter()) {
66            *x += y;
67        }
68    }
69
70    fn axpy(&self, a: &mut [f32], b: &[f32], scalar: f32) {
71        #[cfg(target_arch = "wasm32")]
72        axpy_wasm_simd(a, b, scalar);
73
74        #[cfg(not(target_arch = "wasm32"))]
75        for (x, y) in a.iter_mut().zip(b.iter()) {
76            *x += y * scalar;
77        }
78    }
79
80    fn name(&self) -> &'static str {
81        "WASM-SIMD"
82    }
83
84    fn simd_width(&self) -> usize {
85        4 // 128-bit SIMD = 4 x f32
86    }
87}
88
89// ============ WASM SIMD Implementations ============
90
91#[cfg(target_arch = "wasm32")]
92fn dot_product_wasm_simd(a: &[f32], b: &[f32]) -> f32 {
93    let n = a.len();
94    let chunks = n / 4;
95
96    let mut sum = f32x4_splat(0.0);
97
98    for i in 0..chunks {
99        let va = v128_load(a[i * 4..].as_ptr() as *const v128);
100        let vb = v128_load(b[i * 4..].as_ptr() as *const v128);
101        sum = f32x4_add(sum, f32x4_mul(va, vb));
102    }
103
104    // Horizontal sum
105    let sum_arr = [
106        f32x4_extract_lane::<0>(sum),
107        f32x4_extract_lane::<1>(sum),
108        f32x4_extract_lane::<2>(sum),
109        f32x4_extract_lane::<3>(sum),
110    ];
111    let mut result: f32 = sum_arr.iter().sum();
112
113    // Handle remainder
114    for i in (chunks * 4)..n {
115        result += a[i] * b[i];
116    }
117
118    result
119}
120
121#[cfg(target_arch = "wasm32")]
122fn relu_wasm_simd(data: &mut [f32]) {
123    let zero = f32x4_splat(0.0);
124    let chunks = data.len() / 4;
125
126    for i in 0..chunks {
127        let ptr = data[i * 4..].as_ptr() as *const v128;
128        let v = v128_load(ptr);
129        let result = f32x4_max(v, zero);
130        v128_store(data[i * 4..].as_mut_ptr() as *mut v128, result);
131    }
132
133    for i in (chunks * 4)..data.len() {
134        data[i] = data[i].max(0.0);
135    }
136}
137
138#[cfg(target_arch = "wasm32")]
139fn add_wasm_simd(a: &mut [f32], b: &[f32]) {
140    let chunks = a.len() / 4;
141
142    for i in 0..chunks {
143        let pa = a[i * 4..].as_ptr() as *const v128;
144        let pb = b[i * 4..].as_ptr() as *const v128;
145        let va = v128_load(pa);
146        let vb = v128_load(pb);
147        let result = f32x4_add(va, vb);
148        v128_store(a[i * 4..].as_mut_ptr() as *mut v128, result);
149    }
150
151    for i in (chunks * 4)..a.len() {
152        a[i] += b[i];
153    }
154}
155
156#[cfg(target_arch = "wasm32")]
157fn axpy_wasm_simd(a: &mut [f32], b: &[f32], scalar: f32) {
158    let vs = f32x4_splat(scalar);
159    let chunks = a.len() / 4;
160
161    for i in 0..chunks {
162        let pa = a[i * 4..].as_ptr() as *const v128;
163        let pb = b[i * 4..].as_ptr() as *const v128;
164        let va = v128_load(pa);
165        let vb = v128_load(pb);
166        let result = f32x4_add(va, f32x4_mul(vb, vs));
167        v128_store(a[i * 4..].as_mut_ptr() as *mut v128, result);
168    }
169
170    for i in (chunks * 4)..a.len() {
171        a[i] += b[i] * scalar;
172    }
173}
174
175// ============ Scalar Fallbacks ============
176
177#[cfg(not(target_arch = "wasm32"))]
178fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
179    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
180}
181
182#[cfg(not(target_arch = "wasm32"))]
183fn relu_scalar(data: &mut [f32]) {
184    for x in data.iter_mut() { *x = x.max(0.0); }
185}
186
187fn gelu_scalar(data: &mut [f32]) {
188    const SQRT_2_OVER_PI: f32 = 0.7978845608;
189    const GELU_COEF: f32 = 0.044715;
190    for x in data.iter_mut() {
191        let x3 = *x * *x * *x;
192        let inner = SQRT_2_OVER_PI * (*x + GELU_COEF * x3);
193        *x = 0.5 * *x * (1.0 + inner.tanh());
194    }
195}
196
197fn silu_scalar(data: &mut [f32]) {
198    for x in data.iter_mut() {
199        *x = *x / (1.0 + (-*x).exp());
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_dot_product() {
209        let backend = WasmBackend;
210        let a = vec![1.0, 2.0, 3.0, 4.0];
211        let b = vec![2.0, 3.0, 4.0, 5.0];
212        let result = backend.dot_product(&a, &b);
213        assert!((result - 40.0).abs() < 1e-5);
214    }
215
216    #[test]
217    fn test_add() {
218        let backend = WasmBackend;
219        let mut a = vec![1.0, 2.0, 3.0, 4.0];
220        let b = vec![5.0, 6.0, 7.0, 8.0];
221        backend.add(&mut a, &b);
222        assert_eq!(a, vec![6.0, 8.0, 10.0, 12.0]);
223    }
224}