squads_temporal_client/
retry.rs

1use crate::{
2    ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY,
3    NamespacedClient, NoRetryOnMatching, Result, RetryConfig, raw::IsUserLongPoll,
4};
5use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff};
6use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
7use std::{error::Error, fmt::Debug, future::Future, sync::Arc, time::Duration};
8use tonic::{Code, Request};
9
10/// List of gRPC error codes that client will retry.
11pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
12    Code::DataLoss,
13    Code::Internal,
14    Code::Unknown,
15    Code::ResourceExhausted,
16    Code::Aborted,
17    Code::OutOfRange,
18    Code::Unavailable,
19];
20const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
21
22/// A wrapper for a [WorkflowClientTrait] or [crate::WorkflowService] implementor which performs
23/// auto-retries
24#[derive(Debug, Clone)]
25pub struct RetryClient<SG> {
26    client: SG,
27    retry_config: Arc<RetryConfig>,
28}
29
30impl<SG> RetryClient<SG> {
31    /// Use the provided retry config with the provided client
32    pub fn new(client: SG, retry_config: RetryConfig) -> Self {
33        Self {
34            client,
35            retry_config: Arc::new(retry_config),
36        }
37    }
38}
39
40impl<SG> RetryClient<SG> {
41    /// Return the inner client type
42    pub fn get_client(&self) -> &SG {
43        &self.client
44    }
45
46    /// Return the inner client type mutably
47    pub fn get_client_mut(&mut self) -> &mut SG {
48        &mut self.client
49    }
50
51    /// Disable retry and return the inner client type
52    pub fn into_inner(self) -> SG {
53        self.client
54    }
55
56    pub(crate) fn get_call_info<R>(
57        &self,
58        call_name: &'static str,
59        request: Option<&Request<R>>,
60    ) -> CallInfo {
61        let mut call_type = CallType::Normal;
62        let mut retry_short_circuit = None;
63        if let Some(r) = request.as_ref() {
64            let ext = r.extensions();
65            if ext.get::<IsUserLongPoll>().is_some() {
66                call_type = CallType::UserLongPoll;
67            } else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
68                call_type = CallType::TaskLongPoll;
69            }
70
71            retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
72        }
73        let retry_cfg = if call_type == CallType::TaskLongPoll {
74            RetryConfig::task_poll_retry_policy()
75        } else {
76            (*self.retry_config).clone()
77        };
78        CallInfo {
79            call_type,
80            call_name,
81            retry_cfg,
82            retry_short_circuit,
83        }
84    }
85
86    pub(crate) fn make_future_retry<R, F, Fut>(
87        info: CallInfo,
88        factory: F,
89    ) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
90    where
91        F: FnMut() -> Fut + Unpin,
92        Fut: Future<Output = Result<R>>,
93    {
94        FutureRetry::new(
95            factory,
96            TonicErrorHandler::new(info, RetryConfig::throttle_retry_policy()),
97        )
98    }
99}
100
101impl<SG> NamespacedClient for RetryClient<SG>
102where
103    SG: NamespacedClient,
104{
105    fn namespace(&self) -> String {
106        self.client.namespace()
107    }
108
109    fn identity(&self) -> String {
110        self.client.identity()
111    }
112}
113
114#[derive(Debug)]
115pub(crate) struct TonicErrorHandler<C: Clock> {
116    backoff: ExponentialBackoff<C>,
117    throttle_backoff: ExponentialBackoff<C>,
118    max_retries: usize,
119    call_type: CallType,
120    call_name: &'static str,
121    have_retried_goaway_cancel: bool,
122    retry_short_circuit: Option<NoRetryOnMatching>,
123}
124impl TonicErrorHandler<SystemClock> {
125    fn new(call_info: CallInfo, throttle_cfg: RetryConfig) -> Self {
126        Self::new_with_clock(
127            call_info,
128            throttle_cfg,
129            SystemClock::default(),
130            SystemClock::default(),
131        )
132    }
133}
134impl<C> TonicErrorHandler<C>
135where
136    C: Clock,
137{
138    fn new_with_clock(
139        call_info: CallInfo,
140        throttle_cfg: RetryConfig,
141        clock: C,
142        throttle_clock: C,
143    ) -> Self {
144        Self {
145            call_type: call_info.call_type,
146            call_name: call_info.call_name,
147            max_retries: call_info.retry_cfg.max_retries,
148            backoff: call_info.retry_cfg.into_exp_backoff(clock),
149            throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
150            have_retried_goaway_cancel: false,
151            retry_short_circuit: call_info.retry_short_circuit,
152        }
153    }
154
155    fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
156        let mut do_log = false;
157        // Warn on more than 5 retries for unlimited retrying
158        if self.max_retries == 0 && cur_attempt > 5 {
159            do_log = true;
160        }
161        // Warn if the attempts are more than 50% of max retries
162        if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
163            do_log = true;
164        }
165
166        if do_log {
167            // Error if unlimited retries have been going on for a while
168            if self.max_retries == 0 && cur_attempt > 15 {
169                error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
170            } else {
171                warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
172            }
173        }
174    }
175}
176
177#[derive(Clone, Debug)]
178pub(crate) struct CallInfo {
179    pub call_type: CallType,
180    call_name: &'static str,
181    retry_cfg: RetryConfig,
182    retry_short_circuit: Option<NoRetryOnMatching>,
183}
184
185#[doc(hidden)]
186#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
187pub enum CallType {
188    Normal,
189    // A long poll but won't always retry timeouts/cancels. EX: Get workflow history
190    UserLongPoll,
191    // A worker is polling for a task
192    TaskLongPoll,
193}
194
195impl CallType {
196    pub(crate) fn is_long(&self) -> bool {
197        matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
198    }
199}
200
201impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
202where
203    C: Clock,
204{
205    type OutError = tonic::Status;
206
207    fn handle(
208        &mut self,
209        current_attempt: usize,
210        mut e: tonic::Status,
211    ) -> RetryPolicy<tonic::Status> {
212        // 0 max retries means unlimited retries
213        if self.max_retries > 0 && current_attempt >= self.max_retries {
214            return RetryPolicy::ForwardError(e);
215        }
216
217        if let Some(sc) = self.retry_short_circuit.as_ref()
218            && (sc.predicate)(&e)
219        {
220            e.metadata_mut().insert(
221                ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT,
222                tonic::metadata::MetadataValue::from(0),
223            );
224            return RetryPolicy::ForwardError(e);
225        }
226
227        // Short circuit if message is too large - this is not retryable
228        if e.code() == Code::ResourceExhausted
229            && (e
230                .message()
231                .starts_with("grpc: received message larger than max")
232                || e.message()
233                    .starts_with("grpc: message after decompression larger than max")
234                || e.message()
235                    .starts_with("grpc: received message after decompression larger than max"))
236        {
237            // Leave a marker so we don't have duplicate detection logic in the workflow
238            e.metadata_mut().insert(
239                MESSAGE_TOO_LARGE_KEY,
240                tonic::metadata::MetadataValue::from(0),
241            );
242            return RetryPolicy::ForwardError(e);
243        }
244
245        // Task polls are OK with being cancelled or running into the timeout because there's
246        // nothing to do but retry anyway
247        let long_poll_allowed = self.call_type == CallType::TaskLongPoll
248            && [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
249
250        // Sometimes we can get a GOAWAY that, for whatever reason, isn't quite properly handled
251        // by hyper or some other internal lib, and we want to retry that still. We'll retry that
252        // at most once. Ideally this bit should be removed eventually if we can repro the upstream
253        // bug and it is fixed.
254        let mut goaway_retry_allowed = false;
255        if !self.have_retried_goaway_cancel
256            && e.code() == Code::Cancelled
257            && let Some(e) = e
258                .source()
259                .and_then(|e| e.downcast_ref::<tonic::transport::Error>())
260                .and_then(|te| te.source())
261                .and_then(|tec| tec.downcast_ref::<hyper::Error>())
262            && format!("{e:?}").contains("connection closed")
263        {
264            goaway_retry_allowed = true;
265            self.have_retried_goaway_cancel = true;
266        }
267
268        if RETRYABLE_ERROR_CODES.contains(&e.code()) || long_poll_allowed || goaway_retry_allowed {
269            if current_attempt == 1 {
270                debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
271            } else {
272                self.maybe_log_retry(current_attempt, &e);
273            }
274
275            match self.backoff.next_backoff() {
276                None => RetryPolicy::ForwardError(e), // None is returned when we've ran out of time
277                Some(backoff) => {
278                    // We treat ResourceExhausted as a special case and backoff more
279                    // so we don't overload the server
280                    if e.code() == Code::ResourceExhausted {
281                        let extended_backoff =
282                            backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
283                        RetryPolicy::WaitRetry(extended_backoff)
284                    } else {
285                        RetryPolicy::WaitRetry(backoff)
286                    }
287                }
288            }
289        } else if self.call_type == CallType::TaskLongPoll
290            && self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
291        {
292            // We permit "fatal" errors while long polling for a while, because some proxies return
293            // stupid error codes while getting ready, among other weird infra issues
294            RetryPolicy::WaitRetry(self.backoff.max_interval)
295        } else {
296            RetryPolicy::ForwardError(e)
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use assert_matches::assert_matches;
305    use backoff::Clock;
306    use std::{ops::Add, time::Instant};
307    use temporal_sdk_core_protos::temporal::api::workflowservice::v1::{
308        PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
309    };
310    use tonic::{IntoRequest, Status};
311
312    /// Predefined retry configs with low durations to make unit tests faster
313    const TEST_RETRY_CONFIG: RetryConfig = RetryConfig {
314        initial_interval: Duration::from_millis(1),
315        randomization_factor: 0.0,
316        multiplier: 1.1,
317        max_interval: Duration::from_millis(2),
318        max_elapsed_time: None,
319        max_retries: 10,
320    };
321
322    const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
323    const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
324    const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
325
326    struct FixedClock(Instant);
327    impl Clock for FixedClock {
328        fn now(&self) -> Instant {
329            self.0
330        }
331    }
332
333    #[tokio::test]
334    async fn long_poll_non_retryable_errors() {
335        for code in [
336            Code::InvalidArgument,
337            Code::NotFound,
338            Code::AlreadyExists,
339            Code::PermissionDenied,
340            Code::FailedPrecondition,
341            Code::Unauthenticated,
342            Code::Unimplemented,
343        ] {
344            for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
345                let mut err_handler = TonicErrorHandler::new_with_clock(
346                    CallInfo {
347                        call_type: CallType::TaskLongPoll,
348                        call_name,
349                        retry_cfg: TEST_RETRY_CONFIG,
350                        retry_short_circuit: None,
351                    },
352                    TEST_RETRY_CONFIG,
353                    FixedClock(Instant::now()),
354                    FixedClock(Instant::now()),
355                );
356                let result = err_handler.handle(1, Status::new(code, "Ahh"));
357                assert_matches!(result, RetryPolicy::WaitRetry(_));
358                err_handler.backoff.clock.0 = err_handler
359                    .backoff
360                    .clock
361                    .0
362                    .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
363                let result = err_handler.handle(2, Status::new(code, "Ahh"));
364                assert_matches!(result, RetryPolicy::ForwardError(_));
365            }
366        }
367    }
368
369    #[tokio::test]
370    async fn long_poll_retryable_errors_never_fatal() {
371        for code in RETRYABLE_ERROR_CODES {
372            for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
373                let mut err_handler = TonicErrorHandler::new_with_clock(
374                    CallInfo {
375                        call_type: CallType::TaskLongPoll,
376                        call_name,
377                        retry_cfg: TEST_RETRY_CONFIG,
378                        retry_short_circuit: None,
379                    },
380                    TEST_RETRY_CONFIG,
381                    FixedClock(Instant::now()),
382                    FixedClock(Instant::now()),
383                );
384                let result = err_handler.handle(1, Status::new(code, "Ahh"));
385                assert_matches!(result, RetryPolicy::WaitRetry(_));
386                err_handler.backoff.clock.0 = err_handler
387                    .backoff
388                    .clock
389                    .0
390                    .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
391                let result = err_handler.handle(2, Status::new(code, "Ahh"));
392                assert_matches!(result, RetryPolicy::WaitRetry(_));
393            }
394        }
395    }
396
397    #[tokio::test]
398    async fn retry_resource_exhausted() {
399        let mut err_handler = TonicErrorHandler::new_with_clock(
400            CallInfo {
401                call_type: CallType::TaskLongPoll,
402                call_name: POLL_WORKFLOW_METH_NAME,
403                retry_cfg: TEST_RETRY_CONFIG,
404                retry_short_circuit: None,
405            },
406            RetryConfig {
407                initial_interval: Duration::from_millis(2),
408                randomization_factor: 0.0,
409                multiplier: 4.0,
410                max_interval: Duration::from_millis(10),
411                max_elapsed_time: None,
412                max_retries: 10,
413            },
414            FixedClock(Instant::now()),
415            FixedClock(Instant::now()),
416        );
417        let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
418        match result {
419            RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
420            _ => panic!(),
421        }
422        err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
423        err_handler.throttle_backoff.clock.0 = err_handler
424            .throttle_backoff
425            .clock
426            .0
427            .add(Duration::from_millis(10));
428        let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
429        match result {
430            RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
431            _ => panic!(),
432        }
433    }
434
435    #[tokio::test]
436    async fn retry_short_circuit() {
437        let mut err_handler = TonicErrorHandler::new_with_clock(
438            CallInfo {
439                call_type: CallType::TaskLongPoll,
440                call_name: POLL_WORKFLOW_METH_NAME,
441                retry_cfg: TEST_RETRY_CONFIG,
442                retry_short_circuit: Some(NoRetryOnMatching {
443                    predicate: |s: &Status| s.code() == Code::ResourceExhausted,
444                }),
445            },
446            TEST_RETRY_CONFIG,
447            FixedClock(Instant::now()),
448            FixedClock(Instant::now()),
449        );
450        let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
451        let e = assert_matches!(result, RetryPolicy::ForwardError(e) => e);
452        assert!(
453            e.metadata()
454                .get(ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT)
455                .is_some()
456        );
457    }
458
459    #[tokio::test]
460    async fn message_too_large_not_retried() {
461        let mut err_handler = TonicErrorHandler::new_with_clock(
462            CallInfo {
463                call_type: CallType::TaskLongPoll,
464                call_name: POLL_WORKFLOW_METH_NAME,
465                retry_cfg: TEST_RETRY_CONFIG,
466                retry_short_circuit: None,
467            },
468            TEST_RETRY_CONFIG,
469            FixedClock(Instant::now()),
470            FixedClock(Instant::now()),
471        );
472        let result = err_handler.handle(
473            1,
474            Status::new(
475                Code::ResourceExhausted,
476                "grpc: received message larger than max",
477            ),
478        );
479        assert_matches!(result, RetryPolicy::ForwardError(_));
480
481        let result = err_handler.handle(
482            1,
483            Status::new(
484                Code::ResourceExhausted,
485                "grpc: message after decompression larger than max",
486            ),
487        );
488        assert_matches!(result, RetryPolicy::ForwardError(_));
489
490        let result = err_handler.handle(
491            1,
492            Status::new(
493                Code::ResourceExhausted,
494                "grpc: received message after decompression larger than max",
495            ),
496        );
497        assert_matches!(result, RetryPolicy::ForwardError(_));
498    }
499
500    #[rstest::rstest]
501    #[tokio::test]
502    async fn task_poll_retries_forever<R>(
503        #[values(
504                (
505                    POLL_WORKFLOW_METH_NAME,
506                    PollWorkflowTaskQueueRequest::default(),
507                ),
508                (
509                    POLL_ACTIVITY_METH_NAME,
510                    PollActivityTaskQueueRequest::default(),
511                ),
512                (
513                    POLL_NEXUS_METH_NAME,
514                    PollNexusTaskQueueRequest::default(),
515                ),
516        )]
517        (call_name, req): (&'static str, R),
518    ) {
519        // A bit odd, but we don't need a real client to test the retry client passes through the
520        // correct retry config
521        let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
522        let mut req = req.into_request();
523        req.extensions_mut().insert(IsWorkerTaskLongPoll);
524        for i in 1..=50 {
525            let mut err_handler = TonicErrorHandler::new(
526                fake_retry.get_call_info::<R>(call_name, Some(&req)),
527                RetryConfig::throttle_retry_policy(),
528            );
529            let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
530            assert_matches!(result, RetryPolicy::WaitRetry(_));
531        }
532    }
533
534    #[rstest::rstest]
535    #[tokio::test]
536    async fn task_poll_retries_deadline_exceeded<R>(
537        #[values(
538                (
539                    POLL_WORKFLOW_METH_NAME,
540                    PollWorkflowTaskQueueRequest::default(),
541                ),
542                (
543                    POLL_ACTIVITY_METH_NAME,
544                    PollActivityTaskQueueRequest::default(),
545                ),
546                (
547                    POLL_NEXUS_METH_NAME,
548                    PollNexusTaskQueueRequest::default(),
549                ),
550        )]
551        (call_name, req): (&'static str, R),
552    ) {
553        let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
554        let mut req = req.into_request();
555        req.extensions_mut().insert(IsWorkerTaskLongPoll);
556        // For some reason we will get cancelled in these situations occasionally (always?) too
557        for code in [Code::Cancelled, Code::DeadlineExceeded] {
558            let mut err_handler = TonicErrorHandler::new(
559                fake_retry.get_call_info::<R>(call_name, Some(&req)),
560                RetryConfig::throttle_retry_policy(),
561            );
562            for i in 1..=5 {
563                let result = err_handler.handle(i, Status::new(code, "retryable failure"));
564                assert_matches!(result, RetryPolicy::WaitRetry(_));
565            }
566        }
567    }
568}