state_department/
manager.rs

1use crate::{INITIALIZED, INITIALIZING, UNINITIALIZED};
2use std::{
3    any::{Any, TypeId},
4    cell::UnsafeCell,
5    collections::HashMap,
6    marker::PhantomData,
7    mem::MaybeUninit,
8    ops::Deref,
9    sync::{atomic::AtomicU8, Arc, Weak},
10};
11
12/// The state manager.
13pub struct StateManager<R> {
14    pub(crate) state: UnsafeCell<MaybeUninit<Weak<StateRegistry<R>>>>,
15    pub(crate) initialized: AtomicU8,
16    pub(crate) _phantom: PhantomData<R>,
17}
18impl<R> StateManager<R> {
19    pub(crate) const fn new_() -> Self {
20        Self {
21            state: UnsafeCell::new(MaybeUninit::uninit()),
22            initialized: AtomicU8::new(UNINITIALIZED),
23            _phantom: PhantomData,
24        }
25    }
26
27    /// Initializes the [`StateManager`], giving you an entrypoint for
28    /// populating the state with your desired values.
29    ///
30    /// # Panics
31    ///
32    /// * If the state has already been initialized.
33    ///
34    /// # Example
35    ///
36    /// ```rust
37    /// use state_department::State;
38    ///
39    /// static STATE: State = State::new();
40    ///
41    /// struct Foo {
42    ///     bar: i32
43    /// }
44    ///
45    /// let _lifetime = STATE.init(|state| {
46    ///     state.insert(Foo { bar: 42 });
47    /// });
48    ///
49    /// assert_eq!(STATE.get::<Foo>().bar, 42);
50    /// ```
51    pub fn init<F>(&self, init: F) -> StateLifetime<R>
52    where
53        F: FnOnce(&mut StateRegistry<R>),
54    {
55        match self.try_init(|state| {
56            init(state);
57
58            Ok::<_, ()>(())
59        }) {
60            Some(Ok(result)) => result,
61            Some(Err(_)) => unreachable!(),
62            None => panic!("State already initialized or is currently initializing"),
63        }
64    }
65
66    /// Initializes the [`StateManager`], giving you a **fallible** entrypoint
67    /// for populating the state with your desired values.
68    ///
69    /// Returns `None` if the [`StateManager`] is already initialized.
70    ///
71    /// # Example
72    ///
73    /// ```rust
74    /// use state_department::State;
75    ///
76    /// static STATE: State = State::new();
77    ///
78    /// let lifetime = STATE.try_init(|state| {
79    ///     Err("oh no!")
80    /// });
81    ///
82    /// assert!(lifetime.unwrap().is_err());
83    /// ```
84    pub fn try_init<E, F>(&self, init: F) -> Option<Result<StateLifetime<R>, E>>
85    where
86        F: FnOnce(&mut StateRegistry<R>) -> Result<(), E>,
87    {
88        if !self.try_start_init() {
89            return None;
90        }
91
92        let mut state = StateRegistry::default();
93
94        let result = init(&mut state);
95        let result = result.map(|_| self.finish_init(state));
96
97        Some(result)
98    }
99
100    /// Initializes the [`StateManager`] asynchronously, giving you an
101    /// entrypoint for populating the state with your desired values in an
102    /// asynchronous context.
103    ///
104    /// # Panics
105    ///
106    /// * If the state has already been initialized.
107    ///
108    /// # Example
109    ///
110    /// ```rust
111    /// use state_department::State;
112    ///
113    /// static STATE: State = State::new();
114    ///
115    /// struct Foo {
116    ///     bar: i32
117    /// }
118    ///
119    /// # tokio_test::block_on(async {
120    /// let _lifetime = STATE.init_async(async |state| {
121    ///     state.insert(Foo { bar: 42 });
122    /// })
123    /// .await;
124    ///
125    /// assert_eq!(STATE.get::<Foo>().bar, 42);
126    /// # });
127    /// ```
128    pub async fn init_async<F>(&self, init: F) -> StateLifetime<R>
129    where
130        F: AsyncFnOnce(&mut StateRegistry<R>),
131    {
132        match self
133            .try_init_async(async |state| {
134                init(state).await;
135
136                Ok::<_, ()>(())
137            })
138            .await
139        {
140            Some(Ok(result)) => result,
141            Some(Err(_)) => unreachable!(),
142            None => panic!("State already initialized or is currently initializing"),
143        }
144    }
145
146    /// Initializes the [`StateManager`] asynchronously, giving you a
147    /// **fallible** entrypoint for populating the state with your desired
148    /// values in an asynchronous context.
149    ///
150    /// Returns `None` if the [`StateManager`] is already initialized.
151    ///
152    /// # Example
153    ///
154    /// ```rust
155    /// use state_department::State;
156    ///
157    /// static STATE: State = State::new();
158    ///
159    /// # tokio_test::block_on(async {
160    /// let lifetime = STATE.try_init_async(async |state| {
161    /// #   state.insert(());
162    ///     Err("oh no!")
163    /// });
164    ///
165    /// assert!(lifetime.await.unwrap().is_err());
166    /// # });
167    /// ```
168    pub async fn try_init_async<E, F>(&self, init: F) -> Option<Result<StateLifetime<R>, E>>
169    where
170        F: AsyncFnOnce(&mut StateRegistry<R>) -> Result<(), E>,
171    {
172        if !self.try_start_init() {
173            return None;
174        }
175
176        let mut state = StateRegistry::default();
177        let result = init(&mut state).await;
178        let result = result.map(|_| self.finish_init(state));
179
180        Some(result)
181    }
182
183    #[must_use = "returns whether the state can now be initialized"]
184    fn try_start_init(&self) -> bool {
185        self.initialized
186            .compare_exchange(
187                UNINITIALIZED,
188                INITIALIZING,
189                std::sync::atomic::Ordering::AcqRel,
190                std::sync::atomic::Ordering::Acquire,
191            )
192            .is_ok()
193    }
194
195    fn finish_init(&self, mut state: StateRegistry<R>) -> StateLifetime<R> {
196        state.map.shrink_to_fit();
197
198        // https://github.com/rust-lang/rust-clippy/issues/11382
199        #[allow(clippy::arc_with_non_send_sync)]
200        let state = Arc::new(state);
201
202        unsafe { (*self.state.get()).write(Arc::downgrade(&state)) };
203
204        self.initialized
205            .store(INITIALIZED, std::sync::atomic::Ordering::Release);
206
207        StateLifetime { state: Some(state) }
208    }
209}
210impl<R> Drop for StateManager<R> {
211    fn drop(&mut self) {
212        let initialized = self.initialized.get_mut();
213
214        if *initialized == INITIALIZED {
215            *initialized = UNINITIALIZED;
216
217            unsafe { self.state.get_mut().assume_init_drop() };
218        }
219    }
220}
221unsafe impl<R> Sync for StateManager<R> {}
222
223/// This is the lifetime of your state. If this value is dropped, your state
224/// will be dropped with it, provided that there are no currently held
225/// references to the state.
226///
227/// It's recommended that you gracefully drop the state yourself at the end
228/// of your program by calling [`StateLifetime::try_drop`]. This gives you
229/// an opportunity to raise an error, log a message, or otherwise handle the
230/// situation where there are still held references to your state.
231///
232/// If this value is dropped and there are still held references to the state,
233/// nothing will happen. The values held within the state will not be dropped
234/// until the corresponding [`StateManager`] is dropped.
235#[must_use]
236pub struct StateLifetime<R> {
237    state: Option<Arc<StateRegistry<R>>>,
238}
239impl<R> StateLifetime<R> {
240    /// Attempts to drop the state, returning the [`StateLifetime`] as an
241    /// [`Err`] if there are still held references.
242    pub fn try_drop(mut self) -> Result<(), Self> {
243        let Some(state) = self.state.take() else {
244            return Err(self);
245        };
246
247        let Some(mut state) = Arc::into_inner(state) else {
248            return Err(self);
249        };
250
251        state.map.clear();
252
253        Ok(())
254    }
255}
256impl<R> Drop for StateLifetime<R> {
257    fn drop(&mut self) {
258        let Some(state) = self.state.take() else {
259            return;
260        };
261
262        let Some(state) = Arc::into_inner(state) else {
263            return;
264        };
265
266        drop(state);
267    }
268}
269impl<R> std::fmt::Debug for StateLifetime<R> {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        f.debug_struct("StateLifetime").finish()
272    }
273}
274
275/// A held reference to something in the [`StateManager`].
276///
277/// Whilst this is held, [`StateLifetime::try_drop`] will return its [`Err`]
278/// variant, and dropping the [`StateManager`] will not drop the values held
279/// within it.
280pub struct StateRef<'a, T: Send + Sync + 'static, R> {
281    pub(crate) _state: Arc<StateRegistry<R>>,
282    pub(crate) _phantom: PhantomData<&'a ()>,
283    pub(crate) value: *const T,
284}
285impl<T: Send + Sync + 'static, R> Deref for StateRef<'_, T, R> {
286    type Target = T;
287
288    #[inline(always)]
289    fn deref(&self) -> &Self::Target {
290        unsafe { &*self.value }
291    }
292}
293unsafe impl<T: Send + Sync + 'static, R> Send for StateRef<'_, T, R> {}
294unsafe impl<T: Send + Sync + 'static, R> Sync for StateRef<'_, T, R> {}
295
296/// The registry of values stored in the [`StateManager`].
297///
298/// You will be given an opportunity to populate this with your desired values
299/// when initializing the [`StateManager`] via [`StateManager::init`] or
300/// [`StateManager::try_init`].
301pub struct StateRegistry<R> {
302    map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
303    drop_order: Vec<TypeId>,
304    _phantom: PhantomData<R>,
305}
306impl<R> Default for StateRegistry<R> {
307    fn default() -> Self {
308        Self {
309            map: HashMap::new(),
310            drop_order: Vec::new(),
311            _phantom: PhantomData,
312        }
313    }
314}
315impl<R> StateRegistry<R> {
316    #[inline(always)]
317    pub(crate) fn get(&self, type_id: &TypeId) -> Option<&Box<dyn Any + Send + Sync>> {
318        self.map.get(type_id)
319    }
320
321    pub(crate) fn insert_(&mut self, type_id: TypeId, value: Box<dyn Any + Send + Sync>) {
322        self.map.insert(type_id, value);
323        self.drop_order.push(type_id);
324    }
325
326    /// Inserts a value into the state.
327    pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
328        self.insert_(TypeId::of::<T>(), Box::new(value));
329    }
330}
331impl<R> Drop for StateRegistry<R> {
332    fn drop(&mut self) {
333        for type_id in self.drop_order.iter().rev() {
334            self.map.remove(type_id);
335        }
336    }
337}
338
339#[test]
340fn test_drop_order() {
341    let mut reg = StateRegistry::<crate::AnyContext>::default();
342
343    static DROP: AtomicU8 = AtomicU8::new(0);
344
345    struct Foo<const N: u8>;
346    impl<const N: u8> Drop for Foo<N> {
347        fn drop(&mut self) {
348            let drop = DROP.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
349
350            if drop != N {
351                panic!("drop order is incorrect, expected {N}, got {drop}");
352            } else {
353                println!("dropped Foo<{N}>");
354            }
355        }
356    }
357
358    let d = Foo::<3>;
359    let c = Foo::<2>;
360    let b = Foo::<1>;
361    let a = Foo::<0>;
362
363    reg.map.insert(TypeId::of::<Foo<0>>(), Box::new(a));
364    reg.map.insert(TypeId::of::<Foo<1>>(), Box::new(b));
365    reg.map.insert(TypeId::of::<Foo<2>>(), Box::new(c));
366    reg.map.insert(TypeId::of::<Foo<3>>(), Box::new(d));
367
368    reg.drop_order.push(TypeId::of::<Foo<3>>());
369    reg.drop_order.push(TypeId::of::<Foo<2>>());
370    reg.drop_order.push(TypeId::of::<Foo<1>>());
371    reg.drop_order.push(TypeId::of::<Foo<0>>());
372
373    drop(reg);
374}
375
376#[test]
377fn test_drop_order_rev() {
378    let mut reg = StateRegistry::<crate::AnyContext>::default();
379
380    static DROP: AtomicU8 = AtomicU8::new(0);
381
382    struct Foo<const N: u8>;
383    impl<const N: u8> Drop for Foo<N> {
384        fn drop(&mut self) {
385            let drop = DROP.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
386
387            if drop != N {
388                panic!("drop order is incorrect, expected {N}, got {drop}");
389            } else {
390                println!("dropped Foo<{N}>");
391            }
392        }
393    }
394
395    let d = Foo::<3>;
396    let c = Foo::<2>;
397    let b = Foo::<1>;
398    let a = Foo::<0>;
399
400    reg.map.insert(TypeId::of::<Foo<3>>(), Box::new(d));
401    reg.map.insert(TypeId::of::<Foo<2>>(), Box::new(c));
402    reg.map.insert(TypeId::of::<Foo<1>>(), Box::new(b));
403    reg.map.insert(TypeId::of::<Foo<0>>(), Box::new(a));
404
405    reg.drop_order.push(TypeId::of::<Foo<3>>());
406    reg.drop_order.push(TypeId::of::<Foo<2>>());
407    reg.drop_order.push(TypeId::of::<Foo<1>>());
408    reg.drop_order.push(TypeId::of::<Foo<0>>());
409
410    drop(reg);
411}