ruvllm_wasm/workers/
mod.rs

1//! Web Workers for Parallel Inference in WASM
2//!
3//! This module provides multi-threaded execution in browsers using Web Workers
4//! with SharedArrayBuffer for zero-copy data sharing.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                      Main Thread                                 │
11//! │  ┌──────────────────┐  ┌──────────────────┐                     │
12//! │  │ ParallelInference│  │ SharedBufferMgr  │                     │
13//! │  └────────┬─────────┘  └────────┬─────────┘                     │
14//! │           │                     │                               │
15//! │           ▼                     ▼                               │
16//! │  ┌────────────────────────────────────────┐                     │
17//! │  │            WorkerPool                   │                     │
18//! │  │  ┌──────────┐ ┌──────────┐ ┌──────────┐│                     │
19//! │  │  │TaskQueue │ │SharedMem │ │ Workers  ││                     │
20//! │  │  └──────────┘ └──────────┘ └──────────┘│                     │
21//! │  └────────────────────────────────────────┘                     │
22//! └─────────────────────────────────────────────────────────────────┘
23//!                        │ postMessage │
24//!                        ▼             ▼
25//! ┌────────────────┐ ┌────────────────┐ ┌────────────────┐
26//! │   Worker 0     │ │   Worker 1     │ │   Worker N     │
27//! │ ┌────────────┐ │ │ ┌────────────┐ │ │ ┌────────────┐ │
28//! │ │SharedArray │ │ │ │SharedArray │ │ │ │SharedArray │ │
29//! │ │  Buffer    │ │ │ │  Buffer    │ │ │ │  Buffer    │ │
30//! │ │   View     │ │ │ │   View     │ │ │ │   View     │ │
31//! │ └────────────┘ │ │ └────────────┘ │ │ └────────────┘ │
32//! └────────────────┘ └────────────────┘ └────────────────┘
33//! ```
34//!
35//! # Features
36//!
37//! - **SharedArrayBuffer**: Zero-copy memory sharing between threads
38//! - **Atomics**: Thread synchronization primitives
39//! - **Dynamic Worker Count**: Based on `navigator.hardwareConcurrency`
40//! - **Graceful Fallback**: Single-threaded mode when SharedArrayBuffer unavailable
41//!
42//! # Example
43//!
44//! ```javascript
45//! import { ParallelInference } from 'ruvllm-wasm';
46//!
47//! // Create parallel inference engine
48//! const engine = await ParallelInference.new(4); // 4 workers
49//!
50//! // Check capabilities
51//! console.log('Workers:', engine.workerCount());
52//! console.log('Shared memory:', engine.isSharedMemoryAvailable());
53//!
54//! // Parallel matrix multiplication
55//! const result = await engine.matmul(a, b, m, n, k);
56//! ```
57//!
58//! # Browser Requirements
59//!
60//! For SharedArrayBuffer to work, the page must be served with:
61//! - `Cross-Origin-Opener-Policy: same-origin`
62//! - `Cross-Origin-Embedder-Policy: require-corp`
63
64pub mod feature_detect;
65pub mod messages;
66pub mod pool;
67pub mod shared;
68
69pub use feature_detect::*;
70pub use messages::*;
71pub use pool::*;
72pub use shared::*;
73
74use wasm_bindgen::prelude::*;
75
76/// Maximum recommended workers (prevent resource exhaustion)
77pub const MAX_WORKERS: usize = 16;
78
79/// Default minimum workers
80pub const MIN_WORKERS: usize = 2;
81
82/// WASM page size in bytes (64KB)
83pub const WASM_PAGE_SIZE: usize = 65536;
84
85/// Alignment for SIMD operations (16 bytes for 128-bit SIMD)
86pub const SIMD_ALIGNMENT: usize = 16;
87
88/// Main parallel inference interface for WASM.
89///
90/// Provides high-level API for parallel compute operations in the browser.
91/// Automatically manages worker pool and shared memory.
92#[wasm_bindgen]
93pub struct ParallelInference {
94    pool: WorkerPool,
95    shared_buffers: SharedBufferManager,
96    initialized: bool,
97}
98
99#[wasm_bindgen]
100impl ParallelInference {
101    /// Create a new ParallelInference instance.
102    ///
103    /// # Arguments
104    /// * `num_workers` - Number of workers to spawn. If None, uses optimal count.
105    ///
106    /// # Returns
107    /// A Promise that resolves to ParallelInference instance.
108    ///
109    /// # Example (JavaScript)
110    /// ```javascript
111    /// const inference = await ParallelInference.new(4);
112    /// ```
113    #[wasm_bindgen(constructor)]
114    pub async fn new(num_workers: Option<usize>) -> Result<ParallelInference, JsValue> {
115        crate::utils::set_panic_hook();
116
117        let worker_count = num_workers.unwrap_or_else(optimal_worker_count);
118        let worker_count = worker_count.clamp(MIN_WORKERS, MAX_WORKERS);
119
120        crate::utils::log(&format!(
121            "Initializing ParallelInference with {} workers",
122            worker_count
123        ));
124
125        // Check for SharedArrayBuffer support
126        let shared_memory_available = is_shared_array_buffer_available();
127        if !shared_memory_available {
128            crate::utils::warn(
129                "SharedArrayBuffer not available. Using fallback mode with message passing.",
130            );
131        }
132
133        // Check cross-origin isolation
134        if shared_memory_available && !cross_origin_isolated() {
135            crate::utils::warn(
136                "Page is not cross-origin isolated. SharedArrayBuffer may not work correctly.",
137            );
138        }
139
140        let pool = WorkerPool::new(worker_count).await?;
141        let shared_buffers = SharedBufferManager::new();
142
143        crate::utils::log("ParallelInference initialized successfully");
144
145        Ok(ParallelInference {
146            pool,
147            shared_buffers,
148            initialized: true,
149        })
150    }
151
152    /// Perform parallel matrix multiplication.
153    ///
154    /// Computes C = A * B where:
155    /// - A is m x k
156    /// - B is k x n
157    /// - C is m x n
158    ///
159    /// # Arguments
160    /// * `a` - Matrix A as flat array (row-major)
161    /// * `b` - Matrix B as flat array (row-major)
162    /// * `m` - Number of rows in A
163    /// * `n` - Number of columns in B
164    /// * `k` - Number of columns in A / rows in B
165    ///
166    /// # Returns
167    /// Result matrix C as Float32Array
168    #[wasm_bindgen]
169    pub async fn matmul(
170        &mut self,
171        a: &[f32],
172        b: &[f32],
173        m: usize,
174        n: usize,
175        k: usize,
176    ) -> Result<Vec<f32>, JsValue> {
177        if !self.initialized {
178            return Err(JsValue::from_str("ParallelInference not initialized"));
179        }
180
181        // Validate dimensions
182        if a.len() != m * k {
183            return Err(JsValue::from_str(&format!(
184                "Matrix A size mismatch: expected {} ({}x{}), got {}",
185                m * k,
186                m,
187                k,
188                a.len()
189            )));
190        }
191        if b.len() != k * n {
192            return Err(JsValue::from_str(&format!(
193                "Matrix B size mismatch: expected {} ({}x{}), got {}",
194                k * n,
195                k,
196                n,
197                b.len()
198            )));
199        }
200
201        // For small matrices, compute directly on main thread
202        if m * n * k < 10000 {
203            return Ok(self.matmul_single_thread(a, b, m, n, k));
204        }
205
206        // Use parallel computation
207        self.pool.parallel_matmul(a, b, m, n, k).await
208    }
209
210    /// Perform parallel multi-head attention.
211    ///
212    /// Computes softmax(Q * K^T / sqrt(d_k)) * V for each attention head.
213    ///
214    /// # Arguments
215    /// * `q` - Query tensor (batch_size, num_heads, seq_len, head_dim)
216    /// * `k` - Key tensor (batch_size, num_heads, seq_len, head_dim)
217    /// * `v` - Value tensor (batch_size, num_heads, seq_len, head_dim)
218    /// * `num_heads` - Number of attention heads
219    /// * `head_dim` - Dimension of each head
220    /// * `seq_len` - Sequence length
221    ///
222    /// # Returns
223    /// Output tensor (batch_size, num_heads, seq_len, head_dim)
224    #[wasm_bindgen(js_name = attention)]
225    pub async fn parallel_attention(
226        &mut self,
227        q: &[f32],
228        k: &[f32],
229        v: &[f32],
230        num_heads: usize,
231        head_dim: usize,
232        seq_len: usize,
233    ) -> Result<Vec<f32>, JsValue> {
234        if !self.initialized {
235            return Err(JsValue::from_str("ParallelInference not initialized"));
236        }
237
238        // Validate dimensions
239        let expected_size = num_heads * seq_len * head_dim;
240        if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
241            return Err(JsValue::from_str(&format!(
242                "Tensor size mismatch: expected {}, got Q={}, K={}, V={}",
243                expected_size,
244                q.len(),
245                k.len(),
246                v.len()
247            )));
248        }
249
250        // For small tensors, compute on main thread
251        if expected_size < 10000 {
252            return Ok(self.attention_single_thread(q, k, v, num_heads, head_dim, seq_len));
253        }
254
255        self.pool
256            .parallel_attention(q, k, v, num_heads, head_dim, seq_len)
257            .await
258    }
259
260    /// Perform parallel layer normalization.
261    ///
262    /// # Arguments
263    /// * `input` - Input tensor
264    /// * `gamma` - Scale parameter
265    /// * `beta` - Shift parameter
266    /// * `epsilon` - Small constant for numerical stability
267    ///
268    /// # Returns
269    /// Normalized tensor
270    #[wasm_bindgen(js_name = layerNorm)]
271    pub async fn layer_norm(
272        &mut self,
273        input: &[f32],
274        gamma: &[f32],
275        beta: &[f32],
276        epsilon: f32,
277    ) -> Result<Vec<f32>, JsValue> {
278        if !self.initialized {
279            return Err(JsValue::from_str("ParallelInference not initialized"));
280        }
281
282        if input.len() < 1000 {
283            return Ok(self.layer_norm_single_thread(input, gamma, beta, epsilon));
284        }
285
286        self.pool.parallel_norm(input, gamma, beta, epsilon).await
287    }
288
289    /// Get the number of active workers.
290    #[wasm_bindgen(js_name = workerCount)]
291    pub fn worker_count(&self) -> usize {
292        self.pool.worker_count()
293    }
294
295    /// Check if SharedArrayBuffer is available.
296    #[wasm_bindgen(js_name = isSharedMemoryAvailable)]
297    pub fn is_shared_memory_available(&self) -> bool {
298        is_shared_array_buffer_available()
299    }
300
301    /// Check if the page is cross-origin isolated.
302    #[wasm_bindgen(js_name = isCrossOriginIsolated)]
303    pub fn is_cross_origin_isolated(&self) -> bool {
304        cross_origin_isolated()
305    }
306
307    /// Check if Atomics API is available.
308    #[wasm_bindgen(js_name = isAtomicsAvailable)]
309    pub fn is_atomics_available(&self) -> bool {
310        is_atomics_available()
311    }
312
313    /// Get optimal worker count for the current hardware.
314    #[wasm_bindgen(js_name = optimalWorkerCount)]
315    pub fn get_optimal_worker_count() -> usize {
316        optimal_worker_count()
317    }
318
319    /// Terminate all workers and clean up resources.
320    #[wasm_bindgen]
321    pub fn terminate(&mut self) {
322        self.pool.terminate();
323        self.shared_buffers.clear();
324        self.initialized = false;
325        crate::utils::log("ParallelInference terminated");
326    }
327
328    /// Get statistics about worker pool.
329    #[wasm_bindgen(js_name = getStats)]
330    pub fn get_stats(&self) -> Result<String, JsValue> {
331        let stats = self.pool.stats();
332        serde_json::to_string(&stats).map_err(|e| JsValue::from_str(&e.to_string()))
333    }
334
335    // Private helper methods for single-threaded fallback
336
337    fn matmul_single_thread(&self, a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
338        let mut c = vec![0.0f32; m * n];
339
340        for i in 0..m {
341            for j in 0..n {
342                let mut sum = 0.0f32;
343                for l in 0..k {
344                    sum += a[i * k + l] * b[l * n + j];
345                }
346                c[i * n + j] = sum;
347            }
348        }
349
350        c
351    }
352
353    fn attention_single_thread(
354        &self,
355        q: &[f32],
356        k: &[f32],
357        v: &[f32],
358        num_heads: usize,
359        head_dim: usize,
360        seq_len: usize,
361    ) -> Vec<f32> {
362        let mut output = vec![0.0f32; num_heads * seq_len * head_dim];
363        let scale = 1.0 / (head_dim as f32).sqrt();
364
365        for h in 0..num_heads {
366            let head_offset = h * seq_len * head_dim;
367
368            // Compute attention scores: Q * K^T
369            let mut scores = vec![0.0f32; seq_len * seq_len];
370            for i in 0..seq_len {
371                for j in 0..seq_len {
372                    let mut dot = 0.0f32;
373                    for d in 0..head_dim {
374                        dot += q[head_offset + i * head_dim + d]
375                            * k[head_offset + j * head_dim + d];
376                    }
377                    scores[i * seq_len + j] = dot * scale;
378                }
379            }
380
381            // Softmax
382            for i in 0..seq_len {
383                let row_start = i * seq_len;
384                let max_val = scores[row_start..row_start + seq_len]
385                    .iter()
386                    .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
387
388                let mut sum = 0.0f32;
389                for j in 0..seq_len {
390                    scores[row_start + j] = (scores[row_start + j] - max_val).exp();
391                    sum += scores[row_start + j];
392                }
393
394                for j in 0..seq_len {
395                    scores[row_start + j] /= sum;
396                }
397            }
398
399            // Compute output: scores * V
400            for i in 0..seq_len {
401                for d in 0..head_dim {
402                    let mut sum = 0.0f32;
403                    for j in 0..seq_len {
404                        sum += scores[i * seq_len + j] * v[head_offset + j * head_dim + d];
405                    }
406                    output[head_offset + i * head_dim + d] = sum;
407                }
408            }
409        }
410
411        output
412    }
413
414    fn layer_norm_single_thread(
415        &self,
416        input: &[f32],
417        gamma: &[f32],
418        beta: &[f32],
419        epsilon: f32,
420    ) -> Vec<f32> {
421        let n = input.len();
422        let hidden_dim = gamma.len();
423
424        if n % hidden_dim != 0 {
425            return input.to_vec(); // Fallback: return input unchanged
426        }
427
428        let batch_size = n / hidden_dim;
429        let mut output = vec![0.0f32; n];
430
431        for b in 0..batch_size {
432            let start = b * hidden_dim;
433            let end = start + hidden_dim;
434            let slice = &input[start..end];
435
436            // Compute mean
437            let mean: f32 = slice.iter().sum::<f32>() / hidden_dim as f32;
438
439            // Compute variance
440            let variance: f32 =
441                slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / hidden_dim as f32;
442
443            // Normalize
444            let std = (variance + epsilon).sqrt();
445            for i in 0..hidden_dim {
446                output[start + i] = ((input[start + i] - mean) / std) * gamma[i] + beta[i];
447            }
448        }
449
450        output
451    }
452}
453
454impl Drop for ParallelInference {
455    fn drop(&mut self) {
456        self.terminate();
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn test_matmul_single_thread() {
466        let inference = ParallelInference {
467            pool: WorkerPool::empty(),
468            shared_buffers: SharedBufferManager::new(),
469            initialized: true,
470        };
471
472        // 2x3 * 3x2 = 2x2
473        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
474        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
475
476        let c = inference.matmul_single_thread(&a, &b, 2, 2, 3);
477
478        // Expected: [[22, 28], [49, 64]]
479        assert_eq!(c.len(), 4);
480        assert!((c[0] - 22.0).abs() < 0.001);
481        assert!((c[1] - 28.0).abs() < 0.001);
482        assert!((c[2] - 49.0).abs() < 0.001);
483        assert!((c[3] - 64.0).abs() < 0.001);
484    }
485
486    #[test]
487    fn test_layer_norm_single_thread() {
488        let inference = ParallelInference {
489            pool: WorkerPool::empty(),
490            shared_buffers: SharedBufferManager::new(),
491            initialized: true,
492        };
493
494        let input = vec![1.0, 2.0, 3.0, 4.0];
495        let gamma = vec![1.0, 1.0, 1.0, 1.0];
496        let beta = vec![0.0, 0.0, 0.0, 0.0];
497        let epsilon = 1e-5;
498
499        let output = inference.layer_norm_single_thread(&input, &gamma, &beta, epsilon);
500
501        // After normalization, mean should be ~0 and std ~1
502        let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
503        assert!(mean.abs() < 0.001);
504    }
505}