sci_form/transport/
worker.rs1use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WorkerTask {
19 pub id: usize,
21 pub kind: TaskKind,
23 pub smiles: Vec<String>,
25 pub params: Vec<f64>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
31pub enum TaskKind {
32 EmbedBatch,
34 ComputeEsp,
36 ComputeDos,
38 ComputePopulation,
40 ComputeUff,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct WorkerResult {
47 pub task_id: usize,
49 pub success: bool,
51 pub error: Option<String>,
53 pub data: String,
55}
56
57pub fn split_batch(smiles: &[String], n_workers: usize, seed: u64) -> Vec<WorkerTask> {
68 let n = smiles.len();
69 const MAX_CHUNK: usize = 10_000;
70 let min_workers_for_cap = n.div_ceil(MAX_CHUNK);
72 let workers = n_workers.max(1).max(min_workers_for_cap).min(n);
73 let chunk_size = n.div_ceil(workers);
74
75 smiles
76 .chunks(chunk_size)
77 .enumerate()
78 .map(|(i, chunk)| WorkerTask {
79 id: i,
80 kind: TaskKind::EmbedBatch,
81 smiles: chunk.to_vec(),
82 params: vec![seed as f64],
83 })
84 .collect()
85}
86
87pub fn merge_results(mut results: Vec<WorkerResult>) -> Vec<WorkerResult> {
89 results.sort_by_key(|r| r.task_id);
90 results
91}
92
93pub fn estimate_workers(n_items: usize, max_workers: usize) -> usize {
97 let ideal = (n_items / 100).max(1);
98 ideal.min(max_workers).min(8)
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_split_batch_even() {
107 let smiles: Vec<String> = (0..10).map(|i| format!("mol{}", i)).collect();
108 let tasks = split_batch(&smiles, 2, 42);
109 assert_eq!(tasks.len(), 2);
110 assert_eq!(tasks[0].smiles.len(), 5);
111 assert_eq!(tasks[1].smiles.len(), 5);
112 assert_eq!(tasks[0].id, 0);
113 assert_eq!(tasks[1].id, 1);
114 }
115
116 #[test]
117 fn test_split_batch_uneven() {
118 let smiles: Vec<String> = (0..7).map(|i| format!("mol{}", i)).collect();
119 let tasks = split_batch(&smiles, 3, 42);
120 assert_eq!(tasks.len(), 3);
121 assert_eq!(tasks[0].smiles.len(), 3);
122 assert_eq!(tasks[1].smiles.len(), 3);
123 assert_eq!(tasks[2].smiles.len(), 1);
124 }
125
126 #[test]
127 fn test_split_batch_more_workers_than_items() {
128 let smiles: Vec<String> = vec!["C".to_string(), "CC".to_string()];
129 let tasks = split_batch(&smiles, 10, 42);
130 assert_eq!(tasks.len(), 2); }
132
133 #[test]
134 fn test_merge_results_ordered() {
135 let results = vec![
136 WorkerResult {
137 task_id: 2,
138 success: true,
139 error: None,
140 data: "r2".to_string(),
141 },
142 WorkerResult {
143 task_id: 0,
144 success: true,
145 error: None,
146 data: "r0".to_string(),
147 },
148 WorkerResult {
149 task_id: 1,
150 success: true,
151 error: None,
152 data: "r1".to_string(),
153 },
154 ];
155 let merged = merge_results(results);
156 assert_eq!(merged[0].task_id, 0);
157 assert_eq!(merged[1].task_id, 1);
158 assert_eq!(merged[2].task_id, 2);
159 }
160
161 #[test]
162 fn test_estimate_workers() {
163 assert_eq!(estimate_workers(50, 4), 1);
164 assert_eq!(estimate_workers(500, 8), 5);
165 assert_eq!(estimate_workers(10000, 8), 8);
166 assert_eq!(estimate_workers(10000, 4), 4);
167 }
168
169 #[test]
170 fn test_worker_task_serialization() {
171 let task = WorkerTask {
172 id: 0,
173 kind: TaskKind::EmbedBatch,
174 smiles: vec!["C".to_string(), "CC".to_string()],
175 params: vec![42.0],
176 };
177 let json = serde_json::to_string(&task).unwrap();
178 let back: WorkerTask = serde_json::from_str(&json).unwrap();
179 assert_eq!(back.id, 0);
180 assert_eq!(back.smiles.len(), 2);
181 }
182}