state_department/
async.rs1use crate::{
2 lazy::LazyState,
3 manager::{StateManager, StateRef},
4 StateRegistry, INITIALIZED,
5};
6use async_once_cell::OnceCell;
7use std::{
8 any::{Any, TypeId},
9 cell::UnsafeCell,
10 future::Future,
11 marker::PhantomData,
12 pin::Pin,
13};
14
15pub struct AsyncOnlyContext;
17
18impl StateManager<AsyncOnlyContext> {
19 pub const fn new() -> Self {
30 Self::new_()
31 }
32
33 #[must_use]
64 pub async fn get<T: Send + Sync + 'static>(&self) -> StateRef<'_, T, AsyncOnlyContext> {
65 match self.try_get().await {
66 Some(v) => v,
67 None => panic!("State for {:?} not found", std::any::type_name::<T>()),
68 }
69 }
70
71 #[must_use]
102 pub async fn try_get<T: Send + Sync + 'static>(
103 &self,
104 ) -> Option<StateRef<'_, T, AsyncOnlyContext>> {
105 if self.initialized.load(std::sync::atomic::Ordering::Acquire) != INITIALIZED {
106 return None;
107 }
108
109 let state = unsafe { (*self.state.get()).assume_init_ref() }.upgrade()?;
110
111 if let Some(value) = state.get(&TypeId::of::<T>()) {
112 let value = value.as_ref() as &dyn Any;
113
114 if let Some(value) = value.downcast_ref::<T>() {
115 return Some(StateRef {
116 value,
117 _state: state,
118 _phantom: PhantomData,
119 });
120 }
121
122 if let Some(value) = value.downcast_ref::<LazyState<T>>() {
123 return Some(StateRef {
124 value: value.get(),
125 _state: state,
126 _phantom: PhantomData,
127 });
128 }
129
130 if let Some(value) = value.downcast_ref::<AsyncLazyState<T>>() {
131 return Some(StateRef {
132 value: value.get().await,
133 _state: state,
134 _phantom: PhantomData,
135 });
136 }
137 }
138
139 None
140 }
141}
142impl Default for StateManager<AsyncOnlyContext> {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148impl StateRegistry<AsyncOnlyContext> {
149 pub fn insert_async_lazy<T, F>(&mut self, init: F)
181 where
182 T: Send + Sync + 'static,
183 F: Future<Output = T> + Send + 'static,
184 {
185 self.insert_(
186 TypeId::of::<T>(),
187 Box::new(AsyncLazyState {
188 init: UnsafeCell::new(Some(Box::pin(init))),
189 once: OnceCell::new(),
190 }),
191 );
192 }
193}
194
195struct AsyncLazyState<T: Send + Sync + 'static> {
196 init: UnsafeCell<Option<Pin<Box<dyn Future<Output = T> + Send + 'static>>>>,
197 once: OnceCell<T>,
198}
199impl<T: Send + Sync + 'static> AsyncLazyState<T> {
200 async fn get(&self) -> &T {
201 self.once
202 .get_or_init(async {
203 let init = unsafe { (*self.init.get()).take() }.unwrap();
204 init.await
205 })
206 .await
207 }
208}
209unsafe impl<T: Send + Sync + 'static> Send for AsyncLazyState<T> {}
210unsafe impl<T: Send + Sync + 'static> Sync for AsyncLazyState<T> {}
211
212#[test]
213fn test_state() {
214 use std::sync::atomic::AtomicU8;
215
216 tokio_test::block_on(async {
217 let state = StateManager::<AsyncOnlyContext>::new();
218
219 struct Foo {
220 bar: AtomicU8,
221 }
222
223 struct Baz {
224 qux: i32,
225 }
226
227 let lifetime = state.init(|state| {
228 state.insert(Foo {
229 bar: AtomicU8::new(42),
230 });
231
232 state.insert(Baz { qux: 24 });
233 });
234
235 {
236 let foo: StateRef<'_, Foo, AsyncOnlyContext> = state.get::<Foo>().await;
237
238 assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Relaxed), 42);
239
240 foo.bar.store(24, std::sync::atomic::Ordering::Release);
241 }
242
243 {
244 let foo = state.get::<Foo>().await;
245
246 assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Acquire), 24);
247 }
248
249 {
250 let baz = state.get::<Baz>().await;
251
252 assert_eq!(baz.qux, 24);
253 }
254
255 lifetime.try_drop().unwrap();
256 });
257}
258
259#[test]
260fn test_state_drop_with_ref() {
261 tokio_test::block_on(async {
262 let state = StateManager::<AsyncOnlyContext>::new();
263
264 struct Foo;
265
266 let lifetime = state.init(|state| {
267 state.insert(Foo);
268 });
269
270 let _foo = state.get::<Foo>().await;
271
272 let _ = lifetime.try_drop().unwrap_err();
273 });
274}
275
276#[test]
277fn test_state_use_after_lifetime_drop() {
278 tokio_test::block_on(async {
279 let state = StateManager::<AsyncOnlyContext>::new();
280
281 struct Foo;
282
283 let lifetime = state.init(|state| {
284 state.insert(Foo);
285 });
286
287 lifetime.try_drop().unwrap();
288
289 assert!(state.try_get::<Foo>().await.is_none());
290 });
291}
292
293#[test]
294fn test_state_drop_without_lifetime() {
295 use std::sync::atomic::AtomicU8;
296
297 static DROPPED: AtomicU8 = AtomicU8::new(0);
298
299 tokio_test::block_on(async {
300 let state = StateManager::<AsyncOnlyContext>::new();
301
302 struct Foo;
303 impl Drop for Foo {
304 fn drop(&mut self) {
305 DROPPED.store(1, std::sync::atomic::Ordering::Release);
306 }
307 }
308
309 let lifetime = state.init(|state| {
310 state.insert(Foo);
311 });
312
313 let foo = state.get::<Foo>().await;
314
315 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
316
317 drop(lifetime);
318
319 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
320
321 drop(foo);
322
323 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
324
325 drop(state);
326
327 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
328 });
329}
330
331#[test]
332fn test_lazy_initialization() {
333 use std::sync::atomic::AtomicU8;
334
335 static FOO_INITIALIZED: AtomicU8 = AtomicU8::new(0);
336
337 tokio_test::block_on(async {
338 let state = StateManager::<AsyncOnlyContext>::new();
339
340 struct Foo {
341 bar: i32,
342 }
343
344 let _lifetime = state.init(|state| {
345 state.insert_async_lazy(async {
346 FOO_INITIALIZED.store(1, std::sync::atomic::Ordering::Release);
347
348 Foo { bar: 42 }
349 });
350 });
351
352 assert_eq!(
353 FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
354 0
355 );
356
357 let foo = state.get::<Foo>().await;
358
359 assert_eq!(
360 FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
361 1
362 );
363
364 assert_eq!(foo.bar, 42);
365 });
366}
367
368#[test]
369fn test_sync_lazy_initialization_from_async() {
370 use std::sync::atomic::AtomicU8;
371
372 static FOO_INITIALIZED: AtomicU8 = AtomicU8::new(0);
373
374 tokio_test::block_on(async {
375 let state = StateManager::<AsyncOnlyContext>::new();
376
377 struct Foo {
378 bar: i32,
379 }
380
381 let _lifetime = state.init(|state| {
382 state.insert_lazy(|| {
383 FOO_INITIALIZED.store(1, std::sync::atomic::Ordering::Release);
384
385 Foo { bar: 42 }
386 });
387 });
388
389 assert_eq!(
390 FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
391 0
392 );
393
394 let foo = state.get::<Foo>().await;
395
396 assert_eq!(
397 FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
398 1
399 );
400
401 assert_eq!(foo.bar, 42);
402 });
403}
404
405#[test]
406fn test_state_across_threads() {
407 use std::sync::atomic::AtomicU8;
408
409 static STATE: StateManager<AsyncOnlyContext> = StateManager::<AsyncOnlyContext>::new();
410
411 tokio_test::block_on(async {
412 struct Foo {
413 bar: AtomicU8,
414 }
415
416 let _lifetime = STATE.init(|state| {
417 state.insert(Foo {
418 bar: AtomicU8::new(0),
419 });
420 });
421
422 let thread_count = 10;
423
424 let barrier = std::sync::Arc::new(tokio::sync::Barrier::new(thread_count));
425
426 let threads = (0..thread_count)
427 .map(|_| {
428 let barrier_ref = barrier.clone();
429
430 tokio::spawn(async move {
431 barrier_ref.wait().await;
432
433 STATE
434 .get::<Foo>()
435 .await
436 .bar
437 .fetch_add(1, std::sync::atomic::Ordering::Release);
438 })
439 })
440 .collect::<Vec<_>>();
441
442 for thread in threads {
443 thread.await.unwrap();
444 }
445
446 assert_eq!(
447 STATE
448 .get::<Foo>()
449 .await
450 .bar
451 .load(std::sync::atomic::Ordering::Acquire),
452 thread_count as u8
453 );
454 });
455}
456
457#[test]
458#[should_panic = "State for \"()\" not found"]
459fn test_state_get_inside_init() {
460 tokio_test::block_on(async {
461 let state = StateManager::<AsyncOnlyContext>::new();
462 let _ = state
463 .init_async(async |r| {
464 r.insert(());
465
466 let _ = state.get::<()>().await;
467 })
468 .await;
469 });
470}
471
472#[test]
473#[should_panic = "State already initialized or is currently initializing"]
474fn test_state_init_inside_init() {
475 tokio_test::block_on(async {
476 let state = StateManager::<AsyncOnlyContext>::new();
477 let _ = state.init(|_| {
478 let _ = state.init(|_| {});
479 });
480 });
481}
482
483#[test]
484#[should_panic = "State already initialized or is currently initializing"]
485fn test_state_already_initialized() {
486 tokio_test::block_on(async {
487 let state = StateManager::<AsyncOnlyContext>::new();
488 let _ = state.init(|_| {});
489 let _ = state.init(|_| {});
490 });
491}