ruvector_sparse_inference/backend/
wasm.rs1use super::Backend;
4use crate::config::ActivationType;
5use ndarray::Array2;
6
7#[cfg(target_arch = "wasm32")]
8use std::arch::wasm32::*;
9
10pub struct WasmBackend;
12
13impl Backend for WasmBackend {
14 fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
15 debug_assert_eq!(a.len(), b.len());
16
17 #[cfg(target_arch = "wasm32")]
18 return dot_product_wasm_simd(a, b);
19
20 #[cfg(not(target_arch = "wasm32"))]
21 dot_product_scalar(a, b)
22 }
23
24 fn sparse_matmul(&self, matrix: &Array2<f32>, input: &[f32], rows: &[usize]) -> Vec<f32> {
25 rows.iter()
26 .map(|&row_idx| {
27 let row = matrix.row(row_idx);
28 self.dot_product(row.as_slice().unwrap(), input)
29 })
30 .collect()
31 }
32
33 fn sparse_matmul_accumulate(
34 &self,
35 matrix: &Array2<f32>,
36 input: &[f32],
37 cols: &[usize],
38 output: &mut [f32],
39 ) {
40 for (i, &col_idx) in cols.iter().enumerate() {
41 let col = matrix.column(col_idx);
42 self.axpy(output, col.as_slice().unwrap(), input[i]);
43 }
44 }
45
46 fn activation(&self, data: &mut [f32], activation_type: ActivationType) {
47 match activation_type {
48 ActivationType::Relu => {
49 #[cfg(target_arch = "wasm32")]
50 relu_wasm_simd(data);
51 #[cfg(not(target_arch = "wasm32"))]
52 relu_scalar(data);
53 }
54 ActivationType::Gelu => gelu_scalar(data),
55 ActivationType::Silu | ActivationType::Swish => silu_scalar(data),
56 ActivationType::Identity => { }
57 }
58 }
59
60 fn add(&self, a: &mut [f32], b: &[f32]) {
61 #[cfg(target_arch = "wasm32")]
62 add_wasm_simd(a, b);
63
64 #[cfg(not(target_arch = "wasm32"))]
65 for (x, y) in a.iter_mut().zip(b.iter()) {
66 *x += y;
67 }
68 }
69
70 fn axpy(&self, a: &mut [f32], b: &[f32], scalar: f32) {
71 #[cfg(target_arch = "wasm32")]
72 axpy_wasm_simd(a, b, scalar);
73
74 #[cfg(not(target_arch = "wasm32"))]
75 for (x, y) in a.iter_mut().zip(b.iter()) {
76 *x += y * scalar;
77 }
78 }
79
80 fn name(&self) -> &'static str {
81 "WASM-SIMD"
82 }
83
84 fn simd_width(&self) -> usize {
85 4 }
87}
88
89#[cfg(target_arch = "wasm32")]
92fn dot_product_wasm_simd(a: &[f32], b: &[f32]) -> f32 {
93 let n = a.len();
94 let chunks = n / 4;
95
96 let mut sum = f32x4_splat(0.0);
97
98 for i in 0..chunks {
99 let va = v128_load(a[i * 4..].as_ptr() as *const v128);
100 let vb = v128_load(b[i * 4..].as_ptr() as *const v128);
101 sum = f32x4_add(sum, f32x4_mul(va, vb));
102 }
103
104 let sum_arr = [
106 f32x4_extract_lane::<0>(sum),
107 f32x4_extract_lane::<1>(sum),
108 f32x4_extract_lane::<2>(sum),
109 f32x4_extract_lane::<3>(sum),
110 ];
111 let mut result: f32 = sum_arr.iter().sum();
112
113 for i in (chunks * 4)..n {
115 result += a[i] * b[i];
116 }
117
118 result
119}
120
121#[cfg(target_arch = "wasm32")]
122fn relu_wasm_simd(data: &mut [f32]) {
123 let zero = f32x4_splat(0.0);
124 let chunks = data.len() / 4;
125
126 for i in 0..chunks {
127 let ptr = data[i * 4..].as_ptr() as *const v128;
128 let v = v128_load(ptr);
129 let result = f32x4_max(v, zero);
130 v128_store(data[i * 4..].as_mut_ptr() as *mut v128, result);
131 }
132
133 for i in (chunks * 4)..data.len() {
134 data[i] = data[i].max(0.0);
135 }
136}
137
138#[cfg(target_arch = "wasm32")]
139fn add_wasm_simd(a: &mut [f32], b: &[f32]) {
140 let chunks = a.len() / 4;
141
142 for i in 0..chunks {
143 let pa = a[i * 4..].as_ptr() as *const v128;
144 let pb = b[i * 4..].as_ptr() as *const v128;
145 let va = v128_load(pa);
146 let vb = v128_load(pb);
147 let result = f32x4_add(va, vb);
148 v128_store(a[i * 4..].as_mut_ptr() as *mut v128, result);
149 }
150
151 for i in (chunks * 4)..a.len() {
152 a[i] += b[i];
153 }
154}
155
156#[cfg(target_arch = "wasm32")]
157fn axpy_wasm_simd(a: &mut [f32], b: &[f32], scalar: f32) {
158 let vs = f32x4_splat(scalar);
159 let chunks = a.len() / 4;
160
161 for i in 0..chunks {
162 let pa = a[i * 4..].as_ptr() as *const v128;
163 let pb = b[i * 4..].as_ptr() as *const v128;
164 let va = v128_load(pa);
165 let vb = v128_load(pb);
166 let result = f32x4_add(va, f32x4_mul(vb, vs));
167 v128_store(a[i * 4..].as_mut_ptr() as *mut v128, result);
168 }
169
170 for i in (chunks * 4)..a.len() {
171 a[i] += b[i] * scalar;
172 }
173}
174
175#[cfg(not(target_arch = "wasm32"))]
178fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
179 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
180}
181
182#[cfg(not(target_arch = "wasm32"))]
183fn relu_scalar(data: &mut [f32]) {
184 for x in data.iter_mut() { *x = x.max(0.0); }
185}
186
187fn gelu_scalar(data: &mut [f32]) {
188 const SQRT_2_OVER_PI: f32 = 0.7978845608;
189 const GELU_COEF: f32 = 0.044715;
190 for x in data.iter_mut() {
191 let x3 = *x * *x * *x;
192 let inner = SQRT_2_OVER_PI * (*x + GELU_COEF * x3);
193 *x = 0.5 * *x * (1.0 + inner.tanh());
194 }
195}
196
197fn silu_scalar(data: &mut [f32]) {
198 for x in data.iter_mut() {
199 *x = *x / (1.0 + (-*x).exp());
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_dot_product() {
209 let backend = WasmBackend;
210 let a = vec![1.0, 2.0, 3.0, 4.0];
211 let b = vec![2.0, 3.0, 4.0, 5.0];
212 let result = backend.dot_product(&a, &b);
213 assert!((result - 40.0).abs() < 1e-5);
214 }
215
216 #[test]
217 fn test_add() {
218 let backend = WasmBackend;
219 let mut a = vec![1.0, 2.0, 3.0, 4.0];
220 let b = vec![5.0, 6.0, 7.0, 8.0];
221 backend.add(&mut a, &b);
222 assert_eq!(a, vec![6.0, 8.0, 10.0, 12.0]);
223 }
224}