scoped_stream_sink/
stream_sink_ext.rs

1use core::future::Future;
2use core::marker::PhantomData;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6use futures_core::{FusedStream, Stream};
7use pin_project_lite::pin_project;
8
9use crate::{State, StreamSink};
10
11pin_project! {
12    /// Return type of [`StreamSinkExt::map_send()`].
13    pub struct MapSend<T, F, I> {
14        #[pin]
15        t: T,
16        f: F,
17        phantom: PhantomData<I>,
18    }
19}
20
21impl<T, F, SI, RI, I> StreamSink<SI, RI> for MapSend<T, F, I>
22where
23    T: StreamSink<I, RI>,
24    F: FnMut(I) -> SI,
25{
26    type Error = T::Error;
27
28    fn poll_stream_sink(self: Pin<&mut Self>, cx: &mut Context<'_>) -> State<SI, Self::Error> {
29        let zelf = self.project();
30        let f = zelf.f;
31        match zelf.t.poll_stream_sink(cx) {
32            State::Error(e) => State::Error(e),
33            State::Pending => State::Pending,
34            State::End => State::End,
35            State::RecvReady => State::RecvReady,
36            State::SendReady(i) => State::SendReady(f(i)),
37            State::SendRecvReady(i) => State::SendRecvReady(f(i)),
38        }
39    }
40
41    fn start_send(self: Pin<&mut Self>, item: RI) -> Result<(), Self::Error> {
42        self.project().t.start_send(item)
43    }
44
45    fn poll_close(
46        self: Pin<&mut Self>,
47        cx: &mut Context<'_>,
48    ) -> Poll<Result<Option<SI>, Self::Error>> {
49        let zelf = self.project();
50        let f = zelf.f;
51        match zelf.t.poll_close(cx) {
52            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
53            Poll::Ready(Ok(Some(i))) => Poll::Ready(Ok(Some(f(i)))),
54            Poll::Ready(Ok(None)) => Poll::Ready(Ok(None)),
55            Poll::Pending => Poll::Pending,
56        }
57    }
58}
59
60pin_project! {
61    /// Return type of [`StreamSinkExt::map_recv()`].
62    pub struct MapRecv<T, F, I> {
63        #[pin]
64        t: T,
65        f: F,
66        phantom: PhantomData<I>,
67    }
68}
69
70impl<T, F, SI, RI, I> StreamSink<SI, RI> for MapRecv<T, F, I>
71where
72    T: StreamSink<SI, I>,
73    F: FnMut(RI) -> I,
74{
75    type Error = T::Error;
76
77    fn poll_stream_sink(self: Pin<&mut Self>, cx: &mut Context<'_>) -> State<SI, Self::Error> {
78        self.project().t.poll_stream_sink(cx)
79    }
80
81    fn start_send(self: Pin<&mut Self>, item: RI) -> Result<(), Self::Error> {
82        let zelf = self.project();
83        let f = zelf.f;
84        zelf.t.start_send(f(item))
85    }
86
87    fn poll_close(
88        self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90    ) -> Poll<Result<Option<SI>, Self::Error>> {
91        self.project().t.poll_close(cx)
92    }
93}
94
95pin_project! {
96    /// Return type of [`StreamSinkExt::map_error()`].
97    pub struct MapError<T, F> {
98        #[pin]
99        t: T,
100        f: F,
101    }
102}
103
104impl<T, F, SI, RI, E> StreamSink<SI, RI> for MapError<T, F>
105where
106    T: StreamSink<SI, RI>,
107    F: FnMut(T::Error) -> E,
108{
109    type Error = E;
110
111    fn poll_stream_sink(self: Pin<&mut Self>, cx: &mut Context<'_>) -> State<SI, Self::Error> {
112        let zelf = self.project();
113        let f = zelf.f;
114        match zelf.t.poll_stream_sink(cx) {
115            State::Error(e) => State::Error(f(e)),
116            State::Pending => State::Pending,
117            State::End => State::End,
118            State::RecvReady => State::RecvReady,
119            State::SendReady(i) => State::SendReady(i),
120            State::SendRecvReady(i) => State::SendRecvReady(i),
121        }
122    }
123
124    fn start_send(self: Pin<&mut Self>, item: RI) -> Result<(), Self::Error> {
125        let zelf = self.project();
126        let f = zelf.f;
127        match zelf.t.start_send(item) {
128            Ok(v) => Ok(v),
129            Err(e) => Err(f(e)),
130        }
131    }
132
133    fn poll_close(
134        self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136    ) -> Poll<Result<Option<SI>, Self::Error>> {
137        let zelf = self.project();
138        let f = zelf.f;
139        match zelf.t.poll_close(cx) {
140            Poll::Ready(Err(e)) => Poll::Ready(Err(f(e))),
141            Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)),
142            Poll::Pending => Poll::Pending,
143        }
144    }
145}
146
147pin_project! {
148    /// Return type of [`StreamSinkExt::error_cast()`].
149    pub struct ErrorCast<T, E> {
150        #[pin]
151        t: T,
152        phantom: PhantomData<E>,
153    }
154}
155
156impl<T, SI, RI, E> StreamSink<SI, RI> for ErrorCast<T, E>
157where
158    T: StreamSink<SI, RI>,
159    T::Error: Into<E>,
160{
161    type Error = E;
162
163    fn poll_stream_sink(self: Pin<&mut Self>, cx: &mut Context<'_>) -> State<SI, Self::Error> {
164        match self.project().t.poll_stream_sink(cx) {
165            State::Error(e) => State::Error(e.into()),
166            State::Pending => State::Pending,
167            State::End => State::End,
168            State::RecvReady => State::RecvReady,
169            State::SendReady(i) => State::SendReady(i),
170            State::SendRecvReady(i) => State::SendRecvReady(i),
171        }
172    }
173
174    fn start_send(self: Pin<&mut Self>, item: RI) -> Result<(), Self::Error> {
175        match self.project().t.start_send(item) {
176            Ok(v) => Ok(v),
177            Err(e) => Err(e.into()),
178        }
179    }
180
181    fn poll_close(
182        self: Pin<&mut Self>,
183        cx: &mut Context<'_>,
184    ) -> Poll<Result<Option<SI>, Self::Error>> {
185        match self.project().t.poll_close(cx) {
186            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
187            Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)),
188            Poll::Pending => Poll::Pending,
189        }
190    }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
194enum ChainState {
195    Pending,
196    Ready,
197    UEnd,
198    VEnd,
199    Done,
200}
201
202pin_project! {
203    /// Return type of [`StreamSinkExt::chain()`].
204    pub struct Chain<U, V, II> {
205        #[pin]
206        u: U,
207        #[pin]
208        v: V,
209
210        state: ChainState,
211        phantom: PhantomData<II>,
212    }
213}
214
215impl<U, V, SI, II, RI> StreamSink<SI, RI> for Chain<U, V, II>
216where
217    U: StreamSink<II, RI>,
218    V: StreamSink<SI, II, Error = U::Error>,
219{
220    type Error = U::Error;
221
222    fn poll_stream_sink(self: Pin<&mut Self>, cx: &mut Context<'_>) -> State<SI, Self::Error> {
223        let mut zelf = self.project();
224
225        loop {
226            match *zelf.state {
227                ChainState::Pending => match zelf.v.as_mut().poll_stream_sink(&mut *cx) {
228                    State::Pending => return State::Pending,
229                    State::Error(e) => return State::Error(e),
230                    State::End => *zelf.state = ChainState::VEnd,
231                    State::SendReady(i) => return State::SendReady(i),
232                    State::RecvReady => *zelf.state = ChainState::Ready,
233                    State::SendRecvReady(i) => {
234                        *zelf.state = ChainState::Ready;
235                        return State::SendReady(i);
236                    }
237                },
238                ChainState::Ready => match zelf.u.as_mut().poll_stream_sink(&mut *cx) {
239                    State::Pending => return State::Pending,
240                    State::Error(e) => return State::Error(e),
241                    State::End => *zelf.state = ChainState::UEnd,
242                    State::SendReady(i) => match zelf.v.as_mut().start_send(i) {
243                        Err(e) => return State::Error(e),
244                        Ok(_) => *zelf.state = ChainState::Pending,
245                    },
246                    State::RecvReady => return State::RecvReady,
247                    State::SendRecvReady(i) => match zelf.v.as_mut().start_send(i) {
248                        Err(e) => return State::Error(e),
249                        Ok(_) => {
250                            *zelf.state = ChainState::Pending;
251                            return State::RecvReady;
252                        }
253                    },
254                },
255                ChainState::VEnd => match zelf.u.as_mut().poll_stream_sink(&mut *cx) {
256                    State::Pending => return State::Pending,
257                    State::Error(e) => return State::Error(e),
258                    State::End => *zelf.state = ChainState::Done,
259                    State::SendReady(_) => (),
260                    State::RecvReady | State::SendRecvReady(_) => return State::RecvReady,
261                },
262                ChainState::UEnd => match zelf.v.as_mut().poll_close(&mut *cx) {
263                    Poll::Pending => return State::Pending,
264                    Poll::Ready(Err(e)) => return State::Error(e),
265                    Poll::Ready(Ok(Some(_))) => (),
266                    Poll::Ready(Ok(None)) => *zelf.state = ChainState::Done,
267                },
268                ChainState::Done => return State::End,
269            }
270        }
271    }
272
273    fn start_send(self: Pin<&mut Self>, item: RI) -> Result<(), Self::Error> {
274        self.project().u.start_send(item)
275    }
276
277    fn poll_close(
278        self: Pin<&mut Self>,
279        cx: &mut Context<'_>,
280    ) -> Poll<Result<Option<SI>, Self::Error>> {
281        let mut zelf = self.project();
282
283        loop {
284            match *zelf.state {
285                ChainState::Pending => match zelf.v.as_mut().poll_stream_sink(&mut *cx) {
286                    State::Pending => return Poll::Pending,
287                    State::Error(e) => return Poll::Ready(Err(e)),
288                    State::End => *zelf.state = ChainState::VEnd,
289                    State::SendReady(i) => return Poll::Ready(Ok(Some(i))),
290                    State::RecvReady => *zelf.state = ChainState::Ready,
291                    State::SendRecvReady(i) => {
292                        *zelf.state = ChainState::Ready;
293                        return Poll::Ready(Ok(Some(i)));
294                    }
295                },
296                ChainState::Ready => match zelf.u.as_mut().poll_close(&mut *cx) {
297                    Poll::Pending => return Poll::Pending,
298                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
299                    Poll::Ready(Ok(Some(i))) => match zelf.v.as_mut().start_send(i) {
300                        Err(e) => return Poll::Ready(Err(e)),
301                        Ok(_) => *zelf.state = ChainState::Pending,
302                    },
303                    Poll::Ready(Ok(None)) => *zelf.state = ChainState::UEnd,
304                },
305                ChainState::VEnd => match zelf.u.as_mut().poll_close(&mut *cx) {
306                    Poll::Pending => return Poll::Pending,
307                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
308                    Poll::Ready(Ok(Some(_))) => (),
309                    Poll::Ready(Ok(None)) => *zelf.state = ChainState::Done,
310                },
311                ChainState::UEnd => match zelf.v.as_mut().poll_close(&mut *cx) {
312                    Poll::Pending => return Poll::Pending,
313                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
314                    Poll::Ready(Ok(Some(_))) => (),
315                    Poll::Ready(Ok(None)) => *zelf.state = ChainState::Done,
316                },
317                ChainState::Done => return Poll::Ready(Ok(None)),
318            }
319        }
320    }
321}
322
323pin_project! {
324    /// Return type of [`StreamSinkExt::send_one()`].
325    pub struct SendOne<'a, T: ?Sized, SI, RI, E> {
326        value: Option<(Pin<&'a mut T>, RI)>,
327        error: Option<E>,
328        phantom: PhantomData<SI>,
329    }
330}
331
332impl<'a, T, SI, RI, E> Stream for SendOne<'a, T, SI, RI, E>
333where
334    T: StreamSink<SI, RI, Error = E> + ?Sized,
335{
336    type Item = Result<SI, E>;
337
338    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
339        let zelf = self.project();
340        if let Some(e) = zelf.error.take() {
341            return Poll::Ready(Some(Err(e)));
342        }
343
344        let Some((t, _)) = zelf.value else {
345            return Poll::Ready(None);
346        };
347        match t.as_mut().poll_stream_sink(cx) {
348            State::Error(e) => {
349                *zelf.value = None;
350                Poll::Ready(Some(Err(e)))
351            }
352            State::Pending => Poll::Pending,
353            State::SendReady(v) => Poll::Ready(Some(Ok(v))),
354            State::SendRecvReady(v) => {
355                // SAFETY: zelf.value has value
356                let (t, i) = zelf.value.take().unwrap();
357                *zelf.error = t.start_send(i).err();
358                Poll::Ready(Some(Ok(v)))
359            }
360            State::RecvReady => {
361                // SAFETY: zelf.value has value
362                let (t, i) = zelf.value.take().unwrap();
363                match t.start_send(i) {
364                    Ok(_) => Poll::Ready(None),
365                    Err(e) => Poll::Ready(Some(Err(e))),
366                }
367            }
368            State::End => {
369                *zelf.value = None;
370                Poll::Ready(None)
371            }
372        }
373    }
374
375    fn size_hint(&self) -> (usize, Option<usize>) {
376        match self {
377            Self { error: Some(_), .. } => (1, Some(1)),
378            Self { value: None, .. } => (0, Some(0)),
379            _ => (0, None),
380        }
381    }
382}
383
384impl<'a, T, SI, RI, E> FusedStream for SendOne<'a, T, SI, RI, E>
385where
386    T: StreamSink<SI, RI, Error = E> + ?Sized,
387{
388    fn is_terminated(&self) -> bool {
389        self.value.is_none() && self.error.is_none()
390    }
391}
392
393pin_project! {
394    /// Return type of [`StreamSinkExt::send_iter()`].
395    pub struct SendIter<'a, T: ?Sized, SI, RI, IT, E> {
396        value: Option<(Pin<&'a mut T>, IT)>,
397        error: Option<E>,
398        phantom: PhantomData<(SI, RI)>,
399    }
400}
401
402impl<'a, T, SI, RI, IT, E> Stream for SendIter<'a, T, SI, RI, IT, E>
403where
404    T: StreamSink<SI, RI, Error = E> + ?Sized,
405    IT: Iterator<Item = RI>,
406{
407    type Item = Result<SI, E>;
408
409    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410        let zelf = self.as_mut().project();
411        if let Some(e) = zelf.error.take() {
412            *zelf.value = None;
413            return Poll::Ready(Some(Err(e)));
414        }
415
416        let Some((t, it)) = zelf.value else {
417            return Poll::Ready(None);
418        };
419        match t.as_mut().poll_stream_sink(cx) {
420            State::Error(e) => {
421                *zelf.value = None;
422                Poll::Ready(Some(Err(e)))
423            }
424            State::Pending => Poll::Pending,
425            State::SendReady(v) => Poll::Ready(Some(Ok(v))),
426            State::SendRecvReady(v) => {
427                *zelf.error = if let Some(i) = it.next() {
428                    t.as_mut().start_send(i).err()
429                } else {
430                    *zelf.value = None;
431                    None
432                };
433                Poll::Ready(Some(Ok(v)))
434            }
435            State::RecvReady => {
436                let r = if let Some(i) = it.next() {
437                    t.as_mut().start_send(i).err()
438                } else {
439                    *zelf.value = None;
440                    None
441                };
442                if let Some(e) = r {
443                    *zelf.value = None;
444                    Poll::Ready(Some(Err(e)))
445                } else {
446                    self.poll_next(cx)
447                }
448            }
449            State::End => {
450                *zelf.value = None;
451                Poll::Ready(None)
452            }
453        }
454    }
455
456    fn size_hint(&self) -> (usize, Option<usize>) {
457        match self {
458            Self { error: Some(_), .. } => (1, Some(1)),
459            Self { value: None, .. } => (0, Some(0)),
460            _ => (0, None),
461        }
462    }
463}
464
465impl<'a, T, SI, RI, IT, E> FusedStream for SendIter<'a, T, SI, RI, IT, E>
466where
467    T: StreamSink<SI, RI, Error = E> + ?Sized,
468    IT: Iterator<Item = RI>,
469{
470    fn is_terminated(&self) -> bool {
471        self.value.is_none() && self.error.is_none()
472    }
473}
474
475pin_project! {
476    /// Return type of [`StreamSinkExt::send_try_iter()`].
477    pub struct SendTryIter<'a, T: ?Sized, SI, RI, IT, E> {
478        value: Option<(Pin<&'a mut T>, IT)>,
479        error: Option<E>,
480        phantom: PhantomData<(SI, RI)>,
481    }
482}
483
484impl<'a, T, SI, RI, IT, E> Stream for SendTryIter<'a, T, SI, RI, IT, E>
485where
486    T: StreamSink<SI, RI, Error = E> + ?Sized,
487    IT: Iterator<Item = Result<RI, E>>,
488{
489    type Item = Result<SI, E>;
490
491    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
492        let zelf = self.as_mut().project();
493        if let Some(e) = zelf.error.take() {
494            *zelf.value = None;
495            return Poll::Ready(Some(Err(e)));
496        }
497
498        let Some((t, it)) = zelf.value else {
499            return Poll::Ready(None);
500        };
501        match t.as_mut().poll_stream_sink(cx) {
502            State::Error(e) => Poll::Ready(Some(Err(e))),
503            State::Pending => Poll::Pending,
504            State::SendReady(v) => Poll::Ready(Some(Ok(v))),
505            State::SendRecvReady(v) => {
506                *zelf.error = match it.next() {
507                    Some(Ok(i)) => t.as_mut().start_send(i).err(),
508                    Some(Err(e)) => Some(e),
509                    None => {
510                        *zelf.value = None;
511                        None
512                    }
513                };
514                Poll::Ready(Some(Ok(v)))
515            }
516            State::RecvReady => {
517                let r = match it.next() {
518                    Some(Ok(i)) => t.as_mut().start_send(i).err(),
519                    Some(Err(e)) => Some(e),
520                    None => {
521                        *zelf.value = None;
522                        None
523                    }
524                };
525                if let Some(e) = r {
526                    *zelf.value = None;
527                    Poll::Ready(Some(Err(e)))
528                } else {
529                    self.poll_next(cx)
530                }
531            }
532            State::End => {
533                *zelf.value = None;
534                Poll::Ready(None)
535            }
536        }
537    }
538
539    fn size_hint(&self) -> (usize, Option<usize>) {
540        match self {
541            Self { error: Some(_), .. } => (1, Some(1)),
542            Self { value: None, .. } => (0, Some(0)),
543            _ => (0, None),
544        }
545    }
546}
547
548impl<'a, T, SI, RI, IT, E> FusedStream for SendTryIter<'a, T, SI, RI, IT, E>
549where
550    T: StreamSink<SI, RI, Error = E> + ?Sized,
551    IT: Iterator<Item = Result<RI, E>>,
552{
553    fn is_terminated(&self) -> bool {
554        self.value.is_none() && self.error.is_none()
555    }
556}
557
558pin_project! {
559    /// Return type of [`StreamSinkExt::close()`].
560    pub struct Close<'a, T: ?Sized, SI, RI> {
561        ptr: Option<Pin<&'a mut T>>,
562        phantom: PhantomData<(SI, RI)>,
563    }
564}
565
566impl<'a, T, SI, RI> Stream for Close<'a, T, SI, RI>
567where
568    T: StreamSink<SI, RI> + ?Sized,
569{
570    type Item = Result<SI, T::Error>;
571
572    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
573        let ptr = self.project().ptr;
574        let Some(t) = ptr else {
575            return Poll::Ready(None);
576        };
577
578        let r = match t.as_mut().poll_close(cx) {
579            Poll::Pending => return Poll::Pending,
580            Poll::Ready(v) => v.transpose(),
581        };
582        if matches!(r, None | Some(Err(_))) {
583            *ptr = None;
584        }
585        Poll::Ready(r)
586    }
587
588    fn size_hint(&self) -> (usize, Option<usize>) {
589        match self {
590            Self { ptr: None, .. } => (0, Some(0)),
591            _ => (0, None),
592        }
593    }
594}
595
596impl<'a, T, SI, RI> FusedStream for Close<'a, T, SI, RI>
597where
598    T: StreamSink<SI, RI> + ?Sized,
599{
600    fn is_terminated(&self) -> bool {
601        self.ptr.is_none()
602    }
603}
604
605pin_project! {
606    /// Return type of [`StreamSinkExt::ready()`].
607    pub struct Ready<'a, T: ?Sized, SI, RI> {
608        ptr: Pin<&'a mut T>,
609        phantom: PhantomData<(SI, RI)>,
610    }
611}
612
613impl<'a, T, SI, RI> Future for Ready<'a, T, SI, RI>
614where
615    T: ?Sized + StreamSink<SI, RI>,
616{
617    type Output = Result<(Option<SI>, bool), T::Error>;
618
619    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
620        let zelf = self.project();
621
622        match zelf.ptr.as_mut().poll_stream_sink(cx) {
623            State::Pending => Poll::Pending,
624            State::Error(e) => Poll::Ready(Err(e)),
625            State::End => Poll::Ready(Ok((None, false))),
626            State::RecvReady => Poll::Ready(Ok((None, true))),
627            State::SendReady(i) => Poll::Ready(Ok((Some(i), false))),
628            State::SendRecvReady(i) => Poll::Ready(Ok((Some(i), true))),
629        }
630    }
631}
632
633pin_project! {
634    /// Return type of [`StreamSinkExt::try_send_one()`].
635    pub struct TrySendOne<'a, T: ?Sized, SI, F> {
636        ptr: Pin<&'a mut T>,
637        f: Option<F>,
638        phantom: PhantomData<SI>,
639    }
640}
641
642impl<'a, T, SI, RI, F> Future for TrySendOne<'a, T, SI, F>
643where
644    F: FnOnce() -> RI,
645    T: ?Sized + StreamSink<SI, RI>,
646{
647    type Output = Result<Option<SI>, T::Error>;
648
649    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
650        let zelf = self.project();
651
652        let mut f = || {
653            zelf.f
654                .take()
655                .expect("Future should not be polled after completion")()
656        };
657        match zelf.ptr.as_mut().poll_stream_sink(cx) {
658            State::Pending => Poll::Pending,
659            State::Error(e) => Poll::Ready(Err(e)),
660            State::End => Poll::Ready(Ok(None)),
661            State::SendReady(i) => Poll::Ready(Ok(Some(i))),
662            State::RecvReady => match zelf.ptr.as_mut().start_send(f()) {
663                Ok(_) => Poll::Ready(Ok(None)),
664                Err(e) => Poll::Ready(Err(e)),
665            },
666            State::SendRecvReady(i) => match zelf.ptr.as_mut().start_send(f()) {
667                Ok(_) => Poll::Ready(Ok(Some(i))),
668                Err(e) => Poll::Ready(Err(e)),
669            },
670        }
671    }
672}
673
674pin_project! {
675    /// Return type of [`StreamSinkExt::try_send_future()`].
676    pub struct TrySendFuture<'a, T: ?Sized, SI, F> {
677        ptr: Pin<&'a mut T>,
678        #[pin]
679        fut: F,
680        send: Option<SI>,
681        ready: bool,
682    }
683}
684
685impl<'a, T, SI, RI, F> Future for TrySendFuture<'a, T, SI, F>
686where
687    T: ?Sized + StreamSink<SI, RI>,
688    F: Future<Output = Result<RI, T::Error>>,
689{
690    type Output = Result<Option<SI>, T::Error>;
691
692    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
693        let zelf = self.project();
694
695        while !*zelf.ready {
696            return match zelf.ptr.as_mut().poll_stream_sink(&mut *cx) {
697                State::Pending => Poll::Pending,
698                State::Error(e) => Poll::Ready(Err(e)),
699                State::End => Poll::Ready(Ok(None)),
700                State::RecvReady => {
701                    *zelf.ready = true;
702                    continue;
703                }
704                State::SendReady(i) => Poll::Ready(Ok(Some(i))),
705                State::SendRecvReady(i) => {
706                    *zelf.ready = true;
707                    *zelf.send = Some(i);
708                    continue;
709                }
710            };
711        }
712
713        match zelf.fut.poll(cx) {
714            Poll::Pending => Poll::Pending,
715            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
716            Poll::Ready(Ok(r)) => match zelf.ptr.as_mut().start_send(r) {
717                Err(e) => Poll::Ready(Err(e)),
718                Ok(_) => Poll::Ready(Ok(zelf.send.take())),
719            },
720        }
721    }
722}
723
724/// Extension trait for [`StreamSink`]. Contains helper methods for using [`StreamSink`].
725pub trait StreamSinkExt<SendItem, RecvItem = SendItem>: StreamSink<SendItem, RecvItem> {
726    /// Maps the `SendItem`.
727    fn map_send<F, I>(self, f: F) -> MapSend<Self, F, SendItem>
728    where
729        Self: Sized,
730        F: FnMut(SendItem) -> I,
731    {
732        MapSend {
733            t: self,
734            f,
735            phantom: PhantomData,
736        }
737    }
738
739    /// Maps the `RecvItem`.
740    fn map_recv<F, I>(self, f: F) -> MapRecv<Self, F, I>
741    where
742        Self: Sized,
743        F: FnMut(I) -> RecvItem,
744    {
745        MapRecv {
746            t: self,
747            f,
748            phantom: PhantomData,
749        }
750    }
751
752    /// Maps the error type.
753    fn map_error<F, E>(self, f: F) -> MapError<Self, F>
754    where
755        Self: Sized,
756        F: FnMut(Self::Error) -> E,
757    {
758        MapError { t: self, f }
759    }
760
761    /// Cast the error type.
762    fn error_cast<E>(self) -> ErrorCast<Self, E>
763    where
764        Self: Sized,
765        Self::Error: Into<E>,
766    {
767        ErrorCast {
768            t: self,
769            phantom: PhantomData,
770        }
771    }
772
773    /// Chain two [`StreamSink`].
774    ///
775    /// If either of the [`StreamSink`] ends, the other one will try to be closed.
776    /// Some data may gets lost because of that.
777    fn chain<Other, Item>(self, other: Other) -> Chain<Self, Other, SendItem>
778    where
779        Self: Sized,
780        Other: StreamSink<Item, SendItem, Error = Self::Error>,
781    {
782        Chain {
783            u: self,
784            v: other,
785            state: ChainState::Pending,
786            phantom: PhantomData,
787        }
788    }
789
790    /// Send one item.
791    ///
792    /// The resulting [`Stream`] may be dropped at anytime, with only consequence is loss of item.
793    /// To not lose the item, use [`try_send_one`](Self::try_send_one).
794    fn send_one<'a>(
795        self: Pin<&'a mut Self>,
796        item: RecvItem,
797    ) -> SendOne<'a, Self, SendItem, RecvItem, Self::Error> {
798        SendOne {
799            value: Some((self, item)),
800            error: None,
801            phantom: PhantomData,
802        }
803    }
804
805    /// Send items from an [`IntoIterator`].
806    ///
807    /// The resulting [`Stream`] may be dropped at anytime, with no loss of item.
808    fn send_iter<'a, I: IntoIterator<Item = RecvItem>>(
809        self: Pin<&'a mut Self>,
810        iter: I,
811    ) -> SendIter<'a, Self, SendItem, RecvItem, I::IntoIter, Self::Error> {
812        SendIter {
813            value: Some((self, iter.into_iter())),
814            error: None,
815            phantom: PhantomData,
816        }
817    }
818
819    /// Send items from a fallible [`IntoIterator`].
820    ///
821    /// The resulting [`Stream`] may be dropped at anytime, with no loss of item.
822    fn send_try_iter<'a, I: IntoIterator<Item = Result<RecvItem, Self::Error>>>(
823        self: Pin<&'a mut Self>,
824        iter: I,
825    ) -> SendTryIter<'a, Self, SendItem, RecvItem, I::IntoIter, Self::Error> {
826        SendTryIter {
827            value: Some((self, iter.into_iter())),
828            error: None,
829            phantom: PhantomData,
830        }
831    }
832
833    /// Closes the [`StreamSink`].
834    ///
835    /// You must handle all items that came out of the resulting [`Stream`].
836    fn close<'a>(self: Pin<&'a mut Self>) -> Close<'a, Self, SendItem, RecvItem> {
837        Close {
838            ptr: Some(self),
839            phantom: PhantomData,
840        }
841    }
842
843    /// Polls until it's ready. It is safe drop the [`Future`] before it's ready (cancel-safety).
844    ///
845    /// Possible return value (after awaited):
846    /// - `Err(error)` : Error happened.
847    /// - `Ok((None, false))` : [`StreamSink`] is closed.
848    /// - `Ok((Some(item), false))` : An item is sent.
849    /// - `Ok((None, true))` : Ready to receive item.
850    /// - `Ok((Some(item), true))` : Item is sent and it's ready to receive another item.
851    fn ready<'a>(self: Pin<&'a mut Self>) -> Ready<'a, Self, SendItem, RecvItem> {
852        Ready {
853            ptr: self,
854            phantom: PhantomData,
855        }
856    }
857
858    /// Try to send an item.
859    /// It is safe to drop the [`Future`] before it's ready.
860    fn try_send_one<'a, F: FnOnce() -> RecvItem>(
861        self: Pin<&'a mut Self>,
862        f: F,
863    ) -> TrySendOne<'a, Self, SendItem, F> {
864        TrySendOne {
865            ptr: self,
866            f: Some(f),
867            phantom: PhantomData,
868        }
869    }
870
871    /// Try to send an item using [`Future`].
872    /// It is cancen-safe if and only if the inner [`Future`] is also cancel-safe.
873    fn try_send_future<'a, F: Future<Output = Result<RecvItem, Self::Error>>>(
874        self: Pin<&'a mut Self>,
875        fut: F,
876    ) -> TrySendFuture<'a, Self, SendItem, F> {
877        TrySendFuture {
878            ptr: self,
879            fut,
880            send: None,
881            ready: false,
882        }
883    }
884}
885
886impl<SI, RI, T> StreamSinkExt<SI, RI> for T where T: StreamSink<SI, RI> {}
887
888#[cfg(test)]
889mod tests {
890    use super::*;
891    use crate::ScopedStreamSink;
892
893    use std::pin::pin;
894    use std::prelude::rust_2021::*;
895    use std::time::Duration;
896
897    use anyhow::{bail, Error as AnyError, Result as AnyResult};
898    use futures_util::{SinkExt, StreamExt};
899    use tokio::task::yield_now;
900    use tokio::time::timeout;
901
902    async fn test_helper<F>(f: F) -> AnyResult<()>
903    where
904        F: Future<Output = AnyResult<()>> + Send,
905    {
906        match timeout(Duration::from_secs(5), f).await {
907            Ok(v) => v,
908            Err(_) => bail!("Time ran out"),
909        }
910    }
911
912    #[tokio::test]
913    async fn test_ready_simple() -> AnyResult<()> {
914        let v = <ScopedStreamSink<usize, usize, AnyError>>::new(|_, _| Box::pin(async { Ok(()) }));
915
916        test_helper(async move {
917            let v = pin!(v);
918            assert_eq!(v.ready().await?, (None, false));
919
920            Ok(())
921        })
922        .await
923    }
924
925    #[tokio::test]
926    async fn test_ready_send() -> AnyResult<()> {
927        let v = <ScopedStreamSink<usize, usize, AnyError>>::new(|_, mut sink| {
928            Box::pin(async move {
929                for i in 0..10 {
930                    sink.send(i).await?;
931                    for _ in 0..i {
932                        yield_now().await;
933                    }
934                }
935
936                Ok(())
937            })
938        });
939
940        test_helper(async move {
941            let mut v = pin!(v);
942            for i in 0..10 {
943                assert_eq!(v.as_mut().ready().await?, (Some(i), true));
944                for _ in 0..i {
945                    assert_eq!(v.as_mut().ready().await?, (None, true));
946                }
947            }
948            assert_eq!(v.as_mut().ready().await?, (None, false));
949
950            Ok(())
951        })
952        .await
953    }
954
955    #[tokio::test]
956    async fn test_close_send() -> AnyResult<()> {
957        let v = <ScopedStreamSink<usize, usize, AnyError>>::new(|_, mut sink| {
958            Box::pin(async move {
959                for i in 0..10 {
960                    sink.send(i).await?;
961                    for _ in 0..i {
962                        yield_now().await;
963                    }
964                }
965
966                Ok(())
967            })
968        });
969
970        test_helper(async move {
971            let mut v = pin!(v);
972            let mut s = v.as_mut().close();
973            for i in 0..10 {
974                assert_eq!(s.next().await.transpose()?, Some(i));
975            }
976            assert_eq!(s.next().await.transpose()?, None);
977
978            Ok(())
979        })
980        .await
981    }
982
983    #[tokio::test]
984    async fn test_transform() -> AnyResult<()> {
985        let v = <ScopedStreamSink<usize, usize, AnyError>>::new(|mut stream, mut sink| {
986            Box::pin(async move {
987                while let Some(v) = stream.next().await {
988                    sink.send(v * 2).await?;
989                }
990
991                Ok(())
992            })
993        });
994
995        test_helper(async move {
996            let mut v = pin!(v);
997            for i in 0..10 {
998                let mut s = v.as_mut().send_one(i);
999                assert_eq!(s.next().await.transpose()?, None);
1000                assert_eq!(v.as_mut().ready().await?.0, Some(i * 2));
1001            }
1002            assert_eq!(v.close().next().await.transpose()?, None);
1003
1004            Ok(())
1005        })
1006        .await
1007    }
1008
1009    #[tokio::test]
1010    async fn test_transform_iter() -> AnyResult<()> {
1011        let v = <ScopedStreamSink<usize, usize, AnyError>>::new(|mut stream, mut sink| {
1012            Box::pin(async move {
1013                while let Some(v) = stream.next().await {
1014                    sink.send(v * 2).await?;
1015                }
1016
1017                Ok(())
1018            })
1019        });
1020
1021        test_helper(async move {
1022            let mut v = pin!(v);
1023            let mut s = v.as_mut().send_iter(0..10);
1024            for i in 0..10 {
1025                assert_eq!(s.next().await.transpose()?, Some(i * 2));
1026            }
1027            assert_eq!(s.next().await.transpose()?, None);
1028            assert_eq!(v.close().next().await.transpose()?, None);
1029
1030            Ok(())
1031        })
1032        .await
1033    }
1034
1035    #[tokio::test]
1036    async fn test_map_send() -> AnyResult<()> {
1037        let v = <ScopedStreamSink<usize, usize, AnyError>>::new(|mut stream, mut sink| {
1038            Box::pin(async move {
1039                while let Some(v) = stream.next().await {
1040                    sink.send(v * 2).await?;
1041                }
1042
1043                Ok(())
1044            })
1045        })
1046        .map_send(|v| v + 10);
1047
1048        test_helper(async move {
1049            let mut v = pin!(v);
1050            for i in 0..10 {
1051                let mut s = v.as_mut().send_one(i);
1052                assert_eq!(s.next().await.transpose()?, None);
1053                assert_eq!(v.as_mut().ready().await?.0, Some(i * 2 + 10));
1054            }
1055            assert_eq!(v.close().next().await.transpose()?, None);
1056
1057            Ok(())
1058        })
1059        .await
1060    }
1061
1062    #[tokio::test]
1063    async fn test_map_recv() -> AnyResult<()> {
1064        let v = <ScopedStreamSink<usize, usize, AnyError>>::new(|mut stream, mut sink| {
1065            Box::pin(async move {
1066                while let Some(v) = stream.next().await {
1067                    sink.send(v * 2).await?;
1068                }
1069
1070                Ok(())
1071            })
1072        })
1073        .map_recv(|v| v + 10);
1074
1075        test_helper(async move {
1076            let mut v = pin!(v);
1077            for i in 0..10 {
1078                let mut s = v.as_mut().send_one(i);
1079                assert_eq!(s.next().await.transpose()?, None);
1080                assert_eq!(v.as_mut().ready().await?.0, Some((i + 10) * 2));
1081            }
1082            assert_eq!(v.close().next().await.transpose()?, None);
1083
1084            Ok(())
1085        })
1086        .await
1087    }
1088}