Skip to main content

sendcipher_core/
parallel_mapper.rs

1/* Created on 2025.12.03 */
2/* Copyright (c) 2025-2026 Youcef Lemsafer */
3/* SPDX-License-Identifier: MIT */
4
5use parking_lot::{Condvar, Mutex};
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
8use std::thread;
9
10pub struct ParallelMapper<Input, Output, Map>
11where
12    Input: Send + 'static,
13    Output: Send + 'static,
14    Map: Fn(Input) -> Output + Send + Sync + 'static,
15{
16    /// The maximum number of workers in this parallel workshop
17    max_workers: u32,
18    /// Number of active/busy workers
19    active_workers_count: Arc<AtomicU32>,
20    /// The working class
21    workers: Vec<thread::JoinHandle<()>>,
22    /// Number of pending work items
23    pending: Arc<AtomicU32>,
24    /// Condition variable to notify about full work queue
25    queue_filled_cond: Arc<Condvar>,
26    /// Condition variable to notify about empty work queue
27    queue_empty_cond: Arc<Condvar>,
28    /// Work queue (has at most one element)
29    work_queue: Arc<Mutex<Option<Input>>>,
30    /// Whether to shutdown i.e. no more work
31    is_shutdown: Arc<AtomicBool>,
32    /// Condition variable notifying about result availability
33    new_result_cond: Arc<Condvar>,
34    /// The results produced by the workers
35    results: Arc<Mutex<Vec<Output>>>,
36    /// The transformation performed by the workers on the input they receive
37    map: Arc<Map>,
38}
39
40impl<Input, Output, Map> ParallelMapper<Input, Output, Map>
41where
42    Input: Send + 'static,
43    Output: Send + 'static,
44    Map: Fn(Input) -> Output + Send + Sync + 'static,
45{
46    pub fn new(max_workers: u32, map: Map) -> Self {
47        Self {
48            max_workers,
49            active_workers_count: Arc::new(AtomicU32::new(0u32)),
50            // Lazy creation of workers
51            workers: Vec::with_capacity(max_workers as usize),
52            pending: Arc::new(AtomicU32::new(0u32)),
53            queue_filled_cond: Arc::new(Condvar::new()),
54            queue_empty_cond: Arc::new(Condvar::new()),
55            work_queue: Arc::new(Mutex::new(None)),
56            is_shutdown: Arc::new(AtomicBool::new(false)),
57            new_result_cond: Arc::new(Condvar::new()),
58            results: Arc::new(Mutex::new(Vec::new())),
59            map: Arc::new(map),
60        }
61    }
62
63    pub fn concurrency(&self) -> u32 {
64        self.max_workers
65    }
66
67    pub fn push(&mut self, input: Input) {
68        debug_assert!(!self.is_shutdown.load(Ordering::Relaxed));
69        self.pending.fetch_add(1, Ordering::AcqRel);
70        // The fast and easy path, no worker yet, push input and create a worker
71        if self.workers.is_empty() {
72            {
73                let mut wq = self.work_queue.lock();
74                debug_assert!(wq.is_none());
75                wq.insert(input);
76            }
77            return self.spawn_worker();
78        }
79
80        // Wait for the queue to become available
81        let mut wq = self.work_queue.lock();
82        while wq.is_some() {
83            self.queue_empty_cond.wait(&mut wq);
84        }
85        wq.insert(input);
86        self.queue_filled_cond.notify_one();
87        drop(wq);
88
89        let are_all_busy =
90            self.active_workers_count.load(Ordering::Relaxed) == self.workers.len() as u32;
91        if are_all_busy && self.workers.len() < self.max_workers as usize {
92            self.spawn_worker();
93        }
94    }
95
96    /// Pops and returns result if any
97    pub fn pop_result(&mut self) -> Option<Output> {
98        let mut results = self.results.lock();
99        if results.is_empty() {
100            return None;
101        }
102        results.pop()
103    }
104
105    /// Pops and returns all results if any
106    pub fn pop_all(&mut self) -> Vec<Output> {
107        let mut results = self.results.lock();
108        std::mem::take(results.as_mut())
109    }
110
111    /// Processes the provided inputs in parallel.
112    /// 
113    /// # Ordering
114    /// The order of outputs is **not guaranteed** to match the input order.
115    /// Outputs are returned in completion order.
116    pub fn process_all<I>(&mut self, inputs: I) -> Vec<Output>
117    where
118        I: IntoIterator,
119        I::Item: std::borrow::Borrow<Input>,
120        Input: Clone,
121    {
122        inputs.into_iter().for_each(|input| {
123            self.push(std::borrow::Borrow::borrow(&input).clone());
124        });
125        self.wait();
126        self.pop_all()
127    }
128
129    pub fn wait(&self) {
130        if self.is_shutdown.load(Ordering::Relaxed) {
131            return;
132        }
133        while self.pending.load(Ordering::Relaxed) != 0 {
134            let mut res = self.results.lock();
135            while res.is_empty() {
136                self.new_result_cond.wait(&mut res);
137            }
138        }
139    }
140
141    pub fn finish(&mut self) -> Vec<Output> {
142        self.is_shutdown.store(true, Ordering::Relaxed);
143        self.queue_filled_cond.notify_all();
144        for w in self.workers.drain(..) {
145            w.join();
146        }
147        let mut results = self.results.lock();
148        std::mem::take(results.as_mut())
149    }
150
151    fn spawn_worker(&mut self) {
152        let is_shutdown = Arc::clone(&self.is_shutdown);
153        let active_workers_count = Arc::clone(&self.active_workers_count);
154        let pending = Arc::clone(&self.pending);
155        let work_queue = Arc::clone(&self.work_queue);
156        let queue_empty_cond = Arc::clone(&self.queue_empty_cond);
157        let queue_filled_cond = Arc::clone(&self.queue_filled_cond);
158        let new_result_cond = Arc::clone(&self.new_result_cond);
159        let results = Arc::clone(&self.results);
160        let map = Arc::clone(&self.map);
161        //let worker_id = self.workers.len() + 1;
162
163        self.workers.push(thread::spawn(move || {
164            //  println!("Worker {} starts", worker_id);
165            loop {
166                //    println!("Worker {} loops", worker_id);
167                let mut input = {
168                    let mut wq = work_queue.lock();
169                    queue_filled_cond.wait_while(&mut wq, |q| {
170                        q.is_none() && !is_shutdown.load(Ordering::Acquire)
171                    });
172                    active_workers_count.fetch_add(1, Ordering::AcqRel);
173                    let inpt = wq.take();
174                    queue_empty_cond.notify_one();
175                    inpt
176                };
177                if input.is_none() && is_shutdown.load(Ordering::Acquire) {
178                    active_workers_count.fetch_sub(1, Ordering::AcqRel);
179                    break;
180                }
181                if input.is_some() {
182                    //        println!("Worker {} processes {}", worker_id, *input.as_ref().unwrap());
183                    let result = map(input.take().unwrap());
184                    results.lock().push(result);
185                    pending.fetch_sub(1, Ordering::AcqRel);
186                    new_result_cond.notify_one();
187                }
188                active_workers_count.fetch_sub(1, Ordering::AcqRel);
189            }
190            //  println!("Worker {} ends", worker_id);
191        }));
192    }
193}
194
195impl<Input, Output, Map> Drop for ParallelMapper<Input, Output, Map>
196where
197    Input: Send + 'static,
198    Output: Send + 'static,
199    Map: Fn(Input) -> Output + Send + Sync + 'static,
200{
201    fn drop(&mut self) {
202        self.is_shutdown.store(true, Ordering::Relaxed);
203        self.queue_filled_cond.notify_all();
204        for w in self.workers.drain(..) {
205            let _ = w.join();
206        }
207    }
208}
209
210pub struct DynParallelMapper<Input, Output>
211where
212    Input: Send + 'static,
213    Output: Send + 'static,
214{
215    par_mapper: ParallelMapper<Input, Output, Box<dyn Fn(Input) -> Output + Send + Sync>>,
216}
217
218impl<Input, Output> DynParallelMapper<Input, Output>
219where
220    Input: Send + 'static,
221    Output: Send + 'static,
222{
223    pub fn new(max_workers: u32, map: Box<dyn Fn(Input) -> Output + Send + Sync>) -> Self {
224        Self {
225            par_mapper: ParallelMapper::new(max_workers, map),
226        }
227    }
228    pub fn concurrency(&self) -> u32 {
229        self.par_mapper.concurrency()
230    }
231    pub fn process_all<I>(&mut self, inputs: I) -> Vec<Output>
232    where
233        I: IntoIterator,
234        I::Item: std::borrow::Borrow<Input>,
235        Input: Clone,
236    {
237        self.par_mapper.process_all(inputs)
238    }
239    pub fn push(&mut self, input: Input) {
240        self.par_mapper.push(input)
241    }
242    pub fn pop(&mut self) -> Option<Output> {
243        self.par_mapper.pop_result()
244    }
245    pub fn pop_all(&mut self) -> Vec<Output> {
246        self.par_mapper.pop_all()
247    }
248    pub fn wait(&self) {
249        self.par_mapper.wait();
250    }
251    pub fn finish(&mut self) -> Vec<Output> {
252        self.par_mapper.finish()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use std::thread;
259    use std::time::Duration;
260
261    use crate::parallel_mapper::{DynParallelMapper, ParallelMapper};
262
263    #[test]
264    fn test_basic_workers() {
265        let mut square_computers = ParallelMapper::<i32, i32, _>::new(4, |x| {
266            std::thread::sleep(Duration::from_millis(50));
267            x * x
268        });
269
270        (1..5).for_each(|x| square_computers.push(x));
271        let res = square_computers.finish();
272
273        assert_eq!(4, res.len());
274        assert!(res.contains(&1));
275        assert!(res.contains(&4));
276        assert!(res.contains(&9));
277        assert!(res.contains(&16));
278    }
279
280    #[test]
281    fn test_workers_with_various_completion_time() {
282        let mut par_mapper = ParallelMapper::<i32, i32, _>::new(4, |x| {
283            thread::sleep(Duration::from_micros(x as u64));
284            x * x
285        });
286        let inputs = [17, 11, 7, 23, 61, 31, 79];
287        let first_batch_len = 4;
288        (0..first_batch_len).for_each(|i| par_mapper.push(inputs[i]));
289
290        assert!(par_mapper.workers.len() <= 4);
291
292        let res1 = par_mapper.pop_result();
293        (first_batch_len..inputs.len()).for_each(|i| par_mapper.push(inputs[i]));
294
295        assert!(par_mapper.workers.len() <= 4);
296
297        let res2 = par_mapper.pop_result();
298        let mut results = Vec::new();
299        if res1.is_some() {
300            results.push(res1.unwrap());
301        }
302        if res2.is_some() {
303            results.push(res2.unwrap());
304        }
305        results.extend(par_mapper.finish());
306        results.sort();
307
308        let mut expected_outputs = inputs.map(|x| x * x);
309        expected_outputs.sort();
310        assert_eq!(results, expected_outputs);
311    }
312
313    #[test]
314    fn test_drop_parallel_mapper_instance() {
315        let mut par_mapper = ParallelMapper::<i32, i32, _>::new(4, |x| {
316            thread::sleep(Duration::from_micros(100));
317            x * x
318        });
319        let inputs_count = 16;
320        (0..inputs_count).for_each(|x| par_mapper.push(x));
321        drop(par_mapper);
322    }
323
324    #[test]
325    fn test_interleave_push_pop() {
326        let mut workers = ParallelMapper::<i32, i32, _>::new(4, |x| {
327            thread::sleep(Duration::from_micros(if x % 2 != 0 { 20 } else { 12 }));
328            x * x
329        });
330
331        let iterations_count = 1024i32;
332        let mut results = Vec::<i32>::with_capacity(iterations_count as usize);
333        (0..iterations_count).for_each(|x| {
334            workers.push(x);
335            if let Some(res) = workers.pop_result() {
336                results.push(res);
337            }
338        });
339        results.extend(workers.finish());
340        results.sort();
341        let expected_results = (0..iterations_count).map(|x| x * x).collect::<Vec<i32>>();
342
343        assert_eq!(results.len(), expected_results.len());
344        assert_eq!(results, expected_results);
345    }
346
347    #[test]
348    fn test_wait() {
349        let mut par_mapper = ParallelMapper::<u32, u32, _>::new(4, |x|{
350            thread::sleep(Duration::from_micros(x as u64));
351            x * x
352        });
353
354        let inputs: Vec<u32> = vec![2, 7, 97, 31, 257, 929, 19, 313];
355        let mut results = Vec::<u32>::with_capacity(inputs.len());
356        inputs.iter().for_each(|x| {
357            par_mapper.push(*x);
358            results.extend(par_mapper.pop_all());
359        });
360        par_mapper.wait();
361        results.extend(par_mapper.pop_all());
362        results.sort();
363
364        let mut expected: Vec<u32> = inputs.iter().map(|x| x*x).collect();
365        expected.sort();
366
367        assert_eq!(results.len(), expected.len());
368        assert_eq!(results, expected);
369    }
370
371    #[test]
372    fn test_process_all() {
373        let mut workers = ParallelMapper::<u32, u32, _>::new(4, |x| {
374            thread::sleep(Duration::from_micros(x as u64));
375            x * x
376        });
377        let inputs = vec![23, 11, 67, 251, 7, 8, 641, 37];
378        let mut res = workers.process_all(&inputs);
379
380        let mut expected: Vec<u32> = inputs.iter().map(|x| x*x).collect();
381        res.sort();
382        expected.sort();
383        assert_eq!(res, expected);
384    }
385
386    #[test]
387    fn test_dyn_parallel_mapper() {
388        let mut par_mapper = DynParallelMapper::<i32, i32>::new(4, Box::new(|x| x * x));
389        par_mapper.push(5);
390        par_mapper.push(7);
391        thread::sleep(Duration::from_millis(10));
392        let mut results = par_mapper.pop_all();
393        results.extend(par_mapper.finish());
394
395        assert_eq!(results.len(), 2 as usize);
396        assert!(results.contains(&25));
397        assert!(results.contains(&49));
398    }
399}