scuffle_context/
ext.rs

1use std::future::{Future, IntoFuture};
2use std::pin::Pin;
3use std::task::Poll;
4
5use futures_lite::Stream;
6use tokio_util::sync::{WaitForCancellationFuture, WaitForCancellationFutureOwned};
7
8use crate::{Context, ContextTracker};
9
10/// A reference to a context which implements [`Future`] and can be polled.
11/// Can either be owned or borrowed.
12///
13/// Create by using the [`From`] implementations.
14pub struct ContextRef<'a> {
15    inner: ContextRefInner<'a>,
16}
17
18impl From<Context> for ContextRef<'_> {
19    fn from(ctx: Context) -> Self {
20        ContextRef {
21            inner: ContextRefInner::Owned {
22                fut: ctx.token.cancelled_owned(),
23                tracker: ctx.tracker,
24            },
25        }
26    }
27}
28
29impl<'a> From<&'a Context> for ContextRef<'a> {
30    fn from(ctx: &'a Context) -> Self {
31        ContextRef {
32            inner: ContextRefInner::Ref {
33                fut: ctx.token.cancelled(),
34            },
35        }
36    }
37}
38
39pin_project_lite::pin_project! {
40    #[project = ContextRefInnerProj]
41    enum ContextRefInner<'a> {
42        Owned {
43            #[pin] fut: WaitForCancellationFutureOwned,
44            tracker: ContextTracker,
45        },
46        Ref {
47            #[pin] fut: WaitForCancellationFuture<'a>,
48        },
49    }
50}
51
52impl std::future::Future for ContextRefInner<'_> {
53    type Output = ();
54
55    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
56        match self.project() {
57            ContextRefInnerProj::Owned { fut, .. } => fut.poll(cx),
58            ContextRefInnerProj::Ref { fut } => fut.poll(cx),
59        }
60    }
61}
62
63pin_project_lite::pin_project! {
64    /// A future with a context attached to it.
65    ///
66    /// This future will be cancelled when the context is done.
67    pub struct FutureWithContext<'a, F> {
68        #[pin]
69        future: F,
70        #[pin]
71        ctx: ContextRefInner<'a>,
72        _marker: std::marker::PhantomData<&'a ()>,
73    }
74}
75
76impl<F: Future> Future for FutureWithContext<'_, F> {
77    type Output = Option<F::Output>;
78
79    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
80        let this = self.project();
81
82        match (this.ctx.poll(cx), this.future.poll(cx)) {
83            (_, Poll::Ready(v)) => std::task::Poll::Ready(Some(v)),
84            (Poll::Ready(_), Poll::Pending) => std::task::Poll::Ready(None),
85            (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
86        }
87    }
88}
89
90pub trait ContextFutExt<Fut> {
91    /// Wraps a future with a context and cancels the future when the context is
92    /// done.
93    ///
94    /// # Example
95    ///
96    /// ```rust
97    /// # use scuffle_context::{Context, ContextFutExt};
98    /// # tokio_test::block_on(async {
99    /// let (ctx, handler) = Context::new();
100    ///
101    /// tokio::spawn(async {
102    ///    // Do some work
103    ///    tokio::time::sleep(std::time::Duration::from_secs(10)).await;
104    /// }.with_context(ctx));
105    ///
106    /// // Will stop the spawned task and cancel all associated futures.
107    /// handler.cancel();
108    /// # });
109    /// ```
110    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, Fut>
111    where
112        Self: Sized;
113}
114
115impl<F: IntoFuture> ContextFutExt<F::IntoFuture> for F {
116    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, F::IntoFuture>
117    where
118        F: IntoFuture,
119    {
120        FutureWithContext {
121            future: self.into_future(),
122            ctx: ctx.into().inner,
123            _marker: std::marker::PhantomData,
124        }
125    }
126}
127
128pin_project_lite::pin_project! {
129    /// A stream with a context attached to it.
130    ///
131    /// This stream will be cancelled when the context is done.
132    pub struct StreamWithContext<'a, F> {
133        #[pin]
134        stream: F,
135        #[pin]
136        ctx: ContextRefInner<'a>,
137        _marker: std::marker::PhantomData<&'a ()>,
138    }
139}
140
141impl<F: Stream> Stream for StreamWithContext<'_, F> {
142    type Item = F::Item;
143
144    fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
145        let this = self.project();
146
147        match (this.ctx.poll(cx), this.stream.poll_next(cx)) {
148            (Poll::Ready(_), _) => std::task::Poll::Ready(None),
149            (Poll::Pending, Poll::Ready(v)) => std::task::Poll::Ready(v),
150            (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
151        }
152    }
153
154    fn size_hint(&self) -> (usize, Option<usize>) {
155        self.stream.size_hint()
156    }
157}
158
159pub trait ContextStreamExt<Stream> {
160    /// Wraps a stream with a context and stops the stream when the context is
161    /// done.
162    ///
163    /// # Example
164    ///
165    /// ```rust
166    /// # use scuffle_context::{Context, ContextStreamExt};
167    /// # use futures_lite as futures;
168    /// # use futures_lite::StreamExt;
169    /// # tokio_test::block_on(async {
170    /// let (ctx, handler) = Context::new();
171    ///
172    /// tokio::spawn(async {
173    ///     futures::stream::iter(1..=10).then(|d| async move {
174    ///         // Do some work
175    ///         tokio::time::sleep(std::time::Duration::from_secs(d)).await;
176    ///     }).with_context(ctx);
177    /// });
178    ///
179    /// // Will stop the spawned task and cancel all associated streams.
180    /// handler.cancel();
181    /// # });
182    /// ```
183    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, Stream>
184    where
185        Self: Sized;
186}
187
188impl<F: Stream> ContextStreamExt<F> for F {
189    fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, F> {
190        StreamWithContext {
191            stream: self,
192            ctx: ctx.into().inner,
193            _marker: std::marker::PhantomData,
194        }
195    }
196}
197
198#[cfg_attr(all(coverage_nightly, test), coverage(off))]
199#[cfg(test)]
200mod tests {
201    use std::pin::pin;
202
203    use futures_lite::{Stream, StreamExt};
204    use scuffle_future_ext::FutureExt;
205
206    use super::{Context, ContextFutExt, ContextStreamExt};
207
208    #[tokio::test]
209    async fn future() {
210        let (ctx, handler) = Context::new();
211
212        let task = tokio::spawn(
213            async {
214                // Do some work
215                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
216            }
217            .with_context(ctx),
218        );
219
220        // Sleep for a bit to make sure the future is polled at least once.
221        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
222
223        // Will stop the spawned task and cancel all associated futures.
224        handler.shutdown().await;
225
226        task.await.unwrap();
227    }
228
229    #[tokio::test]
230    async fn future_result() {
231        let (ctx, handler) = Context::new();
232
233        let task = tokio::spawn(async { 1 }.with_context(ctx));
234
235        // Will stop the spawned task and cancel all associated futures.
236        handler.shutdown().await;
237
238        assert_eq!(task.await.unwrap(), Some(1));
239    }
240
241    #[tokio::test]
242    async fn future_ctx_by_ref() {
243        let (ctx, handler) = Context::new();
244
245        let task = tokio::spawn(async move {
246            async {
247                // Do some work
248                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
249            }
250            .with_context(&ctx)
251            .await;
252
253            drop(ctx);
254        });
255
256        // Will stop the spawned task and cancel all associated futures.
257        handler.shutdown().await;
258
259        task.await.unwrap();
260    }
261
262    #[tokio::test]
263    async fn stream() {
264        let (ctx, handler) = Context::new();
265
266        {
267            let mut stream = pin!(futures_lite::stream::iter(0..10).with_context(ctx));
268
269            assert_eq!(stream.size_hint(), (10, Some(10)));
270
271            assert_eq!(stream.next().await, Some(0));
272            assert_eq!(stream.next().await, Some(1));
273            assert_eq!(stream.next().await, Some(2));
274            assert_eq!(stream.next().await, Some(3));
275
276            // Will stop the spawned task and cancel all associated streams.
277            handler.cancel();
278
279            assert_eq!(stream.next().await, None);
280        }
281
282        handler.shutdown().await;
283    }
284
285    #[tokio::test]
286    async fn pending_stream() {
287        let (ctx, handler) = Context::new();
288
289        {
290            let mut stream = pin!(futures_lite::stream::pending::<()>().with_context(ctx));
291
292            // This is expected to timeout
293            assert!(
294                stream
295                    .next()
296                    .with_timeout(std::time::Duration::from_millis(200))
297                    .await
298                    .is_err()
299            );
300        }
301
302        handler.shutdown().await;
303    }
304}