simbld_http/helpers/
auth_middleware.rs

1//! # Authentication Middleware
2//!
3//! Provides middleware for authenticating requests by validating tokens.
4//! The middleware intercepts incoming requests, checks for a valid token,
5//! and either allows the request to proceed or returns an unauthorized response.
6
7use actix_web::{
8    body::{BoxBody, EitherBody, MessageBody},
9    dev::{Service, ServiceRequest, ServiceResponse, Transform},
10    http::StatusCode,
11    web, Error, HttpResponse,
12};
13use futures_util::future::{ok, LocalBoxFuture, Ready};
14use std::task::{Context, Poll};
15
16/// Query parameters for extracting the authentication token.
17///
18/// Used to parse the `key` parameter from the request URL query string.
19///
20#[derive(serde::Deserialize)]
21pub struct TokenParams {
22    /// Optional token value to be validated.
23    pub key: Option<String>,
24}
25
26/// Middleware that authenticates requests by validating tokens.
27///
28/// Implements Actix Web's `Transform` trait to intercept requests and verify
29/// authentication tokens before they reach route handlers.
30///
31pub struct AuthMiddleware;
32
33impl<S, B> Transform<S, ServiceRequest> for AuthMiddleware
34where
35    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
36    B: MessageBody + 'static,
37{
38    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
39    type Error = Error;
40    type Transform = AuthMiddlewareService<S>;
41    type InitError = ();
42    type Future = Ready<Result<Self::Transform, Self::InitError>>;
43
44    fn new_transform(&self, service: S) -> Self::Future {
45        ok(AuthMiddlewareService { service })
46    }
47}
48
49/// Service created by AuthMiddleware to process requests.
50///
51/// Handles the actual authentication logic by extracting and validating
52/// tokens from incoming requests.
53///
54pub struct AuthMiddlewareService<S> {
55    service: S,
56}
57
58impl<S, B> Service<ServiceRequest> for AuthMiddlewareService<S>
59where
60    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
61    B: MessageBody + 'static,
62{
63    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
64    type Error = Error;
65    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
66
67    /// Checks if the service is ready to process a request.
68    ///
69    fn poll_ready(&self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70        Poll::Ready(Ok(()))
71    }
72
73    /// Processes an incoming request by extracting and validating the token.
74    ///
75    /// If a valid token is found, the request is passed to the inner service.
76    /// Otherwise, an appropriate error response is returned.
77    ///
78    fn call(&self, req: ServiceRequest) -> Self::Future {
79        // We check the request parameters, we parse the QueryString: ?key=...
80        let query = web::Query::<TokenParams>::from_query(req.query_string()).ok();
81        let token = query.and_then(|q| q.key.clone()); // Clone the token to avoid invalid references
82
83        let response = match token.as_deref() {
84            Some("validated") => HttpResponse::build(StatusCode::from_u16(200).unwrap())
85                .insert_header(("X-HTTP-Status-Code", "200"))
86                .body("Authentication Successful"),
87            Some("expired") => HttpResponse::build(StatusCode::from_u16(401).unwrap())
88                .insert_header(("X-HTTP-Status-Code", "401"))
89                .insert_header(("X-Auth-Error", "Token Expired"))
90                .body("Your authentication token has expired, please log in again"),
91            Some(_) => HttpResponse::build(StatusCode::from_u16(401).unwrap())
92                .insert_header(("X-HTTP-Status-Code", "401"))
93                .insert_header(("X-Auth-Error", "Invalid Token"))
94                .body("Invalid Token"),
95            None => HttpResponse::build(StatusCode::from_u16(400).unwrap())
96                .insert_header(("X-HTTP-Status-Code", "400"))
97                .insert_header(("X-Auth-Error", "Missing Token"))
98                .body("Missing auth token"),
99        };
100
101        // Here, if it is 200 => we let the request pass through to the route handler
102        if response.status() == StatusCode::from_u16(200).unwrap() {
103            let fut = self.service.call(req);
104            return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
105        }
106
107        // Otherwise, we return the response immediately, without reaching the route handler
108        Box::pin(async move { Ok(req.into_response(response.map_into_right_body())) })
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use actix_web::{test, App};
116
117    /// Tests that the AuthMiddleware correctly validates different token scenarios.
118    ///
119    /// Verifies behavior with:
120    /// - Valid token
121    /// - Expired token
122    /// - Invalid token
123    /// - Missing token
124    ///
125    #[actix_web::test]
126    async fn test_auth_middleware() {
127        let app = test::init_service(App::new().wrap(AuthMiddleware).route(
128            "/protected",
129            web::get().to(|| async {
130                HttpResponse::build(StatusCode::from_u16(200).unwrap()).body("Access Granted")
131            }),
132        ))
133        .await;
134
135        // Test case 1: Valid token
136        let req_valid = test::TestRequest::get().uri("/protected?key=validated").to_request();
137        let resp_valid = test::call_service(&app, req_valid).await;
138        assert_eq!(
139            resp_valid.status(),
140            StatusCode::from_u16(200).unwrap(),
141            "Expected status code 200 for a validated token."
142        );
143
144        // Test case 2: Expired token
145        let req_expired = test::TestRequest::get().uri("/protected?key=expired").to_request();
146        let resp_expired = test::call_service(&app, req_expired).await;
147        assert_eq!(
148            resp_expired.status(),
149            StatusCode::from_u16(401).unwrap(),
150            "Expected status code 401 for an expired token."
151        );
152        assert_eq!(resp_expired.headers().get("X-Auth-Error").unwrap(), "Token Expired");
153
154        // Test case 3: Invalid token
155        let req_invalid = test::TestRequest::get().uri("/protected?key=invalid").to_request();
156        let resp_invalid = test::call_service(&app, req_invalid).await;
157        assert_eq!(
158            resp_invalid.status(),
159            StatusCode::from_u16(401).unwrap(),
160            "Expected status code 401 for an invalid token."
161        );
162        assert_eq!(resp_invalid.headers().get("X-Auth-Error").unwrap(), "Invalid Token");
163
164        // Test case 4: Missing token
165        let req_missing = test::TestRequest::get().uri("/protected").to_request();
166        let resp_missing = test::call_service(&app, req_missing).await;
167        assert_eq!(
168            resp_missing.status(),
169            StatusCode::from_u16(400).unwrap(),
170            "Expected status code 400 for a missing token."
171        );
172        assert_eq!(resp_missing.headers().get("X-Auth-Error").unwrap(), "Missing Token");
173    }
174}