tower_http/auth/
async_require_authorization.rs

1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
11//! use futures_util::future::BoxFuture;
12//! use bytes::Bytes;
13//! use http_body_util::Full;
14//!
15//! #[derive(Clone, Copy)]
16//! struct MyAuth;
17//!
18//! impl<B> AsyncAuthorizeRequest<B> for MyAuth
19//! where
20//!     B: Send + Sync + 'static,
21//! {
22//!     type RequestBody = B;
23//!     type ResponseBody = Full<Bytes>;
24//!     type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
25//!
26//!     fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
27//!         Box::pin(async {
28//!             if let Some(user_id) = check_auth(&request).await {
29//!                 // Set `user_id` as a request extension so it can be accessed by other
30//!                 // services down the stack.
31//!                 request.extensions_mut().insert(user_id);
32//!
33//!                 Ok(request)
34//!             } else {
35//!                 let unauthorized_response = Response::builder()
36//!                     .status(StatusCode::UNAUTHORIZED)
37//!                     .body(Full::<Bytes>::default())
38//!                     .unwrap();
39//!
40//!                 Err(unauthorized_response)
41//!             }
42//!         })
43//!     }
44//! }
45//!
46//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
47//!     // ...
48//!     # None
49//! }
50//!
51//! #[derive(Debug, Clone)]
52//! struct UserId(String);
53//!
54//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
55//!     // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
56//!     // request was authorized and `UserId` will be present.
57//!     let user_id = request
58//!         .extensions()
59//!         .get::<UserId>()
60//!         .expect("UserId will be there if request was authorized");
61//!
62//!     println!("request from {:?}", user_id);
63//!
64//!     Ok(Response::new(Full::default()))
65//! }
66//!
67//! # #[tokio::main]
68//! # async fn main() -> Result<(), BoxError> {
69//! let service = ServiceBuilder::new()
70//!     // Authorize requests using `MyAuth`
71//!     .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
72//!     .service_fn(handle);
73//! # Ok(())
74//! # }
75//! ```
76//!
77//! Or using a closure:
78//!
79//! ```
80//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
81//! use http::{Request, Response, StatusCode};
82//! use tower::{Service, ServiceExt, ServiceBuilder, BoxError};
83//! use futures_util::future::BoxFuture;
84//! use http_body_util::Full;
85//! use bytes::Bytes;
86//!
87//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
88//!     // ...
89//!     # None
90//! }
91//!
92//! #[derive(Debug)]
93//! struct UserId(String);
94//!
95//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
96//!     # todo!();
97//!     // ...
98//! }
99//!
100//! # #[tokio::main]
101//! # async fn main() -> Result<(), BoxError> {
102//! let service = ServiceBuilder::new()
103//!     .layer(AsyncRequireAuthorizationLayer::new(|request: Request<Full<Bytes>>| async move {
104//!         if let Some(user_id) = check_auth(&request).await {
105//!             Ok(request)
106//!         } else {
107//!             let unauthorized_response = Response::builder()
108//!                 .status(StatusCode::UNAUTHORIZED)
109//!                 .body(Full::<Bytes>::default())
110//!                 .unwrap();
111//!
112//!             Err(unauthorized_response)
113//!         }
114//!     }))
115//!     .service_fn(handle);
116//! # Ok(())
117//! # }
118//! ```
119
120use http::{Request, Response};
121use pin_project_lite::pin_project;
122use std::{
123    future::Future,
124    pin::Pin,
125    task::{ready, Context, Poll},
126};
127use tower_layer::Layer;
128use tower_service::Service;
129
130/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
131/// [`Authorization`] header.
132///
133/// See the [module docs](crate::auth::async_require_authorization) for an example.
134///
135/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
136#[derive(Debug, Clone)]
137pub struct AsyncRequireAuthorizationLayer<T> {
138    auth: T,
139}
140
141impl<T> AsyncRequireAuthorizationLayer<T> {
142    /// Authorize requests using a custom scheme.
143    pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
144        Self { auth }
145    }
146}
147
148impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
149where
150    T: Clone,
151{
152    type Service = AsyncRequireAuthorization<S, T>;
153
154    fn layer(&self, inner: S) -> Self::Service {
155        AsyncRequireAuthorization::new(inner, self.auth.clone())
156    }
157}
158
159/// Middleware that authorizes all requests using the [`Authorization`] header.
160///
161/// See the [module docs](crate::auth::async_require_authorization) for an example.
162///
163/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
164#[derive(Clone, Debug)]
165pub struct AsyncRequireAuthorization<S, T> {
166    inner: S,
167    auth: T,
168}
169
170impl<S, T> AsyncRequireAuthorization<S, T> {
171    define_inner_service_accessors!();
172}
173
174impl<S, T> AsyncRequireAuthorization<S, T> {
175    /// Authorize requests using a custom scheme.
176    ///
177    /// The `Authorization` header is required to have the value provided.
178    pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
179        Self { inner, auth }
180    }
181
182    /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`]
183    /// middleware.
184    ///
185    /// [`Layer`]: tower_layer::Layer
186    pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> {
187        AsyncRequireAuthorizationLayer::new(auth)
188    }
189}
190
191impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth>
192where
193    Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>,
194    S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone,
195{
196    type Response = Response<ResBody>;
197    type Error = S::Error;
198    type Future = ResponseFuture<Auth, S, ReqBody>;
199
200    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
201        self.inner.poll_ready(cx)
202    }
203
204    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
205        let inner = self.inner.clone();
206        let authorize = self.auth.authorize(req);
207
208        ResponseFuture {
209            state: State::Authorize { authorize },
210            service: inner,
211        }
212    }
213}
214
215pin_project! {
216    /// Response future for [`AsyncRequireAuthorization`].
217    pub struct ResponseFuture<Auth, S, ReqBody>
218    where
219        Auth: AsyncAuthorizeRequest<ReqBody>,
220        S: Service<Request<Auth::RequestBody>>,
221    {
222        #[pin]
223        state: State<Auth::Future, S::Future>,
224        service: S,
225    }
226}
227
228pin_project! {
229    #[project = StateProj]
230    enum State<A, SFut> {
231        Authorize {
232            #[pin]
233            authorize: A,
234        },
235        Authorized {
236            #[pin]
237            fut: SFut,
238        },
239    }
240}
241
242impl<Auth, S, ReqBody, B> Future for ResponseFuture<Auth, S, ReqBody>
243where
244    Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = B>,
245    S: Service<Request<Auth::RequestBody>, Response = Response<B>>,
246{
247    type Output = Result<Response<B>, S::Error>;
248
249    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250        let mut this = self.project();
251
252        loop {
253            match this.state.as_mut().project() {
254                StateProj::Authorize { authorize } => {
255                    let auth = ready!(authorize.poll(cx));
256                    match auth {
257                        Ok(req) => {
258                            let fut = this.service.call(req);
259                            this.state.set(State::Authorized { fut })
260                        }
261                        Err(res) => {
262                            return Poll::Ready(Ok(res));
263                        }
264                    };
265                }
266                StateProj::Authorized { fut } => {
267                    return fut.poll(cx);
268                }
269            }
270        }
271    }
272}
273
274/// Trait for authorizing requests.
275pub trait AsyncAuthorizeRequest<B> {
276    /// The type of request body returned by `authorize`.
277    ///
278    /// Set this to `B` unless you need to change the request body type.
279    type RequestBody;
280
281    /// The body type used for responses to unauthorized requests.
282    type ResponseBody;
283
284    /// The Future type returned by `authorize`
285    type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
286
287    /// Authorize the request.
288    ///
289    /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
290    fn authorize(&mut self, request: Request<B>) -> Self::Future;
291}
292
293impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F
294where
295    F: FnMut(Request<B>) -> Fut,
296    Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>,
297{
298    type RequestBody = ReqBody;
299    type ResponseBody = ResBody;
300    type Future = Fut;
301
302    fn authorize(&mut self, request: Request<B>) -> Self::Future {
303        self(request)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    #[allow(unused_imports)]
310    use super::*;
311    use crate::test_helpers::Body;
312    use futures_util::future::BoxFuture;
313    use http::{header, StatusCode};
314    use tower::{BoxError, ServiceBuilder, ServiceExt};
315
316    #[derive(Clone, Copy)]
317    struct MyAuth;
318
319    impl<B> AsyncAuthorizeRequest<B> for MyAuth
320    where
321        B: Send + 'static,
322    {
323        type RequestBody = B;
324        type ResponseBody = Body;
325        type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
326
327        fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
328            Box::pin(async move {
329                let authorized = request
330                    .headers()
331                    .get(header::AUTHORIZATION)
332                    .and_then(|it| it.to_str().ok())
333                    .and_then(|it| it.strip_prefix("Bearer "))
334                    .map(|it| it == "69420")
335                    .unwrap_or(false);
336
337                if authorized {
338                    let user_id = UserId("6969".to_owned());
339                    request.extensions_mut().insert(user_id);
340                    Ok(request)
341                } else {
342                    Err(Response::builder()
343                        .status(StatusCode::UNAUTHORIZED)
344                        .body(Body::empty())
345                        .unwrap())
346                }
347            })
348        }
349    }
350
351    #[derive(Clone, Debug)]
352    struct UserId(String);
353
354    #[tokio::test]
355    async fn require_async_auth_works() {
356        let mut service = ServiceBuilder::new()
357            .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
358            .service_fn(echo);
359
360        let request = Request::get("/")
361            .header(header::AUTHORIZATION, "Bearer 69420")
362            .body(Body::empty())
363            .unwrap();
364
365        let res = service.ready().await.unwrap().call(request).await.unwrap();
366
367        assert_eq!(res.status(), StatusCode::OK);
368    }
369
370    #[tokio::test]
371    async fn require_async_auth_401() {
372        let mut service = ServiceBuilder::new()
373            .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
374            .service_fn(echo);
375
376        let request = Request::get("/")
377            .header(header::AUTHORIZATION, "Bearer deez")
378            .body(Body::empty())
379            .unwrap();
380
381        let res = service.ready().await.unwrap().call(request).await.unwrap();
382
383        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
384    }
385
386    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
387        Ok(Response::new(req.into_body()))
388    }
389}