tower_http/auth/async_require_authorization.rs
1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
11//! use futures_util::future::BoxFuture;
12//! use bytes::Bytes;
13//! use http_body_util::Full;
14//!
15//! #[derive(Clone, Copy)]
16//! struct MyAuth;
17//!
18//! impl<B> AsyncAuthorizeRequest<B> for MyAuth
19//! where
20//! B: Send + Sync + 'static,
21//! {
22//! type RequestBody = B;
23//! type ResponseBody = Full<Bytes>;
24//! type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
25//!
26//! fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
27//! Box::pin(async {
28//! if let Some(user_id) = check_auth(&request).await {
29//! // Set `user_id` as a request extension so it can be accessed by other
30//! // services down the stack.
31//! request.extensions_mut().insert(user_id);
32//!
33//! Ok(request)
34//! } else {
35//! let unauthorized_response = Response::builder()
36//! .status(StatusCode::UNAUTHORIZED)
37//! .body(Full::<Bytes>::default())
38//! .unwrap();
39//!
40//! Err(unauthorized_response)
41//! }
42//! })
43//! }
44//! }
45//!
46//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
47//! // ...
48//! # None
49//! }
50//!
51//! #[derive(Debug, Clone)]
52//! struct UserId(String);
53//!
54//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
55//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
56//! // request was authorized and `UserId` will be present.
57//! let user_id = request
58//! .extensions()
59//! .get::<UserId>()
60//! .expect("UserId will be there if request was authorized");
61//!
62//! println!("request from {:?}", user_id);
63//!
64//! Ok(Response::new(Full::default()))
65//! }
66//!
67//! # #[tokio::main]
68//! # async fn main() -> Result<(), BoxError> {
69//! let service = ServiceBuilder::new()
70//! // Authorize requests using `MyAuth`
71//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
72//! .service_fn(handle);
73//! # Ok(())
74//! # }
75//! ```
76//!
77//! Or using a closure:
78//!
79//! ```
80//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
81//! use http::{Request, Response, StatusCode};
82//! use tower::{Service, ServiceExt, ServiceBuilder, BoxError};
83//! use futures_util::future::BoxFuture;
84//! use http_body_util::Full;
85//! use bytes::Bytes;
86//!
87//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
88//! // ...
89//! # None
90//! }
91//!
92//! #[derive(Debug)]
93//! struct UserId(String);
94//!
95//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
96//! # todo!();
97//! // ...
98//! }
99//!
100//! # #[tokio::main]
101//! # async fn main() -> Result<(), BoxError> {
102//! let service = ServiceBuilder::new()
103//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request<Full<Bytes>>| async move {
104//! if let Some(user_id) = check_auth(&request).await {
105//! Ok(request)
106//! } else {
107//! let unauthorized_response = Response::builder()
108//! .status(StatusCode::UNAUTHORIZED)
109//! .body(Full::<Bytes>::default())
110//! .unwrap();
111//!
112//! Err(unauthorized_response)
113//! }
114//! }))
115//! .service_fn(handle);
116//! # Ok(())
117//! # }
118//! ```
119
120use http::{Request, Response};
121use pin_project_lite::pin_project;
122use std::{
123 future::Future,
124 pin::Pin,
125 task::{ready, Context, Poll},
126};
127use tower_layer::Layer;
128use tower_service::Service;
129
130/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
131/// [`Authorization`] header.
132///
133/// See the [module docs](crate::auth::async_require_authorization) for an example.
134///
135/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
136#[derive(Debug, Clone)]
137pub struct AsyncRequireAuthorizationLayer<T> {
138 auth: T,
139}
140
141impl<T> AsyncRequireAuthorizationLayer<T> {
142 /// Authorize requests using a custom scheme.
143 pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
144 Self { auth }
145 }
146}
147
148impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
149where
150 T: Clone,
151{
152 type Service = AsyncRequireAuthorization<S, T>;
153
154 fn layer(&self, inner: S) -> Self::Service {
155 AsyncRequireAuthorization::new(inner, self.auth.clone())
156 }
157}
158
159/// Middleware that authorizes all requests using the [`Authorization`] header.
160///
161/// See the [module docs](crate::auth::async_require_authorization) for an example.
162///
163/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
164#[derive(Clone, Debug)]
165pub struct AsyncRequireAuthorization<S, T> {
166 inner: S,
167 auth: T,
168}
169
170impl<S, T> AsyncRequireAuthorization<S, T> {
171 define_inner_service_accessors!();
172}
173
174impl<S, T> AsyncRequireAuthorization<S, T> {
175 /// Authorize requests using a custom scheme.
176 ///
177 /// The `Authorization` header is required to have the value provided.
178 pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
179 Self { inner, auth }
180 }
181
182 /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`]
183 /// middleware.
184 ///
185 /// [`Layer`]: tower_layer::Layer
186 pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> {
187 AsyncRequireAuthorizationLayer::new(auth)
188 }
189}
190
191impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth>
192where
193 Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>,
194 S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone,
195{
196 type Response = Response<ResBody>;
197 type Error = S::Error;
198 type Future = ResponseFuture<Auth, S, ReqBody>;
199
200 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
201 self.inner.poll_ready(cx)
202 }
203
204 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
205 let inner = self.inner.clone();
206 let authorize = self.auth.authorize(req);
207
208 ResponseFuture {
209 state: State::Authorize { authorize },
210 service: inner,
211 }
212 }
213}
214
215pin_project! {
216 /// Response future for [`AsyncRequireAuthorization`].
217 pub struct ResponseFuture<Auth, S, ReqBody>
218 where
219 Auth: AsyncAuthorizeRequest<ReqBody>,
220 S: Service<Request<Auth::RequestBody>>,
221 {
222 #[pin]
223 state: State<Auth::Future, S::Future>,
224 service: S,
225 }
226}
227
228pin_project! {
229 #[project = StateProj]
230 enum State<A, SFut> {
231 Authorize {
232 #[pin]
233 authorize: A,
234 },
235 Authorized {
236 #[pin]
237 fut: SFut,
238 },
239 }
240}
241
242impl<Auth, S, ReqBody, B> Future for ResponseFuture<Auth, S, ReqBody>
243where
244 Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = B>,
245 S: Service<Request<Auth::RequestBody>, Response = Response<B>>,
246{
247 type Output = Result<Response<B>, S::Error>;
248
249 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250 let mut this = self.project();
251
252 loop {
253 match this.state.as_mut().project() {
254 StateProj::Authorize { authorize } => {
255 let auth = ready!(authorize.poll(cx));
256 match auth {
257 Ok(req) => {
258 let fut = this.service.call(req);
259 this.state.set(State::Authorized { fut })
260 }
261 Err(res) => {
262 return Poll::Ready(Ok(res));
263 }
264 };
265 }
266 StateProj::Authorized { fut } => {
267 return fut.poll(cx);
268 }
269 }
270 }
271 }
272}
273
274/// Trait for authorizing requests.
275pub trait AsyncAuthorizeRequest<B> {
276 /// The type of request body returned by `authorize`.
277 ///
278 /// Set this to `B` unless you need to change the request body type.
279 type RequestBody;
280
281 /// The body type used for responses to unauthorized requests.
282 type ResponseBody;
283
284 /// The Future type returned by `authorize`
285 type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
286
287 /// Authorize the request.
288 ///
289 /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
290 fn authorize(&mut self, request: Request<B>) -> Self::Future;
291}
292
293impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F
294where
295 F: FnMut(Request<B>) -> Fut,
296 Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>,
297{
298 type RequestBody = ReqBody;
299 type ResponseBody = ResBody;
300 type Future = Fut;
301
302 fn authorize(&mut self, request: Request<B>) -> Self::Future {
303 self(request)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 #[allow(unused_imports)]
310 use super::*;
311 use crate::test_helpers::Body;
312 use futures_util::future::BoxFuture;
313 use http::{header, StatusCode};
314 use tower::{BoxError, ServiceBuilder, ServiceExt};
315
316 #[derive(Clone, Copy)]
317 struct MyAuth;
318
319 impl<B> AsyncAuthorizeRequest<B> for MyAuth
320 where
321 B: Send + 'static,
322 {
323 type RequestBody = B;
324 type ResponseBody = Body;
325 type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
326
327 fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
328 Box::pin(async move {
329 let authorized = request
330 .headers()
331 .get(header::AUTHORIZATION)
332 .and_then(|it| it.to_str().ok())
333 .and_then(|it| it.strip_prefix("Bearer "))
334 .map(|it| it == "69420")
335 .unwrap_or(false);
336
337 if authorized {
338 let user_id = UserId("6969".to_owned());
339 request.extensions_mut().insert(user_id);
340 Ok(request)
341 } else {
342 Err(Response::builder()
343 .status(StatusCode::UNAUTHORIZED)
344 .body(Body::empty())
345 .unwrap())
346 }
347 })
348 }
349 }
350
351 #[derive(Clone, Debug)]
352 struct UserId(String);
353
354 #[tokio::test]
355 async fn require_async_auth_works() {
356 let mut service = ServiceBuilder::new()
357 .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
358 .service_fn(echo);
359
360 let request = Request::get("/")
361 .header(header::AUTHORIZATION, "Bearer 69420")
362 .body(Body::empty())
363 .unwrap();
364
365 let res = service.ready().await.unwrap().call(request).await.unwrap();
366
367 assert_eq!(res.status(), StatusCode::OK);
368 }
369
370 #[tokio::test]
371 async fn require_async_auth_401() {
372 let mut service = ServiceBuilder::new()
373 .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
374 .service_fn(echo);
375
376 let request = Request::get("/")
377 .header(header::AUTHORIZATION, "Bearer deez")
378 .body(Body::empty())
379 .unwrap();
380
381 let res = service.ready().await.unwrap().call(request).await.unwrap();
382
383 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
384 }
385
386 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
387 Ok(Response::new(req.into_body()))
388 }
389}