vexide_async/local.rs
1//! Task-local storage
2//!
3//! Task-local storage is a way to create global variables specific to the current task that live
4//! for the entirety of the task's lifetime, almost like statics. Since they are local to the task,
5//! they implement [`Send`] and [`Sync`], regardless of what the underlying data does or does not
6//! implement.
7//!
8//! Task-locals can be declared using the [`task_local`] macro, which creates a [`LocalKey`] with
9//! the same name that can be used to access the local.
10
11use std::{
12 any::Any,
13 boxed::Box,
14 cell::{BorrowError, BorrowMutError, Cell, RefCell, UnsafeCell},
15 collections::btree_map::BTreeMap,
16 ptr,
17 rc::Rc,
18 sync::{
19 LazyLock,
20 atomic::{AtomicU32, Ordering},
21 },
22};
23
24use crate::executor::EXECUTOR;
25
26/// A variable stored in task-local storage.
27///
28/// # Usage
29///
30/// The primary mode of accessing this is through the [`LocalKey::with`] method. For
31/// [`LocalKey<RefCell<T>>`] and [`LocalKey<Cell<T>>`], additional convenience methods are added
32/// that mirror the underlying [`RefCell<T>`] or [`Cell<T>`]'s methods.
33///
34/// # Examples
35///
36/// ```
37/// use std::cell::{Cell, RefCell};
38///
39/// use vexide::prelude::*;
40///
41/// task_local! {
42/// static PHI: f64 = 1.61803;
43/// static COUNTER: Cell<u32> = Cell::new(0);
44/// static NAMES: RefCell<Vec<String>> = RefCell::new(Vec::new());
45/// }
46///
47/// #[vexide::main]
48/// async fn main(_peripherals: Peripherals) {
49/// // LocalKey::with accepts a function and applies it to a reference, returning whatever value
50/// // the function returned
51/// let double_phi = PHI.with(|&phi| phi * 2.0);
52/// assert_eq!(double_phi, 1.61803 * 2.0);
53///
54/// // We can use interior mutability
55/// COUNTER.set(1);
56/// assert_eq!(COUNTER.get(), 1);
57///
58/// NAMES.with_borrow_mut(|names| names.push(String::from("Johnny")));
59/// NAMES.with_borrow(|names| assert_eq!(names.len(), 1));
60///
61/// // Creating another task
62/// spawn(async {
63/// // The locals of the previous task are completely different.
64/// assert_eq!(COUNTER.get(), 0);
65/// NAMES.with_borrow(|names| assert_eq!(names.len(), 0));
66/// })
67/// .await;
68/// }
69/// ```
70#[derive(Debug)]
71pub struct LocalKey<T: 'static> {
72 init: fn() -> T,
73 key: LazyLock<u32>,
74}
75
76unsafe impl<T> Sync for LocalKey<T> {}
77unsafe impl<T> Send for LocalKey<T> {}
78
79/// Declares task-local variables in [`LocalKey`]s of the same names.
80///
81/// # Examples
82///
83/// ```
84/// use std::cell::{Cell, RefCell};
85///
86/// use vexide::prelude::*;
87///
88/// task_local! {
89/// static PHI: f64 = 1.61803;
90/// static COUNTER: Cell<u32> = Cell::new(0);
91/// static NAMES: RefCell<Vec<String>> = RefCell::new(Vec::new());
92/// }
93/// ```
94#[macro_export]
95macro_rules! task_local {
96 {
97 $(#[$attr:meta])*
98 $vis:vis static $name:ident: $type:ty = $init:expr;
99 } => {
100 $(#[$attr])*
101 // publicly reexported in crate::task
102 $vis static $name: $crate::task::LocalKey<$type> = {
103 fn init() -> $type { $init }
104 $crate::task::LocalKey::new(init)
105 };
106 };
107
108 {
109 $(#[$attr:meta])*
110 $vis:vis static $name:ident: $type:ty = $init:expr;
111 $($rest:tt)*
112 } => {
113 $crate::task_local!($vis static $name: $type = $init;);
114 $crate::task_local!($($rest)*);
115 }
116}
117pub use task_local;
118
119impl<T: 'static> LocalKey<T> {
120 #[doc(hidden)]
121 pub const fn new(init: fn() -> T) -> Self {
122 static LOCAL_KEY_COUNTER: AtomicU32 = AtomicU32::new(0);
123
124 Self {
125 init,
126 key: LazyLock::new(|| LOCAL_KEY_COUNTER.fetch_add(1, Ordering::Relaxed)),
127 }
128 }
129
130 /// Obtains a reference to the local and applies it to the function `f`, returning whatever `f`
131 /// returned.
132 ///
133 /// # Examples
134 ///
135 /// ```
136 /// use vexide::task::task_local;
137 ///
138 /// task_local! {
139 /// static PHI: f64 = 1.61803;
140 /// }
141 ///
142 /// let double_phi = PHI.with(|&phi| phi * 2.0);
143 /// assert_eq!(double_phi, 1.61803 * 2.0);
144 /// ```
145 pub fn with<F, R>(&'static self, f: F) -> R
146 where
147 F: FnOnce(&T) -> R,
148 {
149 TaskLocalStorage::with_current(|storage| {
150 // SAFETY: get_or_init is always called with the same return type, T
151 // Also, `key` is unique for this local key.
152 f(unsafe { storage.get_or_init(*self.key, self.init) })
153 })
154 }
155}
156
157impl<T: 'static> LocalKey<Cell<T>> {
158 /// Returns a copy of the contained value.
159 pub fn get(&'static self) -> T
160 where
161 T: Copy,
162 {
163 self.with(Cell::get)
164 }
165
166 /// Sets the contained value.
167 pub fn set(&'static self, value: T) {
168 self.with(|cell| cell.set(value));
169 }
170
171 /// Takes the value of contained value, leaving [`Default::default()`] in its place.
172 pub fn take(&'static self) -> T
173 where
174 T: Default,
175 {
176 self.with(Cell::take)
177 }
178
179 /// Replaces the contained value with `value`, returning the old contained value.
180 pub fn replace(&'static self, value: T) -> T {
181 self.with(|cell| cell.replace(value))
182 }
183}
184
185impl<T: 'static> LocalKey<RefCell<T>> {
186 /// Immutably borrows from the [`RefCell`] and applies the obtained reference to `f`.
187 ///
188 /// # Panics
189 ///
190 /// Panics if the value is currently mutably borrowed. For a non-panicking variant, use
191 /// [`LocalKey::try_with_borrow`].
192 pub fn with_borrow<F, R>(&'static self, f: F) -> R
193 where
194 F: FnOnce(&T) -> R,
195 {
196 self.with(|cell| f(&cell.borrow()))
197 }
198
199 /// Mutably borrows from the [`RefCell`] and applies the obtained reference to `f`.
200 ///
201 /// # Panics
202 ///
203 /// Panics if the value is currently borrowed. For a non-panicking variant, use
204 /// [`LocalKey::try_with_borrow_mut`].
205 pub fn with_borrow_mut<F, R>(&'static self, f: F) -> R
206 where
207 F: FnOnce(&mut T) -> R,
208 {
209 self.with(|cell| f(&mut cell.borrow_mut()))
210 }
211
212 /// Tries to immutably borrow the contained value, returning an error if it is currently
213 /// mutably borrowed, and applies the obtained reference to `f`.
214 ///
215 /// This is the non-panicking variant of [`LocalKey::with_borrow`].
216 ///
217 /// # Errors
218 ///
219 /// Returns [`BorrowError`] if the contained value is currently mutably borrowed.
220 pub fn try_with_borrow<F, R>(&'static self, f: F) -> Result<R, BorrowError>
221 where
222 F: FnOnce(&T) -> R,
223 {
224 self.with(|cell| cell.try_borrow().map(|value| f(&value)))
225 }
226
227 /// Tries to mutably borrow the contained value, returning an error if it is currently borrowed,
228 /// and applies the obtained reference to `f`.
229 ///
230 /// This is the non-panicking variant of [`LocalKey::with_borrow_mut`].
231 ///
232 /// # Errors
233 ///
234 /// Returns [`BorrowMutError`] if the contained value is currently borrowed.
235 pub fn try_with_borrow_mut<F, R>(&'static self, f: F) -> Result<R, BorrowMutError>
236 where
237 F: FnOnce(&T) -> R,
238 {
239 self.with(|cell| cell.try_borrow_mut().map(|value| f(&value)))
240 }
241
242 /// Sets the contained value.
243 ///
244 /// # Panics
245 ///
246 /// Panics if the value is currently borrowed.
247 pub fn set(&'static self, value: T) {
248 self.with_borrow_mut(|refmut| *refmut = value);
249 }
250
251 /// Takes the contained value, leaving [`Default::default()`] in its place.
252 ///
253 /// # Panics
254 ///
255 /// Panics if the value is currently borrowed.
256 pub fn take(&'static self) -> T
257 where
258 T: Default,
259 {
260 self.with(RefCell::take)
261 }
262
263 /// Replaces the contained value with `value`, returning the old contained value.
264 ///
265 /// # Panics
266 ///
267 /// Panics if the value is currently borrowed.
268 pub fn replace(&'static self, value: T) -> T {
269 self.with(|cell| cell.replace(value))
270 }
271}
272
273struct ErasedTaskLocal {
274 value: Box<dyn Any>,
275}
276
277impl ErasedTaskLocal {
278 #[doc(hidden)]
279 fn new<T: 'static>(value: T) -> Self {
280 Self {
281 value: Box::new(value),
282 }
283 }
284
285 /// # Safety
286 ///
287 /// Caller guarantees T is the right type
288 unsafe fn get<T: 'static>(&self) -> &T {
289 if cfg!(debug_assertions) {
290 self.value.downcast_ref().unwrap()
291 } else {
292 unsafe { &*ptr::from_ref(&*self.value).cast() }
293 }
294 }
295}
296
297// Fallback TLS block for when reading from outside of a task.
298thread_local! {
299 static FALLBACK_TLS: TaskLocalStorage = const { TaskLocalStorage::new() };
300}
301
302#[derive(Debug)]
303pub(crate) struct TaskLocalStorage {
304 locals: UnsafeCell<BTreeMap<u32, ErasedTaskLocal>>,
305}
306
307impl TaskLocalStorage {
308 pub(crate) const fn new() -> Self {
309 Self {
310 locals: UnsafeCell::new(BTreeMap::new()),
311 }
312 }
313
314 pub(crate) fn scope(value: Rc<TaskLocalStorage>, scope: impl FnOnce()) {
315 let outer_scope = EXECUTOR.with(|ex| (*ex.tls.borrow_mut()).replace(value));
316
317 scope();
318
319 EXECUTOR.with(|ex| {
320 *ex.tls.borrow_mut() = outer_scope;
321 });
322 }
323
324 /// Gets the Task Local Storage data for the current task.
325 pub(crate) fn with_current<F, R>(f: F) -> R
326 where
327 F: FnOnce(&Self) -> R,
328 {
329 EXECUTOR.with(|ex| {
330 if let Some(tls) = ex.tls.borrow().as_ref() {
331 f(tls)
332 } else {
333 FALLBACK_TLS.with(|fallback| f(fallback))
334 }
335 })
336 }
337
338 /// Gets a reference to the Task Local Storage item identified by the given key.
339 ///
340 /// It is invalid to call this function multiple times with the same key and a different `T`.
341 pub(crate) unsafe fn get_or_init<T: 'static>(&self, key: u32, init: fn() -> T) -> &T {
342 // We need to be careful to not make mutable references to values already inserted into the
343 // map because the current task might have existing shared references to that data. It's
344 // okay if the pointer (ErasedTaskLocal) gets moved around, we just can't assert invalid
345 // exclusive access over its contents.
346
347 let locals = self.locals.get();
348 unsafe {
349 // init() could initialize another task local recursively, so we need to be sure there's
350 // no mutable reference to `self.locals` when we call it. We can't use the
351 // entry API because of this.
352
353 #[expect(
354 clippy::map_entry,
355 reason = "cannot hold mutable reference over init() call"
356 )]
357 if !(*locals).contains_key(&key) {
358 let new_value = ErasedTaskLocal::new(init());
359 (*locals).insert(key, new_value);
360 }
361
362 (*locals).get(&key).unwrap().get()
363 }
364 }
365}