1use crate::{
2 TaskFnPointer, TaskParamPointer, WorkItem, future::WorkFuture, queue::BatchQueue,
3 uniform_tasks_to_pointers, worker::Worker,
4};
5use std::sync::Arc;
6
7#[repr(align(64))]
8pub struct ThreadPool {
9 queue: Arc<BatchQueue>,
10 workers: Vec<Worker>,
11}
12
13impl ThreadPool {
14 pub fn new(worker_count: usize) -> Self {
15 assert!(worker_count > 0, "Must have at least one worker");
16
17 let queue = Arc::new(BatchQueue::new());
18
19 let workers: Vec<Worker> = (0..worker_count)
20 .map(|id| Worker::new(id, queue.clone()))
21 .collect();
22
23 ThreadPool { queue, workers }
24 }
25
26 pub fn submit_raw_task(&self, task_fn: TaskFnPointer, params: TaskParamPointer) -> WorkFuture {
27 self.queue.push_single_task(task_fn, params)
28 }
29
30 pub fn submit_raw_task_batch(&self, tasks: &[WorkItem]) -> WorkFuture {
31 self.queue.push_task_batch(tasks)
32 }
33
34 pub fn submit_task<T>(&self, task_fn: TaskFnPointer, params: &T) -> WorkFuture {
35 let params_ptr = params as *const T as TaskParamPointer;
36 self.submit_raw_task(task_fn, params_ptr)
37 }
38
39 pub fn submit_batch_uniform<T>(&self, task_fn: TaskFnPointer, params_vec: &[T]) -> WorkFuture {
40 let tasks = uniform_tasks_to_pointers(task_fn, params_vec);
41 self.submit_raw_task_batch(&tasks)
42 }
43}
44
45impl Drop for ThreadPool {
46 fn drop(&mut self) {
47 self.queue.shutdown();
48
49 let workers = std::mem::take(&mut self.workers);
50 for worker in workers {
51 let _ = worker.handle.join();
52 }
53 }
54}