1use crate::{
2 ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY,
3 NamespacedClient, NoRetryOnMatching, Result, RetryConfig, raw::IsUserLongPoll,
4};
5use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff};
6use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
7use std::{error::Error, fmt::Debug, future::Future, sync::Arc, time::Duration};
8use tonic::{Code, Request};
9
10pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
12 Code::DataLoss,
13 Code::Internal,
14 Code::Unknown,
15 Code::ResourceExhausted,
16 Code::Aborted,
17 Code::OutOfRange,
18 Code::Unavailable,
19];
20const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
21
22#[derive(Debug, Clone)]
25pub struct RetryClient<SG> {
26 client: SG,
27 retry_config: Arc<RetryConfig>,
28}
29
30impl<SG> RetryClient<SG> {
31 pub fn new(client: SG, retry_config: RetryConfig) -> Self {
33 Self {
34 client,
35 retry_config: Arc::new(retry_config),
36 }
37 }
38}
39
40impl<SG> RetryClient<SG> {
41 pub fn get_client(&self) -> &SG {
43 &self.client
44 }
45
46 pub fn get_client_mut(&mut self) -> &mut SG {
48 &mut self.client
49 }
50
51 pub fn into_inner(self) -> SG {
53 self.client
54 }
55
56 pub(crate) fn get_call_info<R>(
57 &self,
58 call_name: &'static str,
59 request: Option<&Request<R>>,
60 ) -> CallInfo {
61 let mut call_type = CallType::Normal;
62 let mut retry_short_circuit = None;
63 if let Some(r) = request.as_ref() {
64 let ext = r.extensions();
65 if ext.get::<IsUserLongPoll>().is_some() {
66 call_type = CallType::UserLongPoll;
67 } else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
68 call_type = CallType::TaskLongPoll;
69 }
70
71 retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
72 }
73 let retry_cfg = if call_type == CallType::TaskLongPoll {
74 RetryConfig::task_poll_retry_policy()
75 } else {
76 (*self.retry_config).clone()
77 };
78 CallInfo {
79 call_type,
80 call_name,
81 retry_cfg,
82 retry_short_circuit,
83 }
84 }
85
86 pub(crate) fn make_future_retry<R, F, Fut>(
87 info: CallInfo,
88 factory: F,
89 ) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
90 where
91 F: FnMut() -> Fut + Unpin,
92 Fut: Future<Output = Result<R>>,
93 {
94 FutureRetry::new(
95 factory,
96 TonicErrorHandler::new(info, RetryConfig::throttle_retry_policy()),
97 )
98 }
99}
100
101impl<SG> NamespacedClient for RetryClient<SG>
102where
103 SG: NamespacedClient,
104{
105 fn namespace(&self) -> String {
106 self.client.namespace()
107 }
108
109 fn identity(&self) -> String {
110 self.client.identity()
111 }
112}
113
114#[derive(Debug)]
115pub(crate) struct TonicErrorHandler<C: Clock> {
116 backoff: ExponentialBackoff<C>,
117 throttle_backoff: ExponentialBackoff<C>,
118 max_retries: usize,
119 call_type: CallType,
120 call_name: &'static str,
121 have_retried_goaway_cancel: bool,
122 retry_short_circuit: Option<NoRetryOnMatching>,
123}
124impl TonicErrorHandler<SystemClock> {
125 fn new(call_info: CallInfo, throttle_cfg: RetryConfig) -> Self {
126 Self::new_with_clock(
127 call_info,
128 throttle_cfg,
129 SystemClock::default(),
130 SystemClock::default(),
131 )
132 }
133}
134impl<C> TonicErrorHandler<C>
135where
136 C: Clock,
137{
138 fn new_with_clock(
139 call_info: CallInfo,
140 throttle_cfg: RetryConfig,
141 clock: C,
142 throttle_clock: C,
143 ) -> Self {
144 Self {
145 call_type: call_info.call_type,
146 call_name: call_info.call_name,
147 max_retries: call_info.retry_cfg.max_retries,
148 backoff: call_info.retry_cfg.into_exp_backoff(clock),
149 throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
150 have_retried_goaway_cancel: false,
151 retry_short_circuit: call_info.retry_short_circuit,
152 }
153 }
154
155 fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
156 let mut do_log = false;
157 if self.max_retries == 0 && cur_attempt > 5 {
159 do_log = true;
160 }
161 if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
163 do_log = true;
164 }
165
166 if do_log {
167 if self.max_retries == 0 && cur_attempt > 15 {
169 error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
170 } else {
171 warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
172 }
173 }
174 }
175}
176
177#[derive(Clone, Debug)]
178pub(crate) struct CallInfo {
179 pub call_type: CallType,
180 call_name: &'static str,
181 retry_cfg: RetryConfig,
182 retry_short_circuit: Option<NoRetryOnMatching>,
183}
184
185#[doc(hidden)]
186#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
187pub enum CallType {
188 Normal,
189 UserLongPoll,
191 TaskLongPoll,
193}
194
195impl CallType {
196 pub(crate) fn is_long(&self) -> bool {
197 matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
198 }
199}
200
201impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
202where
203 C: Clock,
204{
205 type OutError = tonic::Status;
206
207 fn handle(
208 &mut self,
209 current_attempt: usize,
210 mut e: tonic::Status,
211 ) -> RetryPolicy<tonic::Status> {
212 if self.max_retries > 0 && current_attempt >= self.max_retries {
214 return RetryPolicy::ForwardError(e);
215 }
216
217 if let Some(sc) = self.retry_short_circuit.as_ref()
218 && (sc.predicate)(&e)
219 {
220 e.metadata_mut().insert(
221 ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT,
222 tonic::metadata::MetadataValue::from(0),
223 );
224 return RetryPolicy::ForwardError(e);
225 }
226
227 if e.code() == Code::ResourceExhausted
229 && (e
230 .message()
231 .starts_with("grpc: received message larger than max")
232 || e.message()
233 .starts_with("grpc: message after decompression larger than max")
234 || e.message()
235 .starts_with("grpc: received message after decompression larger than max"))
236 {
237 e.metadata_mut().insert(
239 MESSAGE_TOO_LARGE_KEY,
240 tonic::metadata::MetadataValue::from(0),
241 );
242 return RetryPolicy::ForwardError(e);
243 }
244
245 let long_poll_allowed = self.call_type == CallType::TaskLongPoll
248 && [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
249
250 let mut goaway_retry_allowed = false;
255 if !self.have_retried_goaway_cancel
256 && e.code() == Code::Cancelled
257 && let Some(e) = e
258 .source()
259 .and_then(|e| e.downcast_ref::<tonic::transport::Error>())
260 .and_then(|te| te.source())
261 .and_then(|tec| tec.downcast_ref::<hyper::Error>())
262 && format!("{e:?}").contains("connection closed")
263 {
264 goaway_retry_allowed = true;
265 self.have_retried_goaway_cancel = true;
266 }
267
268 if RETRYABLE_ERROR_CODES.contains(&e.code()) || long_poll_allowed || goaway_retry_allowed {
269 if current_attempt == 1 {
270 debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
271 } else {
272 self.maybe_log_retry(current_attempt, &e);
273 }
274
275 match self.backoff.next_backoff() {
276 None => RetryPolicy::ForwardError(e), Some(backoff) => {
278 if e.code() == Code::ResourceExhausted {
281 let extended_backoff =
282 backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
283 RetryPolicy::WaitRetry(extended_backoff)
284 } else {
285 RetryPolicy::WaitRetry(backoff)
286 }
287 }
288 }
289 } else if self.call_type == CallType::TaskLongPoll
290 && self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
291 {
292 RetryPolicy::WaitRetry(self.backoff.max_interval)
295 } else {
296 RetryPolicy::ForwardError(e)
297 }
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use assert_matches::assert_matches;
305 use backoff::Clock;
306 use std::{ops::Add, time::Instant};
307 use temporal_sdk_core_protos::temporal::api::workflowservice::v1::{
308 PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
309 };
310 use tonic::{IntoRequest, Status};
311
312 const TEST_RETRY_CONFIG: RetryConfig = RetryConfig {
314 initial_interval: Duration::from_millis(1),
315 randomization_factor: 0.0,
316 multiplier: 1.1,
317 max_interval: Duration::from_millis(2),
318 max_elapsed_time: None,
319 max_retries: 10,
320 };
321
322 const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
323 const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
324 const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
325
326 struct FixedClock(Instant);
327 impl Clock for FixedClock {
328 fn now(&self) -> Instant {
329 self.0
330 }
331 }
332
333 #[tokio::test]
334 async fn long_poll_non_retryable_errors() {
335 for code in [
336 Code::InvalidArgument,
337 Code::NotFound,
338 Code::AlreadyExists,
339 Code::PermissionDenied,
340 Code::FailedPrecondition,
341 Code::Unauthenticated,
342 Code::Unimplemented,
343 ] {
344 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
345 let mut err_handler = TonicErrorHandler::new_with_clock(
346 CallInfo {
347 call_type: CallType::TaskLongPoll,
348 call_name,
349 retry_cfg: TEST_RETRY_CONFIG,
350 retry_short_circuit: None,
351 },
352 TEST_RETRY_CONFIG,
353 FixedClock(Instant::now()),
354 FixedClock(Instant::now()),
355 );
356 let result = err_handler.handle(1, Status::new(code, "Ahh"));
357 assert_matches!(result, RetryPolicy::WaitRetry(_));
358 err_handler.backoff.clock.0 = err_handler
359 .backoff
360 .clock
361 .0
362 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
363 let result = err_handler.handle(2, Status::new(code, "Ahh"));
364 assert_matches!(result, RetryPolicy::ForwardError(_));
365 }
366 }
367 }
368
369 #[tokio::test]
370 async fn long_poll_retryable_errors_never_fatal() {
371 for code in RETRYABLE_ERROR_CODES {
372 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
373 let mut err_handler = TonicErrorHandler::new_with_clock(
374 CallInfo {
375 call_type: CallType::TaskLongPoll,
376 call_name,
377 retry_cfg: TEST_RETRY_CONFIG,
378 retry_short_circuit: None,
379 },
380 TEST_RETRY_CONFIG,
381 FixedClock(Instant::now()),
382 FixedClock(Instant::now()),
383 );
384 let result = err_handler.handle(1, Status::new(code, "Ahh"));
385 assert_matches!(result, RetryPolicy::WaitRetry(_));
386 err_handler.backoff.clock.0 = err_handler
387 .backoff
388 .clock
389 .0
390 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
391 let result = err_handler.handle(2, Status::new(code, "Ahh"));
392 assert_matches!(result, RetryPolicy::WaitRetry(_));
393 }
394 }
395 }
396
397 #[tokio::test]
398 async fn retry_resource_exhausted() {
399 let mut err_handler = TonicErrorHandler::new_with_clock(
400 CallInfo {
401 call_type: CallType::TaskLongPoll,
402 call_name: POLL_WORKFLOW_METH_NAME,
403 retry_cfg: TEST_RETRY_CONFIG,
404 retry_short_circuit: None,
405 },
406 RetryConfig {
407 initial_interval: Duration::from_millis(2),
408 randomization_factor: 0.0,
409 multiplier: 4.0,
410 max_interval: Duration::from_millis(10),
411 max_elapsed_time: None,
412 max_retries: 10,
413 },
414 FixedClock(Instant::now()),
415 FixedClock(Instant::now()),
416 );
417 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
418 match result {
419 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
420 _ => panic!(),
421 }
422 err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
423 err_handler.throttle_backoff.clock.0 = err_handler
424 .throttle_backoff
425 .clock
426 .0
427 .add(Duration::from_millis(10));
428 let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
429 match result {
430 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
431 _ => panic!(),
432 }
433 }
434
435 #[tokio::test]
436 async fn retry_short_circuit() {
437 let mut err_handler = TonicErrorHandler::new_with_clock(
438 CallInfo {
439 call_type: CallType::TaskLongPoll,
440 call_name: POLL_WORKFLOW_METH_NAME,
441 retry_cfg: TEST_RETRY_CONFIG,
442 retry_short_circuit: Some(NoRetryOnMatching {
443 predicate: |s: &Status| s.code() == Code::ResourceExhausted,
444 }),
445 },
446 TEST_RETRY_CONFIG,
447 FixedClock(Instant::now()),
448 FixedClock(Instant::now()),
449 );
450 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
451 let e = assert_matches!(result, RetryPolicy::ForwardError(e) => e);
452 assert!(
453 e.metadata()
454 .get(ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT)
455 .is_some()
456 );
457 }
458
459 #[tokio::test]
460 async fn message_too_large_not_retried() {
461 let mut err_handler = TonicErrorHandler::new_with_clock(
462 CallInfo {
463 call_type: CallType::TaskLongPoll,
464 call_name: POLL_WORKFLOW_METH_NAME,
465 retry_cfg: TEST_RETRY_CONFIG,
466 retry_short_circuit: None,
467 },
468 TEST_RETRY_CONFIG,
469 FixedClock(Instant::now()),
470 FixedClock(Instant::now()),
471 );
472 let result = err_handler.handle(
473 1,
474 Status::new(
475 Code::ResourceExhausted,
476 "grpc: received message larger than max",
477 ),
478 );
479 assert_matches!(result, RetryPolicy::ForwardError(_));
480
481 let result = err_handler.handle(
482 1,
483 Status::new(
484 Code::ResourceExhausted,
485 "grpc: message after decompression larger than max",
486 ),
487 );
488 assert_matches!(result, RetryPolicy::ForwardError(_));
489
490 let result = err_handler.handle(
491 1,
492 Status::new(
493 Code::ResourceExhausted,
494 "grpc: received message after decompression larger than max",
495 ),
496 );
497 assert_matches!(result, RetryPolicy::ForwardError(_));
498 }
499
500 #[rstest::rstest]
501 #[tokio::test]
502 async fn task_poll_retries_forever<R>(
503 #[values(
504 (
505 POLL_WORKFLOW_METH_NAME,
506 PollWorkflowTaskQueueRequest::default(),
507 ),
508 (
509 POLL_ACTIVITY_METH_NAME,
510 PollActivityTaskQueueRequest::default(),
511 ),
512 (
513 POLL_NEXUS_METH_NAME,
514 PollNexusTaskQueueRequest::default(),
515 ),
516 )]
517 (call_name, req): (&'static str, R),
518 ) {
519 let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
522 let mut req = req.into_request();
523 req.extensions_mut().insert(IsWorkerTaskLongPoll);
524 for i in 1..=50 {
525 let mut err_handler = TonicErrorHandler::new(
526 fake_retry.get_call_info::<R>(call_name, Some(&req)),
527 RetryConfig::throttle_retry_policy(),
528 );
529 let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
530 assert_matches!(result, RetryPolicy::WaitRetry(_));
531 }
532 }
533
534 #[rstest::rstest]
535 #[tokio::test]
536 async fn task_poll_retries_deadline_exceeded<R>(
537 #[values(
538 (
539 POLL_WORKFLOW_METH_NAME,
540 PollWorkflowTaskQueueRequest::default(),
541 ),
542 (
543 POLL_ACTIVITY_METH_NAME,
544 PollActivityTaskQueueRequest::default(),
545 ),
546 (
547 POLL_NEXUS_METH_NAME,
548 PollNexusTaskQueueRequest::default(),
549 ),
550 )]
551 (call_name, req): (&'static str, R),
552 ) {
553 let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
554 let mut req = req.into_request();
555 req.extensions_mut().insert(IsWorkerTaskLongPoll);
556 for code in [Code::Cancelled, Code::DeadlineExceeded] {
558 let mut err_handler = TonicErrorHandler::new(
559 fake_retry.get_call_info::<R>(call_name, Some(&req)),
560 RetryConfig::throttle_retry_policy(),
561 );
562 for i in 1..=5 {
563 let result = err_handler.handle(i, Status::new(code, "retryable failure"));
564 assert_matches!(result, RetryPolicy::WaitRetry(_));
565 }
566 }
567 }
568}