1pub mod feature_detect;
65pub mod messages;
66pub mod pool;
67pub mod shared;
68
69pub use feature_detect::*;
70pub use messages::*;
71pub use pool::*;
72pub use shared::*;
73
74use wasm_bindgen::prelude::*;
75
76pub const MAX_WORKERS: usize = 16;
78
79pub const MIN_WORKERS: usize = 2;
81
82pub const WASM_PAGE_SIZE: usize = 65536;
84
85pub const SIMD_ALIGNMENT: usize = 16;
87
88#[wasm_bindgen]
93pub struct ParallelInference {
94 pool: WorkerPool,
95 shared_buffers: SharedBufferManager,
96 initialized: bool,
97}
98
99#[wasm_bindgen]
100impl ParallelInference {
101 #[wasm_bindgen(constructor)]
114 pub async fn new(num_workers: Option<usize>) -> Result<ParallelInference, JsValue> {
115 crate::utils::set_panic_hook();
116
117 let worker_count = num_workers.unwrap_or_else(optimal_worker_count);
118 let worker_count = worker_count.clamp(MIN_WORKERS, MAX_WORKERS);
119
120 crate::utils::log(&format!(
121 "Initializing ParallelInference with {} workers",
122 worker_count
123 ));
124
125 let shared_memory_available = is_shared_array_buffer_available();
127 if !shared_memory_available {
128 crate::utils::warn(
129 "SharedArrayBuffer not available. Using fallback mode with message passing.",
130 );
131 }
132
133 if shared_memory_available && !cross_origin_isolated() {
135 crate::utils::warn(
136 "Page is not cross-origin isolated. SharedArrayBuffer may not work correctly.",
137 );
138 }
139
140 let pool = WorkerPool::new(worker_count).await?;
141 let shared_buffers = SharedBufferManager::new();
142
143 crate::utils::log("ParallelInference initialized successfully");
144
145 Ok(ParallelInference {
146 pool,
147 shared_buffers,
148 initialized: true,
149 })
150 }
151
152 #[wasm_bindgen]
169 pub async fn matmul(
170 &mut self,
171 a: &[f32],
172 b: &[f32],
173 m: usize,
174 n: usize,
175 k: usize,
176 ) -> Result<Vec<f32>, JsValue> {
177 if !self.initialized {
178 return Err(JsValue::from_str("ParallelInference not initialized"));
179 }
180
181 if a.len() != m * k {
183 return Err(JsValue::from_str(&format!(
184 "Matrix A size mismatch: expected {} ({}x{}), got {}",
185 m * k,
186 m,
187 k,
188 a.len()
189 )));
190 }
191 if b.len() != k * n {
192 return Err(JsValue::from_str(&format!(
193 "Matrix B size mismatch: expected {} ({}x{}), got {}",
194 k * n,
195 k,
196 n,
197 b.len()
198 )));
199 }
200
201 if m * n * k < 10000 {
203 return Ok(self.matmul_single_thread(a, b, m, n, k));
204 }
205
206 self.pool.parallel_matmul(a, b, m, n, k).await
208 }
209
210 #[wasm_bindgen(js_name = attention)]
225 pub async fn parallel_attention(
226 &mut self,
227 q: &[f32],
228 k: &[f32],
229 v: &[f32],
230 num_heads: usize,
231 head_dim: usize,
232 seq_len: usize,
233 ) -> Result<Vec<f32>, JsValue> {
234 if !self.initialized {
235 return Err(JsValue::from_str("ParallelInference not initialized"));
236 }
237
238 let expected_size = num_heads * seq_len * head_dim;
240 if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
241 return Err(JsValue::from_str(&format!(
242 "Tensor size mismatch: expected {}, got Q={}, K={}, V={}",
243 expected_size,
244 q.len(),
245 k.len(),
246 v.len()
247 )));
248 }
249
250 if expected_size < 10000 {
252 return Ok(self.attention_single_thread(q, k, v, num_heads, head_dim, seq_len));
253 }
254
255 self.pool
256 .parallel_attention(q, k, v, num_heads, head_dim, seq_len)
257 .await
258 }
259
260 #[wasm_bindgen(js_name = layerNorm)]
271 pub async fn layer_norm(
272 &mut self,
273 input: &[f32],
274 gamma: &[f32],
275 beta: &[f32],
276 epsilon: f32,
277 ) -> Result<Vec<f32>, JsValue> {
278 if !self.initialized {
279 return Err(JsValue::from_str("ParallelInference not initialized"));
280 }
281
282 if input.len() < 1000 {
283 return Ok(self.layer_norm_single_thread(input, gamma, beta, epsilon));
284 }
285
286 self.pool.parallel_norm(input, gamma, beta, epsilon).await
287 }
288
289 #[wasm_bindgen(js_name = workerCount)]
291 pub fn worker_count(&self) -> usize {
292 self.pool.worker_count()
293 }
294
295 #[wasm_bindgen(js_name = isSharedMemoryAvailable)]
297 pub fn is_shared_memory_available(&self) -> bool {
298 is_shared_array_buffer_available()
299 }
300
301 #[wasm_bindgen(js_name = isCrossOriginIsolated)]
303 pub fn is_cross_origin_isolated(&self) -> bool {
304 cross_origin_isolated()
305 }
306
307 #[wasm_bindgen(js_name = isAtomicsAvailable)]
309 pub fn is_atomics_available(&self) -> bool {
310 is_atomics_available()
311 }
312
313 #[wasm_bindgen(js_name = optimalWorkerCount)]
315 pub fn get_optimal_worker_count() -> usize {
316 optimal_worker_count()
317 }
318
319 #[wasm_bindgen]
321 pub fn terminate(&mut self) {
322 self.pool.terminate();
323 self.shared_buffers.clear();
324 self.initialized = false;
325 crate::utils::log("ParallelInference terminated");
326 }
327
328 #[wasm_bindgen(js_name = getStats)]
330 pub fn get_stats(&self) -> Result<String, JsValue> {
331 let stats = self.pool.stats();
332 serde_json::to_string(&stats).map_err(|e| JsValue::from_str(&e.to_string()))
333 }
334
335 fn matmul_single_thread(&self, a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
338 let mut c = vec![0.0f32; m * n];
339
340 for i in 0..m {
341 for j in 0..n {
342 let mut sum = 0.0f32;
343 for l in 0..k {
344 sum += a[i * k + l] * b[l * n + j];
345 }
346 c[i * n + j] = sum;
347 }
348 }
349
350 c
351 }
352
353 fn attention_single_thread(
354 &self,
355 q: &[f32],
356 k: &[f32],
357 v: &[f32],
358 num_heads: usize,
359 head_dim: usize,
360 seq_len: usize,
361 ) -> Vec<f32> {
362 let mut output = vec![0.0f32; num_heads * seq_len * head_dim];
363 let scale = 1.0 / (head_dim as f32).sqrt();
364
365 for h in 0..num_heads {
366 let head_offset = h * seq_len * head_dim;
367
368 let mut scores = vec![0.0f32; seq_len * seq_len];
370 for i in 0..seq_len {
371 for j in 0..seq_len {
372 let mut dot = 0.0f32;
373 for d in 0..head_dim {
374 dot += q[head_offset + i * head_dim + d]
375 * k[head_offset + j * head_dim + d];
376 }
377 scores[i * seq_len + j] = dot * scale;
378 }
379 }
380
381 for i in 0..seq_len {
383 let row_start = i * seq_len;
384 let max_val = scores[row_start..row_start + seq_len]
385 .iter()
386 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
387
388 let mut sum = 0.0f32;
389 for j in 0..seq_len {
390 scores[row_start + j] = (scores[row_start + j] - max_val).exp();
391 sum += scores[row_start + j];
392 }
393
394 for j in 0..seq_len {
395 scores[row_start + j] /= sum;
396 }
397 }
398
399 for i in 0..seq_len {
401 for d in 0..head_dim {
402 let mut sum = 0.0f32;
403 for j in 0..seq_len {
404 sum += scores[i * seq_len + j] * v[head_offset + j * head_dim + d];
405 }
406 output[head_offset + i * head_dim + d] = sum;
407 }
408 }
409 }
410
411 output
412 }
413
414 fn layer_norm_single_thread(
415 &self,
416 input: &[f32],
417 gamma: &[f32],
418 beta: &[f32],
419 epsilon: f32,
420 ) -> Vec<f32> {
421 let n = input.len();
422 let hidden_dim = gamma.len();
423
424 if n % hidden_dim != 0 {
425 return input.to_vec(); }
427
428 let batch_size = n / hidden_dim;
429 let mut output = vec![0.0f32; n];
430
431 for b in 0..batch_size {
432 let start = b * hidden_dim;
433 let end = start + hidden_dim;
434 let slice = &input[start..end];
435
436 let mean: f32 = slice.iter().sum::<f32>() / hidden_dim as f32;
438
439 let variance: f32 =
441 slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / hidden_dim as f32;
442
443 let std = (variance + epsilon).sqrt();
445 for i in 0..hidden_dim {
446 output[start + i] = ((input[start + i] - mean) / std) * gamma[i] + beta[i];
447 }
448 }
449
450 output
451 }
452}
453
454impl Drop for ParallelInference {
455 fn drop(&mut self) {
456 self.terminate();
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_matmul_single_thread() {
466 let inference = ParallelInference {
467 pool: WorkerPool::empty(),
468 shared_buffers: SharedBufferManager::new(),
469 initialized: true,
470 };
471
472 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
474 let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
475
476 let c = inference.matmul_single_thread(&a, &b, 2, 2, 3);
477
478 assert_eq!(c.len(), 4);
480 assert!((c[0] - 22.0).abs() < 0.001);
481 assert!((c[1] - 28.0).abs() < 0.001);
482 assert!((c[2] - 49.0).abs() < 0.001);
483 assert!((c[3] - 64.0).abs() < 0.001);
484 }
485
486 #[test]
487 fn test_layer_norm_single_thread() {
488 let inference = ParallelInference {
489 pool: WorkerPool::empty(),
490 shared_buffers: SharedBufferManager::new(),
491 initialized: true,
492 };
493
494 let input = vec![1.0, 2.0, 3.0, 4.0];
495 let gamma = vec![1.0, 1.0, 1.0, 1.0];
496 let beta = vec![0.0, 0.0, 0.0, 0.0];
497 let epsilon = 1e-5;
498
499 let output = inference.layer_norm_single_thread(&input, &gamma, &beta, epsilon);
500
501 let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
503 assert!(mean.abs() < 0.001);
504 }
505}