1use crate::{
2 Client, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY, NamespacedClient, NoRetryOnMatching,
3 Result, RetryConfig, raw::IsUserLongPoll,
4};
5use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff};
6use futures_retry::{ErrorHandler, FutureRetry, RetryPolicy};
7use std::{error::Error, fmt::Debug, future::Future, sync::Arc, time::Duration};
8use tonic::{Code, Request};
9
10pub const RETRYABLE_ERROR_CODES: [Code; 7] = [
12 Code::DataLoss,
13 Code::Internal,
14 Code::Unknown,
15 Code::ResourceExhausted,
16 Code::Aborted,
17 Code::OutOfRange,
18 Code::Unavailable,
19];
20const LONG_POLL_FATAL_GRACE: Duration = Duration::from_secs(60);
21
22#[derive(Debug, Clone)]
25pub struct RetryClient<SG> {
26 client: SG,
27 retry_config: Arc<RetryConfig>,
28}
29
30impl<SG> RetryClient<SG> {
31 pub fn new(client: SG, retry_config: RetryConfig) -> Self {
33 Self {
34 client,
35 retry_config: Arc::new(retry_config),
36 }
37 }
38}
39
40impl<SG> RetryClient<SG> {
41 pub fn get_client(&self) -> &SG {
43 &self.client
44 }
45
46 pub fn get_client_mut(&mut self) -> &mut SG {
48 &mut self.client
49 }
50
51 pub fn into_inner(self) -> SG {
53 self.client
54 }
55
56 pub(crate) fn get_call_info<R>(
57 &self,
58 call_name: &'static str,
59 request: Option<&Request<R>>,
60 ) -> CallInfo {
61 let mut call_type = CallType::Normal;
62 let mut retry_short_circuit = None;
63 if let Some(r) = request.as_ref() {
64 let ext = r.extensions();
65 if ext.get::<IsUserLongPoll>().is_some() {
66 call_type = CallType::UserLongPoll;
67 } else if ext.get::<IsWorkerTaskLongPoll>().is_some() {
68 call_type = CallType::TaskLongPoll;
69 }
70
71 retry_short_circuit = ext.get::<NoRetryOnMatching>().cloned();
72 }
73 let retry_cfg = if call_type == CallType::TaskLongPoll {
74 RetryConfig::task_poll_retry_policy()
75 } else {
76 (*self.retry_config).clone()
77 };
78 CallInfo {
79 call_type,
80 call_name,
81 retry_cfg,
82 retry_short_circuit,
83 }
84 }
85
86 pub(crate) fn make_future_retry<R, F, Fut>(
87 info: CallInfo,
88 factory: F,
89 ) -> FutureRetry<F, TonicErrorHandler<SystemClock>>
90 where
91 F: FnMut() -> Fut + Unpin,
92 Fut: Future<Output = Result<R>>,
93 {
94 FutureRetry::new(
95 factory,
96 TonicErrorHandler::new(info, RetryConfig::throttle_retry_policy()),
97 )
98 }
99}
100
101impl NamespacedClient for RetryClient<Client> {
102 fn namespace(&self) -> &str {
103 &self.client.namespace
104 }
105
106 fn get_identity(&self) -> &str {
107 &self.client.options().identity
108 }
109}
110
111#[derive(Debug)]
112pub(crate) struct TonicErrorHandler<C: Clock> {
113 backoff: ExponentialBackoff<C>,
114 throttle_backoff: ExponentialBackoff<C>,
115 max_retries: usize,
116 call_type: CallType,
117 call_name: &'static str,
118 have_retried_goaway_cancel: bool,
119 retry_short_circuit: Option<NoRetryOnMatching>,
120}
121impl TonicErrorHandler<SystemClock> {
122 fn new(call_info: CallInfo, throttle_cfg: RetryConfig) -> Self {
123 Self::new_with_clock(
124 call_info,
125 throttle_cfg,
126 SystemClock::default(),
127 SystemClock::default(),
128 )
129 }
130}
131impl<C> TonicErrorHandler<C>
132where
133 C: Clock,
134{
135 fn new_with_clock(
136 call_info: CallInfo,
137 throttle_cfg: RetryConfig,
138 clock: C,
139 throttle_clock: C,
140 ) -> Self {
141 Self {
142 call_type: call_info.call_type,
143 call_name: call_info.call_name,
144 max_retries: call_info.retry_cfg.max_retries,
145 backoff: call_info.retry_cfg.into_exp_backoff(clock),
146 throttle_backoff: throttle_cfg.into_exp_backoff(throttle_clock),
147 have_retried_goaway_cancel: false,
148 retry_short_circuit: call_info.retry_short_circuit,
149 }
150 }
151
152 fn maybe_log_retry(&self, cur_attempt: usize, err: &tonic::Status) {
153 let mut do_log = false;
154 if self.max_retries == 0 && cur_attempt > 5 {
156 do_log = true;
157 }
158 if self.max_retries > 0 && cur_attempt * 2 >= self.max_retries {
160 do_log = true;
161 }
162
163 if do_log {
164 if self.max_retries == 0 && cur_attempt > 15 {
166 error!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
167 } else {
168 warn!(error=?err, "gRPC call {} retried {} times", self.call_name, cur_attempt);
169 }
170 }
171 }
172}
173
174#[derive(Clone, Debug)]
175pub(crate) struct CallInfo {
176 pub call_type: CallType,
177 call_name: &'static str,
178 retry_cfg: RetryConfig,
179 retry_short_circuit: Option<NoRetryOnMatching>,
180}
181
182#[doc(hidden)]
183#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
184pub enum CallType {
185 Normal,
186 UserLongPoll,
188 TaskLongPoll,
190}
191
192impl CallType {
193 pub(crate) fn is_long(&self) -> bool {
194 matches!(self, Self::UserLongPoll | Self::TaskLongPoll)
195 }
196}
197
198impl<C> ErrorHandler<tonic::Status> for TonicErrorHandler<C>
199where
200 C: Clock,
201{
202 type OutError = tonic::Status;
203
204 fn handle(
205 &mut self,
206 current_attempt: usize,
207 mut e: tonic::Status,
208 ) -> RetryPolicy<tonic::Status> {
209 if self.max_retries > 0 && current_attempt >= self.max_retries {
211 return RetryPolicy::ForwardError(e);
212 }
213
214 if let Some(sc) = self.retry_short_circuit.as_ref()
215 && (sc.predicate)(&e)
216 {
217 return RetryPolicy::ForwardError(e);
218 }
219
220 if e.code() == Code::ResourceExhausted
222 && (e
223 .message()
224 .starts_with("grpc: received message larger than max")
225 || e.message()
226 .starts_with("grpc: message after decompression larger than max")
227 || e.message()
228 .starts_with("grpc: received message after decompression larger than max"))
229 {
230 e.metadata_mut().insert(
232 MESSAGE_TOO_LARGE_KEY,
233 tonic::metadata::MetadataValue::from(0),
234 );
235 return RetryPolicy::ForwardError(e);
236 }
237
238 let long_poll_allowed = self.call_type == CallType::TaskLongPoll
241 && [Code::Cancelled, Code::DeadlineExceeded].contains(&e.code());
242
243 let mut goaway_retry_allowed = false;
248 if !self.have_retried_goaway_cancel
249 && e.code() == Code::Cancelled
250 && let Some(e) = e
251 .source()
252 .and_then(|e| e.downcast_ref::<tonic::transport::Error>())
253 .and_then(|te| te.source())
254 .and_then(|tec| tec.downcast_ref::<hyper::Error>())
255 && format!("{e:?}").contains("connection closed")
256 {
257 goaway_retry_allowed = true;
258 self.have_retried_goaway_cancel = true;
259 }
260
261 if RETRYABLE_ERROR_CODES.contains(&e.code()) || long_poll_allowed || goaway_retry_allowed {
262 if current_attempt == 1 {
263 debug!(error=?e, "gRPC call {} failed on first attempt", self.call_name);
264 } else {
265 self.maybe_log_retry(current_attempt, &e);
266 }
267
268 match self.backoff.next_backoff() {
269 None => RetryPolicy::ForwardError(e), Some(backoff) => {
271 if e.code() == Code::ResourceExhausted {
274 let extended_backoff =
275 backoff.max(self.throttle_backoff.next_backoff().unwrap_or_default());
276 RetryPolicy::WaitRetry(extended_backoff)
277 } else {
278 RetryPolicy::WaitRetry(backoff)
279 }
280 }
281 }
282 } else if self.call_type == CallType::TaskLongPoll
283 && self.backoff.get_elapsed_time() <= LONG_POLL_FATAL_GRACE
284 {
285 RetryPolicy::WaitRetry(self.backoff.max_interval)
288 } else {
289 RetryPolicy::ForwardError(e)
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use assert_matches::assert_matches;
298 use backoff::Clock;
299 use std::{ops::Add, time::Instant};
300 use squads_temporal_sdk_core_protos::temporal::api::workflowservice::v1::{
301 PollActivityTaskQueueRequest, PollNexusTaskQueueRequest, PollWorkflowTaskQueueRequest,
302 };
303 use tonic::{IntoRequest, Status};
304
305 const TEST_RETRY_CONFIG: RetryConfig = RetryConfig {
307 initial_interval: Duration::from_millis(1),
308 randomization_factor: 0.0,
309 multiplier: 1.1,
310 max_interval: Duration::from_millis(2),
311 max_elapsed_time: None,
312 max_retries: 10,
313 };
314
315 const POLL_WORKFLOW_METH_NAME: &str = "poll_workflow_task_queue";
316 const POLL_ACTIVITY_METH_NAME: &str = "poll_activity_task_queue";
317 const POLL_NEXUS_METH_NAME: &str = "poll_nexus_task_queue";
318
319 struct FixedClock(Instant);
320 impl Clock for FixedClock {
321 fn now(&self) -> Instant {
322 self.0
323 }
324 }
325
326 #[tokio::test]
327 async fn long_poll_non_retryable_errors() {
328 for code in [
329 Code::InvalidArgument,
330 Code::NotFound,
331 Code::AlreadyExists,
332 Code::PermissionDenied,
333 Code::FailedPrecondition,
334 Code::Unauthenticated,
335 Code::Unimplemented,
336 ] {
337 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
338 let mut err_handler = TonicErrorHandler::new_with_clock(
339 CallInfo {
340 call_type: CallType::TaskLongPoll,
341 call_name,
342 retry_cfg: TEST_RETRY_CONFIG,
343 retry_short_circuit: None,
344 },
345 TEST_RETRY_CONFIG,
346 FixedClock(Instant::now()),
347 FixedClock(Instant::now()),
348 );
349 let result = err_handler.handle(1, Status::new(code, "Ahh"));
350 assert_matches!(result, RetryPolicy::WaitRetry(_));
351 err_handler.backoff.clock.0 = err_handler
352 .backoff
353 .clock
354 .0
355 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
356 let result = err_handler.handle(2, Status::new(code, "Ahh"));
357 assert_matches!(result, RetryPolicy::ForwardError(_));
358 }
359 }
360 }
361
362 #[tokio::test]
363 async fn long_poll_retryable_errors_never_fatal() {
364 for code in RETRYABLE_ERROR_CODES {
365 for call_name in [POLL_WORKFLOW_METH_NAME, POLL_ACTIVITY_METH_NAME] {
366 let mut err_handler = TonicErrorHandler::new_with_clock(
367 CallInfo {
368 call_type: CallType::TaskLongPoll,
369 call_name,
370 retry_cfg: TEST_RETRY_CONFIG,
371 retry_short_circuit: None,
372 },
373 TEST_RETRY_CONFIG,
374 FixedClock(Instant::now()),
375 FixedClock(Instant::now()),
376 );
377 let result = err_handler.handle(1, Status::new(code, "Ahh"));
378 assert_matches!(result, RetryPolicy::WaitRetry(_));
379 err_handler.backoff.clock.0 = err_handler
380 .backoff
381 .clock
382 .0
383 .add(LONG_POLL_FATAL_GRACE + Duration::from_secs(1));
384 let result = err_handler.handle(2, Status::new(code, "Ahh"));
385 assert_matches!(result, RetryPolicy::WaitRetry(_));
386 }
387 }
388 }
389
390 #[tokio::test]
391 async fn retry_resource_exhausted() {
392 let mut err_handler = TonicErrorHandler::new_with_clock(
393 CallInfo {
394 call_type: CallType::TaskLongPoll,
395 call_name: POLL_WORKFLOW_METH_NAME,
396 retry_cfg: TEST_RETRY_CONFIG,
397 retry_short_circuit: None,
398 },
399 RetryConfig {
400 initial_interval: Duration::from_millis(2),
401 randomization_factor: 0.0,
402 multiplier: 4.0,
403 max_interval: Duration::from_millis(10),
404 max_elapsed_time: None,
405 max_retries: 10,
406 },
407 FixedClock(Instant::now()),
408 FixedClock(Instant::now()),
409 );
410 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
411 match result {
412 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(2)),
413 _ => panic!(),
414 }
415 err_handler.backoff.clock.0 = err_handler.backoff.clock.0.add(Duration::from_millis(10));
416 err_handler.throttle_backoff.clock.0 = err_handler
417 .throttle_backoff
418 .clock
419 .0
420 .add(Duration::from_millis(10));
421 let result = err_handler.handle(2, Status::new(Code::ResourceExhausted, "leave me alone"));
422 match result {
423 RetryPolicy::WaitRetry(duration) => assert_eq!(duration, Duration::from_millis(8)),
424 _ => panic!(),
425 }
426 }
427
428 #[tokio::test]
429 async fn retry_short_circuit() {
430 let mut err_handler = TonicErrorHandler::new_with_clock(
431 CallInfo {
432 call_type: CallType::TaskLongPoll,
433 call_name: POLL_WORKFLOW_METH_NAME,
434 retry_cfg: TEST_RETRY_CONFIG,
435 retry_short_circuit: Some(NoRetryOnMatching {
436 predicate: |s: &Status| s.code() == Code::ResourceExhausted,
437 }),
438 },
439 TEST_RETRY_CONFIG,
440 FixedClock(Instant::now()),
441 FixedClock(Instant::now()),
442 );
443 let result = err_handler.handle(1, Status::new(Code::ResourceExhausted, "leave me alone"));
444 assert_matches!(result, RetryPolicy::ForwardError(_))
445 }
446
447 #[tokio::test]
448 async fn message_too_large_not_retried() {
449 let mut err_handler = TonicErrorHandler::new_with_clock(
450 CallInfo {
451 call_type: CallType::TaskLongPoll,
452 call_name: POLL_WORKFLOW_METH_NAME,
453 retry_cfg: TEST_RETRY_CONFIG,
454 retry_short_circuit: None,
455 },
456 TEST_RETRY_CONFIG,
457 FixedClock(Instant::now()),
458 FixedClock(Instant::now()),
459 );
460 let result = err_handler.handle(
461 1,
462 Status::new(
463 Code::ResourceExhausted,
464 "grpc: received message larger than max",
465 ),
466 );
467 assert_matches!(result, RetryPolicy::ForwardError(_));
468
469 let result = err_handler.handle(
470 1,
471 Status::new(
472 Code::ResourceExhausted,
473 "grpc: message after decompression larger than max",
474 ),
475 );
476 assert_matches!(result, RetryPolicy::ForwardError(_));
477
478 let result = err_handler.handle(
479 1,
480 Status::new(
481 Code::ResourceExhausted,
482 "grpc: received message after decompression larger than max",
483 ),
484 );
485 assert_matches!(result, RetryPolicy::ForwardError(_));
486 }
487
488 #[rstest::rstest]
489 #[tokio::test]
490 async fn task_poll_retries_forever<R>(
491 #[values(
492 (
493 POLL_WORKFLOW_METH_NAME,
494 PollWorkflowTaskQueueRequest::default(),
495 ),
496 (
497 POLL_ACTIVITY_METH_NAME,
498 PollActivityTaskQueueRequest::default(),
499 ),
500 (
501 POLL_NEXUS_METH_NAME,
502 PollNexusTaskQueueRequest::default(),
503 ),
504 )]
505 (call_name, req): (&'static str, R),
506 ) {
507 let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
510 let mut req = req.into_request();
511 req.extensions_mut().insert(IsWorkerTaskLongPoll);
512 for i in 1..=50 {
513 let mut err_handler = TonicErrorHandler::new(
514 fake_retry.get_call_info::<R>(call_name, Some(&req)),
515 RetryConfig::throttle_retry_policy(),
516 );
517 let result = err_handler.handle(i, Status::new(Code::Unknown, "Ahh"));
518 assert_matches!(result, RetryPolicy::WaitRetry(_));
519 }
520 }
521
522 #[rstest::rstest]
523 #[tokio::test]
524 async fn task_poll_retries_deadline_exceeded<R>(
525 #[values(
526 (
527 POLL_WORKFLOW_METH_NAME,
528 PollWorkflowTaskQueueRequest::default(),
529 ),
530 (
531 POLL_ACTIVITY_METH_NAME,
532 PollActivityTaskQueueRequest::default(),
533 ),
534 (
535 POLL_NEXUS_METH_NAME,
536 PollNexusTaskQueueRequest::default(),
537 ),
538 )]
539 (call_name, req): (&'static str, R),
540 ) {
541 let fake_retry = RetryClient::new((), TEST_RETRY_CONFIG);
542 let mut req = req.into_request();
543 req.extensions_mut().insert(IsWorkerTaskLongPoll);
544 for code in [Code::Cancelled, Code::DeadlineExceeded] {
546 let mut err_handler = TonicErrorHandler::new(
547 fake_retry.get_call_info::<R>(call_name, Some(&req)),
548 RetryConfig::throttle_retry_policy(),
549 );
550 for i in 1..=5 {
551 let result = err_handler.handle(i, Status::new(code, "retryable failure"));
552 assert_matches!(result, RetryPolicy::WaitRetry(_));
553 }
554 }
555 }
556}