Skip to main content

shuttle/
thread.rs

1//! Shuttle's implementation of [`std::thread`].
2
3use crate::runtime::execution::ExecutionState;
4use crate::runtime::task::TaskId;
5use crate::runtime::thread;
6use std::marker::PhantomData;
7use std::panic::Location;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::time::Duration;
10
11pub use std::thread::{panicking, Result};
12
13/// A unique identifier for a running thread
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub struct ThreadId {
16    // TODO Should we add an execution id here, like Loom does?
17    task_id: TaskId,
18}
19
20impl From<ThreadId> for usize {
21    fn from(id: ThreadId) -> usize {
22        id.task_id.into()
23    }
24}
25
26/// A handle to a thread.
27#[derive(Debug, Clone)]
28pub struct Thread {
29    name: Option<String>,
30    id: ThreadId,
31}
32
33impl Thread {
34    /// Gets the thread's name.
35    pub fn name(&self) -> Option<&str> {
36        self.name.as_deref()
37    }
38
39    /// Gets the thread's unique identifier
40    pub fn id(&self) -> ThreadId {
41        self.id
42    }
43
44    /// Atomically makes the handle's token available if it is not already.
45    pub fn unpark(&self) {
46        thread::switch();
47
48        ExecutionState::with(|s| {
49            s.get_mut(self.id.task_id).unpark();
50        });
51    }
52}
53
54/// A scope to spawn scoped threads in.
55///
56/// See [`scope`] for details.
57pub struct Scope<'scope, 'env: 'scope> {
58    num_running_threads: AtomicUsize,
59    main_task: TaskId,
60    scope: PhantomData<&'scope mut &'scope ()>,
61    env: PhantomData<&'env mut &'env ()>,
62}
63
64impl std::fmt::Debug for Scope<'_, '_> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("Scope")
67            .field("num_running_threads", &self.num_running_threads.load(Ordering::Relaxed))
68            .field("main_thread", &self.main_task)
69            .finish_non_exhaustive()
70    }
71}
72
73impl<'scope> Scope<'scope, '_> {
74    /// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
75    ///
76    /// Unlike non-scoped threads, threads spawned with this function may
77    /// borrow non-`'static` data from the outside the scope. See [`scope`] for
78    /// details.
79    #[track_caller]
80    pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
81    where
82        F: FnOnce() -> T + Send + 'scope,
83        T: Send + 'scope,
84    {
85        self.num_running_threads.fetch_add(1, Ordering::Relaxed);
86
87        let finished = std::sync::Arc::new(AtomicBool::new(false));
88        let scope_closure = {
89            let finished = finished.clone();
90            move || {
91                let ret = f();
92
93                if ExecutionState::with(|s| s.exit_current_truncates_execution()) {
94                    thread::switch();
95                }
96
97                finished.store(true, Ordering::Relaxed);
98
99                if self.num_running_threads.fetch_sub(1, Ordering::Relaxed) == 1 {
100                    ExecutionState::with(|s| s.get_mut(self.main_task).unblock());
101                }
102
103                ret
104            }
105        };
106
107        // Note: Scoped threads wrap their inner function in some additional logic (above) to update the `finished` variable on termination.
108        // This logic is expected to run atomically with scoped thread termination, so there *cannot* be a switch after the atomic store.
109        // To avoid violating this invariant, we pass `switch_before_exit = false` (below). Instead, we provide our own context switch on exit
110        // (above) in the `scope_closure` *before* setting `finished` to be `true`.
111        // SAFETY: main task is blocked until all scoped closures complete so all captured references remain valid
112        ScopedJoinHandle {
113            handle: unsafe { spawn_named_unchecked(scope_closure, None, None, false, Location::caller()) },
114            finished,
115            _marker: PhantomData,
116        }
117    }
118}
119
120/// Creates a scope for spawning scoped threads.
121///
122/// The function passed to `scope` will be provided a [`Scope`] object,
123/// through which scoped threads can be [spawned][`Scope::spawn`].
124pub fn scope<'env, F, T>(f: F) -> T
125where
126    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
127{
128    let scope = Scope {
129        num_running_threads: AtomicUsize::new(0),
130        main_task: ExecutionState::with(|s| s.current().id()),
131        env: PhantomData,
132        scope: PhantomData,
133    };
134
135    let ret = f(&scope);
136
137    if scope.num_running_threads.load(Ordering::Relaxed) != 0 {
138        tracing::info!("thread blocked, waiting for completion of scoped threads");
139        ExecutionState::with(|s| s.current_mut().block(false));
140        thread::switch();
141    }
142
143    ret
144}
145
146/// Spawn a new thread, returning a JoinHandle for it.
147///
148/// The join handle can be used (via the `join` method) to block until the child thread has
149/// finished.
150#[track_caller]
151pub fn spawn<F, T>(f: F) -> JoinHandle<T>
152where
153    F: FnOnce() -> T,
154    F: Send + 'static,
155    T: Send + 'static,
156{
157    spawn_named(f, None, None, Location::caller())
158}
159
160fn spawn_named<F, T>(
161    f: F,
162    name: Option<String>,
163    stack_size: Option<usize>,
164    caller: &'static Location<'static>,
165) -> JoinHandle<T>
166where
167    F: FnOnce() -> T,
168    F: Send + 'static,
169    T: Send + 'static,
170{
171    // SAFETY: F is static so all captured references must be `static and therefore
172    // will outlive the spawned continuation
173    unsafe { spawn_named_unchecked(f, name, stack_size, true, caller) }
174}
175
176/// Must ensure all captured references in f are valid for at least as long as the spawned continuation will run
177unsafe fn spawn_named_unchecked<F, T>(
178    f: F,
179    name: Option<String>,
180    stack_size: Option<usize>,
181    switch_before_exit: bool,
182    caller: &'static Location<'static>,
183) -> JoinHandle<T>
184where
185    F: FnOnce() -> T,
186    T: Send,
187{
188    // TODO Check if it's worth avoiding the call to `ExecutionState::config()` if we're going
189    // TODO to use an existing continuation from the pool.
190    let stack_size = stack_size.unwrap_or_else(|| ExecutionState::with(|s| s.config.stack_size));
191    let result = std::sync::Arc::new(std::sync::Mutex::new(None));
192    let task_id = {
193        let result = std::sync::Arc::clone(&result);
194
195        // Allocate `thread_fn` on the heap and assume a `'static` bound.
196        let f: Box<dyn FnOnce()> = Box::new(move || thread_fn(f, switch_before_exit, result));
197        let f: Box<dyn FnOnce() + 'static> = unsafe { std::mem::transmute(f) };
198
199        ExecutionState::spawn_thread(f, stack_size, name.clone(), None, caller)
200    };
201
202    let thread = Thread {
203        id: ThreadId { task_id },
204        name,
205    };
206
207    JoinHandle {
208        task_id,
209        thread,
210        result,
211    }
212}
213
214/// Body of a Shuttle thread, that runs the given closure, handles thread-local destructors, and
215/// stores the result of the thread in the given lock.
216/// The `switch_before_exit` parameter will provide a conditional scheduling point after the inner
217/// function has completed. If this parameter is set to `false`, then this function should be wrapped
218/// in another function which provides a conditional scheduling point before the task exits.
219pub(crate) fn thread_fn<F, T>(
220    f: F,
221    switch_before_exit: bool,
222    result: std::sync::Arc<std::sync::Mutex<Option<Result<T>>>>,
223) where
224    F: FnOnce() -> T,
225{
226    let ret = f();
227
228    if switch_before_exit && ExecutionState::with(|s| s.exit_current_truncates_execution()) {
229        // Exiting the last attached task can truncate the execution. To make the previous
230        // event visible before truncation, we need a scheduling point before exiting.
231        thread::switch();
232    }
233
234    tracing::trace!("thread finished, dropping thread locals");
235
236    // Run thread-local destructors before publishing the result, because
237    // [`JoinHandle::join`] says join "waits for the associated thread to finish", but
238    // destructors must be run on the thread, so it can't be considered "finished" if the
239    // destructors haven't run yet.
240    // See `pop_local` for details on why this loop looks this slightly funky way.
241    while let Some(local) = ExecutionState::with(|state| state.current_mut().pop_local()) {
242        tracing::trace!("dropping thread local {:p}", local);
243        drop(local);
244    }
245
246    tracing::trace!("done dropping thread locals");
247
248    // Publish the result and unblock the waiter. We need to do this now, because once this
249    // closure completes, the Execution will consider this task Finished and invoke the
250    // scheduler.
251    *result.lock().unwrap() = Some(Ok(ret));
252    ExecutionState::with(|state| {
253        if let Some(waiter) = state.current_mut().take_waiter() {
254            state.get_mut(waiter).unblock();
255        }
256    });
257}
258
259/// An owned permission to join on a scoped thread (block on its termination).
260///
261/// See [`Scope::spawn`] for details.
262#[derive(Debug)]
263pub struct ScopedJoinHandle<'scope, T> {
264    handle: JoinHandle<T>,
265    finished: std::sync::Arc<AtomicBool>,
266    _marker: PhantomData<&'scope T>,
267}
268
269impl<T> ScopedJoinHandle<'_, T> {
270    /// Waits for the associated thread to finish.
271    pub fn join(self) -> Result<T> {
272        self.handle.join()
273    }
274
275    /// Extracts a handle to the underlying thread.
276    pub fn thread(&self) -> &Thread {
277        self.handle.thread()
278    }
279
280    /// Checks if the associated thread has finished running its main function.
281    ///
282    /// This might return `true` for a brief moment after the thread's main
283    /// function has returned, but before the thread itself has stopped running.
284    pub fn is_finished(&self) -> bool {
285        self.finished.load(Ordering::Relaxed)
286    }
287}
288
289/// An owned permission to join on a thread (block on its termination).
290#[derive(Debug)]
291pub struct JoinHandle<T> {
292    task_id: TaskId,
293    thread: Thread,
294    result: std::sync::Arc<std::sync::Mutex<Option<Result<T>>>>,
295}
296
297unsafe impl<T> Send for JoinHandle<T> {}
298unsafe impl<T> Sync for JoinHandle<T> {}
299
300impl<T> JoinHandle<T> {
301    /// Waits for the associated thread to finish.
302    pub fn join(self) -> Result<T> {
303        let is_finished = ExecutionState::with(|state| state.get(self.task_id).finished());
304        // If the joinee task is finished then the joiner will not block
305        if is_finished {
306            thread::switch();
307        }
308
309        let should_block = ExecutionState::with(|state| {
310            let me = state.current().id();
311            let target = state.get_mut(self.task_id);
312            if target.set_waiter(me) {
313                state.current_mut().block(false);
314                true
315            } else {
316                false
317            }
318        });
319
320        if should_block {
321            thread::switch();
322        }
323
324        // Waiting thread inherits the clock of the finished thread
325        ExecutionState::with(|state| {
326            let target = state.get_mut(self.task_id);
327            let clock = target.clock.clone();
328            state.update_clock(&clock);
329        });
330
331        self.result.lock().unwrap().take().expect("target should have finished")
332    }
333
334    /// Extracts a handle to the underlying thread.
335    pub fn thread(&self) -> &Thread {
336        &self.thread
337    }
338}
339
340/// Cooperatively gives up a timeslice to the Shuttle scheduler.
341///
342/// Some Shuttle schedulers use this as a hint to deprioritize the current thread in order for other
343/// threads to make progress (e.g., in a spin loop).
344pub fn yield_now() {
345    let waker = ExecutionState::with(|state| state.current().waker());
346    waker.wake_by_ref();
347    ExecutionState::request_yield();
348    thread::switch();
349}
350
351/// Puts the current thread to sleep for at least the specified amount of time.
352// Note that Shuttle does not model time, so this behaves just like a context switch.
353pub fn sleep(_dur: Duration) {
354    thread::switch();
355}
356
357/// Get a handle to the thread that invokes it
358pub fn current() -> Thread {
359    let (task_id, name) = ExecutionState::with(|s| {
360        let me = s.current();
361        (me.id(), me.name())
362    });
363
364    Thread {
365        id: ThreadId { task_id },
366        name,
367    }
368}
369
370/// Blocks unless or until the current thread's token is made available (may wake spuriously).
371pub fn park() {
372    let switch = ExecutionState::with(|s| s.current_mut().park());
373
374    // We only need to context switch if the park token was unavailable. If it was available, then
375    // any execution reachable by context switching here would also be reachable by having not
376    // chosen this thread at the last context switch, because the park state of a thread is only
377    // observable by the thread itself. We also mark it as an explicit yield request by the task,
378    // since otherwise some schedulers might prefer to to reschedule the current task, which in this
379    // context would result in spurious wakeups triggering nearly every time.
380    if switch {
381        ExecutionState::request_yield();
382        thread::switch();
383    }
384}
385
386/// Blocks unless or until the current thread's token is made available or the specified duration
387/// has been reached (may wake spuriously).
388///
389/// Note that Shuttle does not model time, so this behaves identically to `park`. In particular,
390/// Shuttle does not assume that the timeout will ever fire, so if all threads are blocked in a call
391/// to `park_timeout` it will be treated as a deadlock.
392pub fn park_timeout(_dur: Duration) {
393    park();
394}
395
396/// Thread factory, which can be used in order to configure the properties of a new thread.
397#[derive(Debug, Default)]
398pub struct Builder {
399    name: Option<String>,
400    stack_size: Option<usize>,
401}
402
403impl Builder {
404    /// Generates the base configuration for spawning a thread, from which configuration methods can be chained.
405    pub fn new() -> Self {
406        Self {
407            name: None,
408            stack_size: None,
409        }
410    }
411
412    /// Names the thread-to-be. Currently the name is used for identification only in panic messages.
413    pub fn name(mut self, name: String) -> Self {
414        self.name = Some(name);
415        self
416    }
417
418    /// Sets the size of the stack (in bytes) for the new thread.
419    pub fn stack_size(mut self, stack_size: usize) -> Self {
420        self.stack_size = Some(stack_size);
421        self
422    }
423
424    /// Spawns a new thread by taking ownership of the Builder, and returns an `io::Result` to its `JoinHandle`.
425    #[track_caller]
426    pub fn spawn<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
427    where
428        F: FnOnce() -> T,
429        F: Send + 'static,
430        T: Send + 'static,
431    {
432        Ok(spawn_named(f, self.name, self.stack_size, Location::caller()))
433    }
434}
435
436/// A thread local storage key which owns its contents
437// Sadly, the fields of this thing need to be public because function pointers in const fns are
438// unstable, so an explicit instantiation is the only way to construct this struct. User code should
439// not rely on these fields.
440pub struct LocalKey<T: 'static> {
441    #[doc(hidden)]
442    pub init: fn() -> T,
443    #[doc(hidden)]
444    pub _p: PhantomData<T>,
445}
446
447// Safety: `LocalKey` implements thread-local storage; each thread sees its own value of the type T.
448unsafe impl<T> Send for LocalKey<T> {}
449unsafe impl<T> Sync for LocalKey<T> {}
450
451impl<T: 'static> std::fmt::Debug for LocalKey<T> {
452    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453        f.debug_struct("LocalKey").finish_non_exhaustive()
454    }
455}
456
457impl<T: 'static> LocalKey<T> {
458    /// Acquires a reference to the value in this TLS key.
459    ///
460    /// This will lazily initialize the value if this thread has not referenced this key yet.
461    pub fn with<F, R>(&'static self, f: F) -> R
462    where
463        F: FnOnce(&T) -> R,
464    {
465        self.try_with(f).expect(
466            "cannot access a Thread Local Storage value \
467            during or after destruction",
468        )
469    }
470
471    /// Acquires a reference to the value in this TLS key.
472    ///
473    /// This will lazily initialize the value if this thread has not referenced this key yet. If the
474    /// key has been destroyed (which may happen if this is called in a destructor), this function
475    /// will return an AccessError.
476    pub fn try_with<F, R>(&'static self, f: F) -> std::result::Result<R, AccessError>
477    where
478        F: FnOnce(&T) -> R,
479    {
480        let value = self.get().unwrap_or_else(|| {
481            let value = (self.init)();
482
483            ExecutionState::with(move |state| {
484                state.current_mut().init_local(self, value);
485            });
486
487            self.get().unwrap()
488        })?;
489
490        Ok(f(value))
491    }
492
493    fn get(&'static self) -> Option<std::result::Result<&'static T, AccessError>> {
494        // Safety: see the usage below
495        unsafe fn extend_lt<'b, T>(t: &'_ T) -> &'b T {
496            std::mem::transmute(t)
497        }
498
499        ExecutionState::with(|state| {
500            if let Ok(value) = state.current().local(self)? {
501                // Safety: unfortunately the lifetime of a value in our thread-local storage is
502                // bound to the lifetime of `ExecutionState`, which has no visible relation to the
503                // lifetime of the thread we're running on. However, *we* know that the
504                // `ExecutionState` outlives any thread, including the caller, and so it's safe to
505                // give the caller the lifetime it's asking for here.
506                Some(Ok(unsafe { extend_lt(value) }))
507            } else {
508                // Slot has already been destructed
509                Some(Err(AccessError))
510            }
511        })
512    }
513}
514
515/// An error returned by [`LocalKey::try_with`]
516#[derive(Clone, Copy, PartialEq, Eq, Debug)]
517#[non_exhaustive]
518pub struct AccessError;
519
520impl std::fmt::Display for AccessError {
521    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522        std::fmt::Display::fmt("already destroyed", f)
523    }
524}
525
526impl std::error::Error for AccessError {}