tower_oauth2_resource_server/
layer.rs

1use futures_util::{Future, future::BoxFuture};
2use http::{Request, Response};
3use pin_project::pin_project;
4use serde::de::DeserializeOwned;
5
6use std::{
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll, ready},
10};
11use tower::{Layer, Service};
12
13use crate::{error_handler::ErrorHandler, server::OAuth2ResourceServer};
14
15trait Authorize<ReqBody, ResBody> {
16    type Future: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>;
17
18    fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future;
19}
20
21impl<S, ReqBody, ResBody, Claims> Authorize<ReqBody, ResBody>
22    for OAuth2ResourceServerService<S, ResBody, Claims>
23where
24    Claims: DeserializeOwned + Clone + Send + Sync + 'static,
25    ReqBody: Send + 'static,
26    ResBody: Send + 'static,
27{
28    type Future = BoxFuture<'static, Result<Request<ReqBody>, Response<ResBody>>>;
29
30    fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future {
31        let auth = self.auth_manager.clone();
32        let error_handler = self.error_handler.clone();
33        Box::pin(async move {
34            match auth.authorize_request(request).await {
35                Ok(request) => Ok(request),
36                Err(error) => Err(error_handler.map_error(error)),
37            }
38        })
39    }
40}
41
42pub struct OAuth2ResourceServerLayer<ResBody, Claims> {
43    auth_manager: OAuth2ResourceServer<Claims>,
44    error_handler: Arc<dyn ErrorHandler<ResBody>>,
45}
46
47impl<ResBody, Claims> Clone for OAuth2ResourceServerLayer<ResBody, Claims>
48where
49    Claims: Clone,
50{
51    fn clone(&self) -> Self {
52        Self {
53            auth_manager: self.auth_manager.clone(),
54            error_handler: self.error_handler.clone(),
55        }
56    }
57}
58
59impl<S, ResBody, Claims> Layer<S> for OAuth2ResourceServerLayer<ResBody, Claims>
60where
61    Claims: Clone + DeserializeOwned + Send + 'static,
62{
63    type Service = OAuth2ResourceServerService<S, ResBody, Claims>;
64
65    fn layer(&self, inner: S) -> Self::Service {
66        OAuth2ResourceServerService::new(
67            inner,
68            self.auth_manager.clone(),
69            self.error_handler.clone(),
70        )
71    }
72}
73
74impl<ResBody, Claims> OAuth2ResourceServerLayer<ResBody, Claims> {
75    pub(crate) fn new(
76        auth_manager: OAuth2ResourceServer<Claims>,
77        error_handler: Arc<dyn ErrorHandler<ResBody>>,
78    ) -> Self {
79        OAuth2ResourceServerLayer {
80            auth_manager,
81            error_handler,
82        }
83    }
84}
85
86pub struct OAuth2ResourceServerService<S, ResBody, Claims> {
87    inner: S,
88    auth_manager: OAuth2ResourceServer<Claims>,
89    error_handler: Arc<dyn ErrorHandler<ResBody>>,
90}
91
92impl<S, ResBody, Claims> Clone for OAuth2ResourceServerService<S, ResBody, Claims>
93where
94    S: Clone,
95    Claims: Clone,
96{
97    fn clone(&self) -> Self {
98        Self {
99            inner: self.inner.clone(),
100            auth_manager: self.auth_manager.clone(),
101            error_handler: self.error_handler.clone(),
102        }
103    }
104}
105
106impl<S, ResBody, Claims> OAuth2ResourceServerService<S, ResBody, Claims> {
107    fn new(
108        inner: S,
109        auth_manager: OAuth2ResourceServer<Claims>,
110        error_handler: Arc<dyn ErrorHandler<ResBody>>,
111    ) -> Self {
112        Self {
113            inner,
114            auth_manager,
115            error_handler,
116        }
117    }
118}
119
120impl<S, ReqBody, ResBody, Claims> Service<Request<ReqBody>>
121    for OAuth2ResourceServerService<S, ResBody, Claims>
122where
123    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
124    ResBody: Default + Send + 'static,
125    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
126    ReqBody: Send + 'static,
127{
128    type Response = S::Response;
129    type Error = S::Error;
130    type Future = ResponseFuture<S, ReqBody, ResBody, Claims>;
131
132    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133        self.inner.poll_ready(cx)
134    }
135
136    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
137        let inner = self.inner.clone();
138        let authorize = self.authorize(request);
139
140        ResponseFuture {
141            state: State::Authorize { authorize },
142            service: inner,
143        }
144    }
145}
146
147type AuthorizeFuture<S, ReqBody, ResBody, Claims> =
148    <OAuth2ResourceServerService<S, ResBody, Claims> as Authorize<ReqBody, ResBody>>::Future;
149
150#[pin_project]
151pub struct ResponseFuture<S, ReqBody, ResBody, Claims>
152where
153    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
154    ReqBody: Send + 'static,
155    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
156    ResBody: Send + 'static,
157{
158    #[pin]
159    state: State<AuthorizeFuture<S, ReqBody, ResBody, Claims>, S::Future>,
160    service: S,
161}
162
163#[pin_project(project = StateProj)]
164enum State<A, SFut> {
165    Authorize {
166        #[pin]
167        authorize: A,
168    },
169    Authorized {
170        #[pin]
171        fut: SFut,
172    },
173}
174
175impl<S, ReqBody, ResBody, Claims> Future for ResponseFuture<S, ReqBody, ResBody, Claims>
176where
177    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
178    ResBody: Default + Send + 'static,
179    ReqBody: Send + 'static,
180    Claims: Clone + DeserializeOwned + Send + Sync + 'static,
181{
182    type Output = Result<Response<ResBody>, S::Error>;
183
184    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
185        let mut this = self.project();
186
187        loop {
188            match this.state.as_mut().project() {
189                StateProj::Authorize { authorize } => {
190                    let auth = ready!(authorize.poll(cx));
191                    match auth {
192                        Ok(req) => {
193                            let fut = this.service.call(req);
194                            this.state.set(State::Authorized { fut })
195                        }
196                        Err(res) => {
197                            return Poll::Ready(Ok(res));
198                        }
199                    };
200                }
201                StateProj::Authorized { fut } => {
202                    return fut.poll(cx);
203                }
204            }
205        }
206    }
207}