state_department/
async.rs

1use crate::{
2    lazy::LazyState,
3    manager::{StateManager, StateRef},
4    StateRegistry, INITIALIZED,
5};
6use async_once_cell::OnceCell;
7use std::{
8    any::{Any, TypeId},
9    cell::UnsafeCell,
10    future::Future,
11    marker::PhantomData,
12    pin::Pin,
13};
14
15/// A type bound for types that can be used in asynchronous-only contexts.
16pub struct AsyncOnlyContext;
17
18impl StateManager<AsyncOnlyContext> {
19    /// Creates a new [`StateManager`] appropriate for asynchronous-only
20    /// contexts.
21    ///
22    /// # Example
23    ///
24    /// ```rust
25    /// use state_department::AsyncState;
26    ///
27    /// static STATE: AsyncState = AsyncState::new();
28    /// ```
29    pub const fn new() -> Self {
30        Self::new_()
31    }
32
33    /// Returns a reference to a value stored in the state.
34    ///
35    /// # Panics
36    ///
37    /// * If the state has not yet been initialized.
38    /// * If the state has been dropped.
39    /// * If the state does not contain a value of the requested type.
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// use state_department::AsyncState;
45    ///
46    /// static STATE: AsyncState = AsyncState::new();
47    ///
48    /// struct Foo {
49    ///     bar: i32
50    /// }
51    ///
52    /// # tokio_test::block_on(async {
53    /// let _lifetime = STATE.init_async(async |state| {
54    ///     state.insert(Foo { bar: 42 });
55    /// })
56    /// .await;
57    ///
58    /// let foo = STATE.get::<Foo>().await;
59    ///
60    /// assert_eq!(foo.bar, 42);
61    /// # });
62    /// ```
63    #[must_use]
64    pub async fn get<T: Send + Sync + 'static>(&self) -> StateRef<'_, T, AsyncOnlyContext> {
65        match self.try_get().await {
66            Some(v) => v,
67            None => panic!("State for {:?} not found", std::any::type_name::<T>()),
68        }
69    }
70
71    /// Attempts to get a reference to a value stored in the state.
72    ///
73    /// This function does not panic.
74    ///
75    /// # Example
76    ///
77    /// ```rust
78    /// use state_department::AsyncState;
79    ///
80    /// static STATE: AsyncState = AsyncState::new();
81    ///
82    /// struct Foo {
83    ///     bar: i32
84    /// }
85    ///
86    /// # tokio_test::block_on(async {
87    /// let _lifetime = STATE.init_async(async |state| {
88    ///     state.insert(Foo { bar: 42 });
89    /// })
90    /// .await;
91    ///
92    /// let foo = STATE.try_get::<Foo>().await;
93    ///
94    /// assert_eq!(foo.unwrap().bar, 42);
95    ///
96    /// let str = STATE.try_get::<String>().await;
97    ///
98    /// assert!(str.is_none());
99    /// # });
100    /// ```
101    #[must_use]
102    pub async fn try_get<T: Send + Sync + 'static>(
103        &self,
104    ) -> Option<StateRef<'_, T, AsyncOnlyContext>> {
105        if self.initialized.load(std::sync::atomic::Ordering::Acquire) != INITIALIZED {
106            return None;
107        }
108
109        let state = unsafe { (*self.state.get()).assume_init_ref() }.upgrade()?;
110
111        if let Some(value) = state.get(&TypeId::of::<T>()) {
112            let value = value.as_ref() as &dyn Any;
113
114            if let Some(value) = value.downcast_ref::<T>() {
115                return Some(StateRef {
116                    value,
117                    _state: state,
118                    _phantom: PhantomData,
119                });
120            }
121
122            if let Some(value) = value.downcast_ref::<LazyState<T>>() {
123                return Some(StateRef {
124                    value: value.get(),
125                    _state: state,
126                    _phantom: PhantomData,
127                });
128            }
129
130            if let Some(value) = value.downcast_ref::<AsyncLazyState<T>>() {
131                return Some(StateRef {
132                    value: value.get().await,
133                    _state: state,
134                    _phantom: PhantomData,
135                });
136            }
137        }
138
139        None
140    }
141}
142impl Default for StateManager<AsyncOnlyContext> {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148impl StateRegistry<AsyncOnlyContext> {
149    /// Inserts a value into the state that is initialized lazily (when first
150    /// accessed) in an asynchronous context.
151    ///
152    /// # Example
153    ///
154    /// ```rust
155    /// use state_department::AsyncState;
156    ///
157    /// static STATE: AsyncState = AsyncState::new();
158    ///
159    /// struct Foo {
160    ///     bar: i32
161    /// }
162    ///
163    /// # tokio_test::block_on(async {
164    /// let _lifetime = STATE.init_async(async |state| {
165    ///     state.insert_async_lazy(async {
166    ///         println!("Initializing Foo...");
167    ///
168    ///         // Something expensive or long-running...
169    ///
170    ///         Foo { bar: 42 }
171    ///     });
172    /// })
173    /// .await;
174    ///
175    /// let foo = STATE.get::<Foo>().await;
176    /// # });
177    ///
178    /// // > Initializing Foo...
179    /// ```
180    pub fn insert_async_lazy<T, F>(&mut self, init: F)
181    where
182        T: Send + Sync + 'static,
183        F: Future<Output = T> + Send + 'static,
184    {
185        self.insert_(
186            TypeId::of::<T>(),
187            Box::new(AsyncLazyState {
188                init: UnsafeCell::new(Some(Box::pin(init))),
189                once: OnceCell::new(),
190            }),
191        );
192    }
193}
194
195struct AsyncLazyState<T: Send + Sync + 'static> {
196    init: UnsafeCell<Option<Pin<Box<dyn Future<Output = T> + Send + 'static>>>>,
197    once: OnceCell<T>,
198}
199impl<T: Send + Sync + 'static> AsyncLazyState<T> {
200    async fn get(&self) -> &T {
201        self.once
202            .get_or_init(async {
203                let init = unsafe { (*self.init.get()).take() }.unwrap();
204                init.await
205            })
206            .await
207    }
208}
209unsafe impl<T: Send + Sync + 'static> Send for AsyncLazyState<T> {}
210unsafe impl<T: Send + Sync + 'static> Sync for AsyncLazyState<T> {}
211
212#[test]
213fn test_state() {
214    use std::sync::atomic::AtomicU8;
215
216    tokio_test::block_on(async {
217        let state = StateManager::<AsyncOnlyContext>::new();
218
219        struct Foo {
220            bar: AtomicU8,
221        }
222
223        struct Baz {
224            qux: i32,
225        }
226
227        let lifetime = state.init(|state| {
228            state.insert(Foo {
229                bar: AtomicU8::new(42),
230            });
231
232            state.insert(Baz { qux: 24 });
233        });
234
235        {
236            let foo: StateRef<'_, Foo, AsyncOnlyContext> = state.get::<Foo>().await;
237
238            assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Relaxed), 42);
239
240            foo.bar.store(24, std::sync::atomic::Ordering::Release);
241        }
242
243        {
244            let foo = state.get::<Foo>().await;
245
246            assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Acquire), 24);
247        }
248
249        {
250            let baz = state.get::<Baz>().await;
251
252            assert_eq!(baz.qux, 24);
253        }
254
255        lifetime.try_drop().unwrap();
256    });
257}
258
259#[test]
260fn test_state_drop_with_ref() {
261    tokio_test::block_on(async {
262        let state = StateManager::<AsyncOnlyContext>::new();
263
264        struct Foo;
265
266        let lifetime = state.init(|state| {
267            state.insert(Foo);
268        });
269
270        let _foo = state.get::<Foo>().await;
271
272        let _ = lifetime.try_drop().unwrap_err();
273    });
274}
275
276#[test]
277fn test_state_use_after_lifetime_drop() {
278    tokio_test::block_on(async {
279        let state = StateManager::<AsyncOnlyContext>::new();
280
281        struct Foo;
282
283        let lifetime = state.init(|state| {
284            state.insert(Foo);
285        });
286
287        lifetime.try_drop().unwrap();
288
289        assert!(state.try_get::<Foo>().await.is_none());
290    });
291}
292
293#[test]
294fn test_state_drop_without_lifetime() {
295    use std::sync::atomic::AtomicU8;
296
297    static DROPPED: AtomicU8 = AtomicU8::new(0);
298
299    tokio_test::block_on(async {
300        let state = StateManager::<AsyncOnlyContext>::new();
301
302        struct Foo;
303        impl Drop for Foo {
304            fn drop(&mut self) {
305                DROPPED.store(1, std::sync::atomic::Ordering::Release);
306            }
307        }
308
309        let lifetime = state.init(|state| {
310            state.insert(Foo);
311        });
312
313        let foo = state.get::<Foo>().await;
314
315        assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
316
317        drop(lifetime);
318
319        assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
320
321        drop(foo);
322
323        assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
324
325        drop(state);
326
327        assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
328    });
329}
330
331#[test]
332fn test_lazy_initialization() {
333    use std::sync::atomic::AtomicU8;
334
335    static FOO_INITIALIZED: AtomicU8 = AtomicU8::new(0);
336
337    tokio_test::block_on(async {
338        let state = StateManager::<AsyncOnlyContext>::new();
339
340        struct Foo {
341            bar: i32,
342        }
343
344        let _lifetime = state.init(|state| {
345            state.insert_async_lazy(async {
346                FOO_INITIALIZED.store(1, std::sync::atomic::Ordering::Release);
347
348                Foo { bar: 42 }
349            });
350        });
351
352        assert_eq!(
353            FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
354            0
355        );
356
357        let foo = state.get::<Foo>().await;
358
359        assert_eq!(
360            FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
361            1
362        );
363
364        assert_eq!(foo.bar, 42);
365    });
366}
367
368#[test]
369fn test_sync_lazy_initialization_from_async() {
370    use std::sync::atomic::AtomicU8;
371
372    static FOO_INITIALIZED: AtomicU8 = AtomicU8::new(0);
373
374    tokio_test::block_on(async {
375        let state = StateManager::<AsyncOnlyContext>::new();
376
377        struct Foo {
378            bar: i32,
379        }
380
381        let _lifetime = state.init(|state| {
382            state.insert_lazy(|| {
383                FOO_INITIALIZED.store(1, std::sync::atomic::Ordering::Release);
384
385                Foo { bar: 42 }
386            });
387        });
388
389        assert_eq!(
390            FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
391            0
392        );
393
394        let foo = state.get::<Foo>().await;
395
396        assert_eq!(
397            FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
398            1
399        );
400
401        assert_eq!(foo.bar, 42);
402    });
403}
404
405#[test]
406fn test_state_across_threads() {
407    use std::sync::atomic::AtomicU8;
408
409    static STATE: StateManager<AsyncOnlyContext> = StateManager::<AsyncOnlyContext>::new();
410
411    tokio_test::block_on(async {
412        struct Foo {
413            bar: AtomicU8,
414        }
415
416        let _lifetime = STATE.init(|state| {
417            state.insert(Foo {
418                bar: AtomicU8::new(0),
419            });
420        });
421
422        let thread_count = 10;
423
424        let barrier = std::sync::Arc::new(tokio::sync::Barrier::new(thread_count));
425
426        let threads = (0..thread_count)
427            .map(|_| {
428                let barrier_ref = barrier.clone();
429
430                tokio::spawn(async move {
431                    barrier_ref.wait().await;
432
433                    STATE
434                        .get::<Foo>()
435                        .await
436                        .bar
437                        .fetch_add(1, std::sync::atomic::Ordering::Release);
438                })
439            })
440            .collect::<Vec<_>>();
441
442        for thread in threads {
443            thread.await.unwrap();
444        }
445
446        assert_eq!(
447            STATE
448                .get::<Foo>()
449                .await
450                .bar
451                .load(std::sync::atomic::Ordering::Acquire),
452            thread_count as u8
453        );
454    });
455}
456
457#[test]
458#[should_panic = "State for \"()\" not found"]
459fn test_state_get_inside_init() {
460    tokio_test::block_on(async {
461        let state = StateManager::<AsyncOnlyContext>::new();
462        let _ = state
463            .init_async(async |r| {
464                r.insert(());
465
466                let _ = state.get::<()>().await;
467            })
468            .await;
469    });
470}
471
472#[test]
473#[should_panic = "State already initialized or is currently initializing"]
474fn test_state_init_inside_init() {
475    tokio_test::block_on(async {
476        let state = StateManager::<AsyncOnlyContext>::new();
477        let _ = state.init(|_| {
478            let _ = state.init(|_| {});
479        });
480    });
481}
482
483#[test]
484#[should_panic = "State already initialized or is currently initializing"]
485fn test_state_already_initialized() {
486    tokio_test::block_on(async {
487        let state = StateManager::<AsyncOnlyContext>::new();
488        let _ = state.init(|_| {});
489        let _ = state.init(|_| {});
490    });
491}