spider_middleware/
retry.rs1use async_trait::async_trait;
7use log::{info, trace, warn};
8use std::time::Duration;
9
10use crate::middleware::{Middleware, MiddlewareAction};
11use spider_util::error::SpiderError;
12use spider_util::request::Request;
13use spider_util::response::Response;
14
15#[derive(Debug, Clone)]
17pub struct RetryMiddleware {
18 pub max_retries: u32,
20 pub retry_http_codes: Vec<u16>,
22 pub backoff_factor: f64,
24 pub max_delay: Duration,
26}
27
28impl Default for RetryMiddleware {
29 fn default() -> Self {
30 let middleware = RetryMiddleware {
31 max_retries: 3,
32 retry_http_codes: vec![500, 502, 503, 504, 408, 429],
33 backoff_factor: 1.0,
34 max_delay: Duration::from_secs(180),
35 };
36 info!("Initializing RetryMiddleware with config: {:?}", middleware);
37 middleware
38 }
39}
40
41impl RetryMiddleware {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn max_retries(mut self, max_retries: u32) -> Self {
49 self.max_retries = max_retries;
50 self
51 }
52
53 pub fn retry_http_codes(mut self, retry_http_codes: Vec<u16>) -> Self {
55 self.retry_http_codes = retry_http_codes;
56 self
57 }
58
59 pub fn backoff_factor(mut self, backoff_factor: f64) -> Self {
61 self.backoff_factor = backoff_factor;
62 self
63 }
64
65 pub fn max_delay(mut self, max_delay: Duration) -> Self {
67 self.max_delay = max_delay;
68 self
69 }
70}
71
72#[async_trait]
73impl<C: Send + Sync> Middleware<C> for RetryMiddleware {
74 fn name(&self) -> &str {
75 "RetryMiddleware"
76 }
77
78 async fn process_request(
79 &self,
80 _client: &C,
81 request: Request,
82 ) -> Result<MiddlewareAction<Request>, SpiderError> {
83 Ok(MiddlewareAction::Continue(request))
84 }
85
86 async fn process_response(
87 &self,
88 response: Response,
89 ) -> Result<MiddlewareAction<Response>, SpiderError> {
90 trace!(
91 "Processing response for URL: {} with status: {}",
92 response.url, response.status
93 );
94
95 if self.retry_http_codes.contains(&response.status.as_u16()) {
96 let mut request = response.request_from_response();
97 let current_attempts = request.get_retry_attempts();
98
99 if current_attempts < self.max_retries {
100 request.increment_retry_attempts();
101 let delay = self.calculate_delay(current_attempts);
102 info!(
103 "Retrying {} (status: {}, attempt {}/{}) after {:?}",
104 request.url,
105 response.status,
106 current_attempts + 1,
107 self.max_retries,
108 delay
109 );
110 return Ok(MiddlewareAction::Retry(Box::new(request), delay));
111 } else {
112 warn!(
113 "Max retries ({}) reached for {} (status: {}). Dropping response.",
114 self.max_retries, request.url, response.status
115 );
116 return Ok(MiddlewareAction::Drop);
117 }
118 } else {
119 trace!(
120 "Response status {} is not in retry codes, continuing",
121 response.status
122 );
123 }
124
125 Ok(MiddlewareAction::Continue(response))
126 }
127
128 async fn handle_error(
129 &self,
130 request: &Request,
131 error: &SpiderError,
132 ) -> Result<MiddlewareAction<Request>, SpiderError> {
133 trace!("Handling error for request {}: {:?}", request.url, error);
134
135 if let SpiderError::ReqwestError(err_details) = error
136 && (err_details.is_connect || err_details.is_timeout)
137 {
138 let mut new_request = request.clone();
139 let current_attempts = new_request.get_retry_attempts();
140
141 if current_attempts < self.max_retries {
142 new_request.increment_retry_attempts();
143 let delay = self.calculate_delay(current_attempts);
144 info!(
145 "Retrying {} (error: {}, attempt {}/{}) after {:?}",
146 new_request.url,
147 err_details.message,
148 current_attempts + 1,
149 self.max_retries,
150 delay
151 );
152 return Ok(MiddlewareAction::Retry(Box::new(new_request), delay));
153 } else {
154 warn!(
155 "Max retries ({}) reached for {} (error: {}). Dropping request.",
156 self.max_retries, new_request.url, err_details.message
157 );
158 return Ok(MiddlewareAction::Drop);
159 }
160 } else {
161 trace!("Error is not a retryable error, returning original error");
162 }
163
164 Err(error.clone())
165 }
166}
167
168impl RetryMiddleware {
169 fn calculate_delay(&self, retries: u32) -> Duration {
170 let delay_secs = self.backoff_factor * (2.0f64.powi(retries as i32));
171 let delay = Duration::from_secs_f64(delay_secs);
172 delay.min(self.max_delay)
173 }
174}