rustrade_integration/stream/ext/
forward_clone_by.rs1use futures::{Sink, Stream, ready};
2use pin_project::pin_project;
3use std::{
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8#[derive(Debug)]
10#[pin_project]
11pub struct ForwardCloneBy<S, FnPredicate, FnForward> {
12 #[pin]
13 socket: S,
14 predicate: FnPredicate,
15 forward: FnForward,
16}
17
18impl<S, FnPredicate, FnForward> ForwardCloneBy<S, FnPredicate, FnForward> {
19 pub fn new(socket: S, predicate: FnPredicate, forward: FnForward) -> Self {
20 Self {
21 socket,
22 predicate,
23 forward,
24 }
25 }
26}
27
28impl<S, FnPredicate, FnForward, FwdErr> Stream for ForwardCloneBy<S, FnPredicate, FnForward>
29where
30 S: Stream,
31 S::Item: Clone,
32 FnPredicate: FnMut(&S::Item) -> bool,
33 FnForward: FnMut(S::Item) -> Result<(), FwdErr>,
34{
35 type Item = S::Item;
36
37 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
38 let mut this = self.project();
39
40 let next_ready = ready!(this.socket.as_mut().poll_next(cx));
41
42 let Some(item) = next_ready else {
43 return Poll::Ready(None);
44 };
45
46 if (this.predicate)(&item) && (this.forward)(item.clone()).is_err() {
47 return Poll::Ready(None);
48 }
49
50 Poll::Ready(Some(item))
51 }
52}
53
54impl<S, FnPredicate, FnForward, Item> Sink<Item> for ForwardCloneBy<S, FnPredicate, FnForward>
55where
56 S: Sink<Item>,
57{
58 type Error = S::Error;
59
60 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61 self.project().socket.poll_ready(cx)
62 }
63
64 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
65 self.project().socket.start_send(item)
66 }
67
68 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69 self.project().socket.poll_flush(cx)
70 }
71
72 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73 self.project().socket.poll_close(cx)
74 }
75}
76
77#[cfg(test)]
78#[allow(clippy::unwrap_used)] mod tests {
80 use super::*;
81 use crate::stream::ext::BarterStreamExt;
82 use futures::StreamExt;
83 use tokio::sync::mpsc;
84 use tokio_stream::wrappers::UnboundedReceiverStream;
85 use tokio_test::{assert_pending, assert_ready_eq};
86
87 #[tokio::test]
88 async fn test_forward_clone_by() {
89 let waker = futures::task::noop_waker_ref();
90 let mut cx = std::task::Context::from_waker(waker);
91
92 let (tx, rx) = mpsc::unbounded_channel::<i32>();
93 let rx = UnboundedReceiverStream::new(rx);
94
95 let (forward_tx, mut forward_rx) = mpsc::unbounded_channel::<i32>();
96
97 let mut stream = rx.forward_clone_by(
98 |item| *item % 2 == 0,
99 move |item| forward_tx.send(item).map_err(|_| ()),
100 );
101
102 assert_pending!(stream.poll_next_unpin(&mut cx));
103
104 tx.send(1).unwrap();
105 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
106 assert!(forward_rx.try_recv().is_err());
107
108 tx.send(2).unwrap();
109 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
110 assert_eq!(forward_rx.try_recv().unwrap(), 2);
111
112 tx.send(3).unwrap();
113 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
114 assert!(forward_rx.try_recv().is_err());
115
116 tx.send(4).unwrap();
117 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(4));
118 assert_eq!(forward_rx.try_recv().unwrap(), 4);
119
120 drop(tx);
121 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
122 }
123
124 #[tokio::test]
125 async fn test_forward_clone_by_terminates_on_forward_error() {
126 let waker = futures::task::noop_waker_ref();
127 let mut cx = Context::from_waker(waker);
128
129 let (tx, rx) = mpsc::unbounded_channel::<i32>();
130 let rx = UnboundedReceiverStream::new(rx);
131
132 let mut stream = rx.forward_clone_by(
133 |item| *item % 2 == 0,
134 |item| if item == 4 { Err(()) } else { Ok(()) },
135 );
136
137 tx.send(1).unwrap();
138 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
139
140 tx.send(2).unwrap();
141 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
142
143 tx.send(4).unwrap();
144 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
145 }
146
147 #[tokio::test]
148 async fn test_forward_clone_by_with_custom_error_type() {
149 let waker = futures::task::noop_waker_ref();
150 let mut cx = Context::from_waker(waker);
151
152 let (tx, rx) = mpsc::unbounded_channel::<i32>();
153 let rx = UnboundedReceiverStream::new(rx);
154
155 let (forward_tx, mut forward_rx) = mpsc::unbounded_channel::<i32>();
156
157 let mut stream = rx.forward_clone_by(
158 |item| *item % 2 == 0,
159 move |item| -> Result<(), String> {
160 forward_tx
161 .send(item)
162 .map_err(|e| format!("send failed: {e}"))
163 },
164 );
165
166 tx.send(2).unwrap();
167 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
168 assert_eq!(forward_rx.try_recv().unwrap(), 2);
169
170 tx.send(3).unwrap();
171 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
172 assert!(forward_rx.try_recv().is_err());
173 }
174}