simbld_http/helpers/
auth_middleware.rs1use actix_web::{
7 body::{BoxBody, EitherBody, MessageBody},
8 dev::{Service, ServiceRequest, ServiceResponse, Transform},
9 http::StatusCode,
10 web, Error, HttpResponse,
11};
12use futures_util::future::{ok, LocalBoxFuture, Ready};
13use serde::Deserialize;
14use std::task::{Context, Poll};
15
16#[derive(Deserialize)]
18pub struct TokenParams {
19 pub key: Option<String>,
20}
21
22pub struct AuthMiddleware;
23
24pub struct AuthMiddlewareService<S> {
25 service: S,
26}
27
28impl<S, B> Transform<S, ServiceRequest> for AuthMiddleware
29where
30 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
31 B: MessageBody + 'static,
32{
33 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
34 type Error = Error;
35 type Transform = AuthMiddlewareService<S>;
36 type InitError = ();
37 type Future = Ready<Result<Self::Transform, Self::InitError>>;
38
39 fn new_transform(&self, service: S) -> Self::Future {
40 ok(AuthMiddlewareService {
41 service,
42 })
43 }
44}
45
46impl<S, B> Service<ServiceRequest> for AuthMiddlewareService<S>
47where
48 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
49 B: MessageBody + 'static,
50{
51 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
52 type Error = Error;
53 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
54
55 fn poll_ready(&self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56 Poll::Ready(Ok(()))
57 }
58
59 fn call(&self, req: ServiceRequest) -> Self::Future {
60 let query = web::Query::<TokenParams>::from_query(req.query_string()).ok();
62 let token = query.and_then(|q| q.key.clone()); let response = match token.as_deref() {
65 Some("validated") => {
66 HttpResponse::build(StatusCode::from_u16(222).unwrap()).body("Authentication Successful")
67 },
68 Some("expired") => {
69 HttpResponse::build(StatusCode::from_u16(419).unwrap()).body("Page Expired")
70 },
71 Some(_) => HttpResponse::build(StatusCode::from_u16(498).unwrap()).body("Invalid Token"),
72 None => HttpResponse::build(StatusCode::from_u16(983).unwrap()).body("Missing Token"),
73 };
74
75 if response.status() == StatusCode::from_u16(222).unwrap() {
76 let fut = self.service.call(req);
77 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
78 }
79
80 Box::pin(async move { Ok(req.into_response(response.map_into_right_body())) })
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use actix_web::{test, App};
88
89 #[actix_web::test]
90 async fn test_auth_middleware() {
91 let app = test::init_service(App::new().wrap(AuthMiddleware).route(
92 "/protected",
93 web::get().to(|| async {
94 HttpResponse::build(StatusCode::from_u16(222).unwrap()).body("Access Granted")
95 }),
96 ))
97 .await;
98
99 let req_valid = test::TestRequest::get().uri("/protected?key=validated").to_request();
101 let resp_valid = test::call_service(&app, req_valid).await;
102 assert_eq!(
103 resp_valid.status(),
104 StatusCode::from_u16(222).unwrap(),
105 "Expected status code 222 for a validated token."
106 );
107
108 let req_expired = test::TestRequest::get().uri("/protected?key=expired").to_request();
110 let resp_expired = test::call_service(&app, req_expired).await;
111 assert_eq!(
112 resp_expired.status(),
113 StatusCode::from_u16(419).unwrap(),
114 "Expected status code 419 for an expired token."
115 );
116
117 let req_invalid = test::TestRequest::get().uri("/protected?key=invalid").to_request();
119 let resp_invalid = test::call_service(&app, req_invalid).await;
120 assert_eq!(
121 resp_invalid.status(),
122 StatusCode::from_u16(498).unwrap(),
123 "Expected status code 498 for an invalid token."
124 );
125
126 let req_missing = test::TestRequest::get().uri("/protected").to_request();
128 let resp_missing = test::call_service(&app, req_missing).await;
129 assert_eq!(
130 resp_missing.status(),
131 StatusCode::from_u16(983).unwrap(),
132 "Expected status code 983 for a missing token."
133 );
134 }
135}