with_async_context/
async_context.rs

1//! # Async Context Library
2//!
3//! This code was oringally written by Michel Smola and released under the MIT license.
4//! The original crate can be found at https://docs.rs/async-context/0.1.1/async_context/index.html
5//! It was adapted for use at Wonop by Troels Frimodt Rønnow.
6//!
7//! This library provides a mechanism for managing contextual data across async tasks
8//! in Rust applications. It allows you to safely share state within the scope of a
9//! single async execution context.
10//!
11//! ## Key Features
12//!
13//! - Thread-safe context management using thread-local storage
14//! - Prevents nested context creation to avoid confusion and bugs
15//! - Support for both immutable and mutable access to context data
16//! - Type-safe context access through generics
17//! - Automatically cleans up context when the async scope exits
18//!
19//! ## Usage Example
20//!
21//! ```rust,no_run
22//! use with_async_context::{with_async_context, from_context, from_context_mut};
23//!
24//! #[derive(Clone)]
25//! struct MyContext {
26//!     some_value: String
27//! }
28//!
29//! impl ToString for MyContext {
30//!     fn to_string(&self) -> String {
31//!         self.some_value.clone()
32//!     }
33//! }
34//!
35//! async fn my_function() {
36//!     // Access the current context
37//!     let value = from_context(|ctx: Option<&MyContext>| {
38//!         ctx.unwrap().some_value.clone()
39//!     });
40//!
41//!     // Do something with the value...
42//!     from_context_mut(|ctx: Option<&mut MyContext>| {
43//!         if let Some(ctx) = ctx {
44//!             ctx.some_value = "updated value".to_string();
45//!         }
46//!     });
47//! }
48//!
49//! # async fn example() {
50//! let context = MyContext { some_value: "test".to_string() };
51//! let result = with_async_context(context, my_function()).await;
52//! # }
53//! ```
54use core::future::Future;
55use std::{any::Any, cell::RefCell, pin::Pin, rc::Rc, sync::Mutex, task::Poll};
56
57use pin_project::pin_project;
58
59thread_local! {
60    static CONTEXT: RefCell<Option<Rc<RefCell<dyn Any>>>> = RefCell::new(None);
61    static HAS_CONTEXT: RefCell<bool> = RefCell::new(false);
62}
63
64/// Represents an async execution context that carries data of type `C` while executing
65/// a future of type `F` that produces output of type `T`
66#[pin_project]
67pub struct AsyncContext<C, T, F>
68where
69    C: 'static + ToString,
70    F: Future<Output = T>,
71{
72    /// The context data wrapped in a mutex for thread-safety
73    ctx: Mutex<Option<C>>,
74
75    /// The future being executed, marked with #[pin] for self-referential struct support
76    #[pin]
77    future: F,
78}
79
80/// Creates a new async context and executes the provided future within it
81///
82/// # Arguments
83///
84/// * `ctx` - The context data to make available during future execution
85/// * `future` - The future to execute within the context
86///
87/// # Panics
88///
89/// Panics if attempting to create a nested context when one already exists
90///
91/// # Examples
92///
93/// ```rust,no_run
94/// use with_async_context::with_async_context;
95///
96/// struct MyContext;
97///
98/// impl MyContext {
99///     fn new() -> Self {
100///         MyContext
101///     }
102/// }
103///
104/// impl ToString for MyContext {
105///     fn to_string(&self) -> String {
106///         "MyContext".into()
107///     }
108/// }
109///
110/// async fn async_function() {}
111///
112/// # async fn example() {
113/// let result = with_async_context(
114///     MyContext::new(),
115///     async_function()
116/// ).await;
117/// # }
118/// ```
119pub fn with_async_context<C, T, F>(ctx: C, future: F) -> AsyncContext<C, T, F>
120where
121    C: 'static + ToString,
122    F: Future<Output = T>,
123{
124    // Prevent nested contexts
125    if HAS_CONTEXT.with(|x| *x.borrow()) {
126        panic!("Cannot create nested contexts.");
127    }
128
129    AsyncContext {
130        ctx: Mutex::new(Some(ctx)),
131        future,
132    }
133}
134
135impl<C, T, F> Future for AsyncContext<C, T, F>
136where
137    C: 'static + ToString,
138    F: Future<Output = T>,
139{
140    // Returns a tuple of the Future's output value and the context
141    type Output = (T, C);
142
143    fn poll(
144        self: std::pin::Pin<&mut Self>,
145        cx: &mut std::task::Context<'_>,
146    ) -> core::task::Poll<Self::Output> {
147        // Take ownership of the context from the mutex, removing it temporarily
148        let ctx: Option<C> = self
149            .ctx
150            .lock()
151            .expect("Failed to lock context mutex")
152            .take();
153        let ctx = Rc::new(RefCell::new(ctx));
154
155        // Set thread-local flags to indicate context is now active
156        HAS_CONTEXT.with(|x| *x.borrow_mut() = true);
157
158        // Store context in thread local storage
159        CONTEXT.with(|x| *x.borrow_mut() = Some(ctx.clone()));
160
161        // Project the pinned future to get mutable access
162        let projection = self.project();
163        let future: Pin<&mut F> = projection.future;
164
165        // Poll the inner future
166        let poll = future.poll(cx);
167
168        let ctx = ctx.take().expect(
169            "No context is attached to the AyncContext - this is not supposed to be possible.",
170        );
171
172        // Reset thread-local flag since we're done with this poll
173        HAS_CONTEXT.with(|x| *x.borrow_mut() = false);
174        CONTEXT.with(|x| *x.borrow_mut() = None);
175
176        match poll {
177            // If future is complete, return result with context
178            Poll::Ready(value) => return Poll::Ready((value, ctx)),
179            // If pending, restore context to mutex and return pending
180            Poll::Pending => {
181                projection
182                    .ctx
183                    .lock()
184                    .expect("Failed to lock context mutex")
185                    .replace(ctx);
186                return Poll::Pending;
187            }
188        }
189    }
190}
191
192/// Returns the current context as a string, or "(no context)" if none exists
193pub fn context_as_string<C: 'static + ToString>() -> String {
194    from_context(|ctx: Option<&C>| match ctx {
195        Some(c) => c.to_string(),
196        None => "(no context)".to_string(),
197    })
198}
199
200/// Provides immutable access to the current context value
201///
202/// # Type Parameters
203///
204/// * `C` - The type of the context value to access
205/// * `F` - The function type that will operate on the context
206/// * `R` - The return type from the function
207///
208/// # Arguments
209///
210/// * `f` - A function that receives an Option<&C> and returns R
211///
212/// # Examples
213///
214/// ```rust,no_run
215/// use with_async_context::from_context;
216///
217/// struct MyContext {
218///     some_value: String
219/// }
220/// # fn example() {
221/// let value = from_context(|ctx: Option<&MyContext>| {
222///     ctx.unwrap().some_value.clone()
223/// });
224/// # }
225/// ```
226pub fn from_context<C, F, R>(f: F) -> R
227where
228    F: FnOnce(Option<&C>) -> R,
229    C: 'static,
230{
231    CONTEXT.with(|value| match value.borrow().as_ref() {
232        None => f(None),
233        Some(ctx) => {
234            let ctx_inner = ctx.borrow();
235            let ctx_ref = ctx_inner
236                .downcast_ref::<Option<C>>()
237                .expect("Context type mismatch");
238            match ctx_ref {
239                Some(c) => f(Some(c)),
240                None => f(None),
241            }
242        }
243    })
244}
245
246/// Provides mutable access to the current context value
247///
248/// # Type Parameters
249///
250/// * `C` - The type of the context value to access
251/// * `F` - The function type that will operate on the context
252/// * `R` - The return type from the function
253///
254/// # Arguments
255///
256/// * `f` - A function that receives an Option<&mut C> and returns R
257///
258/// # Examples
259///
260/// ```rust,no_run
261/// use with_async_context::from_context_mut;
262///
263/// struct MyContext {
264///     counter: i32
265/// }
266/// # fn example() {
267/// from_context_mut(|ctx: Option<&mut MyContext>| {
268///     if let Some(ctx) = ctx {
269///         ctx.counter += 1;
270///     }
271/// });
272/// # }
273/// ```
274pub fn from_context_mut<C, F, R>(f: F) -> R
275where
276    F: FnOnce(Option<&mut C>) -> R,
277    C: 'static,
278{
279    CONTEXT.with(|value| {
280        let mut binding = value.borrow_mut();
281        match binding.as_mut() {
282            None => f(None),
283            Some(ctx) => {
284                let mut ctx_inner = ctx.borrow_mut();
285                let ctx_ref = ctx_inner
286                    .downcast_mut::<Option<C>>()
287                    .expect("Context type mismatch");
288                match ctx_ref {
289                    Some(c) => f(Some(c)),
290                    None => f(None),
291                }
292            }
293        }
294    })
295}
296
297#[cfg(test)]
298mod tests {
299    use std::{cell::RefCell, fmt::Display, sync::Arc, time::Duration};
300
301    use tokio::time::sleep;
302
303    use super::*;
304
305    #[tokio::test]
306    async fn test_basic_context() {
307        async fn runs_with_context() -> String {
308            let value = from_context(|value: Option<&String>| value.unwrap().clone());
309            value
310        }
311
312        let async_context = with_async_context("foobar".to_string(), runs_with_context());
313        let (value, ctx) = async_context.await;
314
315        assert_eq!("foobar", value);
316        assert_eq!("foobar", &*ctx);
317    }
318
319    #[tokio::test]
320    async fn test_mutable_context() {
321        #[derive(Debug)]
322        struct IntWrapper(RefCell<i32>);
323
324        impl Display for IntWrapper {
325            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326                write!(f, "{}", self.0.borrow())
327            }
328        }
329
330        async fn mutate_context() -> i32 {
331            from_context(|value: Option<&IntWrapper>| {
332                let val = value.unwrap();
333                *val.0.borrow_mut() += 5;
334                *val.0.borrow()
335            })
336        }
337
338        let async_context = with_async_context(IntWrapper(RefCell::new(10)), mutate_context());
339        let (value, ctx) = async_context.await;
340
341        assert_eq!(15, value);
342        assert_eq!("15", ctx.to_string());
343    }
344
345    #[tokio::test]
346    async fn test_complex_type() {
347        #[derive(Debug, Clone, PartialEq)]
348        struct TestStruct {
349            name: String,
350            count: i32,
351        }
352
353        impl Display for TestStruct {
354            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355                write!(f, "{}:{}", self.name, self.count)
356            }
357        }
358
359        async fn use_complex_context() -> TestStruct {
360            from_context(|value: Option<&TestStruct>| value.unwrap().clone())
361        }
362
363        let test_struct = TestStruct {
364            name: "test".to_string(),
365            count: 42,
366        };
367
368        let async_context = with_async_context(test_struct.clone(), use_complex_context());
369        let (value, ctx) = async_context.await;
370
371        assert_eq!(test_struct, value);
372        assert_eq!(test_struct, ctx);
373    }
374
375    #[tokio::test]
376    async fn test_arc_context() {
377        #[derive(Debug)]
378        struct ArcWrapper(Arc<i32>);
379
380        impl Display for ArcWrapper {
381            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382                write!(f, "{}", *self.0)
383            }
384        }
385
386        async fn use_arc_context() -> i32 {
387            from_context(|value: Option<&ArcWrapper>| *value.unwrap().0)
388        }
389
390        let arc_value = Arc::new(100);
391        let async_context = with_async_context(ArcWrapper(arc_value.clone()), use_arc_context());
392        let (value, _) = async_context.await;
393
394        assert_eq!(100, value);
395    }
396
397    #[tokio::test]
398    #[should_panic(expected = "No context found while using from_context")]
399    async fn test_missing_context() {
400        async fn runs_without_context() {
401            from_context(|v: Option<&String>| {
402                v.cloned()
403                    .expect("No context found while using from_context")
404            });
405        }
406
407        runs_without_context().await;
408    }
409
410    #[tokio::test]
411    #[should_panic(expected = "Cannot create nested context")]
412    async fn test_nested_contexts() {
413        async fn inner_fn() -> String {
414            let inner_val = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
415            sleep(Duration::from_millis(50)).await;
416            inner_val
417        }
418
419        async fn outer_fn() -> String {
420            let outer_val = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
421            let inner_context = with_async_context("inner".to_string(), inner_fn()).await;
422            format!("{}-{}", outer_val, inner_context.0)
423        }
424
425        let context = with_async_context("outer".to_string(), outer_fn());
426        let _ = context.await;
427    }
428
429    #[tokio::test]
430    async fn test_context_persistence() {
431        async fn task_with_delay() -> String {
432            let val = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
433            sleep(Duration::from_millis(50)).await;
434            let val2 = from_context(|ctx: Option<&String>| ctx.unwrap().clone());
435            assert_eq!(val, val2);
436            val
437        }
438
439        let context = with_async_context("test".to_string(), task_with_delay());
440        let (result, _) = context.await;
441        assert_eq!("test", result);
442    }
443
444    #[tokio::test]
445    async fn test_parallel_contexts() {
446        #[derive(Debug)]
447        struct IntWrapper(Arc<i32>);
448
449        impl Display for IntWrapper {
450            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451                write!(f, "{}", *self.0)
452            }
453        }
454
455        async fn task(id: i32) -> i32 {
456            let val = from_context(|ctx: Option<&IntWrapper>| *ctx.unwrap().0);
457            sleep(Duration::from_millis(50)).await;
458            val + id
459        }
460
461        let task1 = with_async_context(IntWrapper(Arc::new(1)), task(10));
462        let task2 = with_async_context(IntWrapper(Arc::new(2)), task(20));
463        let task3 = with_async_context(IntWrapper(Arc::new(3)), task(30));
464
465        let ((r1, _), (r2, _), (r3, _)) = tokio::join!(task1, task2, task3);
466
467        assert_eq!(r1, 11);
468        assert_eq!(r2, 22);
469        assert_eq!(r3, 33);
470    }
471
472    #[tokio::test]
473    async fn test_simple_nested_chains() {
474        #[derive(Debug)]
475        struct SimpleContext {
476            value: i32,
477        }
478
479        impl Display for SimpleContext {
480            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481                write!(f, "Value: {}", self.value)
482            }
483        }
484
485        fn nested_task(depth: i32) -> Pin<Box<dyn Future<Output = i32> + Send>> {
486            Box::pin(async move {
487                if depth == 0 {
488                    return from_context(|ctx: Option<&SimpleContext>| ctx.unwrap().value);
489                }
490
491                sleep(Duration::from_millis(10)).await;
492                nested_task(depth - 1).await + 1
493            })
494        }
495
496        let context = SimpleContext { value: 42 };
497        let (result, _) = with_async_context(context, nested_task(3)).await;
498
499        // Each level adds 1, so result should be 42 + 3
500        assert_eq!(result, 45);
501    }
502
503    #[tokio::test]
504    async fn test_value_chains() {
505        #[derive(Debug)]
506        struct NumberContext {
507            value: Arc<i32>,
508        }
509
510        impl Display for NumberContext {
511            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512                write!(f, "Number: {}", *self.value)
513            }
514        }
515
516        fn check_value(
517            depth: i32,
518            expected_value: i32,
519        ) -> Pin<Box<dyn Future<Output = i32> + Send>> {
520            Box::pin(async move {
521                let ret = from_context(|ctx: Option<&NumberContext>| {
522                    let value = *ctx.unwrap().value;
523                    assert_eq!(value, expected_value, "Context value changed");
524                    value
525                });
526                if depth == 0 {
527                    return ret;
528                }
529
530                sleep(Duration::from_millis(1)).await;
531                check_value(depth - 1, expected_value).await
532            })
533        }
534
535        async fn run_value_chain(n: i32) -> i32 {
536            let ctx = NumberContext { value: Arc::new(n) };
537            // Run in local task to avoid Send requirement
538            let result = tokio::task::LocalSet::new()
539                .run_until(async move {
540                    let (result, _) = with_async_context(ctx, check_value(10, n)).await;
541                    result
542                })
543                .await;
544            result
545        }
546
547        let local = tokio::task::LocalSet::new();
548        local.spawn_local(async {
549            let mut chain_tasks = Vec::new();
550            for i in 0..500 {
551                let handle = tokio::task::spawn_local(run_value_chain(i));
552                chain_tasks.push(handle);
553            }
554
555            let results = futures::future::join_all(chain_tasks).await;
556
557            for (i, result) in results.into_iter().enumerate() {
558                assert_eq!(result.unwrap(), i as i32);
559            }
560        });
561        local.await;
562    }
563
564    #[tokio::test]
565    #[should_panic(expected = "Context type mismatch")]
566    async fn test_wrong_context_type() {
567        #[derive(Debug)]
568        struct Context1 {
569            value: i32,
570        }
571
572        impl Display for Context1 {
573            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574                write!(f, "{}", self.value)
575            }
576        }
577
578        #[derive(Debug)]
579        struct Context2;
580
581        async fn access_wrong_type() {
582            // Try to access Context2 when Context1 is active
583            from_context(|ctx: Option<&Context2>| {
584                let _ = ctx.unwrap();
585            });
586        }
587
588        let ctx = Context1 { value: 42 };
589        let context = with_async_context(ctx, access_wrong_type());
590        let _ = context.await;
591    }
592}