Skip to main content

rustrade_integration/stream/ext/
forward_clone_by.rs

1use futures::{Sink, Stream, ready};
2use pin_project::pin_project;
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8/// Stream adapter that forwards clones of matching items whilst also yielding all items.
9#[derive(Debug)]
10#[pin_project]
11pub struct ForwardCloneBy<S, FnPredicate, FnForward> {
12    #[pin]
13    socket: S,
14    predicate: FnPredicate,
15    forward: FnForward,
16}
17
18impl<S, FnPredicate, FnForward> ForwardCloneBy<S, FnPredicate, FnForward> {
19    pub fn new(socket: S, predicate: FnPredicate, forward: FnForward) -> Self {
20        Self {
21            socket,
22            predicate,
23            forward,
24        }
25    }
26}
27
28impl<S, FnPredicate, FnForward, FwdErr> Stream for ForwardCloneBy<S, FnPredicate, FnForward>
29where
30    S: Stream,
31    S::Item: Clone,
32    FnPredicate: FnMut(&S::Item) -> bool,
33    FnForward: FnMut(S::Item) -> Result<(), FwdErr>,
34{
35    type Item = S::Item;
36
37    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
38        let mut this = self.project();
39
40        let next_ready = ready!(this.socket.as_mut().poll_next(cx));
41
42        let Some(item) = next_ready else {
43            return Poll::Ready(None);
44        };
45
46        if (this.predicate)(&item) && (this.forward)(item.clone()).is_err() {
47            return Poll::Ready(None);
48        }
49
50        Poll::Ready(Some(item))
51    }
52}
53
54impl<S, FnPredicate, FnForward, Item> Sink<Item> for ForwardCloneBy<S, FnPredicate, FnForward>
55where
56    S: Sink<Item>,
57{
58    type Error = S::Error;
59
60    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61        self.project().socket.poll_ready(cx)
62    }
63
64    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
65        self.project().socket.start_send(item)
66    }
67
68    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69        self.project().socket.poll_flush(cx)
70    }
71
72    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73        self.project().socket.poll_close(cx)
74    }
75}
76
77#[cfg(test)]
78#[allow(clippy::unwrap_used)] // Test code: panics on bad input are acceptable
79mod tests {
80    use super::*;
81    use crate::stream::ext::BarterStreamExt;
82    use futures::StreamExt;
83    use tokio::sync::mpsc;
84    use tokio_stream::wrappers::UnboundedReceiverStream;
85    use tokio_test::{assert_pending, assert_ready_eq};
86
87    #[tokio::test]
88    async fn test_forward_clone_by() {
89        let waker = futures::task::noop_waker_ref();
90        let mut cx = std::task::Context::from_waker(waker);
91
92        let (tx, rx) = mpsc::unbounded_channel::<i32>();
93        let rx = UnboundedReceiverStream::new(rx);
94
95        let (forward_tx, mut forward_rx) = mpsc::unbounded_channel::<i32>();
96
97        let mut stream = rx.forward_clone_by(
98            |item| *item % 2 == 0,
99            move |item| forward_tx.send(item).map_err(|_| ()),
100        );
101
102        assert_pending!(stream.poll_next_unpin(&mut cx));
103
104        tx.send(1).unwrap();
105        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
106        assert!(forward_rx.try_recv().is_err());
107
108        tx.send(2).unwrap();
109        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
110        assert_eq!(forward_rx.try_recv().unwrap(), 2);
111
112        tx.send(3).unwrap();
113        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
114        assert!(forward_rx.try_recv().is_err());
115
116        tx.send(4).unwrap();
117        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(4));
118        assert_eq!(forward_rx.try_recv().unwrap(), 4);
119
120        drop(tx);
121        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
122    }
123
124    #[tokio::test]
125    async fn test_forward_clone_by_terminates_on_forward_error() {
126        let waker = futures::task::noop_waker_ref();
127        let mut cx = Context::from_waker(waker);
128
129        let (tx, rx) = mpsc::unbounded_channel::<i32>();
130        let rx = UnboundedReceiverStream::new(rx);
131
132        let mut stream = rx.forward_clone_by(
133            |item| *item % 2 == 0,
134            |item| if item == 4 { Err(()) } else { Ok(()) },
135        );
136
137        tx.send(1).unwrap();
138        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
139
140        tx.send(2).unwrap();
141        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
142
143        tx.send(4).unwrap();
144        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
145    }
146
147    #[tokio::test]
148    async fn test_forward_clone_by_with_custom_error_type() {
149        let waker = futures::task::noop_waker_ref();
150        let mut cx = Context::from_waker(waker);
151
152        let (tx, rx) = mpsc::unbounded_channel::<i32>();
153        let rx = UnboundedReceiverStream::new(rx);
154
155        let (forward_tx, mut forward_rx) = mpsc::unbounded_channel::<i32>();
156
157        let mut stream = rx.forward_clone_by(
158            |item| *item % 2 == 0,
159            move |item| -> Result<(), String> {
160                forward_tx
161                    .send(item)
162                    .map_err(|e| format!("send failed: {e}"))
163            },
164        );
165
166        tx.send(2).unwrap();
167        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
168        assert_eq!(forward_rx.try_recv().unwrap(), 2);
169
170        tx.send(3).unwrap();
171        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
172        assert!(forward_rx.try_recv().is_err());
173    }
174}