1use anyhow::Result;
2use backoff::{future::retry, Error as BackoffError, ExponentialBackoff};
3use std::time::Duration;
4use tonic::Code;
5
6pub const DEFAULT_RETRY_TIMEOUT: Duration = Duration::from_secs(120);
8
9#[async_trait::async_trait]
11pub trait RetryableRpc {
12 async fn with_retry<'a, T, F, Fut>(&'a self, operation: F, operation_name: &str) -> Result<T>
14 where
15 F: Fn() -> Fut + Send + Sync + 'a,
16 Fut: std::future::Future<Output = Result<T>> + Send,
17 T: Send;
18
19 async fn with_retry_timeout<'a, T, F, Fut>(
21 &'a self,
22 operation: F,
23 timeout: Duration,
24 operation_name: &str,
25 ) -> Result<T>
26 where
27 F: Fn() -> Fut + Send + Sync + 'a,
28 Fut: std::future::Future<Output = Result<T>> + Send,
29 T: Send;
30}
31
32pub async fn retry_operation<T, F, Fut>(
34 operation: F,
35 timeout: Option<Duration>,
36 operation_name: &str,
37) -> Result<T>
38where
39 F: Fn() -> Fut + Send + Sync,
40 Fut: std::future::Future<Output = Result<T>> + Send,
41{
42 let backoff = ExponentialBackoff {
43 initial_interval: Duration::from_secs(1),
44 max_interval: Duration::from_secs(120),
45 max_elapsed_time: timeout,
46 ..Default::default()
47 };
48
49 retry(backoff, || async {
50 match operation().await {
51 Ok(result) => Ok(result),
52 Err(e) => {
53 if let Some(status) = e.downcast_ref::<tonic::Status>() {
55 match status.code() {
56 Code::Unavailable |
57 Code::DeadlineExceeded |
58 Code::Internal |
59 Code::Aborted => {
60 tracing::warn!(
61 "Network temporarily unavailable when {} due to {}, retrying...",
62 operation_name,
63 status.message(),
64 );
65 Err(BackoffError::transient(e))
66 }
67 Code::NotFound => {
68 tracing::error!(
69 "{} not found due to {}",
70 operation_name,
71 status.message(),
72 );
73 Err(BackoffError::permanent(e))
74 }
75 _ => {
76 tracing::error!(
77 "Permanent error encountered when {}: {} ({})",
78 operation_name,
79 status.message(),
80 status.code()
81 );
82 Err(BackoffError::permanent(e))
83 }
84 }
85 } else {
86 let error_msg = e.to_string().to_lowercase();
88 let error_debug_msg = format!("{e:?}");
89
90 if error_debug_msg.contains("no native certs found") {
91 tracing::error!(
92 "Permanent error when {}: no native certs found",
93 operation_name
94 );
95 Err(BackoffError::permanent(e))
96 } else {
97 let is_transient = error_msg.contains("tls handshake") ||
98 error_msg.contains("dns error") ||
99 error_msg.contains("connection reset") ||
100 error_msg.contains("broken pipe") ||
101 error_msg.contains("transport error") ||
102 error_msg.contains("failed to lookup") ||
103 error_msg.contains("timeout") ||
104 error_msg.contains("deadline exceeded") ||
105 error_msg.contains("error sending request for url");
106
107 if is_transient {
108 tracing::warn!(
109 "Transient transport error when {}: {}, retrying...",
110 operation_name,
111 error_msg
112 );
113 Err(BackoffError::transient(e))
114 } else {
115 tracing::error!(
116 "Permanent error when {}: {}",
117 operation_name,
118 error_msg
119 );
120 Err(BackoffError::permanent(e))
121 }
122 }
123 }
124 }
125 }
126 })
127 .await
128}