sp1_sdk/network/
retry.rs

1use anyhow::Result;
2use backoff::{future::retry, Error as BackoffError, ExponentialBackoff};
3use std::time::Duration;
4use tonic::Code;
5
6/// Default timeout for retry operations.
7pub const DEFAULT_RETRY_TIMEOUT: Duration = Duration::from_secs(120);
8
9/// Trait for implementing retryable RPC operations.
10#[async_trait::async_trait]
11pub trait RetryableRpc {
12    /// Execute an operation with retries using default timeout.
13    async fn with_retry<'a, T, F, Fut>(&'a self, operation: F, operation_name: &str) -> Result<T>
14    where
15        F: Fn() -> Fut + Send + Sync + 'a,
16        Fut: std::future::Future<Output = Result<T>> + Send,
17        T: Send;
18
19    /// Execute an operation with retries using custom timeout.
20    async fn with_retry_timeout<'a, T, F, Fut>(
21        &'a self,
22        operation: F,
23        timeout: Duration,
24        operation_name: &str,
25    ) -> Result<T>
26    where
27        F: Fn() -> Fut + Send + Sync + 'a,
28        Fut: std::future::Future<Output = Result<T>> + Send,
29        T: Send;
30}
31
32/// Execute an async operation with exponential backoff retries.
33pub async fn retry_operation<T, F, Fut>(
34    operation: F,
35    timeout: Option<Duration>,
36    operation_name: &str,
37) -> Result<T>
38where
39    F: Fn() -> Fut + Send + Sync,
40    Fut: std::future::Future<Output = Result<T>> + Send,
41{
42    let backoff = ExponentialBackoff {
43        initial_interval: Duration::from_secs(1),
44        max_interval: Duration::from_secs(120),
45        max_elapsed_time: timeout,
46        ..Default::default()
47    };
48
49    retry(backoff, || async {
50        match operation().await {
51            Ok(result) => Ok(result),
52            Err(e) => {
53                // Check for tonic status errors.
54                if let Some(status) = e.downcast_ref::<tonic::Status>() {
55                    match status.code() {
56                        Code::Unavailable |
57                        Code::DeadlineExceeded |
58                        Code::Internal |
59                        Code::Aborted => {
60                            tracing::warn!(
61                                "Network temporarily unavailable when {} due to {}, retrying...",
62                                operation_name,
63                                status.message(),
64                            );
65                            Err(BackoffError::transient(e))
66                        }
67                        Code::NotFound => {
68                            tracing::error!(
69                                "{} not found due to {}",
70                                operation_name,
71                                status.message(),
72                            );
73                            Err(BackoffError::permanent(e))
74                        }
75                        _ => {
76                            tracing::error!(
77                                "Permanent error encountered when {}: {} ({})",
78                                operation_name,
79                                status.message(),
80                                status.code()
81                            );
82                            Err(BackoffError::permanent(e))
83                        }
84                    }
85                } else {
86                    // Check for common transport errors.
87                    let error_msg = e.to_string().to_lowercase();
88                    let error_debug_msg = format!("{e:?}");
89
90                    if error_debug_msg.contains("no native certs found") {
91                        tracing::error!(
92                            "Permanent error when {}: no native certs found",
93                            operation_name
94                        );
95                        Err(BackoffError::permanent(e))
96                    } else {
97                        let is_transient = error_msg.contains("tls handshake") ||
98                            error_msg.contains("dns error") ||
99                            error_msg.contains("connection reset") ||
100                            error_msg.contains("broken pipe") ||
101                            error_msg.contains("transport error") ||
102                            error_msg.contains("failed to lookup") ||
103                            error_msg.contains("timeout") ||
104                            error_msg.contains("deadline exceeded") ||
105                            error_msg.contains("error sending request for url");
106
107                        if is_transient {
108                            tracing::warn!(
109                                "Transient transport error when {}: {}, retrying...",
110                                operation_name,
111                                error_msg
112                            );
113                            Err(BackoffError::transient(e))
114                        } else {
115                            tracing::error!(
116                                "Permanent error when {}: {}",
117                                operation_name,
118                                error_msg
119                            );
120                            Err(BackoffError::permanent(e))
121                        }
122                    }
123                }
124            }
125        }
126    })
127    .await
128}