Skip to main content

rustrade_integration/socket/
on_stream_err.rs

1use futures::{Sink, Stream};
2use pin_project::pin_project;
3use std::{
4    pin::Pin,
5    task::{Context, Poll, ready},
6};
7
8/// Handles stream errors and determines the appropriate [`StreamErrorAction`].
9pub trait StreamErrorHandler<Err> {
10    /// Handles a stream error and returns the action to take.
11    fn handle(&mut self, error: &Err) -> StreamErrorAction;
12}
13
14impl<Err, F> StreamErrorHandler<Err> for F
15where
16    F: FnMut(&Err) -> StreamErrorAction,
17{
18    #[inline]
19    fn handle(&mut self, error: &Err) -> StreamErrorAction {
20        self(error)
21    }
22}
23
24/// Action to take in response to a stream error.
25#[derive(Debug, Copy, Clone, PartialEq)]
26pub enum StreamErrorAction {
27    /// Keep the stream alive.
28    Continue,
29    /// End the stream and trigger reconnection.
30    Reconnect,
31}
32
33/// Stream wrapper that applies error handling to a Result stream.
34///
35/// When an error occurs:
36/// - `StreamErrorAction::Continue`: Pass the error through
37/// - `StreamErrorAction::Reconnect`: End the stream (triggers reconnection)
38#[derive(Debug)]
39#[pin_project]
40pub struct OnStreamErr<S, ErrHandler> {
41    #[pin]
42    socket: S,
43    on_err: ErrHandler,
44}
45
46impl<S, ErrHandler> OnStreamErr<S, ErrHandler> {
47    pub fn new(socket: S, on_err: ErrHandler) -> Self {
48        Self { socket, on_err }
49    }
50}
51
52impl<S, StOk, StErr, ErrHandler> Stream for OnStreamErr<S, ErrHandler>
53where
54    S: Stream<Item = Result<StOk, StErr>>,
55    ErrHandler: StreamErrorHandler<StErr>,
56{
57    type Item = Result<StOk, StErr>;
58
59    #[inline]
60    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61        let mut this = self.project();
62
63        let next_ready = ready!(this.socket.as_mut().poll_next(cx));
64
65        let Some(result) = next_ready else {
66            return Poll::Ready(None);
67        };
68
69        match result {
70            Ok(item) => Poll::Ready(Some(Ok(item))),
71            Err(error) => match (this.on_err).handle(&error) {
72                StreamErrorAction::Continue => Poll::Ready(Some(Err(error))),
73                StreamErrorAction::Reconnect => Poll::Ready(None),
74            },
75        }
76    }
77}
78
79impl<St, ErrHandler, Item> Sink<Item> for OnStreamErr<St, ErrHandler>
80where
81    St: Sink<Item>,
82{
83    type Error = St::Error;
84
85    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.project().socket.poll_ready(cx)
87    }
88
89    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
90        self.project().socket.start_send(item)
91    }
92
93    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94        self.project().socket.poll_flush(cx)
95    }
96
97    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        self.project().socket.poll_close(cx)
99    }
100}
101
102#[cfg(test)]
103#[allow(clippy::unwrap_used)] // Test code: panics on bad input are acceptable
104mod tests {
105    use super::*;
106    use futures::StreamExt;
107    use tokio::sync::mpsc;
108    use tokio_stream::wrappers::UnboundedReceiverStream;
109    use tokio_test::{assert_pending, assert_ready};
110
111    type TestError = &'static str;
112
113    #[tokio::test]
114    async fn test_on_stream_err_passes_through_ok() {
115        let waker = futures::task::noop_waker_ref();
116        let mut cx = Context::from_waker(waker);
117
118        let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
119        let rx = UnboundedReceiverStream::new(rx);
120
121        let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
122
123        assert_pending!(stream.poll_next_unpin(&mut cx));
124
125        tx.send(Ok(1)).unwrap();
126        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
127
128        tx.send(Ok(2)).unwrap();
129        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
130
131        drop(tx);
132        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
133    }
134
135    #[tokio::test]
136    async fn test_on_stream_err_continue_action() {
137        let waker = futures::task::noop_waker_ref();
138        let mut cx = Context::from_waker(waker);
139
140        let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
141        let rx = UnboundedReceiverStream::new(rx);
142
143        let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
144
145        tx.send(Ok(1)).unwrap();
146        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
147
148        tx.send(Err("error1")).unwrap();
149        assert_eq!(
150            assert_ready!(stream.poll_next_unpin(&mut cx)),
151            Some(Err("error1"))
152        );
153
154        tx.send(Ok(2)).unwrap();
155        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
156
157        drop(tx);
158        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
159    }
160
161    #[tokio::test]
162    async fn test_on_stream_err_reconnect_action() {
163        let waker = futures::task::noop_waker_ref();
164        let mut cx = Context::from_waker(waker);
165
166        let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
167        let rx = UnboundedReceiverStream::new(rx);
168
169        let mut stream = OnStreamErr::new(rx, |error: &TestError| {
170            if *error == "fatal" {
171                StreamErrorAction::Reconnect
172            } else {
173                StreamErrorAction::Continue
174            }
175        });
176
177        tx.send(Ok(1)).unwrap();
178        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
179
180        tx.send(Err("non-fatal")).unwrap();
181        assert_eq!(
182            assert_ready!(stream.poll_next_unpin(&mut cx)),
183            Some(Err("non-fatal"))
184        );
185
186        tx.send(Err("fatal")).unwrap();
187        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
188    }
189}