1use std::sync::{Arc, Mutex, Weak};
2
3use core_affinity::CoreId;
4use indexmap::IndexMap;
5use smallvec::SmallVec;
6use spacetimedb_data_structures::map::HashMap;
7use tokio::sync::{mpsc, oneshot, watch};
8
9use super::notify_once::NotifyOnce;
10
11#[derive(Default, Clone)]
29pub struct JobCores {
30 inner: Option<Arc<Mutex<JobCoresInner>>>,
31}
32
33struct JobCoresInner {
34 job_threads: HashMap<JobThreadId, watch::Sender<CoreId>>,
36 cores: IndexMap<CoreId, CoreInfo>,
37 next_core: usize,
42 next_id: JobThreadId,
43}
44
45#[derive(Default)]
46struct CoreInfo {
47 jobs: SmallVec<[JobThreadId; 4]>,
48}
49
50#[derive(Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
51struct JobThreadId(usize);
52
53impl JobCores {
54 pub fn take(&self) -> JobCore {
56 let inner = if let Some(inner) = &self.inner {
57 let cores = Arc::downgrade(inner);
58 let (id, repin_rx) = inner.lock().unwrap().allocate();
59 Some(JobCoreInner {
60 repin_rx,
61 _guard: JobCoreGuard { cores, id },
62 })
63 } else {
64 None
65 };
66
67 JobCore { inner }
68 }
69}
70
71impl FromIterator<CoreId> for JobCores {
72 fn from_iter<T: IntoIterator<Item = CoreId>>(iter: T) -> Self {
73 let cores: IndexMap<_, _> = iter.into_iter().map(|id| (id, CoreInfo::default())).collect();
74 let inner = (!cores.is_empty()).then(|| {
75 Arc::new(Mutex::new(JobCoresInner {
76 job_threads: HashMap::default(),
77 cores,
78 next_core: 0,
79 next_id: JobThreadId(0),
80 }))
81 });
82 Self { inner }
83 }
84}
85
86impl JobCoresInner {
87 fn allocate(&mut self) -> (JobThreadId, watch::Receiver<CoreId>) {
88 let id = self.next_id;
89 self.next_id.0 += 1;
90
91 let (&core_id, core) = self.cores.get_index_mut(self.next_core).unwrap();
92 core.jobs.push(id);
93 self.next_core = (self.next_core + 1) % self.cores.len();
94
95 let (repin_tx, repin_rx) = watch::channel(core_id);
96 self.job_threads.insert(id, repin_tx);
97
98 (id, repin_rx)
99 }
100
101 fn deallocate(&mut self, id: JobThreadId) {
103 let core_id = *self.job_threads.remove(&id).unwrap().borrow();
104
105 let core_index = self.cores.get_index_of(&core_id).unwrap();
106
107 let steal_from_index = self.next_core.checked_sub(1).unwrap_or(self.cores.len() - 1);
114
115 let (core, steal_from) = match self.cores.get_disjoint_indices_mut([core_index, steal_from_index]) {
117 Ok([(_, core), (_, steal_from)]) => (core, Some(steal_from)),
118 Err(_) => (&mut self.cores[core_index], None),
119 };
120
121 let pos = core.jobs.iter().position(|x| *x == id).unwrap();
122 core.jobs.remove(pos);
123
124 if let Some(steal_from) = steal_from {
125 let stolen = steal_from.jobs.pop().unwrap();
132 core.jobs.push(stolen);
135 self.job_threads[&stolen].send_replace(core_id);
136 }
137
138 self.next_core = steal_from_index;
139 }
140}
141
142#[derive(Default)]
144pub struct JobCore {
145 inner: Option<JobCoreInner>,
146}
147
148struct JobCoreInner {
149 repin_rx: watch::Receiver<CoreId>,
150 _guard: JobCoreGuard,
151}
152
153impl JobCore {
154 pub fn start<F, F2, U, T>(self, init: F, unsize: F2) -> JobThread<T>
159 where
160 F: FnOnce() -> U + Send + 'static,
161 F2: FnOnce(&mut U) -> &mut T + Send + 'static,
162 U: 'static,
163 T: ?Sized + 'static,
164 {
165 let (tx, rx) = mpsc::channel::<Box<Job<T>>>(Self::JOB_CHANNEL_LENGTH);
166 let close = Arc::new(NotifyOnce::new());
167
168 let closed = close.clone();
169 let handle = tokio::runtime::Handle::current();
170 std::thread::spawn(move || {
171 let mut data = init();
172 let data = unsize(&mut data);
173 handle.block_on(self.job_loop(rx, closed, data))
174 });
175
176 JobThread { tx, close }
177 }
178
179 const JOB_CHANNEL_LENGTH: usize = 64;
182
183 async fn job_loop<T: ?Sized>(mut self, mut rx: mpsc::Receiver<Box<Job<T>>>, closed: Arc<NotifyOnce>, data: &mut T) {
184 let repin_rx = self.inner.as_mut().map(|inner| &mut inner.repin_rx);
188 let repin_loop = async {
189 if let Some(rx) = repin_rx {
190 rx.mark_changed();
191 while rx.changed().await.is_ok() {
192 core_affinity::set_for_current(*rx.borrow_and_update());
193 }
194 }
195 };
196
197 let job_loop = async {
198 while let Some(job) = rx.recv().await {
199 tokio::task::block_in_place(|| job(data))
203 }
204 };
205
206 tokio::select! {
207 () = super::also_poll(job_loop, repin_loop) => {}
208 () = closed.notified() => {}
211 }
212 }
213}
214
215struct JobCoreGuard {
217 cores: Weak<Mutex<JobCoresInner>>,
218 id: JobThreadId,
219}
220
221impl Drop for JobCoreGuard {
222 fn drop(&mut self) {
223 if let Some(cores) = self.cores.upgrade() {
224 cores.lock().unwrap().deallocate(self.id);
225 }
226 }
227}
228
229pub struct JobThread<T: ?Sized> {
237 tx: mpsc::Sender<Box<Job<T>>>,
238 close: Arc<NotifyOnce>,
239}
240
241impl<T: ?Sized> Clone for JobThread<T> {
242 fn clone(&self) -> Self {
243 Self {
244 tx: self.tx.clone(),
245 close: self.close.clone(),
246 }
247 }
248}
249
250type Job<T> = dyn FnOnce(&mut T) + Send;
251
252impl<T: ?Sized> JobThread<T> {
253 pub async fn run<F, R>(&self, f: F) -> Result<R, JobThreadClosed>
259 where
260 F: FnOnce(&mut T) -> R + Send + 'static,
261 R: Send + 'static,
262 {
263 let (ret_tx, ret_rx) = oneshot::channel();
264
265 let span = tracing::Span::current();
266 self.tx
267 .send(Box::new(move |data| {
268 let _entered = span.entered();
269 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(data)));
270 if let Err(Err(_panic)) = ret_tx.send(result) {
271 tracing::warn!("uncaught panic on threadpool")
272 }
273 }))
274 .await
275 .map_err(|_| JobThreadClosed)?;
276
277 match ret_rx.await {
278 Ok(Ok(ret)) => Ok(ret),
279 Ok(Err(panic)) => std::panic::resume_unwind(panic),
280 Err(_closed) => Err(JobThreadClosed),
281 }
282 }
283
284 pub fn close(&self) {
286 self.close.notify();
287 }
288
289 pub async fn closed(&self) {
291 self.tx.closed().await
292 }
293
294 pub fn downgrade(&self) -> WeakJobThread<T> {
296 let tx = self.tx.downgrade();
297 let close = Arc::downgrade(&self.close);
298 WeakJobThread { tx, close }
299 }
300}
301
302pub struct JobThreadClosed;
303
304pub struct WeakJobThread<T: ?Sized> {
307 tx: mpsc::WeakSender<Box<Job<T>>>,
308 close: Weak<NotifyOnce>,
309}
310
311impl<T: ?Sized> WeakJobThread<T> {
312 pub fn upgrade(&self) -> Option<JobThread<T>> {
313 Option::zip(self.tx.upgrade(), self.close.upgrade()).map(|(tx, close)| JobThread { tx, close })
314 }
315}