simbld_http/helpers/
auth_middleware.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use actix_web::{
  body::{BoxBody, EitherBody, MessageBody},
  dev::{Service, ServiceRequest, ServiceResponse, Transform},
  http::StatusCode,
  web, Error, HttpResponse,
};
use futures_util::future::{ok, LocalBoxFuture, Ready};
use serde::Deserialize;
use std::task::{Context, Poll};

/// Parameters for URL requests
#[derive(Deserialize)]
pub struct TokenParams {
  pub key: Option<String>,
}

pub struct AuthMiddleware;

pub struct AuthMiddlewareService<S> {
  service: S,
}

impl<S, B> Transform<S, ServiceRequest> for AuthMiddleware
where
  S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
  B: MessageBody + 'static,
{
  type Response = ServiceResponse<EitherBody<B, BoxBody>>;
  type Error = Error;
  type Transform = AuthMiddlewareService<S>;
  type InitError = ();
  type Future = Ready<Result<Self::Transform, Self::InitError>>;

  fn new_transform(&self, service: S) -> Self::Future {
    ok(AuthMiddlewareService {
      service,
    })
  }
}

impl<S, B> Service<ServiceRequest> for AuthMiddlewareService<S>
where
  S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
  B: MessageBody + 'static,
{
  type Response = ServiceResponse<EitherBody<B, BoxBody>>;
  type Error = Error;
  type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

  fn poll_ready(&self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
    Poll::Ready(Ok(()))
  }

  fn call(&self, req: ServiceRequest) -> Self::Future {
    // Check the request parameters
    let query = web::Query::<TokenParams>::from_query(req.query_string()).ok();
    let token = query.and_then(|q| q.key.clone()); // Clone the token to avoid invalid references

    let response = match token.as_deref() {
      Some("validated") => {
        HttpResponse::build(StatusCode::from_u16(222).unwrap()).body("Authentication Successful")
      },
      Some("expired") => {
        HttpResponse::build(StatusCode::from_u16(419).unwrap()).body("Page Expired")
      },
      Some(_) => HttpResponse::build(StatusCode::from_u16(498).unwrap()).body("Invalid Token"),
      None => HttpResponse::build(StatusCode::from_u16(983).unwrap()).body("Missing Token"),
    };

    if response.status() == StatusCode::from_u16(222).unwrap() {
      let fut = self.service.call(req);
      return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
    }

    Box::pin(async move { Ok(req.into_response(response.map_into_right_body())) })
  }
}

#[cfg(test)]
mod tests {
  use super::*;
  use actix_web::{test, App};

  #[actix_web::test]
  async fn test_auth_middleware() {
    let app = test::init_service(App::new().wrap(AuthMiddleware).route(
      "/protected",
      web::get().to(|| async {
        HttpResponse::build(StatusCode::from_u16(222).unwrap()).body("Access Granted")
      }),
    ))
    .await;

    // Test case 1: Valid token
    let req_valid = test::TestRequest::get().uri("/protected?key=validated").to_request();
    let resp_valid = test::call_service(&app, req_valid).await;
    assert_eq!(
      resp_valid.status(),
      StatusCode::from_u16(222).unwrap(),
      "Expected status code 222 for a validated token."
    );

    // Test case 2: Expired token
    let req_expired = test::TestRequest::get().uri("/protected?key=expired").to_request();
    let resp_expired = test::call_service(&app, req_expired).await;
    assert_eq!(
      resp_expired.status(),
      StatusCode::from_u16(419).unwrap(),
      "Expected status code 419 for an expired token."
    );

    // Test case 3: Invalid token
    let req_invalid = test::TestRequest::get().uri("/protected?key=invalid").to_request();
    let resp_invalid = test::call_service(&app, req_invalid).await;
    assert_eq!(
      resp_invalid.status(),
      StatusCode::from_u16(498).unwrap(),
      "Expected status code 498 for an invalid token."
    );

    // Test case 4: Missing token
    let req_missing = test::TestRequest::get().uri("/protected").to_request();
    let resp_missing = test::call_service(&app, req_missing).await;
    assert_eq!(
      resp_missing.status(),
      StatusCode::from_u16(983).unwrap(),
      "Expected status code 983 for a missing token."
    );
  }
}