1use parking_lot::{Condvar, Mutex};
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
8use std::thread;
9
10pub struct ParallelMapper<Input, Output, Map>
11where
12 Input: Send + 'static,
13 Output: Send + 'static,
14 Map: Fn(Input) -> Output + Send + Sync + 'static,
15{
16 max_workers: u32,
18 active_workers_count: Arc<AtomicU32>,
20 workers: Vec<thread::JoinHandle<()>>,
22 pending: Arc<AtomicU32>,
24 queue_filled_cond: Arc<Condvar>,
26 queue_empty_cond: Arc<Condvar>,
28 work_queue: Arc<Mutex<Option<Input>>>,
30 is_shutdown: Arc<AtomicBool>,
32 new_result_cond: Arc<Condvar>,
34 results: Arc<Mutex<Vec<Output>>>,
36 map: Arc<Map>,
38}
39
40impl<Input, Output, Map> ParallelMapper<Input, Output, Map>
41where
42 Input: Send + 'static,
43 Output: Send + 'static,
44 Map: Fn(Input) -> Output + Send + Sync + 'static,
45{
46 pub fn new(max_workers: u32, map: Map) -> Self {
47 Self {
48 max_workers,
49 active_workers_count: Arc::new(AtomicU32::new(0u32)),
50 workers: Vec::with_capacity(max_workers as usize),
52 pending: Arc::new(AtomicU32::new(0u32)),
53 queue_filled_cond: Arc::new(Condvar::new()),
54 queue_empty_cond: Arc::new(Condvar::new()),
55 work_queue: Arc::new(Mutex::new(None)),
56 is_shutdown: Arc::new(AtomicBool::new(false)),
57 new_result_cond: Arc::new(Condvar::new()),
58 results: Arc::new(Mutex::new(Vec::new())),
59 map: Arc::new(map),
60 }
61 }
62
63 pub fn concurrency(&self) -> u32 {
64 self.max_workers
65 }
66
67 pub fn push(&mut self, input: Input) {
68 debug_assert!(!self.is_shutdown.load(Ordering::Relaxed));
69 self.pending.fetch_add(1, Ordering::AcqRel);
70 if self.workers.is_empty() {
72 {
73 let mut wq = self.work_queue.lock();
74 debug_assert!(wq.is_none());
75 wq.insert(input);
76 }
77 return self.spawn_worker();
78 }
79
80 let mut wq = self.work_queue.lock();
82 while wq.is_some() {
83 self.queue_empty_cond.wait(&mut wq);
84 }
85 wq.insert(input);
86 self.queue_filled_cond.notify_one();
87 drop(wq);
88
89 let are_all_busy =
90 self.active_workers_count.load(Ordering::Relaxed) == self.workers.len() as u32;
91 if are_all_busy && self.workers.len() < self.max_workers as usize {
92 self.spawn_worker();
93 }
94 }
95
96 pub fn pop_result(&mut self) -> Option<Output> {
98 let mut results = self.results.lock();
99 if results.is_empty() {
100 return None;
101 }
102 results.pop()
103 }
104
105 pub fn pop_all(&mut self) -> Vec<Output> {
107 let mut results = self.results.lock();
108 std::mem::take(results.as_mut())
109 }
110
111 pub fn process_all<I>(&mut self, inputs: I) -> Vec<Output>
117 where
118 I: IntoIterator,
119 I::Item: std::borrow::Borrow<Input>,
120 Input: Clone,
121 {
122 inputs.into_iter().for_each(|input| {
123 self.push(std::borrow::Borrow::borrow(&input).clone());
124 });
125 self.wait();
126 self.pop_all()
127 }
128
129 pub fn wait(&self) {
130 if self.is_shutdown.load(Ordering::Relaxed) {
131 return;
132 }
133 while self.pending.load(Ordering::Relaxed) != 0 {
134 let mut res = self.results.lock();
135 while res.is_empty() {
136 self.new_result_cond.wait(&mut res);
137 }
138 }
139 }
140
141 pub fn finish(&mut self) -> Vec<Output> {
142 self.is_shutdown.store(true, Ordering::Relaxed);
143 self.queue_filled_cond.notify_all();
144 for w in self.workers.drain(..) {
145 w.join();
146 }
147 let mut results = self.results.lock();
148 std::mem::take(results.as_mut())
149 }
150
151 fn spawn_worker(&mut self) {
152 let is_shutdown = Arc::clone(&self.is_shutdown);
153 let active_workers_count = Arc::clone(&self.active_workers_count);
154 let pending = Arc::clone(&self.pending);
155 let work_queue = Arc::clone(&self.work_queue);
156 let queue_empty_cond = Arc::clone(&self.queue_empty_cond);
157 let queue_filled_cond = Arc::clone(&self.queue_filled_cond);
158 let new_result_cond = Arc::clone(&self.new_result_cond);
159 let results = Arc::clone(&self.results);
160 let map = Arc::clone(&self.map);
161 self.workers.push(thread::spawn(move || {
164 loop {
166 let mut input = {
168 let mut wq = work_queue.lock();
169 queue_filled_cond.wait_while(&mut wq, |q| {
170 q.is_none() && !is_shutdown.load(Ordering::Acquire)
171 });
172 active_workers_count.fetch_add(1, Ordering::AcqRel);
173 let inpt = wq.take();
174 queue_empty_cond.notify_one();
175 inpt
176 };
177 if input.is_none() && is_shutdown.load(Ordering::Acquire) {
178 active_workers_count.fetch_sub(1, Ordering::AcqRel);
179 break;
180 }
181 if input.is_some() {
182 let result = map(input.take().unwrap());
184 results.lock().push(result);
185 pending.fetch_sub(1, Ordering::AcqRel);
186 new_result_cond.notify_one();
187 }
188 active_workers_count.fetch_sub(1, Ordering::AcqRel);
189 }
190 }));
192 }
193}
194
195impl<Input, Output, Map> Drop for ParallelMapper<Input, Output, Map>
196where
197 Input: Send + 'static,
198 Output: Send + 'static,
199 Map: Fn(Input) -> Output + Send + Sync + 'static,
200{
201 fn drop(&mut self) {
202 self.is_shutdown.store(true, Ordering::Relaxed);
203 self.queue_filled_cond.notify_all();
204 for w in self.workers.drain(..) {
205 let _ = w.join();
206 }
207 }
208}
209
210pub struct DynParallelMapper<Input, Output>
211where
212 Input: Send + 'static,
213 Output: Send + 'static,
214{
215 par_mapper: ParallelMapper<Input, Output, Box<dyn Fn(Input) -> Output + Send + Sync>>,
216}
217
218impl<Input, Output> DynParallelMapper<Input, Output>
219where
220 Input: Send + 'static,
221 Output: Send + 'static,
222{
223 pub fn new(max_workers: u32, map: Box<dyn Fn(Input) -> Output + Send + Sync>) -> Self {
224 Self {
225 par_mapper: ParallelMapper::new(max_workers, map),
226 }
227 }
228 pub fn concurrency(&self) -> u32 {
229 self.par_mapper.concurrency()
230 }
231 pub fn process_all<I>(&mut self, inputs: I) -> Vec<Output>
232 where
233 I: IntoIterator,
234 I::Item: std::borrow::Borrow<Input>,
235 Input: Clone,
236 {
237 self.par_mapper.process_all(inputs)
238 }
239 pub fn push(&mut self, input: Input) {
240 self.par_mapper.push(input)
241 }
242 pub fn pop(&mut self) -> Option<Output> {
243 self.par_mapper.pop_result()
244 }
245 pub fn pop_all(&mut self) -> Vec<Output> {
246 self.par_mapper.pop_all()
247 }
248 pub fn wait(&self) {
249 self.par_mapper.wait();
250 }
251 pub fn finish(&mut self) -> Vec<Output> {
252 self.par_mapper.finish()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use std::thread;
259 use std::time::Duration;
260
261 use crate::parallel_mapper::{DynParallelMapper, ParallelMapper};
262
263 #[test]
264 fn test_basic_workers() {
265 let mut square_computers = ParallelMapper::<i32, i32, _>::new(4, |x| {
266 std::thread::sleep(Duration::from_millis(50));
267 x * x
268 });
269
270 (1..5).for_each(|x| square_computers.push(x));
271 let res = square_computers.finish();
272
273 assert_eq!(4, res.len());
274 assert!(res.contains(&1));
275 assert!(res.contains(&4));
276 assert!(res.contains(&9));
277 assert!(res.contains(&16));
278 }
279
280 #[test]
281 fn test_workers_with_various_completion_time() {
282 let mut par_mapper = ParallelMapper::<i32, i32, _>::new(4, |x| {
283 thread::sleep(Duration::from_micros(x as u64));
284 x * x
285 });
286 let inputs = [17, 11, 7, 23, 61, 31, 79];
287 let first_batch_len = 4;
288 (0..first_batch_len).for_each(|i| par_mapper.push(inputs[i]));
289
290 assert!(par_mapper.workers.len() <= 4);
291
292 let res1 = par_mapper.pop_result();
293 (first_batch_len..inputs.len()).for_each(|i| par_mapper.push(inputs[i]));
294
295 assert!(par_mapper.workers.len() <= 4);
296
297 let res2 = par_mapper.pop_result();
298 let mut results = Vec::new();
299 if res1.is_some() {
300 results.push(res1.unwrap());
301 }
302 if res2.is_some() {
303 results.push(res2.unwrap());
304 }
305 results.extend(par_mapper.finish());
306 results.sort();
307
308 let mut expected_outputs = inputs.map(|x| x * x);
309 expected_outputs.sort();
310 assert_eq!(results, expected_outputs);
311 }
312
313 #[test]
314 fn test_drop_parallel_mapper_instance() {
315 let mut par_mapper = ParallelMapper::<i32, i32, _>::new(4, |x| {
316 thread::sleep(Duration::from_micros(100));
317 x * x
318 });
319 let inputs_count = 16;
320 (0..inputs_count).for_each(|x| par_mapper.push(x));
321 drop(par_mapper);
322 }
323
324 #[test]
325 fn test_interleave_push_pop() {
326 let mut workers = ParallelMapper::<i32, i32, _>::new(4, |x| {
327 thread::sleep(Duration::from_micros(if x % 2 != 0 { 20 } else { 12 }));
328 x * x
329 });
330
331 let iterations_count = 1024i32;
332 let mut results = Vec::<i32>::with_capacity(iterations_count as usize);
333 (0..iterations_count).for_each(|x| {
334 workers.push(x);
335 if let Some(res) = workers.pop_result() {
336 results.push(res);
337 }
338 });
339 results.extend(workers.finish());
340 results.sort();
341 let expected_results = (0..iterations_count).map(|x| x * x).collect::<Vec<i32>>();
342
343 assert_eq!(results.len(), expected_results.len());
344 assert_eq!(results, expected_results);
345 }
346
347 #[test]
348 fn test_wait() {
349 let mut par_mapper = ParallelMapper::<u32, u32, _>::new(4, |x|{
350 thread::sleep(Duration::from_micros(x as u64));
351 x * x
352 });
353
354 let inputs: Vec<u32> = vec![2, 7, 97, 31, 257, 929, 19, 313];
355 let mut results = Vec::<u32>::with_capacity(inputs.len());
356 inputs.iter().for_each(|x| {
357 par_mapper.push(*x);
358 results.extend(par_mapper.pop_all());
359 });
360 par_mapper.wait();
361 results.extend(par_mapper.pop_all());
362 results.sort();
363
364 let mut expected: Vec<u32> = inputs.iter().map(|x| x*x).collect();
365 expected.sort();
366
367 assert_eq!(results.len(), expected.len());
368 assert_eq!(results, expected);
369 }
370
371 #[test]
372 fn test_process_all() {
373 let mut workers = ParallelMapper::<u32, u32, _>::new(4, |x| {
374 thread::sleep(Duration::from_micros(x as u64));
375 x * x
376 });
377 let inputs = vec![23, 11, 67, 251, 7, 8, 641, 37];
378 let mut res = workers.process_all(&inputs);
379
380 let mut expected: Vec<u32> = inputs.iter().map(|x| x*x).collect();
381 res.sort();
382 expected.sort();
383 assert_eq!(res, expected);
384 }
385
386 #[test]
387 fn test_dyn_parallel_mapper() {
388 let mut par_mapper = DynParallelMapper::<i32, i32>::new(4, Box::new(|x| x * x));
389 par_mapper.push(5);
390 par_mapper.push(7);
391 thread::sleep(Duration::from_millis(10));
392 let mut results = par_mapper.pop_all();
393 results.extend(par_mapper.finish());
394
395 assert_eq!(results.len(), 2 as usize);
396 assert!(results.contains(&25));
397 assert!(results.contains(&49));
398 }
399}