Skip to main content

vortex_io/runtime/
pool.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5use std::sync::atomic::AtomicBool;
6use std::sync::atomic::Ordering;
7use std::time::Duration;
8
9use parking_lot::Mutex;
10use smol::block_on;
11use vortex_error::VortexExpect;
12use vortex_utils::parallelism::get_available_parallelism;
13
14#[derive(Clone)]
15pub struct CurrentThreadWorkerPool {
16    executor: Arc<smol::Executor<'static>>,
17    state: Arc<Mutex<PoolState>>,
18}
19
20impl CurrentThreadWorkerPool {
21    pub(super) fn new(executor: Arc<smol::Executor<'static>>) -> Self {
22        Self {
23            executor,
24            state: Arc::new(Mutex::new(PoolState::default())),
25        }
26    }
27
28    /// Set the number of worker threads to the available system parallelism as reported by
29    /// [`get_available_parallelism()`] minus 1, to leave a slot open for the calling thread.
30    pub fn set_workers_to_available_parallelism(&self) {
31        let n = get_available_parallelism()
32            .map(|n| n.saturating_sub(1).max(1))
33            .unwrap_or(1);
34        self.set_workers(n);
35    }
36
37    /// Set the number of worker threads
38    /// - If n > current: spawns additional workers
39    /// - If n < current: signals extra workers to shut down
40    pub fn set_workers(&self, n: usize) {
41        let mut state = self.state.lock();
42        let current = state.workers.len();
43
44        if n > current {
45            // Spawn new workers
46            for _ in current..n {
47                let shutdown = Arc::new(AtomicBool::new(false));
48                let executor = Arc::clone(&self.executor);
49                let shutdown_clone = Arc::clone(&shutdown);
50
51                std::thread::Builder::new()
52                    .name("vortex-current-thread-worker".to_string())
53                    .spawn(move || {
54                        // Run the executor with a sleeping future that checks for shutdown
55                        block_on(executor.run(async move {
56                            while !shutdown_clone.load(Ordering::Relaxed) {
57                                smol::Timer::after(Duration::from_millis(100)).await;
58                            }
59                        }))
60                    })
61                    .vortex_expect("Failed to spawn current thread worker");
62
63                state.workers.push(WorkerHandle { shutdown });
64            }
65        } else if n < current {
66            // Signal extra workers to shutdown and remove them
67            while state.workers.len() > n {
68                if let Some(worker) = state.workers.pop() {
69                    worker.shutdown.store(true, Ordering::Relaxed);
70                }
71            }
72        }
73    }
74
75    /// Get the current number of worker threads
76    pub fn worker_count(&self) -> usize {
77        self.state.lock().workers.len()
78    }
79}
80
81#[derive(Default)]
82struct PoolState {
83    /// The set of worker handles for the background threads.
84    workers: Vec<WorkerHandle>,
85}
86
87struct WorkerHandle {
88    /// The shutdown flag indicating that the worker should stop.
89    shutdown: Arc<AtomicBool>,
90}
91
92impl Drop for CurrentThreadWorkerPool {
93    fn drop(&mut self) {
94        let mut state = self.state.lock();
95
96        // Signal all workers to shut down
97        for worker in state.workers.drain(..) {
98            worker.shutdown.store(true, Ordering::Relaxed);
99        }
100    }
101}