stubborn_io/tokio/
io.rs

1use crate::config::ReconnectOptions;
2use crate::log::{error, info};
3use std::future::Future;
4use std::io::{self, ErrorKind, IoSlice};
5use std::marker::PhantomData;
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10use std::time::Duration;
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12use tokio::time::sleep;
13
14/// Trait that should be implemented for an [AsyncRead] and/or [AsyncWrite]
15/// item to enable it to work with the [StubbornIo] struct.
16pub trait UnderlyingIo<C>: Sized + Unpin
17where
18    C: Clone + Send + Unpin,
19{
20    /// The creation function is used by StubbornIo in order to establish both the initial IO connection
21    /// in addition to performing reconnects.
22    fn establish(ctor_arg: C) -> Pin<Box<dyn Future<Output = io::Result<Self>> + Send>>;
23
24    /// When IO items experience an [io::Error](io::Error) during operation, it does not necessarily mean
25    /// it is a disconnect/termination (ex: WouldBlock). This trait provides sensible defaults to classify
26    /// which errors are considered "disconnects", but this can be overridden based on the user's needs.
27    fn is_disconnect_error(&self, err: &io::Error) -> bool {
28        use std::io::ErrorKind::*;
29
30        matches!(
31            err.kind(),
32            NotFound
33                | PermissionDenied
34                | ConnectionRefused
35                | ConnectionReset
36                | ConnectionAborted
37                | NotConnected
38                | AddrInUse
39                | AddrNotAvailable
40                | BrokenPipe
41                | AlreadyExists
42        )
43    }
44
45    /// If the underlying IO item implements AsyncRead, this method allows the user to specify
46    /// if a technically successful read actually means that the connect is closed.
47    /// For example, tokio's TcpStream successfully performs a read of 0 bytes when closed.
48    fn is_final_read(&self, bytes_read: usize) -> bool {
49        // definitely true for tcp, perhaps true for other io as well,
50        // indicative of EOF hit
51        bytes_read == 0
52    }
53}
54
55struct AttemptsTracker {
56    attempt_num: usize,
57    retries_remaining: Box<dyn Iterator<Item = Duration> + Send + Sync>,
58}
59
60struct ReconnectStatus<T, C> {
61    attempts_tracker: AttemptsTracker,
62    #[allow(clippy::type_complexity)]
63    reconnect_attempt: Arc<Mutex<Pin<Box<dyn Future<Output = io::Result<T>> + Send>>>>,
64    _phantom_data: PhantomData<C>,
65}
66
67impl<T, C> ReconnectStatus<T, C>
68where
69    T: UnderlyingIo<C>,
70    C: Clone + Send + Unpin + 'static,
71{
72    pub fn new(options: &ReconnectOptions) -> Self {
73        ReconnectStatus {
74            attempts_tracker: AttemptsTracker {
75                attempt_num: 0,
76                retries_remaining: (options.retries_to_attempt_fn)(),
77            },
78            reconnect_attempt: Arc::new(Mutex::new(Box::pin(async {
79                unreachable!("Not going to happen")
80            }))),
81            _phantom_data: PhantomData,
82        }
83    }
84}
85
86/// The StubbornIo is a wrapper over a tokio AsyncRead/AsyncWrite item that will automatically
87/// invoke the [UnderlyingIo::establish] upon initialization and when a reconnect is needed.
88/// Because it implements deref, you are able to invoke all of the original methods on the wrapped IO.
89pub struct StubbornIo<T, C> {
90    status: Status<T, C>,
91    underlying_io: T,
92    options: ReconnectOptions,
93    ctor_arg: C,
94}
95
96enum Status<T, C> {
97    Connected,
98    Disconnected(ReconnectStatus<T, C>),
99    FailedAndExhausted, // the way one feels after programming in dynamically typed languages
100}
101
102#[inline]
103fn poll_err<T>(
104    kind: ErrorKind,
105    reason: impl Into<Box<dyn std::error::Error + Send + Sync>>,
106) -> Poll<io::Result<T>> {
107    let io_err = io::Error::new(kind, reason);
108    Poll::Ready(Err(io_err))
109}
110
111fn exhausted_err<T>() -> Poll<io::Result<T>> {
112    poll_err(
113        ErrorKind::NotConnected,
114        "Disconnected. Connection attempts have been exhausted.",
115    )
116}
117
118fn disconnected_err<T>() -> Poll<io::Result<T>> {
119    poll_err(ErrorKind::NotConnected, "Underlying I/O is disconnected.")
120}
121
122impl<T, C> Deref for StubbornIo<T, C> {
123    type Target = T;
124
125    fn deref(&self) -> &Self::Target {
126        &self.underlying_io
127    }
128}
129
130impl<T, C> DerefMut for StubbornIo<T, C> {
131    fn deref_mut(&mut self) -> &mut Self::Target {
132        &mut self.underlying_io
133    }
134}
135
136impl<T, C> StubbornIo<T, C>
137where
138    T: UnderlyingIo<C>,
139    C: Clone + Send + Unpin + 'static,
140{
141    /// Connects or creates a handle to the UnderlyingIo item,
142    /// using the default reconnect options.
143    pub async fn connect(ctor_arg: C) -> io::Result<Self> {
144        let options = ReconnectOptions::new();
145        Self::connect_with_options(ctor_arg, options).await
146    }
147
148    pub async fn connect_with_options(ctor_arg: C, options: ReconnectOptions) -> io::Result<Self> {
149        let tcp = match T::establish(ctor_arg.clone()).await {
150            Ok(tcp) => {
151                info!("Initial connection succeeded.");
152                (options.on_connect_callback)();
153                tcp
154            }
155            Err(e) => {
156                error!("Initial connection failed due to: {:?}.", e);
157                (options.on_connect_fail_callback)();
158
159                if options.exit_if_first_connect_fails {
160                    error!("Bailing after initial connection failure.");
161                    return Err(e);
162                }
163
164                let mut result = Err(e);
165
166                for (i, duration) in (options.retries_to_attempt_fn)().enumerate() {
167                    let reconnect_num = i + 1;
168
169                    info!(
170                        "Will re-perform initial connect attempt #{} in {:?}.",
171                        reconnect_num, duration
172                    );
173
174                    sleep(duration).await;
175
176                    info!("Attempting reconnect #{} now.", reconnect_num);
177
178                    match T::establish(ctor_arg.clone()).await {
179                        Ok(tcp) => {
180                            result = Ok(tcp);
181                            (options.on_connect_callback)();
182                            info!("Initial connection successfully established.");
183                            break;
184                        }
185                        Err(e) => {
186                            (options.on_connect_fail_callback)();
187                            result = Err(e);
188                        }
189                    }
190                }
191
192                match result {
193                    Ok(tcp) => tcp,
194                    Err(e) => {
195                        error!("No more re-connect retries remaining. Never able to establish initial connection.");
196                        return Err(e);
197                    }
198                }
199            }
200        };
201
202        Ok(StubbornIo {
203            status: Status::Connected,
204            ctor_arg,
205            underlying_io: tcp,
206            options,
207        })
208    }
209
210    fn on_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
211        match &mut self.status {
212            // initial disconnect
213            Status::Connected => {
214                error!("Disconnect occurred");
215                (self.options.on_disconnect_callback)();
216                self.status = Status::Disconnected(ReconnectStatus::new(&self.options));
217            }
218            Status::Disconnected(_) => {
219                (self.options.on_connect_fail_callback)();
220            }
221            Status::FailedAndExhausted => {
222                unreachable!("on_disconnect will not occur for already exhausted state.")
223            }
224        };
225
226        let ctor_arg = self.ctor_arg.clone();
227
228        // this is ensured to be true now
229        if let Status::Disconnected(reconnect_status) = &mut self.status {
230            let next_duration = match reconnect_status.attempts_tracker.retries_remaining.next() {
231                Some(duration) => duration,
232                None => {
233                    error!("No more re-connect retries remaining. Giving up.");
234                    self.status = Status::FailedAndExhausted;
235                    cx.waker().wake_by_ref();
236                    return;
237                }
238            };
239
240            let future_instant = sleep(next_duration);
241
242            reconnect_status.attempts_tracker.attempt_num += 1;
243            let cur_num = reconnect_status.attempts_tracker.attempt_num;
244
245            let reconnect_attempt = async move {
246                future_instant.await;
247                info!("Attempting reconnect #{} now.", cur_num);
248                T::establish(ctor_arg).await
249            };
250
251            reconnect_status.reconnect_attempt = Arc::new(Mutex::new(Box::pin(reconnect_attempt)));
252
253            info!(
254                "Will perform reconnect attempt #{} in {:?}.",
255                reconnect_status.attempts_tracker.attempt_num, next_duration
256            );
257
258            cx.waker().wake_by_ref();
259        }
260    }
261
262    fn poll_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
263        let (attempt, attempt_num) = match self.status {
264            Status::Connected => unreachable!(),
265            Status::Disconnected(ref mut status) => (
266                status.reconnect_attempt.clone(),
267                status.attempts_tracker.attempt_num,
268            ),
269            Status::FailedAndExhausted => unreachable!(),
270        };
271
272        let mut attempt = attempt.lock().unwrap();
273
274        match attempt.as_mut().poll(cx) {
275            Poll::Ready(Ok(underlying_io)) => {
276                info!("Connection re-established");
277                cx.waker().wake_by_ref();
278                self.status = Status::Connected;
279                (self.options.on_connect_callback)();
280                self.underlying_io = underlying_io;
281            }
282            Poll::Ready(Err(err)) => {
283                error!("Connection attempt #{} failed: {:?}", attempt_num, err);
284                self.on_disconnect(cx);
285            }
286            Poll::Pending => {}
287        }
288    }
289
290    fn is_read_disconnect_detected(
291        &self,
292        poll_result: &Poll<io::Result<()>>,
293        bytes_read: usize,
294    ) -> bool {
295        match poll_result {
296            Poll::Ready(Ok(())) if self.is_final_read(bytes_read) => true,
297            Poll::Ready(Err(err)) => self.is_disconnect_error(err),
298            _ => false,
299        }
300    }
301
302    fn is_write_disconnect_detected<X>(&self, poll_result: &Poll<io::Result<X>>) -> bool {
303        match poll_result {
304            Poll::Ready(Err(err)) => self.is_disconnect_error(err),
305            _ => false,
306        }
307    }
308}
309
310impl<T, C> AsyncRead for StubbornIo<T, C>
311where
312    T: UnderlyingIo<C> + AsyncRead,
313    C: Clone + Send + Unpin + 'static,
314{
315    fn poll_read(
316        mut self: Pin<&mut Self>,
317        cx: &mut Context<'_>,
318        buf: &mut ReadBuf<'_>,
319    ) -> Poll<io::Result<()>> {
320        match &mut self.status {
321            Status::Connected => {
322                let pre_len = buf.filled().len();
323                let poll = AsyncRead::poll_read(Pin::new(&mut self.underlying_io), cx, buf);
324                let post_len = buf.filled().len();
325                let bytes_read = post_len - pre_len;
326                if self.is_read_disconnect_detected(&poll, bytes_read) {
327                    self.on_disconnect(cx);
328                    Poll::Pending
329                } else {
330                    poll
331                }
332            }
333            Status::Disconnected(_) => {
334                self.poll_disconnect(cx);
335                Poll::Pending
336            }
337            Status::FailedAndExhausted => exhausted_err(),
338        }
339    }
340}
341
342impl<T, C> AsyncWrite for StubbornIo<T, C>
343where
344    T: UnderlyingIo<C> + AsyncWrite,
345    C: Clone + Send + Unpin + 'static,
346{
347    fn poll_write(
348        mut self: Pin<&mut Self>,
349        cx: &mut Context<'_>,
350        buf: &[u8],
351    ) -> Poll<io::Result<usize>> {
352        match &mut self.status {
353            Status::Connected => {
354                let poll = AsyncWrite::poll_write(Pin::new(&mut self.underlying_io), cx, buf);
355
356                if self.is_write_disconnect_detected(&poll) {
357                    self.on_disconnect(cx);
358                    Poll::Pending
359                } else {
360                    poll
361                }
362            }
363            Status::Disconnected(_) => {
364                self.poll_disconnect(cx);
365                Poll::Pending
366            }
367            Status::FailedAndExhausted => exhausted_err(),
368        }
369    }
370
371    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
372        match &mut self.status {
373            Status::Connected => {
374                let poll = AsyncWrite::poll_flush(Pin::new(&mut self.underlying_io), cx);
375
376                if self.is_write_disconnect_detected(&poll) {
377                    self.on_disconnect(cx);
378                    Poll::Pending
379                } else {
380                    poll
381                }
382            }
383            Status::Disconnected(_) => {
384                self.poll_disconnect(cx);
385                Poll::Pending
386            }
387            Status::FailedAndExhausted => exhausted_err(),
388        }
389    }
390
391    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
392        match &mut self.status {
393            Status::Connected => {
394                let poll = AsyncWrite::poll_shutdown(Pin::new(&mut self.underlying_io), cx);
395                if poll.is_ready() {
396                    // if completed, we are disconnected whether error or not
397                    self.on_disconnect(cx);
398                }
399
400                poll
401            }
402            Status::Disconnected(_) => disconnected_err(),
403            Status::FailedAndExhausted => exhausted_err(),
404        }
405    }
406
407    fn poll_write_vectored(
408        mut self: Pin<&mut Self>,
409        cx: &mut Context<'_>,
410        bufs: &[IoSlice<'_>],
411    ) -> Poll<io::Result<usize>> {
412        match &mut self.status {
413            Status::Connected => {
414                let poll =
415                    AsyncWrite::poll_write_vectored(Pin::new(&mut self.underlying_io), cx, bufs);
416
417                if self.is_write_disconnect_detected(&poll) {
418                    self.on_disconnect(cx);
419                    Poll::Pending
420                } else {
421                    poll
422                }
423            }
424            Status::Disconnected(_) => {
425                self.poll_disconnect(cx);
426                Poll::Pending
427            }
428            Status::FailedAndExhausted => exhausted_err(),
429        }
430    }
431
432    fn is_write_vectored(&self) -> bool {
433        self.underlying_io.is_write_vectored()
434    }
435}