simbld_http/helpers/
auth_middleware.rs1use actix_web::{
8 body::{BoxBody, EitherBody, MessageBody},
9 dev::{Service, ServiceRequest, ServiceResponse, Transform},
10 http::StatusCode,
11 web, Error, HttpResponse,
12};
13use futures_util::future::{ok, LocalBoxFuture, Ready};
14use std::task::{Context, Poll};
15
16#[derive(serde::Deserialize)]
21pub struct TokenParams {
22 pub key: Option<String>,
24}
25
26pub struct AuthMiddleware;
32
33impl<S, B> Transform<S, ServiceRequest> for AuthMiddleware
34where
35 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
36 B: MessageBody + 'static,
37{
38 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
39 type Error = Error;
40 type Transform = AuthMiddlewareService<S>;
41 type InitError = ();
42 type Future = Ready<Result<Self::Transform, Self::InitError>>;
43
44 fn new_transform(&self, service: S) -> Self::Future {
45 ok(AuthMiddlewareService { service })
46 }
47}
48
49pub struct AuthMiddlewareService<S> {
55 service: S,
56}
57
58impl<S, B> Service<ServiceRequest> for AuthMiddlewareService<S>
59where
60 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
61 B: MessageBody + 'static,
62{
63 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
64 type Error = Error;
65 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
66
67 fn poll_ready(&self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70 Poll::Ready(Ok(()))
71 }
72
73 fn call(&self, req: ServiceRequest) -> Self::Future {
79 let query = web::Query::<TokenParams>::from_query(req.query_string()).ok();
81 let token = query.and_then(|q| q.key.clone()); let response = match token.as_deref() {
84 Some("validated") => HttpResponse::build(StatusCode::from_u16(200).unwrap())
85 .insert_header(("X-HTTP-Status-Code", "200"))
86 .body("Authentication Successful"),
87 Some("expired") => HttpResponse::build(StatusCode::from_u16(401).unwrap())
88 .insert_header(("X-HTTP-Status-Code", "401"))
89 .insert_header(("X-Auth-Error", "Token Expired"))
90 .body("Your authentication token has expired, please log in again"),
91 Some(_) => HttpResponse::build(StatusCode::from_u16(401).unwrap())
92 .insert_header(("X-HTTP-Status-Code", "401"))
93 .insert_header(("X-Auth-Error", "Invalid Token"))
94 .body("Invalid Token"),
95 None => HttpResponse::build(StatusCode::from_u16(400).unwrap())
96 .insert_header(("X-HTTP-Status-Code", "400"))
97 .insert_header(("X-Auth-Error", "Missing Token"))
98 .body("Missing auth token"),
99 };
100
101 if response.status() == StatusCode::from_u16(200).unwrap() {
103 let fut = self.service.call(req);
104 return Box::pin(async move { fut.await.map(|res| res.map_into_left_body()) });
105 }
106
107 Box::pin(async move { Ok(req.into_response(response.map_into_right_body())) })
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use actix_web::{test, App};
116
117 #[actix_web::test]
126 async fn test_auth_middleware() {
127 let app = test::init_service(App::new().wrap(AuthMiddleware).route(
128 "/protected",
129 web::get().to(|| async {
130 HttpResponse::build(StatusCode::from_u16(200).unwrap()).body("Access Granted")
131 }),
132 ))
133 .await;
134
135 let req_valid = test::TestRequest::get().uri("/protected?key=validated").to_request();
137 let resp_valid = test::call_service(&app, req_valid).await;
138 assert_eq!(
139 resp_valid.status(),
140 StatusCode::from_u16(200).unwrap(),
141 "Expected status code 200 for a validated token."
142 );
143
144 let req_expired = test::TestRequest::get().uri("/protected?key=expired").to_request();
146 let resp_expired = test::call_service(&app, req_expired).await;
147 assert_eq!(
148 resp_expired.status(),
149 StatusCode::from_u16(401).unwrap(),
150 "Expected status code 401 for an expired token."
151 );
152 assert_eq!(resp_expired.headers().get("X-Auth-Error").unwrap(), "Token Expired");
153
154 let req_invalid = test::TestRequest::get().uri("/protected?key=invalid").to_request();
156 let resp_invalid = test::call_service(&app, req_invalid).await;
157 assert_eq!(
158 resp_invalid.status(),
159 StatusCode::from_u16(401).unwrap(),
160 "Expected status code 401 for an invalid token."
161 );
162 assert_eq!(resp_invalid.headers().get("X-Auth-Error").unwrap(), "Invalid Token");
163
164 let req_missing = test::TestRequest::get().uri("/protected").to_request();
166 let resp_missing = test::call_service(&app, req_missing).await;
167 assert_eq!(
168 resp_missing.status(),
169 StatusCode::from_u16(400).unwrap(),
170 "Expected status code 400 for a missing token."
171 );
172 assert_eq!(resp_missing.headers().get("X-Auth-Error").unwrap(), "Missing Token");
173 }
174}