squads_temporal_client/
retry.rs

1use crate::{
2    Client, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY, NamespacedClient, NoRetryOnMatching,
3    Result, RetryConfig, raw::IsUserLongPoll,
4};
5use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff};
6use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
7use std::{error::Error, fmt::Debug, future::Future, sync::Arc, time::Duration};
8use tonic::{Code, Request};
9
10/// List of gRPC error codes that client will retry.
11pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
12    Code::DataLoss,
13    Code::Internal,
14    Code::Unknown,
15    Code::ResourceExhausted,
16    Code::Aborted,
17    Code::OutOfRange,
18    Code::Unavailable,
19];
20const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
21
22/// A wrapper for a [WorkflowClientTrait] or [crate::WorkflowService] implementor which performs
23/// auto-retries
24#[derive(Debug, Clone)]
25pub struct RetryClient<SG> {
26    client: SG,
27    retry_config: Arc<RetryConfig>,
28}
29
30impl<SG> RetryClient<SG> {
31    /// Use the provided retry config with the provided client
32    pub fn new(client: SG, retry_config: RetryConfig) -> Self {
33        Self {
34            client,
35            retry_config: Arc::new(retry_config),
36        }
37    }
38}
39
40impl<SG> RetryClient<SG> {
41    /// Return the inner client type
42    pub fn get_client(&self) -> &SG {
43        &self.client
44    }
45
46    /// Return the inner client type mutably
47    pub fn get_client_mut(&mut self) -> &mut SG {
48        &mut self.client
49    }
50
51    /// Disable retry and return the inner client type
52    pub fn into_inner(self) -> SG {
53        self.client
54    }
55
56    pub(crate) fn get_call_info<R>(
57        &self,
58        call_name: &'static str,
59        request: Option<&Request<R>>,
60    ) -> CallInfo {
61        let mut call_type = CallType::Normal;
62        let mut retry_short_circuit = None;
63        if let Some(r) = request.as_ref() {
64            let ext = r.extensions();
65            if ext.get::<IsUserLongPoll>().is_some() {
66                call_type = CallType::UserLongPoll;
67            } else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
68                call_type = CallType::TaskLongPoll;
69            }
70
71            retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
72        }
73        let retry_cfg = if call_type == CallType::TaskLongPoll {
74            RetryConfig::task_poll_retry_policy()
75        } else {
76            (*self.retry_config).clone()
77        };
78        CallInfo {
79            call_type,
80            call_name,
81            retry_cfg,
82            retry_short_circuit,
83        }
84    }
85
86    pub(crate) fn make_future_retry<R, F, Fut>(
87        info: CallInfo,
88        factory: F,
89    ) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
90    where
91        F: FnMut() -> Fut + Unpin,
92        Fut: Future<Output = Result<R>>,
93    {
94        FutureRetry::new(
95            factory,
96            TonicErrorHandler::new(info, RetryConfig::throttle_retry_policy()),
97        )
98    }
99}
100
101impl NamespacedClient for RetryClient<Client> {
102    fn namespace(&self) -> &str {
103        &self.client.namespace
104    }
105
106    fn get_identity(&self) -> &str {
107        &self.client.options().identity
108    }
109}
110
111#[derive(Debug)]
112pub(crate) struct TonicErrorHandler<C: Clock> {
113    backoff: ExponentialBackoff<C>,
114    throttle_backoff: ExponentialBackoff<C>,
115    max_retries: usize,
116    call_type: CallType,
117    call_name: &'static str,
118    have_retried_goaway_cancel: bool,
119    retry_short_circuit: Option<NoRetryOnMatching>,
120}
121impl TonicErrorHandler<SystemClock> {
122    fn new(call_info: CallInfo, throttle_cfg: RetryConfig) -> Self {
123        Self::new_with_clock(
124            call_info,
125            throttle_cfg,
126            SystemClock::default(),
127            SystemClock::default(),
128        )
129    }
130}
131impl<C> TonicErrorHandler<C>
132where
133    C: Clock,
134{
135    fn new_with_clock(
136        call_info: CallInfo,
137        throttle_cfg: RetryConfig,
138        clock: C,
139        throttle_clock: C,
140    ) -> Self {
141        Self {
142            call_type: call_info.call_type,
143            call_name: call_info.call_name,
144            max_retries: call_info.retry_cfg.max_retries,
145            backoff: call_info.retry_cfg.into_exp_backoff(clock),
146            throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
147            have_retried_goaway_cancel: false,
148            retry_short_circuit: call_info.retry_short_circuit,
149        }
150    }
151
152    fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
153        let mut do_log = false;
154        // Warn on more than 5 retries for unlimited retrying
155        if self.max_retries == 0 && cur_attempt > 5 {
156            do_log = true;
157        }
158        // Warn if the attempts are more than 50% of max retries
159        if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
160            do_log = true;
161        }
162
163        if do_log {
164            // Error if unlimited retries have been going on for a while
165            if self.max_retries == 0 && cur_attempt > 15 {
166                error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
167            } else {
168                warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
169            }
170        }
171    }
172}
173
174#[derive(Clone, Debug)]
175pub(crate) struct CallInfo {
176    pub call_type: CallType,
177    call_name: &'static str,
178    retry_cfg: RetryConfig,
179    retry_short_circuit: Option<NoRetryOnMatching>,
180}
181
182#[doc(hidden)]
183#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
184pub enum CallType {
185    Normal,
186    // A long poll but won't always retry timeouts/cancels. EX: Get workflow history
187    UserLongPoll,
188    // A worker is polling for a task
189    TaskLongPoll,
190}
191
192impl CallType {
193    pub(crate) fn is_long(&self) -> bool {
194        matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
195    }
196}
197
198impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
199where
200    C: Clock,
201{
202    type OutError = tonic::Status;
203
204    fn handle(
205        &mut self,
206        current_attempt: usize,
207        mut e: tonic::Status,
208    ) -> RetryPolicy<tonic::Status> {
209        // 0 max retries means unlimited retries
210        if self.max_retries > 0 && current_attempt >= self.max_retries {
211            return RetryPolicy::ForwardError(e);
212        }
213
214        if let Some(sc) = self.retry_short_circuit.as_ref()
215            && (sc.predicate)(&e)
216        {
217            return RetryPolicy::ForwardError(e);
218        }
219
220        // Short circuit if message is too large - this is not retryable
221        if e.code() == Code::ResourceExhausted
222            && (e
223                .message()
224                .starts_with("grpc: received message larger than max")
225                || e.message()
226                    .starts_with("grpc: message after decompression larger than max")
227                || e.message()
228                    .starts_with("grpc: received message after decompression larger than max"))
229        {
230            // Leave a marker so we don't have duplicate detection logic in the workflow
231            e.metadata_mut().insert(
232                MESSAGE_TOO_LARGE_KEY,
233                tonic::metadata::MetadataValue::from(0),
234            );
235            return RetryPolicy::ForwardError(e);
236        }
237
238        // Task polls are OK with being cancelled or running into the timeout because there's
239        // nothing to do but retry anyway
240        let long_poll_allowed = self.call_type == CallType::TaskLongPoll
241            && [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
242
243        // Sometimes we can get a GOAWAY that, for whatever reason, isn't quite properly handled
244        // by hyper or some other internal lib, and we want to retry that still. We'll retry that
245        // at most once. Ideally this bit should be removed eventually if we can repro the upstream
246        // bug and it is fixed.
247        let mut goaway_retry_allowed = false;
248        if !self.have_retried_goaway_cancel
249            && e.code() == Code::Cancelled
250            && let Some(e) = e
251                .source()
252                .and_then(|e| e.downcast_ref::<tonic::transport::Error>())
253                .and_then(|te| te.source())
254                .and_then(|tec| tec.downcast_ref::<hyper::Error>())
255            && format!("{e:?}").contains("connection closed")
256        {
257            goaway_retry_allowed = true;
258            self.have_retried_goaway_cancel = true;
259        }
260
261        if RETRYABLE_ERROR_CODES.contains(&e.code()) || long_poll_allowed || goaway_retry_allowed {
262            if current_attempt == 1 {
263                debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
264            } else {
265                self.maybe_log_retry(current_attempt, &e);
266            }
267
268            match self.backoff.next_backoff() {
269                None => RetryPolicy::ForwardError(e), // None is returned when we've ran out of time
270                Some(backoff) => {
271                    // We treat ResourceExhausted as a special case and backoff more
272                    // so we don't overload the server
273                    if e.code() == Code::ResourceExhausted {
274                        let extended_backoff =
275                            backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
276                        RetryPolicy::WaitRetry(extended_backoff)
277                    } else {
278                        RetryPolicy::WaitRetry(backoff)
279                    }
280                }
281            }
282        } else if self.call_type == CallType::TaskLongPoll
283            && self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
284        {
285            // We permit "fatal" errors while long polling for a while, because some proxies return
286            // stupid error codes while getting ready, among other weird infra issues
287            RetryPolicy::WaitRetry(self.backoff.max_interval)
288        } else {
289            RetryPolicy::ForwardError(e)
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use assert_matches::assert_matches;
298    use backoff::Clock;
299    use std::{ops::Add, time::Instant};
300    use squads_temporal_sdk_core_protos::temporal::api::workflowservice::v1::{
301        PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
302    };
303    use tonic::{IntoRequest, Status};
304
305    /// Predefined retry configs with low durations to make unit tests faster
306    const TEST_RETRY_CONFIG: RetryConfig = RetryConfig {
307        initial_interval: Duration::from_millis(1),
308        randomization_factor: 0.0,
309        multiplier: 1.1,
310        max_interval: Duration::from_millis(2),
311        max_elapsed_time: None,
312        max_retries: 10,
313    };
314
315    const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
316    const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
317    const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
318
319    struct FixedClock(Instant);
320    impl Clock for FixedClock {
321        fn now(&self) -> Instant {
322            self.0
323        }
324    }
325
326    #[tokio::test]
327    async fn long_poll_non_retryable_errors() {
328        for code in [
329            Code::InvalidArgument,
330            Code::NotFound,
331            Code::AlreadyExists,
332            Code::PermissionDenied,
333            Code::FailedPrecondition,
334            Code::Unauthenticated,
335            Code::Unimplemented,
336        ] {
337            for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
338                let mut err_handler = TonicErrorHandler::new_with_clock(
339                    CallInfo {
340                        call_type: CallType::TaskLongPoll,
341                        call_name,
342                        retry_cfg: TEST_RETRY_CONFIG,
343                        retry_short_circuit: None,
344                    },
345                    TEST_RETRY_CONFIG,
346                    FixedClock(Instant::now()),
347                    FixedClock(Instant::now()),
348                );
349                let result = err_handler.handle(1, Status::new(code, "Ahh"));
350                assert_matches!(result, RetryPolicy::WaitRetry(_));
351                err_handler.backoff.clock.0 = err_handler
352                    .backoff
353                    .clock
354                    .0
355                    .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
356                let result = err_handler.handle(2, Status::new(code, "Ahh"));
357                assert_matches!(result, RetryPolicy::ForwardError(_));
358            }
359        }
360    }
361
362    #[tokio::test]
363    async fn long_poll_retryable_errors_never_fatal() {
364        for code in RETRYABLE_ERROR_CODES {
365            for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
366                let mut err_handler = TonicErrorHandler::new_with_clock(
367                    CallInfo {
368                        call_type: CallType::TaskLongPoll,
369                        call_name,
370                        retry_cfg: TEST_RETRY_CONFIG,
371                        retry_short_circuit: None,
372                    },
373                    TEST_RETRY_CONFIG,
374                    FixedClock(Instant::now()),
375                    FixedClock(Instant::now()),
376                );
377                let result = err_handler.handle(1, Status::new(code, "Ahh"));
378                assert_matches!(result, RetryPolicy::WaitRetry(_));
379                err_handler.backoff.clock.0 = err_handler
380                    .backoff
381                    .clock
382                    .0
383                    .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
384                let result = err_handler.handle(2, Status::new(code, "Ahh"));
385                assert_matches!(result, RetryPolicy::WaitRetry(_));
386            }
387        }
388    }
389
390    #[tokio::test]
391    async fn retry_resource_exhausted() {
392        let mut err_handler = TonicErrorHandler::new_with_clock(
393            CallInfo {
394                call_type: CallType::TaskLongPoll,
395                call_name: POLL_WORKFLOW_METH_NAME,
396                retry_cfg: TEST_RETRY_CONFIG,
397                retry_short_circuit: None,
398            },
399            RetryConfig {
400                initial_interval: Duration::from_millis(2),
401                randomization_factor: 0.0,
402                multiplier: 4.0,
403                max_interval: Duration::from_millis(10),
404                max_elapsed_time: None,
405                max_retries: 10,
406            },
407            FixedClock(Instant::now()),
408            FixedClock(Instant::now()),
409        );
410        let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
411        match result {
412            RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
413            _ => panic!(),
414        }
415        err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
416        err_handler.throttle_backoff.clock.0 = err_handler
417            .throttle_backoff
418            .clock
419            .0
420            .add(Duration::from_millis(10));
421        let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
422        match result {
423            RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
424            _ => panic!(),
425        }
426    }
427
428    #[tokio::test]
429    async fn retry_short_circuit() {
430        let mut err_handler = TonicErrorHandler::new_with_clock(
431            CallInfo {
432                call_type: CallType::TaskLongPoll,
433                call_name: POLL_WORKFLOW_METH_NAME,
434                retry_cfg: TEST_RETRY_CONFIG,
435                retry_short_circuit: Some(NoRetryOnMatching {
436                    predicate: |s: &Status| s.code() == Code::ResourceExhausted,
437                }),
438            },
439            TEST_RETRY_CONFIG,
440            FixedClock(Instant::now()),
441            FixedClock(Instant::now()),
442        );
443        let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
444        assert_matches!(result, RetryPolicy::ForwardError(_))
445    }
446
447    #[tokio::test]
448    async fn message_too_large_not_retried() {
449        let mut err_handler = TonicErrorHandler::new_with_clock(
450            CallInfo {
451                call_type: CallType::TaskLongPoll,
452                call_name: POLL_WORKFLOW_METH_NAME,
453                retry_cfg: TEST_RETRY_CONFIG,
454                retry_short_circuit: None,
455            },
456            TEST_RETRY_CONFIG,
457            FixedClock(Instant::now()),
458            FixedClock(Instant::now()),
459        );
460        let result = err_handler.handle(
461            1,
462            Status::new(
463                Code::ResourceExhausted,
464                "grpc: received message larger than max",
465            ),
466        );
467        assert_matches!(result, RetryPolicy::ForwardError(_));
468
469        let result = err_handler.handle(
470            1,
471            Status::new(
472                Code::ResourceExhausted,
473                "grpc: message after decompression larger than max",
474            ),
475        );
476        assert_matches!(result, RetryPolicy::ForwardError(_));
477
478        let result = err_handler.handle(
479            1,
480            Status::new(
481                Code::ResourceExhausted,
482                "grpc: received message after decompression larger than max",
483            ),
484        );
485        assert_matches!(result, RetryPolicy::ForwardError(_));
486    }
487
488    #[rstest::rstest]
489    #[tokio::test]
490    async fn task_poll_retries_forever<R>(
491        #[values(
492                (
493                    POLL_WORKFLOW_METH_NAME,
494                    PollWorkflowTaskQueueRequest::default(),
495                ),
496                (
497                    POLL_ACTIVITY_METH_NAME,
498                    PollActivityTaskQueueRequest::default(),
499                ),
500                (
501                    POLL_NEXUS_METH_NAME,
502                    PollNexusTaskQueueRequest::default(),
503                ),
504        )]
505        (call_name, req): (&'static str, R),
506    ) {
507        // A bit odd, but we don't need a real client to test the retry client passes through the
508        // correct retry config
509        let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
510        let mut req = req.into_request();
511        req.extensions_mut().insert(IsWorkerTaskLongPoll);
512        for i in 1..=50 {
513            let mut err_handler = TonicErrorHandler::new(
514                fake_retry.get_call_info::<R>(call_name, Some(&req)),
515                RetryConfig::throttle_retry_policy(),
516            );
517            let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
518            assert_matches!(result, RetryPolicy::WaitRetry(_));
519        }
520    }
521
522    #[rstest::rstest]
523    #[tokio::test]
524    async fn task_poll_retries_deadline_exceeded<R>(
525        #[values(
526                (
527                    POLL_WORKFLOW_METH_NAME,
528                    PollWorkflowTaskQueueRequest::default(),
529                ),
530                (
531                    POLL_ACTIVITY_METH_NAME,
532                    PollActivityTaskQueueRequest::default(),
533                ),
534                (
535                    POLL_NEXUS_METH_NAME,
536                    PollNexusTaskQueueRequest::default(),
537                ),
538        )]
539        (call_name, req): (&'static str, R),
540    ) {
541        let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
542        let mut req = req.into_request();
543        req.extensions_mut().insert(IsWorkerTaskLongPoll);
544        // For some reason we will get cancelled in these situations occasionally (always?) too
545        for code in [Code::Cancelled, Code::DeadlineExceeded] {
546            let mut err_handler = TonicErrorHandler::new(
547                fake_retry.get_call_info::<R>(call_name, Some(&req)),
548                RetryConfig::throttle_retry_policy(),
549            );
550            for i in 1..=5 {
551                let result = err_handler.handle(i, Status::new(code, "retryable failure"));
552                assert_matches!(result, RetryPolicy::WaitRetry(_));
553            }
554        }
555    }
556}