1use core::future::Future;
55use std::{any::Any, cell::RefCell, pin::Pin, rc::Rc, sync::Mutex, task::Poll};
56
57use pin_project::pin_project;
58
59thread_local! {
60 static CONTEXT: RefCell<Option<Rc<RefCell<dyn Any>>>> = RefCell::new(None);
61 static HAS_CONTEXT: RefCell<bool> = RefCell::new(false);
62}
63
64#[pin_project]
67pub struct AsyncContext<C, T, F>
68where
69 C: 'static + ToString,
70 F: Future<Output = T>,
71{
72 ctx: Mutex<Option<C>>,
74
75 #[pin]
77 future: F,
78}
79
80pub fn with_async_context<C, T, F>(ctx: C, future: F) -> AsyncContext<C, T, F>
120where
121 C: 'static + ToString,
122 F: Future<Output = T>,
123{
124 if HAS_CONTEXT.with(|x| *x.borrow()) {
126 panic!("Cannot create nested contexts.");
127 }
128
129 AsyncContext {
130 ctx: Mutex::new(Some(ctx)),
131 future,
132 }
133}
134
135impl<C, T, F> Future for AsyncContext<C, T, F>
136where
137 C: 'static + ToString,
138 F: Future<Output = T>,
139{
140 type Output = (T, C);
142
143 fn poll(
144 self: std::pin::Pin<&mut Self>,
145 cx: &mut std::task::Context<'_>,
146 ) -> core::task::Poll<Self::Output> {
147 let ctx: Option<C> = self
149 .ctx
150 .lock()
151 .expect("Failed to lock context mutex")
152 .take();
153 let ctx = Rc::new(RefCell::new(ctx));
154
155 HAS_CONTEXT.with(|x| *x.borrow_mut() = true);
157
158 CONTEXT.with(|x| *x.borrow_mut() = Some(ctx.clone()));
160
161 let projection = self.project();
163 let future: Pin<&mut F> = projection.future;
164
165 let poll = future.poll(cx);
167
168 let ctx = ctx.take().expect(
169 "No context is attached to the AyncContext - this is not supposed to be possible.",
170 );
171
172 HAS_CONTEXT.with(|x| *x.borrow_mut() = false);
174 CONTEXT.with(|x| *x.borrow_mut() = None);
175
176 match poll {
177 Poll::Ready(value) => return Poll::Ready((value, ctx)),
179 Poll::Pending => {
181 projection
182 .ctx
183 .lock()
184 .expect("Failed to lock context mutex")
185 .replace(ctx);
186 return Poll::Pending;
187 }
188 }
189 }
190}
191
192pub fn context_as_string<C: 'static + ToString>() -> String {
194 from_context(|ctx: Option<&C>| match ctx {
195 Some(c) => c.to_string(),
196 None => "(no context)".to_string(),
197 })
198}
199
200pub fn from_context<C, F, R>(f: F) -> R
227where
228 F: FnOnce(Option<&C>) -> R,
229 C: 'static,
230{
231 CONTEXT.with(|value| match value.borrow().as_ref() {
232 None => f(None),
233 Some(ctx) => {
234 let ctx_inner = ctx.borrow();
235 let ctx_ref = ctx_inner
236 .downcast_ref::<Option<C>>()
237 .expect("Context type mismatch");
238 match ctx_ref {
239 Some(c) => f(Some(c)),
240 None => f(None),
241 }
242 }
243 })
244}
245
246pub fn from_context_mut<C, F, R>(f: F) -> R
275where
276 F: FnOnce(Option<&mut C>) -> R,
277 C: 'static,
278{
279 CONTEXT.with(|value| {
280 let mut binding = value.borrow_mut();
281 match binding.as_mut() {
282 None => f(None),
283 Some(ctx) => {
284 let mut ctx_inner = ctx.borrow_mut();
285 let ctx_ref = ctx_inner
286 .downcast_mut::<Option<C>>()
287 .expect("Context type mismatch");
288 match ctx_ref {
289 Some(c) => f(Some(c)),
290 None => f(None),
291 }
292 }
293 }
294 })
295}
296
297#[cfg(test)]
298mod tests {
299 use std::{cell::RefCell, fmt::Display, sync::Arc, time::Duration};
300
301 use tokio::time::sleep;
302
303 use super::*;
304
305 #[tokio::test]
306 async fn test_basic_context() {
307 async fn runs_with_context() -> String {
308 let value = from_context(|value: Option<&String>| value.unwrap().clone());
309 value
310 }
311
312 let async_context = with_async_context("foobar".to_string(), runs_with_context());
313 let (value, ctx) = async_context.await;
314
315 assert_eq!("foobar", value);
316 assert_eq!("foobar", &*ctx);
317 }
318
319 #[tokio::test]
320 async fn test_mutable_context() {
321 #[derive(Debug)]
322 struct IntWrapper(RefCell<i32>);
323
324 impl Display for IntWrapper {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 write!(f, "{}", self.0.borrow())
327 }
328 }
329
330 async fn mutate_context() -> i32 {
331 from_context(|value: Option<&IntWrapper>| {
332 let val = value.unwrap();
333 *val.0.borrow_mut() += 5;
334 *val.0.borrow()
335 })
336 }
337
338 let async_context = with_async_context(IntWrapper(RefCell::new(10)), mutate_context());
339 let (value, ctx) = async_context.await;
340
341 assert_eq!(15, value);
342 assert_eq!("15", ctx.to_string());
343 }
344
345 #[tokio::test]
346 async fn test_complex_type() {
347 #[derive(Debug, Clone, PartialEq)]
348 struct TestStruct {
349 name: String,
350 count: i32,
351 }
352
353 impl Display for TestStruct {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 write!(f, "{}:{}", self.name, self.count)
356 }
357 }
358
359 async fn use_complex_context() -> TestStruct {
360 from_context(|value: Option<&TestStruct>| value.unwrap().clone())
361 }
362
363 let test_struct = TestStruct {
364 name: "test".to_string(),
365 count: 42,
366 };
367
368 let async_context = with_async_context(test_struct.clone(), use_complex_context());
369 let (value, ctx) = async_context.await;
370
371 assert_eq!(test_struct, value);
372 assert_eq!(test_struct, ctx);
373 }
374
375 #[tokio::test]
376 async fn test_arc_context() {
377 #[derive(Debug)]
378 struct ArcWrapper(Arc<i32>);
379
380 impl Display for ArcWrapper {
381 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382 write!(f, "{}", *self.0)
383 }
384 }
385
386 async fn use_arc_context() -> i32 {
387 from_context(|value: Option<&ArcWrapper>| *value.unwrap().0)
388 }
389
390 let arc_value = Arc::new(100);
391 let async_context = with_async_context(ArcWrapper(arc_value.clone()), use_arc_context());
392 let (value, _) = async_context.await;
393
394 assert_eq!(100, value);
395 }
396
397 #[tokio::test]
398 #[should_panic(expected = "No context found while using from_context")]
399 async fn test_missing_context() {
400 async fn runs_without_context() {
401 from_context(|v: Option<&String>| {
402 v.cloned()
403 .expect("No context found while using from_context")
404 });
405 }
406
407 runs_without_context().await;
408 }
409
410 #[tokio::test]
411 #[should_panic(expected = "Cannot create nested context")]
412 async fn test_nested_contexts() {
413 async fn inner_fn() -> String {
414 let inner_val = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
415 sleep(Duration::from_millis(50)).await;
416 inner_val
417 }
418
419 async fn outer_fn() -> String {
420 let outer_val = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
421 let inner_context = with_async_context("inner".to_string(), inner_fn()).await;
422 format!("{}-{}", outer_val, inner_context.0)
423 }
424
425 let context = with_async_context("outer".to_string(), outer_fn());
426 let _ = context.await;
427 }
428
429 #[tokio::test]
430 async fn test_context_persistence() {
431 async fn task_with_delay() -> String {
432 let val = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
433 sleep(Duration::from_millis(50)).await;
434 let val2 = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
435 assert_eq!(val, val2);
436 val
437 }
438
439 let context = with_async_context("test".to_string(), task_with_delay());
440 let (result, _) = context.await;
441 assert_eq!("test", result);
442 }
443
444 #[tokio::test]
445 async fn test_parallel_contexts() {
446 #[derive(Debug)]
447 struct IntWrapper(Arc<i32>);
448
449 impl Display for IntWrapper {
450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451 write!(f, "{}", *self.0)
452 }
453 }
454
455 async fn task(id: i32) -> i32 {
456 let val = from_context(|ctx: Option<&IntWrapper>| *ctx.unwrap().0);
457 sleep(Duration::from_millis(50)).await;
458 val + id
459 }
460
461 let task1 = with_async_context(IntWrapper(Arc::new(1)), task(10));
462 let task2 = with_async_context(IntWrapper(Arc::new(2)), task(20));
463 let task3 = with_async_context(IntWrapper(Arc::new(3)), task(30));
464
465 let ((r1, _), (r2, _), (r3, _)) = tokio::join!(task1, task2, task3);
466
467 assert_eq!(r1, 11);
468 assert_eq!(r2, 22);
469 assert_eq!(r3, 33);
470 }
471
472 #[tokio::test]
473 async fn test_simple_nested_chains() {
474 #[derive(Debug)]
475 struct SimpleContext {
476 value: i32,
477 }
478
479 impl Display for SimpleContext {
480 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481 write!(f, "Value: {}", self.value)
482 }
483 }
484
485 fn nested_task(depth: i32) -> Pin<Box<dyn Future<Output = i32> + Send>> {
486 Box::pin(async move {
487 if depth == 0 {
488 return from_context(|ctx: Option<&SimpleContext>| ctx.unwrap().value);
489 }
490
491 sleep(Duration::from_millis(10)).await;
492 nested_task(depth - 1).await + 1
493 })
494 }
495
496 let context = SimpleContext { value: 42 };
497 let (result, _) = with_async_context(context, nested_task(3)).await;
498
499 assert_eq!(result, 45);
501 }
502
503 #[tokio::test]
504 async fn test_value_chains() {
505 #[derive(Debug)]
506 struct NumberContext {
507 value: Arc<i32>,
508 }
509
510 impl Display for NumberContext {
511 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512 write!(f, "Number: {}", *self.value)
513 }
514 }
515
516 fn check_value(
517 depth: i32,
518 expected_value: i32,
519 ) -> Pin<Box<dyn Future<Output = i32> + Send>> {
520 Box::pin(async move {
521 let ret = from_context(|ctx: Option<&NumberContext>| {
522 let value = *ctx.unwrap().value;
523 assert_eq!(value, expected_value, "Context value changed");
524 value
525 });
526 if depth == 0 {
527 return ret;
528 }
529
530 sleep(Duration::from_millis(1)).await;
531 check_value(depth - 1, expected_value).await
532 })
533 }
534
535 async fn run_value_chain(n: i32) -> i32 {
536 let ctx = NumberContext { value: Arc::new(n) };
537 let result = tokio::task::LocalSet::new()
539 .run_until(async move {
540 let (result, _) = with_async_context(ctx, check_value(10, n)).await;
541 result
542 })
543 .await;
544 result
545 }
546
547 let local = tokio::task::LocalSet::new();
548 local.spawn_local(async {
549 let mut chain_tasks = Vec::new();
550 for i in 0..500 {
551 let handle = tokio::task::spawn_local(run_value_chain(i));
552 chain_tasks.push(handle);
553 }
554
555 let results = futures::future::join_all(chain_tasks).await;
556
557 for (i, result) in results.into_iter().enumerate() {
558 assert_eq!(result.unwrap(), i as i32);
559 }
560 });
561 local.await;
562 }
563
564 #[tokio::test]
565 #[should_panic(expected = "Context type mismatch")]
566 async fn test_wrong_context_type() {
567 #[derive(Debug)]
568 struct Context1 {
569 value: i32,
570 }
571
572 impl Display for Context1 {
573 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574 write!(f, "{}", self.value)
575 }
576 }
577
578 #[derive(Debug)]
579 struct Context2;
580
581 async fn access_wrong_type() {
582 from_context(|ctx: Option<&Context2>| {
584 let _ = ctx.unwrap();
585 });
586 }
587
588 let ctx = Context1 { value: 42 };
589 let context = with_async_context(ctx, access_wrong_type());
590 let _ = context.await;
591 }
592}