1use crate::{INITIALIZED, INITIALIZING, UNINITIALIZED};
2use std::{
3 any::{Any, TypeId},
4 cell::UnsafeCell,
5 collections::HashMap,
6 marker::PhantomData,
7 mem::MaybeUninit,
8 ops::Deref,
9 sync::{atomic::AtomicU8, Arc, Weak},
10};
11
12pub struct StateManager<R> {
14 pub(crate) state: UnsafeCell<MaybeUninit<Weak<StateRegistry<R>>>>,
15 pub(crate) initialized: AtomicU8,
16 pub(crate) _phantom: PhantomData<R>,
17}
18impl<R> StateManager<R> {
19 pub(crate) const fn new_() -> Self {
20 Self {
21 state: UnsafeCell::new(MaybeUninit::uninit()),
22 initialized: AtomicU8::new(UNINITIALIZED),
23 _phantom: PhantomData,
24 }
25 }
26
27 pub fn init<F>(&self, init: F) -> StateLifetime<R>
52 where
53 F: FnOnce(&mut StateRegistry<R>),
54 {
55 match self.try_init(|state| {
56 init(state);
57
58 Ok::<_, ()>(())
59 }) {
60 Some(Ok(result)) => result,
61 Some(Err(_)) => unreachable!(),
62 None => panic!("State already initialized or is currently initializing"),
63 }
64 }
65
66 pub fn try_init<E, F>(&self, init: F) -> Option<Result<StateLifetime<R>, E>>
85 where
86 F: FnOnce(&mut StateRegistry<R>) -> Result<(), E>,
87 {
88 if !self.try_start_init() {
89 return None;
90 }
91
92 let mut state = StateRegistry::default();
93
94 let result = init(&mut state);
95 let result = result.map(|_| self.finish_init(state));
96
97 Some(result)
98 }
99
100 pub async fn init_async<F>(&self, init: F) -> StateLifetime<R>
129 where
130 F: AsyncFnOnce(&mut StateRegistry<R>),
131 {
132 match self
133 .try_init_async(async |state| {
134 init(state).await;
135
136 Ok::<_, ()>(())
137 })
138 .await
139 {
140 Some(Ok(result)) => result,
141 Some(Err(_)) => unreachable!(),
142 None => panic!("State already initialized or is currently initializing"),
143 }
144 }
145
146 pub async fn try_init_async<E, F>(&self, init: F) -> Option<Result<StateLifetime<R>, E>>
169 where
170 F: AsyncFnOnce(&mut StateRegistry<R>) -> Result<(), E>,
171 {
172 if !self.try_start_init() {
173 return None;
174 }
175
176 let mut state = StateRegistry::default();
177 let result = init(&mut state).await;
178 let result = result.map(|_| self.finish_init(state));
179
180 Some(result)
181 }
182
183 #[must_use = "returns whether the state can now be initialized"]
184 fn try_start_init(&self) -> bool {
185 self.initialized
186 .compare_exchange(
187 UNINITIALIZED,
188 INITIALIZING,
189 std::sync::atomic::Ordering::AcqRel,
190 std::sync::atomic::Ordering::Acquire,
191 )
192 .is_ok()
193 }
194
195 fn finish_init(&self, mut state: StateRegistry<R>) -> StateLifetime<R> {
196 state.map.shrink_to_fit();
197
198 #[allow(clippy::arc_with_non_send_sync)]
200 let state = Arc::new(state);
201
202 unsafe { (*self.state.get()).write(Arc::downgrade(&state)) };
203
204 self.initialized
205 .store(INITIALIZED, std::sync::atomic::Ordering::Release);
206
207 StateLifetime { state: Some(state) }
208 }
209}
210impl<R> Drop for StateManager<R> {
211 fn drop(&mut self) {
212 let initialized = self.initialized.get_mut();
213
214 if *initialized == INITIALIZED {
215 *initialized = UNINITIALIZED;
216
217 unsafe { self.state.get_mut().assume_init_drop() };
218 }
219 }
220}
221unsafe impl<R> Sync for StateManager<R> {}
222
223#[must_use]
236pub struct StateLifetime<R> {
237 state: Option<Arc<StateRegistry<R>>>,
238}
239impl<R> StateLifetime<R> {
240 pub fn try_drop(mut self) -> Result<(), Self> {
243 let Some(state) = self.state.take() else {
244 return Err(self);
245 };
246
247 let Some(mut state) = Arc::into_inner(state) else {
248 return Err(self);
249 };
250
251 state.map.clear();
252
253 Ok(())
254 }
255}
256impl<R> Drop for StateLifetime<R> {
257 fn drop(&mut self) {
258 let Some(state) = self.state.take() else {
259 return;
260 };
261
262 let Some(state) = Arc::into_inner(state) else {
263 return;
264 };
265
266 drop(state);
267 }
268}
269impl<R> std::fmt::Debug for StateLifetime<R> {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 f.debug_struct("StateLifetime").finish()
272 }
273}
274
275pub struct StateRef<'a, T: Send + Sync + 'static, R> {
281 pub(crate) _state: Arc<StateRegistry<R>>,
282 pub(crate) _phantom: PhantomData<&'a ()>,
283 pub(crate) value: *const T,
284}
285impl<T: Send + Sync + 'static, R> Deref for StateRef<'_, T, R> {
286 type Target = T;
287
288 #[inline(always)]
289 fn deref(&self) -> &Self::Target {
290 unsafe { &*self.value }
291 }
292}
293unsafe impl<T: Send + Sync + 'static, R> Send for StateRef<'_, T, R> {}
294unsafe impl<T: Send + Sync + 'static, R> Sync for StateRef<'_, T, R> {}
295
296pub struct StateRegistry<R> {
302 map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
303 drop_order: Vec<TypeId>,
304 _phantom: PhantomData<R>,
305}
306impl<R> Default for StateRegistry<R> {
307 fn default() -> Self {
308 Self {
309 map: HashMap::new(),
310 drop_order: Vec::new(),
311 _phantom: PhantomData,
312 }
313 }
314}
315impl<R> StateRegistry<R> {
316 #[inline(always)]
317 pub(crate) fn get(&self, type_id: &TypeId) -> Option<&Box<dyn Any + Send + Sync>> {
318 self.map.get(type_id)
319 }
320
321 pub(crate) fn insert_(&mut self, type_id: TypeId, value: Box<dyn Any + Send + Sync>) {
322 self.map.insert(type_id, value);
323 self.drop_order.push(type_id);
324 }
325
326 pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
328 self.insert_(TypeId::of::<T>(), Box::new(value));
329 }
330}
331impl<R> Drop for StateRegistry<R> {
332 fn drop(&mut self) {
333 for type_id in self.drop_order.iter().rev() {
334 self.map.remove(type_id);
335 }
336 }
337}
338
339#[test]
340fn test_drop_order() {
341 let mut reg = StateRegistry::<crate::AnyContext>::default();
342
343 static DROP: AtomicU8 = AtomicU8::new(0);
344
345 struct Foo<const N: u8>;
346 impl<const N: u8> Drop for Foo<N> {
347 fn drop(&mut self) {
348 let drop = DROP.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
349
350 if drop != N {
351 panic!("drop order is incorrect, expected {N}, got {drop}");
352 } else {
353 println!("dropped Foo<{N}>");
354 }
355 }
356 }
357
358 let d = Foo::<3>;
359 let c = Foo::<2>;
360 let b = Foo::<1>;
361 let a = Foo::<0>;
362
363 reg.map.insert(TypeId::of::<Foo<0>>(), Box::new(a));
364 reg.map.insert(TypeId::of::<Foo<1>>(), Box::new(b));
365 reg.map.insert(TypeId::of::<Foo<2>>(), Box::new(c));
366 reg.map.insert(TypeId::of::<Foo<3>>(), Box::new(d));
367
368 reg.drop_order.push(TypeId::of::<Foo<3>>());
369 reg.drop_order.push(TypeId::of::<Foo<2>>());
370 reg.drop_order.push(TypeId::of::<Foo<1>>());
371 reg.drop_order.push(TypeId::of::<Foo<0>>());
372
373 drop(reg);
374}
375
376#[test]
377fn test_drop_order_rev() {
378 let mut reg = StateRegistry::<crate::AnyContext>::default();
379
380 static DROP: AtomicU8 = AtomicU8::new(0);
381
382 struct Foo<const N: u8>;
383 impl<const N: u8> Drop for Foo<N> {
384 fn drop(&mut self) {
385 let drop = DROP.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
386
387 if drop != N {
388 panic!("drop order is incorrect, expected {N}, got {drop}");
389 } else {
390 println!("dropped Foo<{N}>");
391 }
392 }
393 }
394
395 let d = Foo::<3>;
396 let c = Foo::<2>;
397 let b = Foo::<1>;
398 let a = Foo::<0>;
399
400 reg.map.insert(TypeId::of::<Foo<3>>(), Box::new(d));
401 reg.map.insert(TypeId::of::<Foo<2>>(), Box::new(c));
402 reg.map.insert(TypeId::of::<Foo<1>>(), Box::new(b));
403 reg.map.insert(TypeId::of::<Foo<0>>(), Box::new(a));
404
405 reg.drop_order.push(TypeId::of::<Foo<3>>());
406 reg.drop_order.push(TypeId::of::<Foo<2>>());
407 reg.drop_order.push(TypeId::of::<Foo<1>>());
408 reg.drop_order.push(TypeId::of::<Foo<0>>());
409
410 drop(reg);
411}