1use crate::{
2 ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, MESSAGE_TOO_LARGE_KEY,
3 grpc::IsUserLongPoll,
4 request_extensions::{IsWorkerTaskLongPoll, NoRetryOnMatching, RetryConfigForCall},
5};
6use backoff::{
7 Clock, SystemClock,
8 backoff::Backoff,
9 exponential::{self, ExponentialBackoff},
10};
11use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
12use std::{
13 error::Error,
14 fmt::Debug,
15 future::Future,
16 time::{Duration, Instant},
17};
18use tonic::Code;
19
20#[doc(hidden)]
22pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
23 Code::DataLoss,
24 Code::Internal,
25 Code::Unknown,
26 Code::ResourceExhausted,
27 Code::Aborted,
28 Code::OutOfRange,
29 Code::Unavailable,
30];
31const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
32
33#[derive(Clone, Debug, PartialEq)]
35pub struct RetryOptions {
36 pub initial_interval: Duration,
38 pub randomization_factor: f64,
41 pub multiplier: f64,
43 pub max_interval: Duration,
45 pub max_elapsed_time: Option<Duration>,
48 pub max_retries: usize,
50}
51
52impl Default for RetryOptions {
53 fn default() -> Self {
54 Self {
55 initial_interval: Duration::from_millis(100), randomization_factor: 0.2, multiplier: 1.7, max_interval: Duration::from_secs(5), max_elapsed_time: Some(Duration::from_secs(10)), max_retries: 10,
61 }
62 }
63}
64
65impl RetryOptions {
66 pub(crate) const fn task_poll_retry_policy() -> Self {
67 Self {
68 initial_interval: Duration::from_millis(200),
69 randomization_factor: 0.2,
70 multiplier: 2.0,
71 max_interval: Duration::from_secs(10),
72 max_elapsed_time: None,
73 max_retries: 0,
74 }
75 }
76
77 pub(crate) const fn throttle_retry_policy() -> Self {
78 Self {
79 initial_interval: Duration::from_secs(1),
80 randomization_factor: 0.2,
81 multiplier: 2.0,
82 max_interval: Duration::from_secs(10),
83 max_elapsed_time: None,
84 max_retries: 0,
85 }
86 }
87
88 pub const fn no_retries() -> Self {
90 Self {
91 initial_interval: Duration::from_secs(0),
92 randomization_factor: 0.0,
93 multiplier: 1.0,
94 max_interval: Duration::from_secs(0),
95 max_elapsed_time: None,
96 max_retries: 1,
97 }
98 }
99
100 pub(crate) fn get_call_info<R>(
101 &self,
102 call_name: &'static str,
103 request: Option<&tonic::Request<R>>,
104 ) -> CallInfo {
105 let mut call_type = CallType::Normal;
106 let mut retry_short_circuit = None;
107 let mut retry_cfg_override = None;
108 if let Some(r) = request.as_ref() {
109 let ext = r.extensions();
110 if ext.get::<IsUserLongPoll>().is_some() {
111 call_type = CallType::UserLongPoll;
112 } else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
113 call_type = CallType::TaskLongPoll;
114 }
115
116 retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
117 retry_cfg_override = ext.get::<RetryConfigForCall>().cloned();
118 }
119 let retry_cfg = if let Some(ovr) = retry_cfg_override {
120 ovr.0
121 } else if call_type == CallType::TaskLongPoll {
122 RetryOptions::task_poll_retry_policy()
123 } else {
124 self.clone()
125 };
126 CallInfo {
127 call_type,
128 call_name,
129 retry_cfg,
130 retry_short_circuit,
131 }
132 }
133
134 pub(crate) fn into_exp_backoff<C>(self, clock: C) -> exponential::ExponentialBackoff<C> {
135 exponential::ExponentialBackoff {
136 current_interval: self.initial_interval,
137 initial_interval: self.initial_interval,
138 randomization_factor: self.randomization_factor,
139 multiplier: self.multiplier,
140 max_interval: self.max_interval,
141 max_elapsed_time: self.max_elapsed_time,
142 clock,
143 start_time: Instant::now(),
144 }
145 }
146}
147
148impl From<RetryOptions> for backoff::ExponentialBackoff {
149 fn from(c: RetryOptions) -> Self {
150 c.into_exp_backoff(SystemClock::default())
151 }
152}
153
154pub(crate) fn make_future_retry<R, F, Fut>(
155 info: CallInfo,
156 factory: F,
157) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
158where
159 F: FnMut() -> Fut + Unpin,
160 Fut: Future<Output = Result<R, tonic::Status>>,
161{
162 FutureRetry::new(
163 factory,
164 TonicErrorHandler::new(info, RetryOptions::throttle_retry_policy()),
165 )
166}
167
168#[derive(Debug)]
169pub(crate) struct TonicErrorHandler<C: Clock> {
170 backoff: ExponentialBackoff<C>,
171 throttle_backoff: ExponentialBackoff<C>,
172 max_retries: usize,
173 call_type: CallType,
174 call_name: &'static str,
175 retry_short_circuit: Option<NoRetryOnMatching>,
176}
177impl TonicErrorHandler<SystemClock> {
178 fn new(call_info: CallInfo, throttle_cfg: RetryOptions) -> Self {
179 Self::new_with_clock(
180 call_info,
181 throttle_cfg,
182 SystemClock::default(),
183 SystemClock::default(),
184 )
185 }
186}
187impl<C> TonicErrorHandler<C>
188where
189 C: Clock,
190{
191 fn new_with_clock(
192 call_info: CallInfo,
193 throttle_cfg: RetryOptions,
194 clock: C,
195 throttle_clock: C,
196 ) -> Self {
197 Self {
198 call_type: call_info.call_type,
199 call_name: call_info.call_name,
200 max_retries: call_info.retry_cfg.max_retries,
201 backoff: call_info.retry_cfg.into_exp_backoff(clock),
202 throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
203 retry_short_circuit: call_info.retry_short_circuit,
204 }
205 }
206
207 fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
208 let mut do_log = false;
209 if self.max_retries == 0 && cur_attempt > 5 {
211 do_log = true;
212 }
213 if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
215 do_log = true;
216 }
217
218 if do_log {
219 if self.max_retries == 0 && cur_attempt > 15 {
221 error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
222 } else {
223 warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
224 }
225 }
226 }
227}
228
229#[derive(Clone, Debug)]
230pub(crate) struct CallInfo {
231 pub call_type: CallType,
232 call_name: &'static str,
233 retry_cfg: RetryOptions,
234 retry_short_circuit: Option<NoRetryOnMatching>,
235}
236
237#[doc(hidden)]
238#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
239pub enum CallType {
240 Normal,
241 UserLongPoll,
243 TaskLongPoll,
245}
246
247impl CallType {
248 pub(crate) fn is_long(&self) -> bool {
249 matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
250 }
251}
252
253impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
254where
255 C: Clock,
256{
257 type OutError = tonic::Status;
258
259 fn handle(
260 &mut self,
261 current_attempt: usize,
262 mut e: tonic::Status,
263 ) -> RetryPolicy<tonic::Status> {
264 if self.max_retries > 0 && current_attempt >= self.max_retries {
266 return RetryPolicy::ForwardError(e);
267 }
268
269 if let Some(sc) = self.retry_short_circuit.as_ref()
270 && (sc.predicate)(&e)
271 {
272 e.metadata_mut().insert(
273 ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT,
274 tonic::metadata::MetadataValue::from(0),
275 );
276 return RetryPolicy::ForwardError(e);
277 }
278
279 if e.code() == Code::ResourceExhausted
281 && (e
282 .message()
283 .starts_with("grpc: received message larger than max")
284 || e.message()
285 .starts_with("grpc: message after decompression larger than max")
286 || e.message()
287 .starts_with("grpc: received message after decompression larger than max"))
288 {
289 e.metadata_mut().insert(
291 MESSAGE_TOO_LARGE_KEY,
292 tonic::metadata::MetadataValue::from(0),
293 );
294 return RetryPolicy::ForwardError(e);
295 }
296
297 let long_poll_allowed = self.call_type == CallType::TaskLongPoll
300 && [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
301
302 let transport_cancel_retry_allowed =
307 e.code() == Code::Cancelled && is_transport_cancelled(&e);
308
309 if RETRYABLE_ERROR_CODES.contains(&e.code())
310 || long_poll_allowed
311 || transport_cancel_retry_allowed
312 {
313 if current_attempt == 1 {
314 debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
315 } else {
316 self.maybe_log_retry(current_attempt, &e);
317 }
318
319 match self.backoff.next_backoff() {
320 None => RetryPolicy::ForwardError(e), Some(backoff) => {
322 if e.code() == Code::ResourceExhausted {
325 let extended_backoff =
326 backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
327 RetryPolicy::WaitRetry(extended_backoff)
328 } else {
329 RetryPolicy::WaitRetry(backoff)
330 }
331 }
332 }
333 } else if self.call_type == CallType::TaskLongPoll
334 && self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
335 {
336 RetryPolicy::WaitRetry(self.backoff.max_interval)
339 } else {
340 RetryPolicy::ForwardError(e)
341 }
342 }
343}
344
345fn is_transport_cancelled(status: &tonic::Status) -> bool {
350 status
351 .source()
352 .and_then(|e| e.downcast_ref::<tonic::transport::Error>())
353 .and_then(|te| te.source())
354 .and_then(|tec| tec.downcast_ref::<hyper::Error>())
355 .is_some()
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use assert_matches::assert_matches;
362 use backoff::Clock;
363 use std::{ops::Add, time::Instant};
364 use temporalio_common::protos::temporal::api::workflowservice::v1::{
365 PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
366 };
367 use tonic::{IntoRequest, Status};
368
369 const TEST_RETRY_CONFIG: RetryOptions = RetryOptions {
371 initial_interval: Duration::from_millis(1),
372 randomization_factor: 0.0,
373 multiplier: 1.1,
374 max_interval: Duration::from_millis(2),
375 max_elapsed_time: None,
376 max_retries: 10,
377 };
378
379 const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
380 const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
381 const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
382
383 struct FixedClock(Instant);
384 impl Clock for FixedClock {
385 fn now(&self) -> Instant {
386 self.0
387 }
388 }
389
390 #[tokio::test]
391 async fn long_poll_non_retryable_errors() {
392 for code in [
393 Code::InvalidArgument,
394 Code::NotFound,
395 Code::AlreadyExists,
396 Code::PermissionDenied,
397 Code::FailedPrecondition,
398 Code::Unauthenticated,
399 Code::Unimplemented,
400 ] {
401 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
402 let mut err_handler = TonicErrorHandler::new_with_clock(
403 CallInfo {
404 call_type: CallType::TaskLongPoll,
405 call_name,
406 retry_cfg: TEST_RETRY_CONFIG,
407 retry_short_circuit: None,
408 },
409 TEST_RETRY_CONFIG,
410 FixedClock(Instant::now()),
411 FixedClock(Instant::now()),
412 );
413 let result = err_handler.handle(1, Status::new(code, "Ahh"));
414 assert_matches!(result, RetryPolicy::WaitRetry(_));
415 err_handler.backoff.clock.0 = err_handler
416 .backoff
417 .clock
418 .0
419 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
420 let result = err_handler.handle(2, Status::new(code, "Ahh"));
421 assert_matches!(result, RetryPolicy::ForwardError(_));
422 }
423 }
424 }
425
426 #[tokio::test]
427 async fn long_poll_retryable_errors_never_fatal() {
428 for code in RETRYABLE_ERROR_CODES {
429 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
430 let mut err_handler = TonicErrorHandler::new_with_clock(
431 CallInfo {
432 call_type: CallType::TaskLongPoll,
433 call_name,
434 retry_cfg: TEST_RETRY_CONFIG,
435 retry_short_circuit: None,
436 },
437 TEST_RETRY_CONFIG,
438 FixedClock(Instant::now()),
439 FixedClock(Instant::now()),
440 );
441 let result = err_handler.handle(1, Status::new(code, "Ahh"));
442 assert_matches!(result, RetryPolicy::WaitRetry(_));
443 err_handler.backoff.clock.0 = err_handler
444 .backoff
445 .clock
446 .0
447 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
448 let result = err_handler.handle(2, Status::new(code, "Ahh"));
449 assert_matches!(result, RetryPolicy::WaitRetry(_));
450 }
451 }
452 }
453
454 #[tokio::test]
455 async fn retry_resource_exhausted() {
456 let mut err_handler = TonicErrorHandler::new_with_clock(
457 CallInfo {
458 call_type: CallType::TaskLongPoll,
459 call_name: POLL_WORKFLOW_METH_NAME,
460 retry_cfg: TEST_RETRY_CONFIG,
461 retry_short_circuit: None,
462 },
463 RetryOptions {
464 initial_interval: Duration::from_millis(2),
465 randomization_factor: 0.0,
466 multiplier: 4.0,
467 max_interval: Duration::from_millis(10),
468 max_elapsed_time: None,
469 max_retries: 10,
470 },
471 FixedClock(Instant::now()),
472 FixedClock(Instant::now()),
473 );
474 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
475 match result {
476 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
477 _ => panic!(),
478 }
479 err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
480 err_handler.throttle_backoff.clock.0 = err_handler
481 .throttle_backoff
482 .clock
483 .0
484 .add(Duration::from_millis(10));
485 let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
486 match result {
487 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
488 _ => panic!(),
489 }
490 }
491
492 #[tokio::test]
493 async fn retry_short_circuit() {
494 let mut err_handler = TonicErrorHandler::new_with_clock(
495 CallInfo {
496 call_type: CallType::TaskLongPoll,
497 call_name: POLL_WORKFLOW_METH_NAME,
498 retry_cfg: TEST_RETRY_CONFIG,
499 retry_short_circuit: Some(NoRetryOnMatching {
500 predicate: |s: &Status| s.code() == Code::ResourceExhausted,
501 }),
502 },
503 TEST_RETRY_CONFIG,
504 FixedClock(Instant::now()),
505 FixedClock(Instant::now()),
506 );
507 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
508 let e = assert_matches!(result, RetryPolicy::ForwardError(e) => e);
509 assert!(
510 e.metadata()
511 .get(ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT)
512 .is_some()
513 );
514 }
515
516 #[tokio::test]
517 async fn message_too_large_not_retried() {
518 let mut err_handler = TonicErrorHandler::new_with_clock(
519 CallInfo {
520 call_type: CallType::TaskLongPoll,
521 call_name: POLL_WORKFLOW_METH_NAME,
522 retry_cfg: TEST_RETRY_CONFIG,
523 retry_short_circuit: None,
524 },
525 TEST_RETRY_CONFIG,
526 FixedClock(Instant::now()),
527 FixedClock(Instant::now()),
528 );
529 let result = err_handler.handle(
530 1,
531 Status::new(
532 Code::ResourceExhausted,
533 "grpc: received message larger than max",
534 ),
535 );
536 assert_matches!(result, RetryPolicy::ForwardError(_));
537
538 let result = err_handler.handle(
539 1,
540 Status::new(
541 Code::ResourceExhausted,
542 "grpc: message after decompression larger than max",
543 ),
544 );
545 assert_matches!(result, RetryPolicy::ForwardError(_));
546
547 let result = err_handler.handle(
548 1,
549 Status::new(
550 Code::ResourceExhausted,
551 "grpc: received message after decompression larger than max",
552 ),
553 );
554 assert_matches!(result, RetryPolicy::ForwardError(_));
555 }
556
557 #[rstest::rstest]
558 #[tokio::test]
559 async fn task_poll_retries_forever<R>(
560 #[values(
561 (
562 POLL_WORKFLOW_METH_NAME,
563 PollWorkflowTaskQueueRequest::default(),
564 ),
565 (
566 POLL_ACTIVITY_METH_NAME,
567 PollActivityTaskQueueRequest::default(),
568 ),
569 (
570 POLL_NEXUS_METH_NAME,
571 PollNexusTaskQueueRequest::default(),
572 ),
573 )]
574 (call_name, req): (&'static str, R),
575 ) {
576 let mut req = req.into_request();
579 req.extensions_mut().insert(IsWorkerTaskLongPoll);
580 for i in 1..=50 {
581 let mut err_handler = TonicErrorHandler::new(
582 TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
583 RetryOptions::throttle_retry_policy(),
584 );
585 let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
586 assert_matches!(result, RetryPolicy::WaitRetry(_));
587 }
588 }
589
590 #[rstest::rstest]
591 #[tokio::test]
592 async fn task_poll_retries_deadline_exceeded<R>(
593 #[values(
594 (
595 POLL_WORKFLOW_METH_NAME,
596 PollWorkflowTaskQueueRequest::default(),
597 ),
598 (
599 POLL_ACTIVITY_METH_NAME,
600 PollActivityTaskQueueRequest::default(),
601 ),
602 (
603 POLL_NEXUS_METH_NAME,
604 PollNexusTaskQueueRequest::default(),
605 ),
606 )]
607 (call_name, req): (&'static str, R),
608 ) {
609 let mut req = req.into_request();
610 req.extensions_mut().insert(IsWorkerTaskLongPoll);
611 for code in [Code::Cancelled, Code::DeadlineExceeded] {
613 let mut err_handler = TonicErrorHandler::new(
614 TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
615 RetryOptions::throttle_retry_policy(),
616 );
617 for i in 1..=5 {
618 let result = err_handler.handle(i, Status::new(code, "retryable failure"));
619 assert_matches!(result, RetryPolicy::WaitRetry(_));
620 }
621 }
622 }
623
624 #[tokio::test]
625 async fn plain_cancelled_not_retried_on_normal_call() {
626 let mut err_handler = TonicErrorHandler::new_with_clock(
629 CallInfo {
630 call_type: CallType::Normal,
631 call_name: "respond_activity_task_completed",
632 retry_cfg: TEST_RETRY_CONFIG,
633 retry_short_circuit: None,
634 },
635 TEST_RETRY_CONFIG,
636 FixedClock(Instant::now()),
637 FixedClock(Instant::now()),
638 );
639 let result = err_handler.handle(1, Status::new(Code::Cancelled, "caller cancelled"));
640 assert_matches!(result, RetryPolicy::ForwardError(_));
641 }
642
643 #[tokio::test]
644 async fn is_transport_cancelled_false_for_plain_status() {
645 let status = Status::new(Code::Cancelled, "caller cancelled");
648 assert!(!is_transport_cancelled(&status));
649 }
650
651 #[tokio::test]
652 async fn transport_sourced_cancelled_retried_on_full_budget() {
653 let mut err_handler = TonicErrorHandler::new_with_clock(
664 CallInfo {
665 call_type: CallType::Normal,
666 call_name: "respond_activity_task_completed",
667 retry_cfg: TEST_RETRY_CONFIG,
668 retry_short_circuit: None,
669 },
670 TEST_RETRY_CONFIG,
671 FixedClock(Instant::now()),
672 FixedClock(Instant::now()),
673 );
674
675 for i in 1..=5 {
678 let endpoint = tonic::transport::Endpoint::from_static("http://[::1]:1")
679 .connect_timeout(Duration::from_millis(1));
680 let transport_err = endpoint.connect().await.unwrap_err();
681 let status = Status::from_error(Box::new(transport_err));
682
683 let result = err_handler.handle(i, status);
684 assert_matches!(
685 result,
686 RetryPolicy::WaitRetry(_),
687 "Transport error should be retried on attempt {i}"
688 );
689 }
690 }
691}