tower_oauth2_resource_server/
layer.rs1use futures_util::{Future, future::BoxFuture};
2use http::{Request, Response};
3use pin_project::pin_project;
4use serde::de::DeserializeOwned;
5
6use std::{
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll, ready},
10};
11use tower::{Layer, Service};
12
13use crate::{error_handler::ErrorHandler, server::OAuth2ResourceServer};
14
15trait Authorize<ReqBody, ResBody> {
16 type Future: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>;
17
18 fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future;
19}
20
21impl<S, ReqBody, ResBody, Claims> Authorize<ReqBody, ResBody>
22 for OAuth2ResourceServerService<S, ResBody, Claims>
23where
24 Claims: DeserializeOwned + Clone + Send + Sync + 'static,
25 ReqBody: Send + 'static,
26 ResBody: Send + 'static,
27{
28 type Future = BoxFuture<'static, Result<Request<ReqBody>, Response<ResBody>>>;
29
30 fn authorize(&mut self, request: Request<ReqBody>) -> Self::Future {
31 let auth = self.auth_manager.clone();
32 let error_handler = self.error_handler.clone();
33 Box::pin(async move {
34 match auth.authorize_request(request).await {
35 Ok(request) => Ok(request),
36 Err(error) => Err(error_handler.map_error(error)),
37 }
38 })
39 }
40}
41
42pub struct OAuth2ResourceServerLayer<ResBody, Claims> {
43 auth_manager: OAuth2ResourceServer<Claims>,
44 error_handler: Arc<dyn ErrorHandler<ResBody>>,
45}
46
47impl<ResBody, Claims> Clone for OAuth2ResourceServerLayer<ResBody, Claims>
48where
49 Claims: Clone,
50{
51 fn clone(&self) -> Self {
52 Self {
53 auth_manager: self.auth_manager.clone(),
54 error_handler: self.error_handler.clone(),
55 }
56 }
57}
58
59impl<S, ResBody, Claims> Layer<S> for OAuth2ResourceServerLayer<ResBody, Claims>
60where
61 Claims: Clone + DeserializeOwned + Send + 'static,
62{
63 type Service = OAuth2ResourceServerService<S, ResBody, Claims>;
64
65 fn layer(&self, inner: S) -> Self::Service {
66 OAuth2ResourceServerService::new(
67 inner,
68 self.auth_manager.clone(),
69 self.error_handler.clone(),
70 )
71 }
72}
73
74impl<ResBody, Claims> OAuth2ResourceServerLayer<ResBody, Claims> {
75 pub(crate) fn new(
76 auth_manager: OAuth2ResourceServer<Claims>,
77 error_handler: Arc<dyn ErrorHandler<ResBody>>,
78 ) -> Self {
79 OAuth2ResourceServerLayer {
80 auth_manager,
81 error_handler,
82 }
83 }
84}
85
86pub struct OAuth2ResourceServerService<S, ResBody, Claims> {
87 inner: S,
88 auth_manager: OAuth2ResourceServer<Claims>,
89 error_handler: Arc<dyn ErrorHandler<ResBody>>,
90}
91
92impl<S, ResBody, Claims> Clone for OAuth2ResourceServerService<S, ResBody, Claims>
93where
94 S: Clone,
95 Claims: Clone,
96{
97 fn clone(&self) -> Self {
98 Self {
99 inner: self.inner.clone(),
100 auth_manager: self.auth_manager.clone(),
101 error_handler: self.error_handler.clone(),
102 }
103 }
104}
105
106impl<S, ResBody, Claims> OAuth2ResourceServerService<S, ResBody, Claims> {
107 fn new(
108 inner: S,
109 auth_manager: OAuth2ResourceServer<Claims>,
110 error_handler: Arc<dyn ErrorHandler<ResBody>>,
111 ) -> Self {
112 Self {
113 inner,
114 auth_manager,
115 error_handler,
116 }
117 }
118}
119
120impl<S, ReqBody, ResBody, Claims> Service<Request<ReqBody>>
121 for OAuth2ResourceServerService<S, ResBody, Claims>
122where
123 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
124 ResBody: Default + Send + 'static,
125 Claims: Clone + DeserializeOwned + Send + Sync + 'static,
126 ReqBody: Send + 'static,
127{
128 type Response = S::Response;
129 type Error = S::Error;
130 type Future = ResponseFuture<S, ReqBody, ResBody, Claims>;
131
132 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133 self.inner.poll_ready(cx)
134 }
135
136 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
137 let inner = self.inner.clone();
138 let authorize = self.authorize(request);
139
140 ResponseFuture {
141 state: State::Authorize { authorize },
142 service: inner,
143 }
144 }
145}
146
147type AuthorizeFuture<S, ReqBody, ResBody, Claims> =
148 <OAuth2ResourceServerService<S, ResBody, Claims> as Authorize<ReqBody, ResBody>>::Future;
149
150#[pin_project]
151pub struct ResponseFuture<S, ReqBody, ResBody, Claims>
152where
153 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
154 ReqBody: Send + 'static,
155 Claims: Clone + DeserializeOwned + Send + Sync + 'static,
156 ResBody: Send + 'static,
157{
158 #[pin]
159 state: State<AuthorizeFuture<S, ReqBody, ResBody, Claims>, S::Future>,
160 service: S,
161}
162
163#[pin_project(project = StateProj)]
164enum State<A, SFut> {
165 Authorize {
166 #[pin]
167 authorize: A,
168 },
169 Authorized {
170 #[pin]
171 fut: SFut,
172 },
173}
174
175impl<S, ReqBody, ResBody, Claims> Future for ResponseFuture<S, ReqBody, ResBody, Claims>
176where
177 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
178 ResBody: Default + Send + 'static,
179 ReqBody: Send + 'static,
180 Claims: Clone + DeserializeOwned + Send + Sync + 'static,
181{
182 type Output = Result<Response<ResBody>, S::Error>;
183
184 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
185 let mut this = self.project();
186
187 loop {
188 match this.state.as_mut().project() {
189 StateProj::Authorize { authorize } => {
190 let auth = ready!(authorize.poll(cx));
191 match auth {
192 Ok(req) => {
193 let fut = this.service.call(req);
194 this.state.set(State::Authorized { fut })
195 }
196 Err(res) => {
197 return Poll::Ready(Ok(res));
198 }
199 };
200 }
201 StateProj::Authorized { fut } => {
202 return fut.poll(cx);
203 }
204 }
205 }
206 }
207}