Skip to main content

spider_lib/middlewares/
retry.rs

1// src/middleware/retry.rs
2use async_trait::async_trait;
3use std::time::Duration;
4use tracing::{info, warn};
5
6use crate::error::SpiderError;
7use crate::middleware::{Middleware, MiddlewareAction};
8use crate::request::Request;
9use crate::response::Response;
10
11/// Middleware that retries failed requests.
12#[derive(Debug, Clone)]
13pub struct RetryMiddleware {
14    /// Maximum number of times to retry a request.
15    pub max_retries: u32,
16    /// HTTP status codes that should trigger a retry.
17    pub retry_http_codes: Vec<u16>,
18    /// Factor for exponential backoff (delay = backoff_factor * (2^retries)).
19    pub backoff_factor: f64,
20    /// Maximum delay between retries.
21    pub max_delay: Duration,
22}
23
24impl Default for RetryMiddleware {
25    fn default() -> Self {
26        let middleware = RetryMiddleware {
27            max_retries: 3,
28            retry_http_codes: vec![500, 502, 503, 504, 408, 429],
29            backoff_factor: 1.0,
30            max_delay: Duration::from_secs(180),
31        };
32        info!("Initializing RetryMiddleware with config: {:?}", middleware);
33        middleware
34    }
35}
36
37
38impl RetryMiddleware {
39    /// Creates a new `RetryMiddleware` with default settings.
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Sets the maximum number of times to retry a request.
45    pub fn max_retries(mut self, max_retries: u32) -> Self {
46        self.max_retries = max_retries;
47        self
48    }
49
50    /// Sets the HTTP status codes that should trigger a retry.
51    pub fn retry_http_codes(mut self, retry_http_codes: Vec<u16>) -> Self {
52        self.retry_http_codes = retry_http_codes;
53        self
54    }
55
56    /// Sets the factor for exponential backoff.
57    pub fn backoff_factor(mut self, backoff_factor: f64) -> Self {
58        self.backoff_factor = backoff_factor;
59        self
60    }
61
62    /// Sets the maximum delay between retries.
63    pub fn max_delay(mut self, max_delay: Duration) -> Self {
64        self.max_delay = max_delay;
65        self
66    }
67}
68
69#[async_trait]
70impl<C: Send + Sync> Middleware<C> for RetryMiddleware {
71    fn name(&self) -> &str {
72        "RetryMiddleware"
73    }
74
75    async fn process_request(
76        &mut self,
77        _client: &C,
78        request: Request,
79    ) -> Result<MiddlewareAction<Request>, SpiderError> {
80        Ok(MiddlewareAction::Continue(request))
81    }
82
83    async fn process_response(
84        &mut self,
85        response: Response,
86    ) -> Result<MiddlewareAction<Response>, SpiderError> {
87        if self
88            .retry_http_codes
89            .contains(&response.status.as_u16())
90        {
91            let mut request = response.request_from_response();
92            let current_attempts = request.get_retry_attempts();
93
94            if current_attempts < self.max_retries {
95                request.increment_retry_attempts();
96                let delay = self.calculate_delay(current_attempts);
97                warn!(
98                    "Retrying {} (status: {}, attempt {}/{}) after {:?}",
99                    request.url,
100                    response.status,
101                    current_attempts + 1,
102                    self.max_retries,
103                    delay
104                );
105                return Ok(MiddlewareAction::Retry(Box::new(request), delay));
106            } else {
107                warn!(
108                    "Max retries ({}) reached for {} (status: {}). Dropping response.",
109                    self.max_retries, request.url, response.status
110                );
111                return Ok(MiddlewareAction::Drop);
112            }
113        }
114
115        Ok(MiddlewareAction::Continue(response))
116    }
117}
118
119impl RetryMiddleware {
120    fn calculate_delay(&self, retries: u32) -> Duration {
121        let delay_secs = self.backoff_factor * (2.0f64.powi(retries as i32));
122        let delay = Duration::from_secs_f64(delay_secs);
123        delay.min(self.max_delay)
124    }
125}