1use crate::{
2 ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, MESSAGE_TOO_LARGE_KEY,
3 grpc::IsUserLongPoll,
4 request_extensions::{IsWorkerTaskLongPoll, NoRetryOnMatching, RetryConfigForCall},
5};
6use backoff::{
7 Clock, SystemClock,
8 backoff::Backoff,
9 exponential::{self, ExponentialBackoff},
10};
11use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
12use std::{
13 error::Error,
14 fmt::Debug,
15 future::Future,
16 time::{Duration, Instant},
17};
18use tonic::Code;
19
20#[doc(hidden)]
22pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
23 Code::DataLoss,
24 Code::Internal,
25 Code::Unknown,
26 Code::ResourceExhausted,
27 Code::Aborted,
28 Code::OutOfRange,
29 Code::Unavailable,
30];
31const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
32
33#[derive(Clone, Debug, PartialEq)]
35pub struct RetryOptions {
36 pub initial_interval: Duration,
38 pub randomization_factor: f64,
41 pub multiplier: f64,
43 pub max_interval: Duration,
45 pub max_elapsed_time: Option<Duration>,
48 pub max_retries: usize,
50}
51
52impl Default for RetryOptions {
53 fn default() -> Self {
54 Self {
55 initial_interval: Duration::from_millis(100), randomization_factor: 0.2, multiplier: 1.7, max_interval: Duration::from_secs(5), max_elapsed_time: Some(Duration::from_secs(10)), max_retries: 10,
61 }
62 }
63}
64
65impl RetryOptions {
66 pub(crate) const fn task_poll_retry_policy() -> Self {
67 Self {
68 initial_interval: Duration::from_millis(200),
69 randomization_factor: 0.2,
70 multiplier: 2.0,
71 max_interval: Duration::from_secs(10),
72 max_elapsed_time: None,
73 max_retries: 0,
74 }
75 }
76
77 pub(crate) const fn throttle_retry_policy() -> Self {
78 Self {
79 initial_interval: Duration::from_secs(1),
80 randomization_factor: 0.2,
81 multiplier: 2.0,
82 max_interval: Duration::from_secs(10),
83 max_elapsed_time: None,
84 max_retries: 0,
85 }
86 }
87
88 pub const fn no_retries() -> Self {
90 Self {
91 initial_interval: Duration::from_secs(0),
92 randomization_factor: 0.0,
93 multiplier: 1.0,
94 max_interval: Duration::from_secs(0),
95 max_elapsed_time: None,
96 max_retries: 1,
97 }
98 }
99
100 pub(crate) fn get_call_info<R>(
101 &self,
102 call_name: &'static str,
103 request: Option<&tonic::Request<R>>,
104 ) -> CallInfo {
105 let mut call_type = CallType::Normal;
106 let mut retry_short_circuit = None;
107 let mut retry_cfg_override = None;
108 if let Some(r) = request.as_ref() {
109 let ext = r.extensions();
110 if ext.get::<IsUserLongPoll>().is_some() {
111 call_type = CallType::UserLongPoll;
112 } else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
113 call_type = CallType::TaskLongPoll;
114 }
115
116 retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
117 retry_cfg_override = ext.get::<RetryConfigForCall>().cloned();
118 }
119 let retry_cfg = if let Some(ovr) = retry_cfg_override {
120 ovr.0
121 } else if call_type == CallType::TaskLongPoll {
122 RetryOptions::task_poll_retry_policy()
123 } else {
124 self.clone()
125 };
126 CallInfo {
127 call_type,
128 call_name,
129 retry_cfg,
130 retry_short_circuit,
131 }
132 }
133
134 pub(crate) fn into_exp_backoff<C>(self, clock: C) -> exponential::ExponentialBackoff<C> {
135 exponential::ExponentialBackoff {
136 current_interval: self.initial_interval,
137 initial_interval: self.initial_interval,
138 randomization_factor: self.randomization_factor,
139 multiplier: self.multiplier,
140 max_interval: self.max_interval,
141 max_elapsed_time: self.max_elapsed_time,
142 clock,
143 start_time: Instant::now(),
144 }
145 }
146}
147
148impl From<RetryOptions> for backoff::ExponentialBackoff {
149 fn from(c: RetryOptions) -> Self {
150 c.into_exp_backoff(SystemClock::default())
151 }
152}
153
154pub(crate) fn make_future_retry<R, F, Fut>(
155 info: CallInfo,
156 factory: F,
157) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
158where
159 F: FnMut() -> Fut + Unpin,
160 Fut: Future<Output = Result<R, tonic::Status>>,
161{
162 FutureRetry::new(
163 factory,
164 TonicErrorHandler::new(info, RetryOptions::throttle_retry_policy()),
165 )
166}
167
168#[derive(Debug)]
169pub(crate) struct TonicErrorHandler<C: Clock> {
170 backoff: ExponentialBackoff<C>,
171 throttle_backoff: ExponentialBackoff<C>,
172 max_retries: usize,
173 call_type: CallType,
174 call_name: &'static str,
175 have_retried_goaway_cancel: bool,
176 retry_short_circuit: Option<NoRetryOnMatching>,
177}
178impl TonicErrorHandler<SystemClock> {
179 fn new(call_info: CallInfo, throttle_cfg: RetryOptions) -> Self {
180 Self::new_with_clock(
181 call_info,
182 throttle_cfg,
183 SystemClock::default(),
184 SystemClock::default(),
185 )
186 }
187}
188impl<C> TonicErrorHandler<C>
189where
190 C: Clock,
191{
192 fn new_with_clock(
193 call_info: CallInfo,
194 throttle_cfg: RetryOptions,
195 clock: C,
196 throttle_clock: C,
197 ) -> Self {
198 Self {
199 call_type: call_info.call_type,
200 call_name: call_info.call_name,
201 max_retries: call_info.retry_cfg.max_retries,
202 backoff: call_info.retry_cfg.into_exp_backoff(clock),
203 throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
204 have_retried_goaway_cancel: false,
205 retry_short_circuit: call_info.retry_short_circuit,
206 }
207 }
208
209 fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
210 let mut do_log = false;
211 if self.max_retries == 0 && cur_attempt > 5 {
213 do_log = true;
214 }
215 if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
217 do_log = true;
218 }
219
220 if do_log {
221 if self.max_retries == 0 && cur_attempt > 15 {
223 error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
224 } else {
225 warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
226 }
227 }
228 }
229}
230
231#[derive(Clone, Debug)]
232pub(crate) struct CallInfo {
233 pub call_type: CallType,
234 call_name: &'static str,
235 retry_cfg: RetryOptions,
236 retry_short_circuit: Option<NoRetryOnMatching>,
237}
238
239#[doc(hidden)]
240#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
241pub enum CallType {
242 Normal,
243 UserLongPoll,
245 TaskLongPoll,
247}
248
249impl CallType {
250 pub(crate) fn is_long(&self) -> bool {
251 matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
252 }
253}
254
255impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
256where
257 C: Clock,
258{
259 type OutError = tonic::Status;
260
261 fn handle(
262 &mut self,
263 current_attempt: usize,
264 mut e: tonic::Status,
265 ) -> RetryPolicy<tonic::Status> {
266 if self.max_retries > 0 && current_attempt >= self.max_retries {
268 return RetryPolicy::ForwardError(e);
269 }
270
271 if let Some(sc) = self.retry_short_circuit.as_ref()
272 && (sc.predicate)(&e)
273 {
274 e.metadata_mut().insert(
275 ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT,
276 tonic::metadata::MetadataValue::from(0),
277 );
278 return RetryPolicy::ForwardError(e);
279 }
280
281 if e.code() == Code::ResourceExhausted
283 && (e
284 .message()
285 .starts_with("grpc: received message larger than max")
286 || e.message()
287 .starts_with("grpc: message after decompression larger than max")
288 || e.message()
289 .starts_with("grpc: received message after decompression larger than max"))
290 {
291 e.metadata_mut().insert(
293 MESSAGE_TOO_LARGE_KEY,
294 tonic::metadata::MetadataValue::from(0),
295 );
296 return RetryPolicy::ForwardError(e);
297 }
298
299 let long_poll_allowed = self.call_type == CallType::TaskLongPoll
302 && [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
303
304 let mut goaway_retry_allowed = false;
309 if !self.have_retried_goaway_cancel
310 && e.code() == Code::Cancelled
311 && let Some(e) = e
312 .source()
313 .and_then(|e| e.downcast_ref::<tonic::transport::Error>())
314 .and_then(|te| te.source())
315 .and_then(|tec| tec.downcast_ref::<hyper::Error>())
316 && format!("{e:?}").contains("connection closed")
317 {
318 goaway_retry_allowed = true;
319 self.have_retried_goaway_cancel = true;
320 }
321
322 if RETRYABLE_ERROR_CODES.contains(&e.code()) || long_poll_allowed || goaway_retry_allowed {
323 if current_attempt == 1 {
324 debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
325 } else {
326 self.maybe_log_retry(current_attempt, &e);
327 }
328
329 match self.backoff.next_backoff() {
330 None => RetryPolicy::ForwardError(e), Some(backoff) => {
332 if e.code() == Code::ResourceExhausted {
335 let extended_backoff =
336 backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
337 RetryPolicy::WaitRetry(extended_backoff)
338 } else {
339 RetryPolicy::WaitRetry(backoff)
340 }
341 }
342 }
343 } else if self.call_type == CallType::TaskLongPoll
344 && self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
345 {
346 RetryPolicy::WaitRetry(self.backoff.max_interval)
349 } else {
350 RetryPolicy::ForwardError(e)
351 }
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use assert_matches::assert_matches;
359 use backoff::Clock;
360 use std::{ops::Add, time::Instant};
361 use temporalio_common::protos::temporal::api::workflowservice::v1::{
362 PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
363 };
364 use tonic::{IntoRequest, Status};
365
366 const TEST_RETRY_CONFIG: RetryOptions = RetryOptions {
368 initial_interval: Duration::from_millis(1),
369 randomization_factor: 0.0,
370 multiplier: 1.1,
371 max_interval: Duration::from_millis(2),
372 max_elapsed_time: None,
373 max_retries: 10,
374 };
375
376 const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
377 const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
378 const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
379
380 struct FixedClock(Instant);
381 impl Clock for FixedClock {
382 fn now(&self) -> Instant {
383 self.0
384 }
385 }
386
387 #[tokio::test]
388 async fn long_poll_non_retryable_errors() {
389 for code in [
390 Code::InvalidArgument,
391 Code::NotFound,
392 Code::AlreadyExists,
393 Code::PermissionDenied,
394 Code::FailedPrecondition,
395 Code::Unauthenticated,
396 Code::Unimplemented,
397 ] {
398 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
399 let mut err_handler = TonicErrorHandler::new_with_clock(
400 CallInfo {
401 call_type: CallType::TaskLongPoll,
402 call_name,
403 retry_cfg: TEST_RETRY_CONFIG,
404 retry_short_circuit: None,
405 },
406 TEST_RETRY_CONFIG,
407 FixedClock(Instant::now()),
408 FixedClock(Instant::now()),
409 );
410 let result = err_handler.handle(1, Status::new(code, "Ahh"));
411 assert_matches!(result, RetryPolicy::WaitRetry(_));
412 err_handler.backoff.clock.0 = err_handler
413 .backoff
414 .clock
415 .0
416 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
417 let result = err_handler.handle(2, Status::new(code, "Ahh"));
418 assert_matches!(result, RetryPolicy::ForwardError(_));
419 }
420 }
421 }
422
423 #[tokio::test]
424 async fn long_poll_retryable_errors_never_fatal() {
425 for code in RETRYABLE_ERROR_CODES {
426 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
427 let mut err_handler = TonicErrorHandler::new_with_clock(
428 CallInfo {
429 call_type: CallType::TaskLongPoll,
430 call_name,
431 retry_cfg: TEST_RETRY_CONFIG,
432 retry_short_circuit: None,
433 },
434 TEST_RETRY_CONFIG,
435 FixedClock(Instant::now()),
436 FixedClock(Instant::now()),
437 );
438 let result = err_handler.handle(1, Status::new(code, "Ahh"));
439 assert_matches!(result, RetryPolicy::WaitRetry(_));
440 err_handler.backoff.clock.0 = err_handler
441 .backoff
442 .clock
443 .0
444 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
445 let result = err_handler.handle(2, Status::new(code, "Ahh"));
446 assert_matches!(result, RetryPolicy::WaitRetry(_));
447 }
448 }
449 }
450
451 #[tokio::test]
452 async fn retry_resource_exhausted() {
453 let mut err_handler = TonicErrorHandler::new_with_clock(
454 CallInfo {
455 call_type: CallType::TaskLongPoll,
456 call_name: POLL_WORKFLOW_METH_NAME,
457 retry_cfg: TEST_RETRY_CONFIG,
458 retry_short_circuit: None,
459 },
460 RetryOptions {
461 initial_interval: Duration::from_millis(2),
462 randomization_factor: 0.0,
463 multiplier: 4.0,
464 max_interval: Duration::from_millis(10),
465 max_elapsed_time: None,
466 max_retries: 10,
467 },
468 FixedClock(Instant::now()),
469 FixedClock(Instant::now()),
470 );
471 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
472 match result {
473 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
474 _ => panic!(),
475 }
476 err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
477 err_handler.throttle_backoff.clock.0 = err_handler
478 .throttle_backoff
479 .clock
480 .0
481 .add(Duration::from_millis(10));
482 let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
483 match result {
484 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
485 _ => panic!(),
486 }
487 }
488
489 #[tokio::test]
490 async fn retry_short_circuit() {
491 let mut err_handler = TonicErrorHandler::new_with_clock(
492 CallInfo {
493 call_type: CallType::TaskLongPoll,
494 call_name: POLL_WORKFLOW_METH_NAME,
495 retry_cfg: TEST_RETRY_CONFIG,
496 retry_short_circuit: Some(NoRetryOnMatching {
497 predicate: |s: &Status| s.code() == Code::ResourceExhausted,
498 }),
499 },
500 TEST_RETRY_CONFIG,
501 FixedClock(Instant::now()),
502 FixedClock(Instant::now()),
503 );
504 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
505 let e = assert_matches!(result, RetryPolicy::ForwardError(e) => e);
506 assert!(
507 e.metadata()
508 .get(ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT)
509 .is_some()
510 );
511 }
512
513 #[tokio::test]
514 async fn message_too_large_not_retried() {
515 let mut err_handler = TonicErrorHandler::new_with_clock(
516 CallInfo {
517 call_type: CallType::TaskLongPoll,
518 call_name: POLL_WORKFLOW_METH_NAME,
519 retry_cfg: TEST_RETRY_CONFIG,
520 retry_short_circuit: None,
521 },
522 TEST_RETRY_CONFIG,
523 FixedClock(Instant::now()),
524 FixedClock(Instant::now()),
525 );
526 let result = err_handler.handle(
527 1,
528 Status::new(
529 Code::ResourceExhausted,
530 "grpc: received message larger than max",
531 ),
532 );
533 assert_matches!(result, RetryPolicy::ForwardError(_));
534
535 let result = err_handler.handle(
536 1,
537 Status::new(
538 Code::ResourceExhausted,
539 "grpc: message after decompression larger than max",
540 ),
541 );
542 assert_matches!(result, RetryPolicy::ForwardError(_));
543
544 let result = err_handler.handle(
545 1,
546 Status::new(
547 Code::ResourceExhausted,
548 "grpc: received message after decompression larger than max",
549 ),
550 );
551 assert_matches!(result, RetryPolicy::ForwardError(_));
552 }
553
554 #[rstest::rstest]
555 #[tokio::test]
556 async fn task_poll_retries_forever<R>(
557 #[values(
558 (
559 POLL_WORKFLOW_METH_NAME,
560 PollWorkflowTaskQueueRequest::default(),
561 ),
562 (
563 POLL_ACTIVITY_METH_NAME,
564 PollActivityTaskQueueRequest::default(),
565 ),
566 (
567 POLL_NEXUS_METH_NAME,
568 PollNexusTaskQueueRequest::default(),
569 ),
570 )]
571 (call_name, req): (&'static str, R),
572 ) {
573 let mut req = req.into_request();
576 req.extensions_mut().insert(IsWorkerTaskLongPoll);
577 for i in 1..=50 {
578 let mut err_handler = TonicErrorHandler::new(
579 TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
580 RetryOptions::throttle_retry_policy(),
581 );
582 let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
583 assert_matches!(result, RetryPolicy::WaitRetry(_));
584 }
585 }
586
587 #[rstest::rstest]
588 #[tokio::test]
589 async fn task_poll_retries_deadline_exceeded<R>(
590 #[values(
591 (
592 POLL_WORKFLOW_METH_NAME,
593 PollWorkflowTaskQueueRequest::default(),
594 ),
595 (
596 POLL_ACTIVITY_METH_NAME,
597 PollActivityTaskQueueRequest::default(),
598 ),
599 (
600 POLL_NEXUS_METH_NAME,
601 PollNexusTaskQueueRequest::default(),
602 ),
603 )]
604 (call_name, req): (&'static str, R),
605 ) {
606 let mut req = req.into_request();
607 req.extensions_mut().insert(IsWorkerTaskLongPoll);
608 for code in [Code::Cancelled, Code::DeadlineExceeded] {
610 let mut err_handler = TonicErrorHandler::new(
611 TEST_RETRY_CONFIG.get_call_info::<R>(call_name, Some(&req)),
612 RetryOptions::throttle_retry_policy(),
613 );
614 for i in 1..=5 {
615 let result = err_handler.handle(i, Status::new(code, "retryable failure"));
616 assert_matches!(result, RetryPolicy::WaitRetry(_));
617 }
618 }
619 }
620}