stream_reconnect/
stream.rs

1use std::error::Error;
2use std::future::Future;
3use std::iter::once;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use futures::{ready, Sink, Stream};
11use log::{debug, error, info};
12
13use crate::config::ReconnectOptions;
14
15/// Trait that should be implemented for an [Stream] and/or [Sink]
16/// item to enable it to work with the [ReconnectStream] struct.
17pub trait UnderlyingStream<C, I, E>: Sized + Unpin
18where
19    C: Clone + Send + Unpin,
20    E: Error,
21{
22    /// The creation function is used by [ReconnectStream] in order to establish both the initial IO connection
23    /// in addition to performing reconnects.
24    fn establish(ctor_arg: C) -> Pin<Box<dyn Future<Output = Result<Self, E>> + Send>>;
25
26    /// When sink send experience an `Error` during operation, it does not necessarily mean
27    /// it is a disconnect/termination (ex: WouldBlock).
28    /// You may specify which errors are considered "disconnects" by this method.
29    fn is_write_disconnect_error(&self, err: &E) -> bool;
30
31    /// It's common practice for [Stream] implementations that return an `Err`
32    /// when there's an error.
33    /// You may match the result to tell them apart from normal response.
34    /// By default, no response is considered a "disconnect".
35    #[allow(unused_variables)]
36    fn is_read_disconnect_error(&self, item: &I) -> bool {
37        false
38    }
39
40    /// This is returned when retry quota exhausted.
41    fn exhaust_err() -> E;
42}
43
44struct AttemptsTracker {
45    attempt_num: usize,
46    retries_remaining: Box<dyn Iterator<Item = Duration> + Send>,
47}
48
49struct ReconnectStatus<T, C, I, E> {
50    attempts_tracker: AttemptsTracker,
51    reconnect_attempt: Pin<Box<dyn Future<Output = Result<T, E>> + Send>>,
52    _marker_1: PhantomData<C>,
53    _marker_2: PhantomData<I>,
54    _marker_3: PhantomData<E>,
55}
56
57impl<T, C, I, E> ReconnectStatus<T, C, I, E>
58where
59    T: UnderlyingStream<C, I, E>,
60    C: Clone + Send + Unpin + 'static,
61    E: Error + Unpin,
62{
63    pub fn new(options: &ReconnectOptions) -> Self {
64        ReconnectStatus {
65            attempts_tracker: AttemptsTracker {
66                attempt_num: 0,
67                retries_remaining: (options.retries_to_attempt_fn())(),
68            },
69            reconnect_attempt: Box::pin(async { unreachable!("Not going to happen") }),
70            _marker_1: PhantomData,
71            _marker_2: PhantomData,
72            _marker_3: PhantomData,
73        }
74    }
75}
76
77/// The ReconnectStream is a wrapper over a [Stream]/[Sink] item that will automatically
78/// invoke the [UnderlyingStream::establish] upon initialization and when a reconnect is needed.
79/// Because it implements deref, you are able to invoke all of the original methods on the wrapped stream.
80pub struct ReconnectStream<T, C, I, E> {
81    status: Status<T, C, I, E>,
82    underlying_io: T,
83    options: ReconnectOptions,
84    ctor_arg: C,
85    _marker: PhantomData<I>,
86}
87
88enum Status<T, C, I, E> {
89    Connected,
90    Disconnected(ReconnectStatus<T, C, I, E>),
91    FailedAndExhausted, // the way one feels after programming in dynamically typed languages
92}
93
94impl<T, C, I, E> Deref for ReconnectStream<T, C, I, E> {
95    type Target = T;
96
97    fn deref(&self) -> &Self::Target {
98        &self.underlying_io
99    }
100}
101
102impl<T, C, I, E> DerefMut for ReconnectStream<T, C, I, E> {
103    fn deref_mut(&mut self) -> &mut Self::Target {
104        &mut self.underlying_io
105    }
106}
107
108impl<T, C, I, E> ReconnectStream<T, C, I, E>
109where
110    T: UnderlyingStream<C, I, E>,
111    C: Clone + Send + Unpin + 'static,
112    I: Unpin,
113    E: Error + Unpin,
114{
115    /// Connects or creates a handle to the [UnderlyingStream] item,
116    /// using the default reconnect options.
117    pub async fn connect(ctor_arg: C) -> Result<Self, E> {
118        let options = ReconnectOptions::new();
119        Self::connect_with_options(ctor_arg, options).await
120    }
121
122    pub async fn connect_with_options(ctor_arg: C, options: ReconnectOptions) -> Result<Self, E> {
123        let tries = (**options.retries_to_attempt_fn())()
124            .map(Some)
125            .chain(once(None));
126        let mut result = None;
127        for (counter, maybe_delay) in tries.enumerate() {
128            match T::establish(ctor_arg.clone()).await {
129                Ok(inner) => {
130                    debug!("Initial connection succeeded.");
131                    (options.on_connect_callback())();
132                    result = Some(Ok(inner));
133                    break;
134                }
135                Err(e) => {
136                    error!("Connection failed due to: {:?}.", e);
137                    (options.on_connect_fail_callback())();
138
139                    if options.exit_if_first_connect_fails() {
140                        error!("Bailing after initial connection failure.");
141                        return Err(e);
142                    }
143
144                    result = Some(Err(e));
145
146                    if let Some(delay) = maybe_delay {
147                        debug!(
148                            "Will re-perform initial connect attempt #{} in {:?}.",
149                            counter + 1,
150                            delay
151                        );
152
153                        #[cfg(feature = "tokio")]
154                        let sleep_fut = tokio::time::sleep(delay);
155                        #[cfg(feature = "async-std")]
156                        let sleep_fut = async_std::task::sleep(delay);
157
158                        sleep_fut.await;
159
160                        debug!("Attempting reconnect #{} now.", counter + 1);
161                    }
162                }
163            }
164        }
165
166        match result.unwrap() {
167            Ok(inner) => Ok(ReconnectStream {
168                status: Status::Connected,
169                ctor_arg,
170                underlying_io: inner,
171                options,
172                _marker: PhantomData,
173            }),
174            Err(e) => {
175                error!("No more re-connect retries remaining. Never able to establish initial connection.");
176                Err(e)
177            }
178        }
179    }
180
181    fn on_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
182        match &mut self.status {
183            // initial disconnect
184            Status::Connected => {
185                error!("Disconnect occurred");
186                (self.options.on_disconnect_callback())();
187                self.status = Status::Disconnected(ReconnectStatus::new(&self.options));
188            }
189            Status::Disconnected(_) => {
190                (self.options.on_connect_fail_callback())();
191            }
192            Status::FailedAndExhausted => {
193                unreachable!("on_disconnect will not occur for already exhausted state.")
194            }
195        };
196
197        let ctor_arg = self.ctor_arg.clone();
198
199        // this is ensured to be true now
200        if let Status::Disconnected(reconnect_status) = &mut self.status {
201            let next_duration = match reconnect_status.attempts_tracker.retries_remaining.next() {
202                Some(duration) => duration,
203                None => {
204                    error!("No more re-connect retries remaining. Giving up.");
205                    self.status = Status::FailedAndExhausted;
206                    cx.waker().wake_by_ref();
207                    return;
208                }
209            };
210
211            #[cfg(feature = "tokio")]
212            let future_instant = tokio::time::sleep(next_duration);
213            #[cfg(feature = "async-std")]
214            let future_instant = async_std::task::sleep(next_duration);
215
216            reconnect_status.attempts_tracker.attempt_num += 1;
217            let cur_num = reconnect_status.attempts_tracker.attempt_num;
218
219            let reconnect_attempt = async move {
220                future_instant.await;
221                debug!("Attempting reconnect #{} now.", cur_num);
222                T::establish(ctor_arg).await
223            };
224
225            reconnect_status.reconnect_attempt = Box::pin(reconnect_attempt);
226
227            debug!(
228                "Will perform reconnect attempt #{} in {:?}.",
229                reconnect_status.attempts_tracker.attempt_num, next_duration
230            );
231
232            cx.waker().wake_by_ref();
233        }
234    }
235
236    fn poll_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
237        let (attempt, attempt_num) = match &mut self.status {
238            Status::Connected => unreachable!(),
239            Status::Disconnected(ref mut status) => (
240                Pin::new(&mut status.reconnect_attempt),
241                status.attempts_tracker.attempt_num,
242            ),
243            Status::FailedAndExhausted => unreachable!(),
244        };
245
246        match attempt.poll(cx) {
247            Poll::Ready(Ok(underlying_io)) => {
248                info!("Connection re-established");
249                cx.waker().wake_by_ref();
250                self.status = Status::Connected;
251                (self.options.on_connect_callback())();
252                self.underlying_io = underlying_io;
253            }
254            Poll::Ready(Err(err)) => {
255                error!("Connection attempt #{} failed: {:?}", attempt_num, err);
256                self.on_disconnect(cx);
257            }
258            Poll::Pending => {}
259        }
260    }
261
262    fn is_write_disconnect_detected<X>(&self, poll_result: &Poll<Result<X, E>>) -> bool {
263        match poll_result {
264            Poll::Ready(Err(err)) => self.is_write_disconnect_error(err),
265            _ => false,
266        }
267    }
268}
269
270impl<T, C, I, E> Stream for ReconnectStream<T, C, I, E>
271where
272    T: UnderlyingStream<C, I, E> + Stream<Item = I>,
273    C: Clone + Send + Unpin + 'static,
274    I: Unpin,
275    E: Error + Unpin,
276{
277    type Item = I;
278
279    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
280        match self.status {
281            Status::Connected => {
282                let poll = ready!(Pin::new(&mut self.underlying_io).poll_next(cx));
283                if let Some(poll) = poll {
284                    if self.is_read_disconnect_error(&poll) {
285                        self.on_disconnect(cx);
286                        Poll::Pending
287                    } else {
288                        Poll::Ready(Some(poll))
289                    }
290                } else {
291                    self.on_disconnect(cx);
292                    Poll::Pending
293                }
294            }
295            Status::Disconnected(_) => {
296                self.poll_disconnect(cx);
297                Poll::Pending
298            }
299            Status::FailedAndExhausted => Poll::Ready(None),
300        }
301    }
302
303    fn size_hint(&self) -> (usize, Option<usize>) {
304        self.underlying_io.size_hint()
305    }
306}
307
308impl<T, C, I, I2, E> Sink<I> for ReconnectStream<T, C, I2, E>
309where
310    T: UnderlyingStream<C, I2, E> + Sink<I, Error = E>,
311    C: Clone + Send + Unpin + 'static,
312    I2: Unpin,
313    E: Error + Unpin,
314{
315    type Error = E;
316
317    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318        match self.status {
319            Status::Connected => {
320                let poll = Pin::new(&mut self.underlying_io).poll_ready(cx);
321
322                if self.is_write_disconnect_detected(&poll) {
323                    self.on_disconnect(cx);
324                    Poll::Pending
325                } else {
326                    poll
327                }
328            }
329            Status::Disconnected(_) => {
330                self.poll_disconnect(cx);
331                Poll::Pending
332            }
333            Status::FailedAndExhausted => Poll::Ready(Err(T::exhaust_err())),
334        }
335    }
336
337    fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
338        Pin::new(&mut self.underlying_io).start_send(item)
339    }
340
341    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
342        match self.status {
343            Status::Connected => {
344                let poll = Pin::new(&mut self.underlying_io).poll_flush(cx);
345
346                if self.is_write_disconnect_detected(&poll) {
347                    self.on_disconnect(cx);
348                    Poll::Pending
349                } else {
350                    poll
351                }
352            }
353            Status::Disconnected(_) => {
354                self.poll_disconnect(cx);
355                Poll::Pending
356            }
357            Status::FailedAndExhausted => Poll::Ready(Err(T::exhaust_err())),
358        }
359    }
360
361    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
362        match self.status {
363            Status::Connected => {
364                let poll = Pin::new(&mut self.underlying_io).poll_close(cx);
365                if poll.is_ready() {
366                    // if completed, we are disconnected whether error or not
367                    self.on_disconnect(cx);
368                }
369
370                poll
371            }
372            Status::Disconnected(_) => Poll::Pending,
373            Status::FailedAndExhausted => Poll::Ready(Err(T::exhaust_err())),
374        }
375    }
376}