1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
//! An async version of the [retry](crate::sync::retry) and
//! [retry_if](crate::sync::retry_if) function along with
//! [RetryFuture](crate::future::RetryFuture) used to implement them.

use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
    time::Instant,
};

use crate::Backoff;

/// Retries the provided function if it returns an error whenever the backoff
/// allows. The first call resulting in success will have it's value returned
/// to the caller.
///
/// # Example
/// ```
/// # use tryagain::*;
/// # async {
/// // In the real world this would do some computation that returns a result.
/// async fn returns_err() -> Result<(), ()> {
///     Err(())
/// }
///
/// // In this example we never get a value, we just spin forever.
/// let value = tryagain::future::retry(ExponentialBackoff::default(), || returns_err()).await;
/// # };
/// ```
pub fn retry<B, F, T, E, Fut>(
    backoff: B,
    func: F,
) -> RetryFuture<F, Fut, impl Fn(&E, u32) -> bool, B>
where
    B: Backoff,
    F: Fn() -> Fut,
    Fut: Future<Output = Result<T, E>>,
{
    retry_if(backoff, func, |_, _| true)
}

/// Calls the provided function and if an error is returned it is passed to
/// the predicate to determine if the function should be retried when the
/// backoff function allows.
///
/// # Example
/// ```
/// # use tryagain::*;
/// # async {
/// enum Error {
///     Recoverable,
///     Fatal,
/// }
///
/// async fn returns_fatal_error() -> Result<(), Error> {
///     Err(Error::Fatal)
/// }
///
/// // Returns a Result of Error::Fatal
/// let result = tryagain::future::retry_if(
///     ExponentialBackoff::default(),
///     || returns_fatal_error(),
///     |error, _iterations| match error {
///         Error::Fatal => false, // This error isn't recoverable.
///         _ => true,
///     },
/// ).await;
/// # };
/// ```
pub fn retry_if<B, F, P, T, E, Fut>(backoff: B, func: F, predicate: P) -> RetryFuture<F, Fut, P, B>
where
    B: Backoff,
    F: Fn() -> Fut,
    Fut: Future<Output = Result<T, E>>,
    P: Fn(&E, u32) -> bool,
{
    let future = func();

    RetryFuture {
        factory: func,
        future,
        predicate,
        backoff,
        paused_until: None,
        iterations: 0,
    }
}

#[pin_project::pin_project]
/// A future that will retry an operation.
pub struct RetryFuture<F, Fut, P, B> {
    factory: F,
    #[pin]
    future: Fut,
    predicate: P,
    backoff: B,
    paused_until: Option<Instant>,
    iterations: u32,
}

impl<T, E, F, Fut, P, B> Future for RetryFuture<F, Fut, P, B>
where
    F: Fn() -> Fut,
    Fut: Future<Output = Result<T, E>>,
    P: Fn(&E, u32) -> bool,
    B: Backoff,
{
    type Output = Result<T, E>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut this = self.project();

        if let Some(paused_until) = this.paused_until {
            if Instant::now() < *paused_until {
                return Poll::Pending;
            }

            *this.paused_until = None;
        }

        let result = match this.future.as_mut().poll(cx) {
            Poll::Ready(res) => res,
            Poll::Pending => return Poll::Pending,
        };

        match result {
            Ok(value) => return Poll::Ready(Ok(value)),
            Err(e) => {
                *this.iterations += 1;
                let can_continue = (this.predicate)(&e, *this.iterations);

                if !can_continue {
                    return Poll::Ready(Err(e));
                }

                let new_future = (this.factory)();
                this.future.set(new_future);

                let duration = this.backoff.backoff_period(*this.iterations);
                let waker = cx.waker().clone();

                *this.paused_until = Some(Instant::now() + duration);

                // This is a hack to call the waker, I don't have a better way
                // to do this other than looping, which would block.
                #[cfg(feature = "runtime-tokio")]
                tokio::spawn(async move {
                    tokio::time::sleep(duration).await;
                    waker.wake();
                });

                #[cfg(feature = "runtime-async-std")]
                async_std::task::spawn(async move {
                    async_std::task::sleep(duration).await;
                    waker.wake();
                });

                Poll::Pending
            }
        }
    }
}