scoped_stream_sink/
scoped_stream_sink.rs

1use alloc::boxed::Box;
2use core::cell::{Cell, UnsafeCell};
3use core::convert::Infallible;
4use core::future::Future;
5use core::marker::{PhantomData, PhantomPinned};
6use core::mem::transmute;
7use core::ops::DerefMut;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10
11use futures_core::Stream;
12use futures_sink::Sink;
13use pin_project_lite::pin_project;
14
15#[cfg(feature = "std")]
16use crate::LocalThread;
17use crate::{State, StreamSink};
18
19#[cfg(feature = "std")]
20pin_project! {
21    /// Scoped version of [`StreamSink`]. Makes building [`StreamSink`] much easier to do.
22    #[must_use = "StreamSink will not do anything if not used"]
23    pub struct ScopedStreamSink<'env, SI, RI, E> {
24        fut: Option<Pin<Box<dyn Future<Output = Result<(), E>> + Send + 'env>>>,
25
26        data: Pin<Box<StreamSinkInner<'env, 'env, SI, RI, E>>>,
27    }
28}
29
30struct StreamSinkInnerData<SI, RI, E> {
31    send: UnsafeCell<Option<Result<SI, E>>>,
32    recv: UnsafeCell<Option<RI>>,
33    close_send: Cell<bool>,
34    close_recv: Cell<bool>,
35
36    // Borrow technique from Tokio to pass pesky Miri :table-flip:
37    // <https://github.com/rust-lang/rust/pull/82834>
38    _pinned: PhantomPinned,
39}
40
41// SAFETY: We don't ever use immutable borrow for any of the operations, so it's automatically Sync too.
42// Similar to unstable Exclusive struct.
43unsafe impl<SI: Send, RI: Send, E: Send> Send for StreamSinkInnerData<SI, RI, E> {}
44unsafe impl<SI: Send, RI: Send, E: Send> Sync for StreamSinkInnerData<SI, RI, E> {}
45
46impl<SI, RI, E> StreamSinkInnerData<SI, RI, E> {
47    const fn new() -> Self {
48        StreamSinkInnerData {
49            send: UnsafeCell::new(None),
50            recv: UnsafeCell::new(None),
51            close_send: Cell::new(false),
52            close_recv: Cell::new(false),
53            _pinned: PhantomPinned,
54        }
55    }
56}
57
58#[cfg(feature = "std")]
59pin_project! {
60    struct StreamSinkInner<'scope, 'env: 'scope, SI, RI, E> {
61        #[pin]
62        inner: LocalThread<StreamSinkInnerData<SI, RI, E>>,
63
64        phantom: PhantomData<&'scope mut &'env (SI, RI, E)>,
65    }
66}
67
68#[cfg(feature = "std")]
69pin_project! {
70    /// [`Stream`] half of inner [`ScopedStreamSink`].
71    /// Produce receive type values.
72    /// Can only be closed from it's outer [`ScopedStreamSink`].
73    ///
74    /// # Note About Thread-safety
75    ///
76    /// Even though [`StreamPart`] is both [`Send`] and [`Sink`], it's reference
77    /// **should** not be sent across thread. This is currently impossible, due to
78    /// lack of async version of [`scope`](std::thread::scope).
79    /// To future-proof that possibility, any usage of it will panic if called from different
80    /// thread than the outer thread. It also may panics outer thread too.
81    ///
82    /// Also do note that some of the check depends on `debug_assertions` build config
83    /// (AKA only on debug builds).
84    #[must_use = "Stream will not do anything if not used"]
85    pub struct StreamPart<'scope, 'env: 'scope, SI, RI, E> {
86        ptr: Pin<&'scope mut StreamSinkInner<'scope, 'env, SI, RI, E>>,
87    }
88}
89
90#[cfg(feature = "std")]
91pin_project! {
92    /// [`Sink`] half of inner [`ScopedStreamSink`].
93    /// Can receive both send type or a [`Result`] type.
94    /// Closing will complete when outer [`ScopedStreamSink`] is closed and received all data.
95    ///
96    /// # Note About Thread-safety
97    ///
98    /// Even though [`SinkPart`] is both [`Send`] and [`Sink`], it's reference
99    /// **should** not be sent across thread. This is currently impossible, due to
100    /// lack of async version of [`scope`](std::thread::scope).
101    /// To future-proof that possibility, any usage of it will panic if called from different
102    /// thread than the outer thread. It also may panics outer thread too.
103    ///
104    /// Also do note that some of the check depends on `debug_assertions` build config
105    /// (AKA only on debug builds).
106    #[must_use = "Sink will not do anything if not used"]
107    pub struct SinkPart<'scope, 'env: 'scope, SI, RI, E> {
108        ptr: Pin<&'scope mut StreamSinkInner<'scope, 'env, SI, RI, E>>,
109    }
110}
111
112#[cfg(feature = "std")]
113impl<'env, SI, RI, E> ScopedStreamSink<'env, SI, RI, E> {
114    /// Creates new [`ScopedStreamSink`].
115    /// Safety is guaranteed by scoping both [`StreamPart`] and [`SinkPart`].
116    pub fn new<F>(f: F) -> Self
117    where
118        for<'scope> F: FnOnce(
119            StreamPart<'scope, 'env, SI, RI, E>,
120            SinkPart<'scope, 'env, SI, RI, E>,
121        )
122            -> Pin<Box<dyn Future<Output = Result<(), E>> + Send + 'scope>>,
123    {
124        let mut data = Box::pin(StreamSinkInner {
125            inner: LocalThread::new(StreamSinkInnerData::new()),
126
127            phantom: PhantomData,
128        });
129
130        let (stream, sink);
131        // SAFETY: Borrow is scoped, so it can't get out of scope.
132        // Also, StreamPart and SinkPart write access is separated.
133        unsafe {
134            stream = StreamPart {
135                ptr: transmute::<Pin<&mut StreamSinkInner<SI, RI, E>>, _>(data.as_mut()),
136            };
137            sink = SinkPart {
138                ptr: transmute::<Pin<&mut StreamSinkInner<SI, RI, E>>, _>(data.as_mut()),
139            };
140        }
141        let fut = f(stream, sink);
142
143        Self {
144            fut: Some(fut),
145            data,
146        }
147    }
148}
149
150impl<SI, RI, E> StreamSinkInnerData<SI, RI, E> {
151    fn stream_sink<F>(&self, cx: &mut Context<'_>, fut: &mut Option<Pin<F>>) -> State<SI, E>
152    where
153        F: DerefMut,
154        F::Target: Future<Output = Result<(), E>>,
155    {
156        let ret = match fut {
157            Some(f) => f.as_mut().poll(cx),
158            None => Poll::Ready(Ok(())),
159        };
160
161        if let Poll::Ready(v) = ret {
162            *fut = None;
163            self.close_send.set(true);
164            self.close_recv.set(true);
165
166            if let Err(e) = v {
167                return State::Error(e);
168            }
169        }
170
171        match unsafe {
172            (
173                (*self.send.get()).take(),
174                !self.close_recv.get() && (*self.recv.get()).is_none(),
175            )
176        } {
177            (Some(Err(e)), _) => State::Error(e),
178            (Some(Ok(i)), true) => State::SendRecvReady(i),
179            (Some(Ok(i)), false) => State::SendReady(i),
180            (None, _) if fut.is_none() => State::End,
181            (None, true) => State::RecvReady,
182            (None, false) => State::Pending,
183        }
184    }
185
186    fn send_outer(&self, item: RI) {
187        if self.close_recv.get() {
188            panic!("ScopedStreamSink is closed!");
189        }
190        let recv = unsafe { &mut *self.recv.get() };
191        if recv.is_some() {
192            panic!("ScopedStreamSink is not ready to receive!");
193        }
194
195        *recv = Some(item);
196    }
197
198    fn close_outer<F>(
199        &self,
200        cx: &mut Context<'_>,
201        fut: &mut Option<Pin<F>>,
202    ) -> Poll<Result<Option<SI>, E>>
203    where
204        F: DerefMut,
205        F::Target: Future<Output = Result<(), E>>,
206    {
207        self.close_recv.set(true);
208        let ret = match fut {
209            Some(f) => f.as_mut().poll(cx),
210            None => Poll::Ready(Ok(())),
211        };
212
213        if let Poll::Ready(v) = ret {
214            *fut = None;
215
216            if let Err(e) = v {
217                return Poll::Ready(Err(e));
218            }
219        }
220
221        let ret = unsafe { (*self.send.get()).take() };
222        if ret.is_none() && fut.is_some() {
223            Poll::Pending
224        } else {
225            Poll::Ready(ret.transpose())
226        }
227    }
228
229    fn next(&self) -> Poll<Option<RI>> {
230        match unsafe { (*self.recv.get()).take() } {
231            v @ Some(_) => Poll::Ready(v),
232            None if self.close_recv.get() => Poll::Ready(None),
233            None => Poll::Pending,
234        }
235    }
236
237    fn flush<E2>(&self) -> Poll<Result<(), E2>> {
238        if !self.close_send.get() && unsafe { (*self.send.get()).is_none() } {
239            Poll::Ready(Ok(()))
240        } else {
241            Poll::Pending
242        }
243    }
244
245    fn send_inner(&self, item: Result<SI, E>) {
246        if self.close_send.get() {
247            panic!("ScopedStreamSink is closed!");
248        }
249        let send = unsafe { &mut *self.send.get() };
250        if send.is_some() {
251            panic!("poll_ready() is not called first!");
252        }
253
254        *send = Some(item);
255    }
256
257    fn close_inner<E2>(&self) -> Poll<Result<(), E2>> {
258        self.close_send.set(true);
259        if unsafe { (*self.send.get()).is_none() } {
260            Poll::Ready(Ok(()))
261        } else {
262            Poll::Pending
263        }
264    }
265}
266
267#[cfg(feature = "std")]
268impl<'env, SI, RI, E> StreamSink<SI, RI> for ScopedStreamSink<'env, SI, RI, E> {
269    type Error = E;
270
271    fn poll_stream_sink(self: Pin<&mut Self>, cx: &mut Context<'_>) -> State<SI, Self::Error> {
272        let this = self.project();
273        this.data.inner.set_inner_ctx().stream_sink(cx, this.fut)
274    }
275
276    fn start_send(self: Pin<&mut Self>, item: RI) -> Result<(), Self::Error> {
277        self.data.inner.set_inner_ctx().send_outer(item);
278        Ok(())
279    }
280
281    fn poll_close(
282        self: Pin<&mut Self>,
283        cx: &mut Context<'_>,
284    ) -> Poll<Result<Option<SI>, Self::Error>> {
285        let this = self.project();
286        this.data.inner.set_inner_ctx().close_outer(cx, this.fut)
287    }
288}
289
290#[cfg(feature = "std")]
291impl<'scope, 'env, SI, RI, E> Stream for StreamPart<'scope, 'env, SI, RI, E> {
292    type Item = RI;
293
294    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
295        self.into_ref().ptr.inner.get_inner().next()
296    }
297}
298
299#[cfg(feature = "std")]
300impl<'scope, 'env, SI, RI, E> Sink<Result<SI, E>> for SinkPart<'scope, 'env, SI, RI, E> {
301    type Error = Infallible;
302
303    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
304        <Self as Sink<Result<SI, E>>>::poll_flush(self, cx)
305    }
306
307    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
308        self.into_ref().ptr.inner.get_inner().flush()
309    }
310
311    fn start_send(self: Pin<&mut Self>, item: Result<SI, E>) -> Result<(), Self::Error> {
312        self.into_ref().ptr.inner.get_inner().send_inner(item);
313        Ok(())
314    }
315
316    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
317        self.into_ref().ptr.inner.get_inner().close_inner()
318    }
319}
320
321#[cfg(feature = "std")]
322impl<'scope, 'env, SI, RI, E> Sink<SI> for SinkPart<'scope, 'env, SI, RI, E> {
323    type Error = Infallible;
324
325    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
326        <Self as Sink<Result<SI, E>>>::poll_flush(self, cx)
327    }
328
329    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330        <Self as Sink<Result<SI, E>>>::poll_flush(self, cx)
331    }
332
333    fn start_send(self: Pin<&mut Self>, item: SI) -> Result<(), Self::Error> {
334        <Self as Sink<Result<SI, E>>>::start_send(self, Ok(item))
335    }
336
337    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
338        <Self as Sink<Result<SI, E>>>::poll_close(self, cx)
339    }
340}
341
342pin_project! {
343    /// Locally scoped version of [`StreamSink`]. Does not implement [`Send`].
344    #[must_use = "StreamSink will not do anything if not used"]
345    pub struct LocalScopedStreamSink<'env, SI, RI, E> {
346        fut: Option<Pin<Box<dyn Future<Output = Result<(), E>> + 'env>>>,
347
348        data: Pin<Box<LocalStreamSinkInner<'env, 'env, SI, RI, E>>>,
349    }
350}
351
352pin_project! {
353    struct LocalStreamSinkInner<'scope, 'env: 'scope, SI, RI, E> {
354        inner: StreamSinkInnerData<SI, RI, E>,
355
356        #[pin]
357        pinned: PhantomPinned,
358        phantom: PhantomData<(&'scope mut &'env (SI, RI, E), *mut u8)>,
359    }
360}
361
362pin_project! {
363    /// [`Stream`] half of inner [`LocalScopedStreamSink`].
364    /// Produce receive type values.
365    /// Can only be closed from it's outer [`LocalScopedStreamSink`].
366    #[must_use = "Stream will not do anything if not used"]
367    pub struct LocalStreamPart<'scope, 'env: 'scope, SI, RI, E> {
368        ptr: Pin<&'scope mut LocalStreamSinkInner<'scope, 'env, SI, RI, E>>,
369    }
370}
371
372pin_project! {
373    /// [`Sink`] half of inner [`LocalScopedStreamSink`].
374    /// Can receive both send type or a [`Result`] type.
375    /// Closing will complete when outer [`LocalScopedStreamSink`] is closed and received all data.
376    #[must_use = "Sink will not do anything if not used"]
377    pub struct LocalSinkPart<'scope, 'env: 'scope, SI, RI, E> {
378        ptr: Pin<&'scope mut LocalStreamSinkInner<'scope, 'env, SI, RI, E>>,
379    }
380}
381
382impl<'env, SI, RI, E> LocalScopedStreamSink<'env, SI, RI, E> {
383    /// Creates new [`LocalScopedStreamSink`].
384    /// Safety is guaranteed by scoping both [`LocalStreamPart`] and [`LocalSinkPart`].
385    pub fn new<F>(f: F) -> Self
386    where
387        for<'scope> F: FnOnce(
388            LocalStreamPart<'scope, 'env, SI, RI, E>,
389            LocalSinkPart<'scope, 'env, SI, RI, E>,
390        ) -> Pin<Box<dyn Future<Output = Result<(), E>> + 'scope>>,
391    {
392        let mut data = Box::pin(LocalStreamSinkInner {
393            inner: StreamSinkInnerData::new(),
394
395            pinned: PhantomPinned,
396            phantom: PhantomData,
397        });
398
399        let (stream, sink);
400        // SAFETY: Borrow is scoped, so it can't get out of scope.
401        // Also, StreamPart and SinkPart write access is separated.
402        unsafe {
403            stream = LocalStreamPart {
404                ptr: transmute::<Pin<&mut LocalStreamSinkInner<SI, RI, E>>, _>(data.as_mut()),
405            };
406            sink = LocalSinkPart {
407                ptr: transmute::<Pin<&mut LocalStreamSinkInner<SI, RI, E>>, _>(data.as_mut()),
408            };
409        }
410        let fut = f(stream, sink);
411
412        Self {
413            fut: Some(fut),
414            data,
415        }
416    }
417}
418
419impl<'env, SI, RI, E> StreamSink<SI, RI> for LocalScopedStreamSink<'env, SI, RI, E> {
420    type Error = E;
421
422    fn poll_stream_sink(self: Pin<&mut Self>, cx: &mut Context<'_>) -> State<SI, Self::Error> {
423        let this = self.project();
424        this.data.as_mut().project().inner.stream_sink(cx, this.fut)
425    }
426
427    fn start_send(self: Pin<&mut Self>, item: RI) -> Result<(), Self::Error> {
428        self.project()
429            .data
430            .as_mut()
431            .project()
432            .inner
433            .send_outer(item);
434        Ok(())
435    }
436
437    fn poll_close(
438        self: Pin<&mut Self>,
439        cx: &mut Context<'_>,
440    ) -> Poll<Result<Option<SI>, Self::Error>> {
441        let this = self.project();
442        this.data.as_mut().project().inner.close_outer(cx, this.fut)
443    }
444}
445
446impl<'scope, 'env, SI, RI, E> Stream for LocalStreamPart<'scope, 'env, SI, RI, E> {
447    type Item = RI;
448
449    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
450        self.project().ptr.as_mut().project().inner.next()
451    }
452}
453
454impl<'scope, 'env, SI, RI, E> Sink<Result<SI, E>> for LocalSinkPart<'scope, 'env, SI, RI, E> {
455    type Error = Infallible;
456
457    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
458        <Self as Sink<Result<SI, E>>>::poll_flush(self, cx)
459    }
460
461    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
462        self.project().ptr.as_mut().project().inner.flush()
463    }
464
465    fn start_send(self: Pin<&mut Self>, item: Result<SI, E>) -> Result<(), Self::Error> {
466        self.project().ptr.as_mut().project().inner.send_inner(item);
467        Ok(())
468    }
469
470    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
471        self.project().ptr.as_mut().project().inner.close_inner()
472    }
473}
474
475impl<'scope, 'env, SI, RI, E> Sink<SI> for LocalSinkPart<'scope, 'env, SI, RI, E> {
476    type Error = Infallible;
477
478    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
479        <Self as Sink<Result<SI, E>>>::poll_flush(self, cx)
480    }
481
482    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
483        <Self as Sink<Result<SI, E>>>::poll_flush(self, cx)
484    }
485
486    fn start_send(self: Pin<&mut Self>, item: SI) -> Result<(), Self::Error> {
487        <Self as Sink<Result<SI, E>>>::start_send(self, Ok(item))
488    }
489
490    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
491        <Self as Sink<Result<SI, E>>>::poll_close(self, cx)
492    }
493}