spider_middleware/
retry.rs1use async_trait::async_trait;
12use log::{info, trace, warn};
13use std::time::Duration;
14
15use crate::middleware::{Middleware, MiddlewareAction};
16use spider_util::error::SpiderError;
17use spider_util::request::Request;
18use spider_util::response::Response;
19
20#[derive(Debug, Clone)]
22pub struct RetryMiddleware {
23 pub max_retries: u32,
25 pub retry_http_codes: Vec<u16>,
27 pub backoff_factor: f64,
29 pub max_delay: Duration,
31}
32
33impl Default for RetryMiddleware {
34 fn default() -> Self {
35 let middleware = RetryMiddleware {
36 max_retries: 3,
37 retry_http_codes: vec![500, 502, 503, 504, 408, 429],
38 backoff_factor: 1.0,
39 max_delay: Duration::from_secs(180),
40 };
41 info!("Initializing RetryMiddleware with config: {:?}", middleware);
42 middleware
43 }
44}
45
46impl RetryMiddleware {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn max_retries(mut self, max_retries: u32) -> Self {
54 self.max_retries = max_retries;
55 self
56 }
57
58 pub fn retry_http_codes(mut self, retry_http_codes: Vec<u16>) -> Self {
60 self.retry_http_codes = retry_http_codes;
61 self
62 }
63
64 pub fn backoff_factor(mut self, backoff_factor: f64) -> Self {
66 self.backoff_factor = backoff_factor;
67 self
68 }
69
70 pub fn max_delay(mut self, max_delay: Duration) -> Self {
72 self.max_delay = max_delay;
73 self
74 }
75}
76
77#[async_trait]
78impl<C: Send + Sync> Middleware<C> for RetryMiddleware {
79 fn name(&self) -> &str {
80 "RetryMiddleware"
81 }
82
83 async fn process_request(
84 &mut self,
85 _client: &C,
86 request: Request,
87 ) -> Result<MiddlewareAction<Request>, SpiderError> {
88 Ok(MiddlewareAction::Continue(request))
89 }
90
91 async fn process_response(
92 &mut self,
93 response: Response,
94 ) -> Result<MiddlewareAction<Response>, SpiderError> {
95 trace!(
96 "Processing response for URL: {} with status: {}",
97 response.url, response.status
98 );
99
100 if self.retry_http_codes.contains(&response.status.as_u16()) {
101 let mut request = response.request_from_response();
102 let current_attempts = request.get_retry_attempts();
103
104 if current_attempts < self.max_retries {
105 request.increment_retry_attempts();
106 let delay = self.calculate_delay(current_attempts);
107 info!(
108 "Retrying {} (status: {}, attempt {}/{}) after {:?}",
109 request.url,
110 response.status,
111 current_attempts + 1,
112 self.max_retries,
113 delay
114 );
115 return Ok(MiddlewareAction::Retry(Box::new(request), delay));
116 } else {
117 warn!(
118 "Max retries ({}) reached for {} (status: {}). Dropping response.",
119 self.max_retries, request.url, response.status
120 );
121 return Ok(MiddlewareAction::Drop);
122 }
123 } else {
124 trace!(
125 "Response status {} is not in retry codes, continuing",
126 response.status
127 );
128 }
129
130 Ok(MiddlewareAction::Continue(response))
131 }
132
133 async fn handle_error(
134 &mut self,
135 request: &Request,
136 error: &SpiderError,
137 ) -> Result<MiddlewareAction<Request>, SpiderError> {
138 trace!("Handling error for request {}: {:?}", request.url, error);
139
140 if let SpiderError::ReqwestError(err_details) = error
141 && (err_details.is_connect || err_details.is_timeout)
142 {
143 let mut new_request = request.clone();
144 let current_attempts = new_request.get_retry_attempts();
145
146 if current_attempts < self.max_retries {
147 new_request.increment_retry_attempts();
148 let delay = self.calculate_delay(current_attempts);
149 info!(
150 "Retrying {} (error: {}, attempt {}/{}) after {:?}",
151 new_request.url,
152 err_details.message,
153 current_attempts + 1,
154 self.max_retries,
155 delay
156 );
157 return Ok(MiddlewareAction::Retry(Box::new(new_request), delay));
158 } else {
159 warn!(
160 "Max retries ({}) reached for {} (error: {}). Dropping request.",
161 self.max_retries, new_request.url, err_details.message
162 );
163 return Ok(MiddlewareAction::Drop);
164 }
165 } else {
166 trace!("Error is not a retryable error, returning original error");
167 }
168
169 Err(error.clone())
170 }
171}
172
173impl RetryMiddleware {
174 fn calculate_delay(&self, retries: u32) -> Duration {
175 let delay_secs = self.backoff_factor * (2.0f64.powi(retries as i32));
176 let delay = Duration::from_secs_f64(delay_secs);
177 delay.min(self.max_delay)
178 }
179}