1use std::future::{Future, IntoFuture};
2use std::pin::Pin;
3use std::task::Poll;
4
5use futures_lite::Stream;
6use tokio_util::sync::{WaitForCancellationFuture, WaitForCancellationFutureOwned};
7
8use crate::{Context, ContextTracker};
9
10pub struct ContextRef<'a> {
15 inner: ContextRefInner<'a>,
16}
17
18impl From<Context> for ContextRef<'_> {
19 fn from(ctx: Context) -> Self {
20 ContextRef {
21 inner: ContextRefInner::Owned {
22 fut: ctx.token.cancelled_owned(),
23 tracker: ctx.tracker,
24 },
25 }
26 }
27}
28
29impl<'a> From<&'a Context> for ContextRef<'a> {
30 fn from(ctx: &'a Context) -> Self {
31 ContextRef {
32 inner: ContextRefInner::Ref {
33 fut: ctx.token.cancelled(),
34 },
35 }
36 }
37}
38
39pin_project_lite::pin_project! {
40 #[project = ContextRefInnerProj]
41 enum ContextRefInner<'a> {
42 Owned {
43 #[pin] fut: WaitForCancellationFutureOwned,
44 tracker: ContextTracker,
45 },
46 Ref {
47 #[pin] fut: WaitForCancellationFuture<'a>,
48 },
49 }
50}
51
52impl std::future::Future for ContextRefInner<'_> {
53 type Output = ();
54
55 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
56 match self.project() {
57 ContextRefInnerProj::Owned { fut, .. } => fut.poll(cx),
58 ContextRefInnerProj::Ref { fut } => fut.poll(cx),
59 }
60 }
61}
62
63pin_project_lite::pin_project! {
64 pub struct FutureWithContext<'a, F> {
68 #[pin]
69 future: F,
70 #[pin]
71 ctx: ContextRefInner<'a>,
72 _marker: std::marker::PhantomData<&'a ()>,
73 }
74}
75
76impl<F: Future> Future for FutureWithContext<'_, F> {
77 type Output = Option<F::Output>;
78
79 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
80 let this = self.project();
81
82 match (this.ctx.poll(cx), this.future.poll(cx)) {
83 (_, Poll::Ready(v)) => std::task::Poll::Ready(Some(v)),
84 (Poll::Ready(_), Poll::Pending) => std::task::Poll::Ready(None),
85 (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
86 }
87 }
88}
89
90pub trait ContextFutExt<Fut> {
91 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, Fut>
111 where
112 Self: Sized;
113}
114
115impl<F: IntoFuture> ContextFutExt<F::IntoFuture> for F {
116 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, F::IntoFuture>
117 where
118 F: IntoFuture,
119 {
120 FutureWithContext {
121 future: self.into_future(),
122 ctx: ctx.into().inner,
123 _marker: std::marker::PhantomData,
124 }
125 }
126}
127
128pin_project_lite::pin_project! {
129 pub struct StreamWithContext<'a, F> {
133 #[pin]
134 stream: F,
135 #[pin]
136 ctx: ContextRefInner<'a>,
137 _marker: std::marker::PhantomData<&'a ()>,
138 }
139}
140
141impl<F: Stream> Stream for StreamWithContext<'_, F> {
142 type Item = F::Item;
143
144 fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
145 let this = self.project();
146
147 match (this.ctx.poll(cx), this.stream.poll_next(cx)) {
148 (Poll::Ready(_), _) => std::task::Poll::Ready(None),
149 (Poll::Pending, Poll::Ready(v)) => std::task::Poll::Ready(v),
150 (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
151 }
152 }
153
154 fn size_hint(&self) -> (usize, Option<usize>) {
155 self.stream.size_hint()
156 }
157}
158
159pub trait ContextStreamExt<Stream> {
160 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, Stream>
184 where
185 Self: Sized;
186}
187
188impl<F: Stream> ContextStreamExt<F> for F {
189 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, F> {
190 StreamWithContext {
191 stream: self,
192 ctx: ctx.into().inner,
193 _marker: std::marker::PhantomData,
194 }
195 }
196}
197
198#[cfg_attr(all(coverage_nightly, test), coverage(off))]
199#[cfg(test)]
200mod tests {
201 use std::pin::pin;
202
203 use futures_lite::{Stream, StreamExt};
204 use scuffle_future_ext::FutureExt;
205
206 use super::{Context, ContextFutExt, ContextStreamExt};
207
208 #[tokio::test]
209 async fn future() {
210 let (ctx, handler) = Context::new();
211
212 let task = tokio::spawn(
213 async {
214 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
216 }
217 .with_context(ctx),
218 );
219
220 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
222
223 handler.shutdown().await;
225
226 task.await.unwrap();
227 }
228
229 #[tokio::test]
230 async fn future_result() {
231 let (ctx, handler) = Context::new();
232
233 let task = tokio::spawn(async { 1 }.with_context(ctx));
234
235 handler.shutdown().await;
237
238 assert_eq!(task.await.unwrap(), Some(1));
239 }
240
241 #[tokio::test]
242 async fn future_ctx_by_ref() {
243 let (ctx, handler) = Context::new();
244
245 let task = tokio::spawn(async move {
246 async {
247 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
249 }
250 .with_context(&ctx)
251 .await;
252
253 drop(ctx);
254 });
255
256 handler.shutdown().await;
258
259 task.await.unwrap();
260 }
261
262 #[tokio::test]
263 async fn stream() {
264 let (ctx, handler) = Context::new();
265
266 {
267 let mut stream = pin!(futures_lite::stream::iter(0..10).with_context(ctx));
268
269 assert_eq!(stream.size_hint(), (10, Some(10)));
270
271 assert_eq!(stream.next().await, Some(0));
272 assert_eq!(stream.next().await, Some(1));
273 assert_eq!(stream.next().await, Some(2));
274 assert_eq!(stream.next().await, Some(3));
275
276 handler.cancel();
278
279 assert_eq!(stream.next().await, None);
280 }
281
282 handler.shutdown().await;
283 }
284
285 #[tokio::test]
286 async fn pending_stream() {
287 let (ctx, handler) = Context::new();
288
289 {
290 let mut stream = pin!(futures_lite::stream::pending::<()>().with_context(ctx));
291
292 assert!(
294 stream
295 .next()
296 .with_timeout(std::time::Duration::from_millis(200))
297 .await
298 .is_err()
299 );
300 }
301
302 handler.shutdown().await;
303 }
304}