Skip to main content

sci_form/transport/
worker.rs

1//! Web Worker dispatch strategy for WASM-safe parallelism.
2//!
3//! In browsers, heavy computations block the main thread. Web Workers provide
4//! separate threads, but SharedArrayBuffer requires cross-origin isolation.
5//!
6//! This module provides:
7//! - Task descriptors for dispatching work to workers
8//! - Result aggregation for collecting worker outputs
9//! - Batch splitting for dividing work across N workers
10//!
11//! The actual Web Worker creation happens in JavaScript; this module provides
12//! the data structures and splitting logic.
13
14use serde::{Deserialize, Serialize};
15
16/// A task to be dispatched to a Web Worker.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WorkerTask {
19    /// Unique task identifier.
20    pub id: usize,
21    /// Task type.
22    pub kind: TaskKind,
23    /// SMILES strings to process (for batch embedding tasks).
24    pub smiles: Vec<String>,
25    /// Numeric parameters (e.g., seed, spacing, etc.).
26    pub params: Vec<f64>,
27}
28
29/// Types of tasks that can be dispatched to workers.
30#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
31pub enum TaskKind {
32    /// Batch conformer generation.
33    EmbedBatch,
34    /// ESP grid computation.
35    ComputeEsp,
36    /// DOS computation.
37    ComputeDos,
38    /// Population analysis batch.
39    ComputePopulation,
40    /// UFF energy evaluation batch.
41    ComputeUff,
42}
43
44/// Result from a completed worker task.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct WorkerResult {
47    /// Task ID this result corresponds to.
48    pub task_id: usize,
49    /// Whether the task completed successfully.
50    pub success: bool,
51    /// Error message if failed.
52    pub error: Option<String>,
53    /// JSON-encoded result data.
54    pub data: String,
55}
56
57/// Split a batch of SMILES into N worker tasks.
58///
59/// `smiles`: all SMILES to process
60/// `n_workers`: number of workers to distribute across
61/// `seed`: RNG seed
62///
63/// Each chunk is capped at 10 000 SMILES to prevent unbounded memory
64/// allocation when few workers are requested for very large inputs.
65///
66/// Returns a vector of tasks, one per worker.
67pub fn split_batch(smiles: &[String], n_workers: usize, seed: u64) -> Vec<WorkerTask> {
68    let n = smiles.len();
69    const MAX_CHUNK: usize = 10_000;
70    // Ensure enough workers so no single chunk exceeds MAX_CHUNK
71    let min_workers_for_cap = n.div_ceil(MAX_CHUNK);
72    let workers = n_workers.max(1).max(min_workers_for_cap).min(n);
73    let chunk_size = n.div_ceil(workers);
74
75    smiles
76        .chunks(chunk_size)
77        .enumerate()
78        .map(|(i, chunk)| WorkerTask {
79            id: i,
80            kind: TaskKind::EmbedBatch,
81            smiles: chunk.to_vec(),
82            params: vec![seed as f64],
83        })
84        .collect()
85}
86
87/// Merge worker results back into order (by task_id).
88pub fn merge_results(mut results: Vec<WorkerResult>) -> Vec<WorkerResult> {
89    results.sort_by_key(|r| r.task_id);
90    results
91}
92
93/// Estimate the optimal number of workers based on data size.
94///
95/// Heuristic: 1 worker per 100 molecules, max 8 for browser environments.
96pub fn estimate_workers(n_items: usize, max_workers: usize) -> usize {
97    let ideal = (n_items / 100).max(1);
98    ideal.min(max_workers).min(8)
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_split_batch_even() {
107        let smiles: Vec<String> = (0..10).map(|i| format!("mol{}", i)).collect();
108        let tasks = split_batch(&smiles, 2, 42);
109        assert_eq!(tasks.len(), 2);
110        assert_eq!(tasks[0].smiles.len(), 5);
111        assert_eq!(tasks[1].smiles.len(), 5);
112        assert_eq!(tasks[0].id, 0);
113        assert_eq!(tasks[1].id, 1);
114    }
115
116    #[test]
117    fn test_split_batch_uneven() {
118        let smiles: Vec<String> = (0..7).map(|i| format!("mol{}", i)).collect();
119        let tasks = split_batch(&smiles, 3, 42);
120        assert_eq!(tasks.len(), 3);
121        assert_eq!(tasks[0].smiles.len(), 3);
122        assert_eq!(tasks[1].smiles.len(), 3);
123        assert_eq!(tasks[2].smiles.len(), 1);
124    }
125
126    #[test]
127    fn test_split_batch_more_workers_than_items() {
128        let smiles: Vec<String> = vec!["C".to_string(), "CC".to_string()];
129        let tasks = split_batch(&smiles, 10, 42);
130        assert_eq!(tasks.len(), 2); // min(10, 2) = 2 workers
131    }
132
133    #[test]
134    fn test_merge_results_ordered() {
135        let results = vec![
136            WorkerResult {
137                task_id: 2,
138                success: true,
139                error: None,
140                data: "r2".to_string(),
141            },
142            WorkerResult {
143                task_id: 0,
144                success: true,
145                error: None,
146                data: "r0".to_string(),
147            },
148            WorkerResult {
149                task_id: 1,
150                success: true,
151                error: None,
152                data: "r1".to_string(),
153            },
154        ];
155        let merged = merge_results(results);
156        assert_eq!(merged[0].task_id, 0);
157        assert_eq!(merged[1].task_id, 1);
158        assert_eq!(merged[2].task_id, 2);
159    }
160
161    #[test]
162    fn test_estimate_workers() {
163        assert_eq!(estimate_workers(50, 4), 1);
164        assert_eq!(estimate_workers(500, 8), 5);
165        assert_eq!(estimate_workers(10000, 8), 8);
166        assert_eq!(estimate_workers(10000, 4), 4);
167    }
168
169    #[test]
170    fn test_worker_task_serialization() {
171        let task = WorkerTask {
172            id: 0,
173            kind: TaskKind::EmbedBatch,
174            smiles: vec!["C".to_string(), "CC".to_string()],
175            params: vec![42.0],
176        };
177        let json = serde_json::to_string(&task).unwrap();
178        let back: WorkerTask = serde_json::from_str(&json).unwrap();
179        assert_eq!(back.id, 0);
180        assert_eq!(back.smiles.len(), 2);
181    }
182}