Skip to main content

temporalio_client/
retry.rs

1use crate::{
2    ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, MESSAGE_TOO_LARGE_KEY,
3    grpc::IsUserLongPoll,
4    request_extensions::{IsWorkerTaskLongPoll, NoRetryOnMatching, RetryConfigForCall},
5};
6use backoff::{
7    Clock, SystemClock,
8    backoff::Backoff,
9    exponential::{self, ExponentialBackoff},
10};
11use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
12use std::{
13    error::Error,
14    fmt::Debug,
15    future::Future,
16    time::{Duration, Instant},
17};
18use tonic::Code;
19
20/// List of gRPC error codes that client will retry.
21#[doc(hidden)]
22pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
23    Code::DataLoss,
24    Code::Internal,
25    Code::Unknown,
26    Code::ResourceExhausted,
27    Code::Aborted,
28    Code::OutOfRange,
29    Code::Unavailable,
30];
31const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
32
33/// Configuration for retrying requests to the server
34#[derive(Clone, Debug, PartialEq)]
35pub struct RetryOptions {
36    /// initial wait time before the first retry.
37    pub initial_interval: Duration,
38    /// randomization jitter that is used as a multiplier for the current retry interval
39    /// and is added or subtracted from the interval length.
40    pub randomization_factor: f64,
41    /// rate at which retry time should be increased, until it reaches max_interval.
42    pub multiplier: f64,
43    /// maximum amount of time to wait between retries.
44    pub max_interval: Duration,
45    /// maximum total amount of time requests should be retried for, if None is set then no limit
46    /// will be used.
47    pub max_elapsed_time: Option<Duration>,
48    /// maximum number of retry attempts.
49    pub max_retries: usize,
50}
51
52impl Default for RetryOptions {
53    fn default() -> Self {
54        Self {
55            initial_interval: Duration::from_millis(100), // 100 ms wait by default.
56            randomization_factor: 0.2,                    // +-20% jitter.
57            multiplier: 1.7, // each next retry delay will increase by 70%
58            max_interval: Duration::from_secs(5), // until it reaches 5 seconds.
59            max_elapsed_time: Some(Duration::from_secs(10)), // 10 seconds total allocated time for all retries.
60            max_retries: 10,
61        }
62    }
63}
64
65impl RetryOptions {
66    pub(crate) const fn task_poll_retry_policy() -> Self {
67        Self {
68            initial_interval: Duration::from_millis(200),
69            randomization_factor: 0.2,
70            multiplier: 2.0,
71            max_interval: Duration::from_secs(10),
72            max_elapsed_time: None,
73            max_retries: 0,
74        }
75    }
76
77    pub(crate) const fn throttle_retry_policy() -> Self {
78        Self {
79            initial_interval: Duration::from_secs(1),
80            randomization_factor: 0.2,
81            multiplier: 2.0,
82            max_interval: Duration::from_secs(10),
83            max_elapsed_time: None,
84            max_retries: 0,
85        }
86    }
87
88    /// A retry policy that never retires
89    pub const fn no_retries() -> Self {
90        Self {
91            initial_interval: Duration::from_secs(0),
92            randomization_factor: 0.0,
93            multiplier: 1.0,
94            max_interval: Duration::from_secs(0),
95            max_elapsed_time: None,
96            max_retries: 1,
97        }
98    }
99
100    pub(crate) fn get_call_info<R>(
101        &self,
102        call_name: &'static str,
103        request: Option<&tonic::Request<R>>,
104    ) -> CallInfo {
105        let mut call_type = CallType::Normal;
106        let mut retry_short_circuit = None;
107        let mut retry_cfg_override = None;
108        if let Some(r) = request.as_ref() {
109            let ext = r.extensions();
110            if ext.get::<IsUserLongPoll>().is_some() {
111                call_type = CallType::UserLongPoll;
112            } else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
113                call_type = CallType::TaskLongPoll;
114            }
115
116            retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
117            retry_cfg_override = ext.get::<RetryConfigForCall>().cloned();
118        }
119        let retry_cfg = if let Some(ovr) = retry_cfg_override {
120            ovr.0
121        } else if call_type == CallType::TaskLongPoll {
122            RetryOptions::task_poll_retry_policy()
123        } else {
124            self.clone()
125        };
126        CallInfo {
127            call_type,
128            call_name,
129            retry_cfg,
130            retry_short_circuit,
131        }
132    }
133
134    pub(crate) fn into_exp_backoff<C>(self, clock: C) -> exponential::ExponentialBackoff<C> {
135        exponential::ExponentialBackoff {
136            current_interval: self.initial_interval,
137            initial_interval: self.initial_interval,
138            randomization_factor: self.randomization_factor,
139            multiplier: self.multiplier,
140            max_interval: self.max_interval,
141            max_elapsed_time: self.max_elapsed_time,
142            clock,
143            start_time: Instant::now(),
144        }
145    }
146}
147
148impl From<RetryOptions> for backoff::ExponentialBackoff {
149    fn from(c: RetryOptions) -> Self {
150        c.into_exp_backoff(SystemClock::default())
151    }
152}
153
154pub(crate) fn make_future_retry<R, F, Fut>(
155    info: CallInfo,
156    factory: F,
157) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
158where
159    F: FnMut() -> Fut + Unpin,
160    Fut: Future<Output = Result<R, tonic::Status>>,
161{
162    FutureRetry::new(
163        factory,
164        TonicErrorHandler::new(info, RetryOptions::throttle_retry_policy()),
165    )
166}
167
168#[derive(Debug)]
169pub(crate) struct TonicErrorHandler<C: Clock> {
170    backoff: ExponentialBackoff<C>,
171    throttle_backoff: ExponentialBackoff<C>,
172    max_retries: usize,
173    call_type: CallType,
174    call_name: &'static str,
175    have_retried_goaway_cancel: bool,
176    retry_short_circuit: Option<NoRetryOnMatching>,
177}
178impl TonicErrorHandler<SystemClock> {
179    fn new(call_info: CallInfo, throttle_cfg: RetryOptions) -> Self {
180        Self::new_with_clock(
181            call_info,
182            throttle_cfg,
183            SystemClock::default(),
184            SystemClock::default(),
185        )
186    }
187}
188impl<C> TonicErrorHandler<C>
189where
190    C: Clock,
191{
192    fn new_with_clock(
193        call_info: CallInfo,
194        throttle_cfg: RetryOptions,
195        clock: C,
196        throttle_clock: C,
197    ) -> Self {
198        Self {
199            call_type: call_info.call_type,
200            call_name: call_info.call_name,
201            max_retries: call_info.retry_cfg.max_retries,
202            backoff: call_info.retry_cfg.into_exp_backoff(clock),
203            throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
204            have_retried_goaway_cancel: false,
205            retry_short_circuit: call_info.retry_short_circuit,
206        }
207    }
208
209    fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
210        let mut do_log = false;
211        // Warn on more than 5 retries for unlimited retrying
212        if self.max_retries == 0 && cur_attempt > 5 {
213            do_log = true;
214        }
215        // Warn if the attempts are more than 50% of max retries
216        if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
217            do_log = true;
218        }
219
220        if do_log {
221            // Error if unlimited retries have been going on for a while
222            if self.max_retries == 0 && cur_attempt > 15 {
223                error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
224            } else {
225                warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
226            }
227        }
228    }
229}
230
231#[derive(Clone, Debug)]
232pub(crate) struct CallInfo {
233    pub call_type: CallType,
234    call_name: &'static str,
235    retry_cfg: RetryOptions,
236    retry_short_circuit: Option<NoRetryOnMatching>,
237}
238
239#[doc(hidden)]
240#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
241pub enum CallType {
242    Normal,
243    // A long poll but won't always retry timeouts/cancels. EX: Get workflow history
244    UserLongPoll,
245    // A worker is polling for a task
246    TaskLongPoll,
247}
248
249impl CallType {
250    pub(crate) fn is_long(&self) -> bool {
251        matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
252    }
253}
254
255impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
256where
257    C: Clock,
258{
259    type OutError = tonic::Status;
260
261    fn handle(
262        &mut self,
263        current_attempt: usize,
264        mut e: tonic::Status,
265    ) -> RetryPolicy<tonic::Status> {
266        // 0 max retries means unlimited retries
267        if self.max_retries > 0 && current_attempt >= self.max_retries {
268            return RetryPolicy::ForwardError(e);
269        }
270
271        if let Some(sc) = self.retry_short_circuit.as_ref()
272            && (sc.predicate)(&e)
273        {
274            e.metadata_mut().insert(
275                ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT,
276                tonic::metadata::MetadataValue::from(0),
277            );
278            return RetryPolicy::ForwardError(e);
279        }
280
281        // Short circuit if message is too large - this is not retryable
282        if e.code() == Code::ResourceExhausted
283            && (e
284                .message()
285                .starts_with("grpc: received message larger than max")
286                || e.message()
287                    .starts_with("grpc: message after decompression larger than max")
288                || e.message()
289                    .starts_with("grpc: received message after decompression larger than max"))
290        {
291            // Leave a marker so we don't have duplicate detection logic in the workflow
292            e.metadata_mut().insert(
293                MESSAGE_TOO_LARGE_KEY,
294                tonic::metadata::MetadataValue::from(0),
295            );
296            return RetryPolicy::ForwardError(e);
297        }
298
299        // Task polls are OK with being cancelled or running into the timeout because there's
300        // nothing to do but retry anyway
301        let long_poll_allowed = self.call_type == CallType::TaskLongPoll
302            && [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
303
304        // Sometimes we can get a GOAWAY that, for whatever reason, isn't quite properly handled
305        // by hyper or some other internal lib, and we want to retry that still. We'll retry that
306        // at most once. Ideally this bit should be removed eventually if we can repro the upstream
307        // bug and it is fixed.
308        let mut goaway_retry_allowed = false;
309        if !self.have_retried_goaway_cancel
310            && e.code() == Code::Cancelled
311            && let Some(e) = e
312                .source()
313                .and_then(|e| e.downcast_ref::<tonic::transport::Error>())
314                .and_then(|te| te.source())
315                .and_then(|tec| tec.downcast_ref::<hyper::Error>())
316            && format!("{e:?}").contains("connection closed")
317        {
318            goaway_retry_allowed = true;
319            self.have_retried_goaway_cancel = true;
320        }
321
322        if RETRYABLE_ERROR_CODES.contains(&e.code()) || long_poll_allowed || goaway_retry_allowed {
323            if current_attempt == 1 {
324                debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
325            } else {
326                self.maybe_log_retry(current_attempt, &e);
327            }
328
329            match self.backoff.next_backoff() {
330                None => RetryPolicy::ForwardError(e), // None is returned when we've ran out of time
331                Some(backoff) => {
332                    // We treat ResourceExhausted as a special case and backoff more
333                    // so we don't overload the server
334                    if e.code() == Code::ResourceExhausted {
335                        let extended_backoff =
336                            backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
337                        RetryPolicy::WaitRetry(extended_backoff)
338                    } else {
339                        RetryPolicy::WaitRetry(backoff)
340                    }
341                }
342            }
343        } else if self.call_type == CallType::TaskLongPoll
344            && self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
345        {
346            // We permit "fatal" errors while long polling for a while, because some proxies return
347            // stupid error codes while getting ready, among other weird infra issues
348            RetryPolicy::WaitRetry(self.backoff.max_interval)
349        } else {
350            RetryPolicy::ForwardError(e)
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use assert_matches::assert_matches;
359    use backoff::Clock;
360    use std::{ops::Add, time::Instant};
361    use temporalio_common::protos::temporal::api::workflowservice::v1::{
362        PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
363    };
364    use tonic::{IntoRequest, Status};
365
366    /// Predefined retry configs with low durations to make unit tests faster
367    const TEST_RETRY_CONFIG: RetryOptions = RetryOptions {
368        initial_interval: Duration::from_millis(1),
369        randomization_factor: 0.0,
370        multiplier: 1.1,
371        max_interval: Duration::from_millis(2),
372        max_elapsed_time: None,
373        max_retries: 10,
374    };
375
376    const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
377    const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
378    const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
379
380    struct FixedClock(Instant);
381    impl Clock for FixedClock {
382        fn now(&self) -> Instant {
383            self.0
384        }
385    }
386
387    #[tokio::test]
388    async fn long_poll_non_retryable_errors() {
389        for code in [
390            Code::InvalidArgument,
391            Code::NotFound,
392            Code::AlreadyExists,
393            Code::PermissionDenied,
394            Code::FailedPrecondition,
395            Code::Unauthenticated,
396            Code::Unimplemented,
397        ] {
398            for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
399                let mut err_handler = TonicErrorHandler::new_with_clock(
400                    CallInfo {
401                        call_type: CallType::TaskLongPoll,
402                        call_name,
403                        retry_cfg: TEST_RETRY_CONFIG,
404                        retry_short_circuit: None,
405                    },
406                    TEST_RETRY_CONFIG,
407                    FixedClock(Instant::now()),
408                    FixedClock(Instant::now()),
409                );
410                let result = err_handler.handle(1, Status::new(code, "Ahh"));
411                assert_matches!(result, RetryPolicy::WaitRetry(_));
412                err_handler.backoff.clock.0 = err_handler
413                    .backoff
414                    .clock
415                    .0
416                    .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
417                let result = err_handler.handle(2, Status::new(code, "Ahh"));
418                assert_matches!(result, RetryPolicy::ForwardError(_));
419            }
420        }
421    }
422
423    #[tokio::test]
424    async fn long_poll_retryable_errors_never_fatal() {
425        for code in RETRYABLE_ERROR_CODES {
426            for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
427                let mut err_handler = TonicErrorHandler::new_with_clock(
428                    CallInfo {
429                        call_type: CallType::TaskLongPoll,
430                        call_name,
431                        retry_cfg: TEST_RETRY_CONFIG,
432                        retry_short_circuit: None,
433                    },
434                    TEST_RETRY_CONFIG,
435                    FixedClock(Instant::now()),
436                    FixedClock(Instant::now()),
437                );
438                let result = err_handler.handle(1, Status::new(code, "Ahh"));
439                assert_matches!(result, RetryPolicy::WaitRetry(_));
440                err_handler.backoff.clock.0 = err_handler
441                    .backoff
442                    .clock
443                    .0
444                    .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
445                let result = err_handler.handle(2, Status::new(code, "Ahh"));
446                assert_matches!(result, RetryPolicy::WaitRetry(_));
447            }
448        }
449    }
450
451    #[tokio::test]
452    async fn retry_resource_exhausted() {
453        let mut err_handler = TonicErrorHandler::new_with_clock(
454            CallInfo {
455                call_type: CallType::TaskLongPoll,
456                call_name: POLL_WORKFLOW_METH_NAME,
457                retry_cfg: TEST_RETRY_CONFIG,
458                retry_short_circuit: None,
459            },
460            RetryOptions {
461                initial_interval: Duration::from_millis(2),
462                randomization_factor: 0.0,
463                multiplier: 4.0,
464                max_interval: Duration::from_millis(10),
465                max_elapsed_time: None,
466                max_retries: 10,
467            },
468            FixedClock(Instant::now()),
469            FixedClock(Instant::now()),
470        );
471        let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
472        match result {
473            RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
474            _ => panic!(),
475        }
476        err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
477        err_handler.throttle_backoff.clock.0 = err_handler
478            .throttle_backoff
479            .clock
480            .0
481            .add(Duration::from_millis(10));
482        let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
483        match result {
484            RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
485            _ => panic!(),
486        }
487    }
488
489    #[tokio::test]
490    async fn retry_short_circuit() {
491        let mut err_handler = TonicErrorHandler::new_with_clock(
492            CallInfo {
493                call_type: CallType::TaskLongPoll,
494                call_name: POLL_WORKFLOW_METH_NAME,
495                retry_cfg: TEST_RETRY_CONFIG,
496                retry_short_circuit: Some(NoRetryOnMatching {
497                    predicate: |s: &Status| s.code() == Code::ResourceExhausted,
498                }),
499            },
500            TEST_RETRY_CONFIG,
501            FixedClock(Instant::now()),
502            FixedClock(Instant::now()),
503        );
504        let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
505        let e = assert_matches!(result, RetryPolicy::ForwardError(e) => e);
506        assert!(
507            e.metadata()
508                .get(ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT)
509                .is_some()
510        );
511    }
512
513    #[tokio::test]
514    async fn message_too_large_not_retried() {
515        let mut err_handler = TonicErrorHandler::new_with_clock(
516            CallInfo {
517                call_type: CallType::TaskLongPoll,
518                call_name: POLL_WORKFLOW_METH_NAME,
519                retry_cfg: TEST_RETRY_CONFIG,
520                retry_short_circuit: None,
521            },
522            TEST_RETRY_CONFIG,
523            FixedClock(Instant::now()),
524            FixedClock(Instant::now()),
525        );
526        let result = err_handler.handle(
527            1,
528            Status::new(
529                Code::ResourceExhausted,
530                "grpc: received message larger than max",
531            ),
532        );
533        assert_matches!(result, RetryPolicy::ForwardError(_));
534
535        let result = err_handler.handle(
536            1,
537            Status::new(
538                Code::ResourceExhausted,
539                "grpc: message after decompression larger than max",
540            ),
541        );
542        assert_matches!(result, RetryPolicy::ForwardError(_));
543
544        let result = err_handler.handle(
545            1,
546            Status::new(
547                Code::ResourceExhausted,
548                "grpc: received message after decompression larger than max",
549            ),
550        );
551        assert_matches!(result, RetryPolicy::ForwardError(_));
552    }
553
554    #[rstest::rstest]
555    #[tokio::test]
556    async fn task_poll_retries_forever<R>(
557        #[values(
558                (
559                    POLL_WORKFLOW_METH_NAME,
560                    PollWorkflowTaskQueueRequest::default(),
561                ),
562                (
563                    POLL_ACTIVITY_METH_NAME,
564                    PollActivityTaskQueueRequest::default(),
565                ),
566                (
567                    POLL_NEXUS_METH_NAME,
568                    PollNexusTaskQueueRequest::default(),
569                ),
570        )]
571        (call_name, req): (&'static str, R),
572    ) {
573        // A bit odd, but we don't need a real client to test the retry client passes through the
574        // correct retry config
575        let mut req = req.into_request();
576        req.extensions_mut().insert(IsWorkerTaskLongPoll);
577        for i in 1..=50 {
578            let mut err_handler = TonicErrorHandler::new(
579                TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
580                RetryOptions::throttle_retry_policy(),
581            );
582            let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
583            assert_matches!(result, RetryPolicy::WaitRetry(_));
584        }
585    }
586
587    #[rstest::rstest]
588    #[tokio::test]
589    async fn task_poll_retries_deadline_exceeded<R>(
590        #[values(
591                (
592                    POLL_WORKFLOW_METH_NAME,
593                    PollWorkflowTaskQueueRequest::default(),
594                ),
595                (
596                    POLL_ACTIVITY_METH_NAME,
597                    PollActivityTaskQueueRequest::default(),
598                ),
599                (
600                    POLL_NEXUS_METH_NAME,
601                    PollNexusTaskQueueRequest::default(),
602                ),
603        )]
604        (call_name, req): (&'static str, R),
605    ) {
606        let mut req = req.into_request();
607        req.extensions_mut().insert(IsWorkerTaskLongPoll);
608        // For some reason we will get cancelled in these situations occasionally (always?) too
609        for code in [Code::Cancelled, Code::DeadlineExceeded] {
610            let mut err_handler = TonicErrorHandler::new(
611                TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
612                RetryOptions::throttle_retry_policy(),
613            );
614            for i in 1..=5 {
615                let result = err_handler.handle(i, Status::new(code, "retryable failure"));
616                assert_matches!(result, RetryPolicy::WaitRetry(_));
617            }
618        }
619    }
620}