ts_webapi/middleware/
authorization.rs

1//! Authorization middleware
2//!
3
4use core::{
5    mem,
6    task::{Context, Poll},
7};
8
9use http::{Request, Response};
10use http_body::Body;
11use tower_layer::Layer;
12use tower_service::Service;
13
14use crate::middleware::futures::{UndefinedFuture, undefined::DefiningFuture};
15
16/// Authorization middleware layer.
17#[derive(Debug)]
18pub struct Authorization<B, Fut>
19where
20    Fut: DefiningFuture<B>,
21{
22    /// The function to authorize the request.
23    authorize: fn(Request<B>) -> Fut,
24}
25
26impl<B, Fut> Clone for Authorization<B, Fut>
27where
28    Fut: DefiningFuture<B>,
29{
30    fn clone(&self) -> Self {
31        Self {
32            authorize: self.authorize,
33        }
34    }
35}
36
37impl<B, Fut> Authorization<B, Fut>
38where
39    Fut: DefiningFuture<B>,
40{
41    /// Create a new authorization layer.
42    pub fn new(authorize: fn(Request<B>) -> Fut) -> Self {
43        Self { authorize }
44    }
45}
46
47impl<Svc, B, Fut> Layer<Svc> for Authorization<B, Fut>
48where
49    Svc: Clone,
50    Fut: DefiningFuture<B>,
51{
52    type Service = AuthorizationService<Svc, B, Fut>;
53
54    fn layer(&self, inner: Svc) -> Self::Service {
55        AuthorizationService {
56            inner,
57            auth: self.clone(),
58        }
59    }
60}
61
62/// Tower service for the middleware.
63#[derive(Debug)]
64pub struct AuthorizationService<Svc, B, Fut>
65where
66    Svc: Clone,
67    Fut: DefiningFuture<B>,
68{
69    /// Inner service.
70    inner: Svc,
71    /// The logic layer.
72    auth: Authorization<B, Fut>,
73}
74
75impl<Svc, B, Fut> Clone for AuthorizationService<Svc, B, Fut>
76where
77    Svc: Clone,
78    Fut: DefiningFuture<B>,
79{
80    fn clone(&self) -> Self {
81        Self {
82            inner: self.inner.clone(),
83            auth: self.auth.clone(),
84        }
85    }
86}
87
88impl<Svc, B, Fut> AuthorizationService<Svc, B, Fut>
89where
90    Svc: Clone,
91    Fut: DefiningFuture<B>,
92{
93    /// Create a new service.
94    pub fn new(inner: Svc, auth: Authorization<B, Fut>) -> Self {
95        Self { inner, auth }
96    }
97}
98
99impl<Svc, ReqBody, ResBody, Fut> Service<Request<ReqBody>>
100    for AuthorizationService<Svc, ReqBody, Fut>
101where
102    Svc: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
103    ResBody: Body + Send + Default,
104    ReqBody: Send + 'static,
105    Fut: DefiningFuture<ReqBody>,
106{
107    type Response = Svc::Response;
108    type Error = Svc::Error;
109    type Future = UndefinedFuture<Svc, ReqBody>;
110
111    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112        self.inner.poll_ready(cx)
113    }
114
115    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
116        let auth_future = (self.auth.clone().authorize)(request);
117        let mut inner = self.inner.clone();
118        // mem::swap due to https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
119        mem::swap(&mut self.inner, &mut inner);
120
121        UndefinedFuture::define(Box::pin(auth_future), inner)
122    }
123}
124
125#[cfg(test)]
126mod test {
127    use axum::{Router, routing::get};
128    use http::{Request, StatusCode};
129    use tower::ServiceExt;
130    use ts_token::jwt::TokenType;
131
132    use crate::{
133        middleware::{authorization::Authorization, test::get_request},
134        test::ResponseTestExt,
135    };
136
137    #[tokio::test]
138    async fn axum() {
139        async fn authorize<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
140            Ok(request)
141        }
142
143        let auth = Authorization::new(authorize);
144
145        let request = get_request(Some(TokenType::Common));
146
147        Router::new()
148            .route("/resource/id", get(|| async move { StatusCode::OK }))
149            .layer(auth)
150            .oneshot(request)
151            .await
152            .unwrap()
153            .expect_status(StatusCode::OK);
154    }
155}