ts_webapi/middleware/
authorization.rs1use core::{
5 mem,
6 task::{Context, Poll},
7};
8
9use http::{Request, Response};
10use http_body::Body;
11use tower_layer::Layer;
12use tower_service::Service;
13
14use crate::middleware::futures::{UndefinedFuture, undefined::DefiningFuture};
15
16#[derive(Debug)]
18pub struct Authorization<B, Fut>
19where
20 Fut: DefiningFuture<B>,
21{
22 authorize: fn(Request<B>) -> Fut,
24}
25
26impl<B, Fut> Clone for Authorization<B, Fut>
27where
28 Fut: DefiningFuture<B>,
29{
30 fn clone(&self) -> Self {
31 Self {
32 authorize: self.authorize,
33 }
34 }
35}
36
37impl<B, Fut> Authorization<B, Fut>
38where
39 Fut: DefiningFuture<B>,
40{
41 pub fn new(authorize: fn(Request<B>) -> Fut) -> Self {
43 Self { authorize }
44 }
45}
46
47impl<Svc, B, Fut> Layer<Svc> for Authorization<B, Fut>
48where
49 Svc: Clone,
50 Fut: DefiningFuture<B>,
51{
52 type Service = AuthorizationService<Svc, B, Fut>;
53
54 fn layer(&self, inner: Svc) -> Self::Service {
55 AuthorizationService {
56 inner,
57 auth: self.clone(),
58 }
59 }
60}
61
62#[derive(Debug)]
64pub struct AuthorizationService<Svc, B, Fut>
65where
66 Svc: Clone,
67 Fut: DefiningFuture<B>,
68{
69 inner: Svc,
71 auth: Authorization<B, Fut>,
73}
74
75impl<Svc, B, Fut> Clone for AuthorizationService<Svc, B, Fut>
76where
77 Svc: Clone,
78 Fut: DefiningFuture<B>,
79{
80 fn clone(&self) -> Self {
81 Self {
82 inner: self.inner.clone(),
83 auth: self.auth.clone(),
84 }
85 }
86}
87
88impl<Svc, B, Fut> AuthorizationService<Svc, B, Fut>
89where
90 Svc: Clone,
91 Fut: DefiningFuture<B>,
92{
93 pub fn new(inner: Svc, auth: Authorization<B, Fut>) -> Self {
95 Self { inner, auth }
96 }
97}
98
99impl<Svc, ReqBody, ResBody, Fut> Service<Request<ReqBody>>
100 for AuthorizationService<Svc, ReqBody, Fut>
101where
102 Svc: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
103 ResBody: Body + Send + Default,
104 ReqBody: Send + 'static,
105 Fut: DefiningFuture<ReqBody>,
106{
107 type Response = Svc::Response;
108 type Error = Svc::Error;
109 type Future = UndefinedFuture<Svc, ReqBody>;
110
111 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112 self.inner.poll_ready(cx)
113 }
114
115 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
116 let auth_future = (self.auth.clone().authorize)(request);
117 let mut inner = self.inner.clone();
118 mem::swap(&mut self.inner, &mut inner);
120
121 UndefinedFuture::define(Box::pin(auth_future), inner)
122 }
123}
124
125#[cfg(test)]
126mod test {
127 use axum::{Router, routing::get};
128 use http::{Request, StatusCode};
129 use tower::ServiceExt;
130 use ts_token::jwt::TokenType;
131
132 use crate::{
133 middleware::{authorization::Authorization, test::get_request},
134 test::ResponseTestExt,
135 };
136
137 #[tokio::test]
138 async fn axum() {
139 async fn authorize<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
140 Ok(request)
141 }
142
143 let auth = Authorization::new(authorize);
144
145 let request = get_request(Some(TokenType::Common));
146
147 Router::new()
148 .route("/resource/id", get(|| async move { StatusCode::OK }))
149 .layer(auth)
150 .oneshot(request)
151 .await
152 .unwrap()
153 .expect_status(StatusCode::OK);
154 }
155}