sci_form/transport/
worker.rs1use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WorkerTask {
19 pub id: usize,
21 pub kind: TaskKind,
23 pub smiles: Vec<String>,
25 pub params: Vec<f64>,
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
31pub enum TaskKind {
32 EmbedBatch,
34 ComputeEsp,
36 ComputeDos,
38 ComputePopulation,
40 ComputeUff,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct WorkerResult {
47 pub task_id: usize,
49 pub success: bool,
51 pub error: Option<String>,
53 pub data: String,
55}
56
57pub fn split_batch(smiles: &[String], n_workers: usize, seed: u64) -> Vec<WorkerTask> {
65 let n = smiles.len();
66 let workers = n_workers.max(1).min(n);
67 let chunk_size = n.div_ceil(workers);
68
69 smiles
70 .chunks(chunk_size)
71 .enumerate()
72 .map(|(i, chunk)| WorkerTask {
73 id: i,
74 kind: TaskKind::EmbedBatch,
75 smiles: chunk.to_vec(),
76 params: vec![seed as f64],
77 })
78 .collect()
79}
80
81pub fn merge_results(mut results: Vec<WorkerResult>) -> Vec<WorkerResult> {
83 results.sort_by_key(|r| r.task_id);
84 results
85}
86
87pub fn estimate_workers(n_items: usize, max_workers: usize) -> usize {
91 let ideal = (n_items / 100).max(1);
92 ideal.min(max_workers).min(8)
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_split_batch_even() {
101 let smiles: Vec<String> = (0..10).map(|i| format!("mol{}", i)).collect();
102 let tasks = split_batch(&smiles, 2, 42);
103 assert_eq!(tasks.len(), 2);
104 assert_eq!(tasks[0].smiles.len(), 5);
105 assert_eq!(tasks[1].smiles.len(), 5);
106 assert_eq!(tasks[0].id, 0);
107 assert_eq!(tasks[1].id, 1);
108 }
109
110 #[test]
111 fn test_split_batch_uneven() {
112 let smiles: Vec<String> = (0..7).map(|i| format!("mol{}", i)).collect();
113 let tasks = split_batch(&smiles, 3, 42);
114 assert_eq!(tasks.len(), 3);
115 assert_eq!(tasks[0].smiles.len(), 3);
116 assert_eq!(tasks[1].smiles.len(), 3);
117 assert_eq!(tasks[2].smiles.len(), 1);
118 }
119
120 #[test]
121 fn test_split_batch_more_workers_than_items() {
122 let smiles: Vec<String> = vec!["C".to_string(), "CC".to_string()];
123 let tasks = split_batch(&smiles, 10, 42);
124 assert_eq!(tasks.len(), 2); }
126
127 #[test]
128 fn test_merge_results_ordered() {
129 let results = vec![
130 WorkerResult {
131 task_id: 2,
132 success: true,
133 error: None,
134 data: "r2".to_string(),
135 },
136 WorkerResult {
137 task_id: 0,
138 success: true,
139 error: None,
140 data: "r0".to_string(),
141 },
142 WorkerResult {
143 task_id: 1,
144 success: true,
145 error: None,
146 data: "r1".to_string(),
147 },
148 ];
149 let merged = merge_results(results);
150 assert_eq!(merged[0].task_id, 0);
151 assert_eq!(merged[1].task_id, 1);
152 assert_eq!(merged[2].task_id, 2);
153 }
154
155 #[test]
156 fn test_estimate_workers() {
157 assert_eq!(estimate_workers(50, 4), 1);
158 assert_eq!(estimate_workers(500, 8), 5);
159 assert_eq!(estimate_workers(10000, 8), 8);
160 assert_eq!(estimate_workers(10000, 4), 4);
161 }
162
163 #[test]
164 fn test_worker_task_serialization() {
165 let task = WorkerTask {
166 id: 0,
167 kind: TaskKind::EmbedBatch,
168 smiles: vec!["C".to_string(), "CC".to_string()],
169 params: vec![42.0],
170 };
171 let json = serde_json::to_string(&task).unwrap();
172 let back: WorkerTask = serde_json::from_str(&json).unwrap();
173 assert_eq!(back.id, 0);
174 assert_eq!(back.smiles.len(), 2);
175 }
176}