1use crate::runtime::execution::ExecutionState;
4use crate::runtime::task::TaskId;
5use crate::runtime::thread;
6use std::marker::PhantomData;
7use std::panic::Location;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::time::Duration;
10
11pub use std::thread::{panicking, Result};
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub struct ThreadId {
16 task_id: TaskId,
18}
19
20impl From<ThreadId> for usize {
21 fn from(id: ThreadId) -> usize {
22 id.task_id.into()
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct Thread {
29 name: Option<String>,
30 id: ThreadId,
31}
32
33impl Thread {
34 pub fn name(&self) -> Option<&str> {
36 self.name.as_deref()
37 }
38
39 pub fn id(&self) -> ThreadId {
41 self.id
42 }
43
44 pub fn unpark(&self) {
46 thread::switch();
47
48 ExecutionState::with(|s| {
49 s.get_mut(self.id.task_id).unpark();
50 });
51 }
52}
53
54pub struct Scope<'scope, 'env: 'scope> {
58 num_running_threads: AtomicUsize,
59 main_task: TaskId,
60 scope: PhantomData<&'scope mut &'scope ()>,
61 env: PhantomData<&'env mut &'env ()>,
62}
63
64impl std::fmt::Debug for Scope<'_, '_> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("Scope")
67 .field("num_running_threads", &self.num_running_threads.load(Ordering::Relaxed))
68 .field("main_thread", &self.main_task)
69 .finish_non_exhaustive()
70 }
71}
72
73impl<'scope> Scope<'scope, '_> {
74 #[track_caller]
80 pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
81 where
82 F: FnOnce() -> T + Send + 'scope,
83 T: Send + 'scope,
84 {
85 self.num_running_threads.fetch_add(1, Ordering::Relaxed);
86
87 let finished = std::sync::Arc::new(AtomicBool::new(false));
88 let scope_closure = {
89 let finished = finished.clone();
90 move || {
91 let ret = f();
92
93 if ExecutionState::with(|s| s.exit_current_truncates_execution()) {
94 thread::switch();
95 }
96
97 finished.store(true, Ordering::Relaxed);
98
99 if self.num_running_threads.fetch_sub(1, Ordering::Relaxed) == 1 {
100 ExecutionState::with(|s| s.get_mut(self.main_task).unblock());
101 }
102
103 ret
104 }
105 };
106
107 ScopedJoinHandle {
113 handle: unsafe { spawn_named_unchecked(scope_closure, None, None, false, Location::caller()) },
114 finished,
115 _marker: PhantomData,
116 }
117 }
118}
119
120pub fn scope<'env, F, T>(f: F) -> T
125where
126 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
127{
128 let scope = Scope {
129 num_running_threads: AtomicUsize::new(0),
130 main_task: ExecutionState::with(|s| s.current().id()),
131 env: PhantomData,
132 scope: PhantomData,
133 };
134
135 let ret = f(&scope);
136
137 if scope.num_running_threads.load(Ordering::Relaxed) != 0 {
138 tracing::info!("thread blocked, waiting for completion of scoped threads");
139 ExecutionState::with(|s| s.current_mut().block(false));
140 thread::switch();
141 }
142
143 ret
144}
145
146#[track_caller]
151pub fn spawn<F, T>(f: F) -> JoinHandle<T>
152where
153 F: FnOnce() -> T,
154 F: Send + 'static,
155 T: Send + 'static,
156{
157 spawn_named(f, None, None, Location::caller())
158}
159
160fn spawn_named<F, T>(
161 f: F,
162 name: Option<String>,
163 stack_size: Option<usize>,
164 caller: &'static Location<'static>,
165) -> JoinHandle<T>
166where
167 F: FnOnce() -> T,
168 F: Send + 'static,
169 T: Send + 'static,
170{
171 unsafe { spawn_named_unchecked(f, name, stack_size, true, caller) }
174}
175
176unsafe fn spawn_named_unchecked<F, T>(
178 f: F,
179 name: Option<String>,
180 stack_size: Option<usize>,
181 switch_before_exit: bool,
182 caller: &'static Location<'static>,
183) -> JoinHandle<T>
184where
185 F: FnOnce() -> T,
186 T: Send,
187{
188 let stack_size = stack_size.unwrap_or_else(|| ExecutionState::with(|s| s.config.stack_size));
191 let result = std::sync::Arc::new(std::sync::Mutex::new(None));
192 let task_id = {
193 let result = std::sync::Arc::clone(&result);
194
195 let f: Box<dyn FnOnce()> = Box::new(move || thread_fn(f, switch_before_exit, result));
197 let f: Box<dyn FnOnce() + 'static> = unsafe { std::mem::transmute(f) };
198
199 ExecutionState::spawn_thread(f, stack_size, name.clone(), None, caller)
200 };
201
202 let thread = Thread {
203 id: ThreadId { task_id },
204 name,
205 };
206
207 JoinHandle {
208 task_id,
209 thread,
210 result,
211 }
212}
213
214pub(crate) fn thread_fn<F, T>(
220 f: F,
221 switch_before_exit: bool,
222 result: std::sync::Arc<std::sync::Mutex<Option<Result<T>>>>,
223) where
224 F: FnOnce() -> T,
225{
226 let ret = f();
227
228 if switch_before_exit && ExecutionState::with(|s| s.exit_current_truncates_execution()) {
229 thread::switch();
232 }
233
234 tracing::trace!("thread finished, dropping thread locals");
235
236 while let Some(local) = ExecutionState::with(|state| state.current_mut().pop_local()) {
242 tracing::trace!("dropping thread local {:p}", local);
243 drop(local);
244 }
245
246 tracing::trace!("done dropping thread locals");
247
248 *result.lock().unwrap() = Some(Ok(ret));
252 ExecutionState::with(|state| {
253 if let Some(waiter) = state.current_mut().take_waiter() {
254 state.get_mut(waiter).unblock();
255 }
256 });
257}
258
259#[derive(Debug)]
263pub struct ScopedJoinHandle<'scope, T> {
264 handle: JoinHandle<T>,
265 finished: std::sync::Arc<AtomicBool>,
266 _marker: PhantomData<&'scope T>,
267}
268
269impl<T> ScopedJoinHandle<'_, T> {
270 pub fn join(self) -> Result<T> {
272 self.handle.join()
273 }
274
275 pub fn thread(&self) -> &Thread {
277 self.handle.thread()
278 }
279
280 pub fn is_finished(&self) -> bool {
285 self.finished.load(Ordering::Relaxed)
286 }
287}
288
289#[derive(Debug)]
291pub struct JoinHandle<T> {
292 task_id: TaskId,
293 thread: Thread,
294 result: std::sync::Arc<std::sync::Mutex<Option<Result<T>>>>,
295}
296
297unsafe impl<T> Send for JoinHandle<T> {}
298unsafe impl<T> Sync for JoinHandle<T> {}
299
300impl<T> JoinHandle<T> {
301 pub fn join(self) -> Result<T> {
303 let is_finished = ExecutionState::with(|state| state.get(self.task_id).finished());
304 if is_finished {
306 thread::switch();
307 }
308
309 let should_block = ExecutionState::with(|state| {
310 let me = state.current().id();
311 let target = state.get_mut(self.task_id);
312 if target.set_waiter(me) {
313 state.current_mut().block(false);
314 true
315 } else {
316 false
317 }
318 });
319
320 if should_block {
321 thread::switch();
322 }
323
324 ExecutionState::with(|state| {
326 let target = state.get_mut(self.task_id);
327 let clock = target.clock.clone();
328 state.update_clock(&clock);
329 });
330
331 self.result.lock().unwrap().take().expect("target should have finished")
332 }
333
334 pub fn thread(&self) -> &Thread {
336 &self.thread
337 }
338}
339
340pub fn yield_now() {
345 let waker = ExecutionState::with(|state| state.current().waker());
346 waker.wake_by_ref();
347 ExecutionState::request_yield();
348 thread::switch();
349}
350
351pub fn sleep(_dur: Duration) {
354 thread::switch();
355}
356
357pub fn current() -> Thread {
359 let (task_id, name) = ExecutionState::with(|s| {
360 let me = s.current();
361 (me.id(), me.name())
362 });
363
364 Thread {
365 id: ThreadId { task_id },
366 name,
367 }
368}
369
370pub fn park() {
372 let switch = ExecutionState::with(|s| s.current_mut().park());
373
374 if switch {
381 ExecutionState::request_yield();
382 thread::switch();
383 }
384}
385
386pub fn park_timeout(_dur: Duration) {
393 park();
394}
395
396#[derive(Debug, Default)]
398pub struct Builder {
399 name: Option<String>,
400 stack_size: Option<usize>,
401}
402
403impl Builder {
404 pub fn new() -> Self {
406 Self {
407 name: None,
408 stack_size: None,
409 }
410 }
411
412 pub fn name(mut self, name: String) -> Self {
414 self.name = Some(name);
415 self
416 }
417
418 pub fn stack_size(mut self, stack_size: usize) -> Self {
420 self.stack_size = Some(stack_size);
421 self
422 }
423
424 #[track_caller]
426 pub fn spawn<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
427 where
428 F: FnOnce() -> T,
429 F: Send + 'static,
430 T: Send + 'static,
431 {
432 Ok(spawn_named(f, self.name, self.stack_size, Location::caller()))
433 }
434}
435
436pub struct LocalKey<T: 'static> {
441 #[doc(hidden)]
442 pub init: fn() -> T,
443 #[doc(hidden)]
444 pub _p: PhantomData<T>,
445}
446
447unsafe impl<T> Send for LocalKey<T> {}
449unsafe impl<T> Sync for LocalKey<T> {}
450
451impl<T: 'static> std::fmt::Debug for LocalKey<T> {
452 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453 f.debug_struct("LocalKey").finish_non_exhaustive()
454 }
455}
456
457impl<T: 'static> LocalKey<T> {
458 pub fn with<F, R>(&'static self, f: F) -> R
462 where
463 F: FnOnce(&T) -> R,
464 {
465 self.try_with(f).expect(
466 "cannot access a Thread Local Storage value \
467 during or after destruction",
468 )
469 }
470
471 pub fn try_with<F, R>(&'static self, f: F) -> std::result::Result<R, AccessError>
477 where
478 F: FnOnce(&T) -> R,
479 {
480 let value = self.get().unwrap_or_else(|| {
481 let value = (self.init)();
482
483 ExecutionState::with(move |state| {
484 state.current_mut().init_local(self, value);
485 });
486
487 self.get().unwrap()
488 })?;
489
490 Ok(f(value))
491 }
492
493 fn get(&'static self) -> Option<std::result::Result<&'static T, AccessError>> {
494 unsafe fn extend_lt<'b, T>(t: &'_ T) -> &'b T {
496 std::mem::transmute(t)
497 }
498
499 ExecutionState::with(|state| {
500 if let Ok(value) = state.current().local(self)? {
501 Some(Ok(unsafe { extend_lt(value) }))
507 } else {
508 Some(Err(AccessError))
510 }
511 })
512 }
513}
514
515#[derive(Clone, Copy, PartialEq, Eq, Debug)]
517#[non_exhaustive]
518pub struct AccessError;
519
520impl std::fmt::Display for AccessError {
521 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522 std::fmt::Display::fmt("already destroyed", f)
523 }
524}
525
526impl std::error::Error for AccessError {}