simbld_http/helpers/
unified_middleware_helper.rs

1use actix_service::{Service, Transform};
2use actix_web::{
3    dev::{ServiceRequest, ServiceResponse},
4    error::Error as ActixError,
5    http::{header, StatusCode},
6    HttpResponse,
7};
8use futures_util::future::{ready, LocalBoxFuture, Ready};
9use std::{
10    collections::{HashMap, HashSet},
11    rc::Rc,
12    sync::{Arc, Mutex},
13    time::{Duration, Instant},
14};
15use thiserror::Error;
16
17pub type ConditionFunction = Rc<Box<dyn for<'a> Fn(&'a ServiceRequest) -> bool + 'static>>;
18pub type InterceptFunction = Rc<dyn Fn(&ServiceRequest) -> bool>;
19pub type RateLimiters = Arc<Mutex<HashMap<String, (u64, Instant)>>>;
20pub type AllowedOrigins = HashSet<String>;
21
22impl std::fmt::Debug for UnifiedMiddleware {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("UnifiedMiddleware")
25            .field("allowed_origins", &self.allowed_origins)
26            .field("max_requests", &self.max_requests)
27            .field("window_duration", &self.window_duration)
28            .finish()
29    }
30}
31
32pub struct UnifiedMiddleware {
33    pub allowed_origins: AllowedOrigins,
34    pub rate_limiters: RateLimiters,
35    pub max_requests: usize,
36    pub window_duration: Duration,
37    pub intercept_dependencies: InterceptFunction,
38    pub condition: ConditionFunction,
39}
40
41#[derive(Debug, Error)]
42pub enum UnifiedError {
43    #[error("An internal error occurred in the middleware.")]
44    InternalMiddlewareError,
45    #[error("Invalid request.")]
46    InvalidRequest,
47    #[error("Unauthorized access.")]
48    Unauthorized,
49}
50
51impl actix_web::ResponseError for UnifiedError {
52    fn status_code(&self) -> StatusCode {
53        match self {
54            UnifiedError::InternalMiddlewareError => StatusCode::INTERNAL_SERVER_ERROR,
55            UnifiedError::InvalidRequest => StatusCode::BAD_REQUEST,
56            UnifiedError::Unauthorized => StatusCode::UNAUTHORIZED,
57        }
58    }
59
60    fn error_response(&self) -> HttpResponse {
61        HttpResponse::build(self.status_code())
62            .content_type("application/json")
63            .body(format!("{{\"error\": \"{}\"}}", self))
64    }
65}
66
67pub type OptionalConditionFunction =
68    Option<Box<dyn for<'a> Fn(&'a ServiceRequest) -> bool + 'static>>;
69
70impl UnifiedMiddleware {
71    /// Creates a new middleware with complete and flexible configuration.
72    ///
73    /// # arguments
74    ///
75    /// * `Allowed_origins' - Authorized Kid Origins, separated by commas (ex:" http: //example.com,http: // Localhost: 3000 ")
76    /// * `rate_limiters` - Storage of rate limiters by IP
77    /// * `Max_requests' - Maximum number of requests authorized in the time window
78    /// * `Window_Duration` - Duration of the window for the rate limiter
79    /// * `intercept_dependencies' - function that determines if the request must be intercepted
80    /// * `Condition ' - Additional condition to apply the middleware
81    ///
82    pub fn new(
83        allowed_origins: String,
84        rate_limiters: RateLimiters,
85        max_requests: usize,
86        window_duration: Duration,
87        intercept_dependencies: InterceptFunction,
88        condition: OptionalConditionFunction,
89    ) -> Self {
90        let origins: AllowedOrigins =
91            allowed_origins.split(',').map(|s| s.trim().to_string()).collect();
92
93        let default_condition: Box<dyn for<'a> Fn(&'a ServiceRequest) -> bool + 'static> =
94            Box::new(|_| true);
95
96        Self {
97            allowed_origins: origins,
98            rate_limiters,
99            max_requests,
100            window_duration,
101            intercept_dependencies,
102            condition: Rc::new(condition.unwrap_or(default_condition)),
103        }
104    }
105
106    /// Simplified version for current use cases.
107    ///
108    /// This function automatically initializes data and functions structures
109    /// necessary with reasonable default values.
110    ///
111    /// # arguments
112    ///
113    /// * `Allowed_origins` - List of authorized Cors
114    /// * `Max_requests' - Maximum number of requests authorized in the time window
115    /// * `Window_Duration` - Duration of the window for the rate limiter
116    ///
117    pub fn simple(
118        allowed_origins: Vec<String>,
119        max_requests: usize,
120        window_duration: Duration,
121    ) -> Self {
122        Self::new(
123            allowed_origins.join(","),
124            Arc::new(Mutex::new(HashMap::new())),
125            max_requests,
126            window_duration,
127            Rc::new(|_| true),
128            Some(Box::new(|_| true)),
129        )
130    }
131}
132
133impl<S, B> Transform<S, ServiceRequest> for UnifiedMiddleware
134where
135    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
136    B: 'static,
137{
138    type Response = ServiceResponse<B>;
139    type Error = ActixError;
140    type Transform = UnifiedMiddlewareService<S>;
141    type InitError = ();
142    type Future = Ready<Result<Self::Transform, Self::InitError>>;
143
144    fn new_transform(&self, service: S) -> Self::Future {
145        ready(Ok(UnifiedMiddlewareService {
146            service: Rc::new(service),
147            allowed_origins: self.allowed_origins.clone(),
148            rate_limiters: self.rate_limiters.clone(),
149            max_requests: self.max_requests,
150            window_duration: self.window_duration,
151            intercept_dependencies: self.intercept_dependencies.clone(),
152            condition: self.condition.clone(),
153        }))
154    }
155}
156
157pub struct UnifiedMiddlewareService<S> {
158    service: Rc<S>,
159    allowed_origins: AllowedOrigins,
160    rate_limiters: RateLimiters,
161    max_requests: usize,
162    window_duration: Duration,
163    intercept_dependencies: InterceptFunction,
164    condition: ConditionFunction,
165}
166
167impl<S, B> Service<ServiceRequest> for UnifiedMiddlewareService<S>
168where
169    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
170    B: 'static,
171{
172    type Response = ServiceResponse<B>;
173    type Error = ActixError;
174    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
175
176    fn poll_ready(
177        &self,
178        cx: &mut std::task::Context<'_>,
179    ) -> std::task::Poll<Result<(), Self::Error>> {
180        self.service.poll_ready(cx)
181    }
182
183    fn call(&self, req: ServiceRequest) -> Self::Future {
184        let service = self.service.clone();
185        let condition = self.condition.clone();
186        let intercept = self.intercept_dependencies.clone();
187        let allowed_origins = self.allowed_origins.clone();
188        let rate_limiters = self.rate_limiters.clone();
189        let max_requests = self.max_requests;
190        let window_duration = self.window_duration;
191
192        Box::pin(async move {
193            // Check if the conditions are met to apply the middleware
194            if !(*condition)(&req) {
195                return service.call(req).await;
196            }
197
198            if (*intercept)(&req) {
199                // Check the origin if it is a CORS request
200                check_origin(&req, &allowed_origins)?;
201
202                // Check the rate limiting
203                check_rate_limit(&req, rate_limiters, max_requests, window_duration)?;
204            }
205
206            service.call(req).await
207        })
208    }
209}
210
211// Function to check the origin of the request
212fn check_origin(req: &ServiceRequest, allowed_origins: &AllowedOrigins) -> Result<(), ActixError> {
213    if let Some(origin) = req.headers().get(header::ORIGIN) {
214        if let Ok(origin_str) = origin.to_str() {
215            if !allowed_origins.contains(origin_str) && !allowed_origins.contains("*") {
216                return Err(UnifiedError::Unauthorized.into());
217            }
218        }
219    }
220    Ok(())
221}
222
223// Function to update the rate limiter for a specific client IP
224fn check_rate_limit(
225    req: &ServiceRequest,
226    rate_limiters: RateLimiters,
227    max_requests: usize,
228    window_duration: Duration,
229) -> Result<(), ActixError> {
230    let client_ip = match req.connection_info().realip_remote_addr() {
231        Some(ip) => ip.to_string(),
232        None => "unknown".to_string(),
233    };
234
235    let should_limit =
236        update_rate_limiter(&client_ip, rate_limiters, max_requests, window_duration)?;
237
238    if should_limit {
239        return Err(ActixError::from(UnifiedError::InvalidRequest));
240    }
241
242    Ok(())
243}
244
245// Function to update the rate limiter for a specific client IP
246fn update_rate_limiter(
247    client_ip: &str,
248    rate_limiters: RateLimiters,
249    max_requests: usize,
250    window_duration: Duration,
251) -> Result<bool, ActixError> {
252    let mut limiters = rate_limiters.lock().map_err(|_| UnifiedError::InternalMiddlewareError)?;
253
254    let now = Instant::now();
255    let entry = limiters.entry(client_ip.to_string()).or_insert_with(|| (0, now));
256
257    if now.duration_since(entry.1) > window_duration {
258        // Reinitialize the entry if the window has expired
259        *entry = (1, now);
260        Ok(false)
261    } else {
262        // Increment the request count
263        entry.0 += 1;
264        Ok(entry.0 > max_requests as u64)
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use actix_web::{
272        dev::Service,
273        http::StatusCode,
274        test::{init_service, TestRequest},
275        web, App, HttpResponse,
276    };
277    use std::{sync::Arc, time::Duration};
278
279    async fn test_handler() -> HttpResponse {
280        HttpResponse::Ok().body("test success")
281    }
282
283    #[actix_web::test]
284    async fn test_rate_limiting() {
285        let rate_limiters = Arc::new(Mutex::new(HashMap::new()));
286        let max_requests = 2;
287        let window_duration = Duration::from_secs(1);
288
289        let middleware = UnifiedMiddleware::new(
290            "*".to_string(),
291            rate_limiters,
292            max_requests,
293            window_duration,
294            Rc::new(|_| true),
295            None,
296        );
297
298        let app =
299            init_service(App::new().wrap(middleware).route("/test", web::get().to(test_handler)))
300                .await;
301
302        // first requests - should be authorized
303        for _ in 0..max_requests {
304            let req = TestRequest::get().uri("/test").to_request();
305            let resp = app.call(req).await.unwrap();
306            assert_eq!(resp.status(), StatusCode::OK);
307        }
308
309        // request exceeding the limit - should be rejected
310        let req = TestRequest::get().uri("/test").to_request();
311        let resp = app.call(req).await;
312        assert!(resp.is_err());
313    }
314
315    #[actix_web::test]
316    async fn test_allowed_origins() {
317        let rate_limiters = Arc::new(Mutex::new(HashMap::new()));
318        let allowed_origins = "https://example.com,https://test.com".to_string();
319
320        let middleware = UnifiedMiddleware::new(
321            allowed_origins,
322            rate_limiters,
323            100,
324            Duration::from_secs(60),
325            Rc::new(|_| true),
326            None,
327        );
328
329        let app =
330            init_service(App::new().wrap(middleware).route("/test", web::get().to(test_handler)))
331                .await;
332
333        // Authorized origin
334        let mut req = TestRequest::get().uri("/test");
335        req = req.insert_header((header::ORIGIN, "https://example.com"));
336        let resp = app.call(req.to_request()).await.unwrap();
337        assert_eq!(resp.status(), StatusCode::OK);
338
339        // not allowed origin
340        let mut req = TestRequest::get().uri("/test");
341        req = req.insert_header((header::ORIGIN, "https://unauthorized.com"));
342        let resp = app.call(req.to_request()).await;
343        assert!(resp.is_err());
344    }
345
346    #[actix_web::test]
347    async fn test_reset_rate_limiting_window() {
348        let rate_limiters = Arc::new(Mutex::new(HashMap::new()));
349        let max_requests = 1;
350        let window_duration = Duration::from_millis(10); // short duration for the test
351
352        let middleware = UnifiedMiddleware::new(
353            "*".to_string(),
354            rate_limiters.clone(),
355            max_requests,
356            window_duration,
357            Rc::new(|_| true),
358            None,
359        );
360
361        let app =
362            init_service(App::new().wrap(middleware).route("/test", web::get().to(test_handler)))
363                .await;
364
365        // First request - should be authorized
366        let req = TestRequest::get().uri("/test").to_request();
367        let resp = app.call(req).await.unwrap();
368        assert_eq!(resp.status(), StatusCode::OK);
369
370        // second immediate request - should be rejected
371        let req = TestRequest::get().uri("/test").to_request();
372        let resp = app.call(req).await;
373        assert!(resp.is_err());
374
375        // Wait until the window expires
376        tokio::time::sleep(window_duration * 2).await;
377
378        // new request after expiration - should be authorized
379        let req = TestRequest::get().uri("/test").to_request();
380        let resp = app.call(req).await.unwrap();
381        assert_eq!(resp.status(), StatusCode::OK);
382    }
383}