tower_async_http/auth/
async_require_authorization.rs

1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_async_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use http_body_util::Full;
11//! use bytes::Bytes;
12//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
13//! use futures_util::future::BoxFuture;
14//!
15//! #[derive(Clone, Copy)]
16//! struct MyAuth;
17//!
18//! impl<B> AsyncAuthorizeRequest<B> for MyAuth
19//! where
20//!     B: Send + Sync + 'static,
21//! {
22//!     type RequestBody = B;
23//!     type ResponseBody = Full<Bytes>;
24//!
25//!     async fn authorize(&self, mut request: Request<B>) -> Result<Request<B>, Response<Self::ResponseBody>> {
26//!         if let Some(user_id) = check_auth(&request).await {
27//!             // Set `user_id` as a request extension so it can be accessed by other
28//!             // services down the stack.
29//!             request.extensions_mut().insert(user_id);
30//!
31//!             Ok(request)
32//!         } else {
33//!             let unauthorized_response = Response::builder()
34//!                 .status(StatusCode::UNAUTHORIZED)
35//!                 .body(Full::<Bytes>::default())
36//!                 .unwrap();
37//!
38//!             Err(unauthorized_response)
39//!         }
40//!     }
41//! }
42//!
43//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
44//!     // ...
45//!     # None
46//! }
47//!
48//! #[derive(Clone, Debug)]
49//! struct UserId(String);
50//!
51//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
52//!     // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
53//!     // request was authorized and `UserId` will be present.
54//!     let user_id = request
55//!         .extensions()
56//!         .get::<UserId>()
57//!         .expect("UserId will be there if request was authorized");
58//!
59//!     println!("request from {:?}", user_id);
60//!
61//!     Ok(Response::new(Full::<Bytes>::default()))
62//! }
63//!
64//! # #[tokio::main]
65//! # async fn main() -> Result<(), BoxError> {
66//! let service = ServiceBuilder::new()
67//!     // Authorize requests using `MyAuth`
68//!     .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
69//!     .service_fn(handle);
70//! # Ok(())
71//! # }
72//! ```
73//!
74//! Or using a closure:
75//!
76//! ```
77//! use tower_async_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
78//! use http::{Request, Response, StatusCode};
79//! use http_body_util::Full;
80//! use bytes::Bytes;
81//! use tower_async::{Service, ServiceExt, ServiceBuilder, BoxError};
82//! use futures_util::future::BoxFuture;
83//!
84//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
85//!     // ...
86//!     # None
87//! }
88//!
89//! #[derive(Debug)]
90//! struct UserId(String);
91//!
92//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
93//!     # todo!();
94//!     // ...
95//! }
96//!
97//! # #[tokio::main]
98//! # async fn main() -> Result<(), BoxError> {
99//! let service = ServiceBuilder::new()
100//!     .layer(AsyncRequireAuthorizationLayer::new(|request: Request<Full<Bytes>>| async move {
101//!         if let Some(user_id) = check_auth(&request).await {
102//!             Ok(request)
103//!         } else {
104//!             let unauthorized_response = Response::builder()
105//!                 .status(StatusCode::UNAUTHORIZED)
106//!                 .body(Full::<Bytes>::default())
107//!                 .unwrap();
108//!
109//!             Err(unauthorized_response)
110//!         }
111//!     }))
112//!     .service_fn(handle);
113//! # Ok(())
114//! # }
115//! ```
116
117use http::{Request, Response};
118use std::future::Future;
119use tower_async_layer::Layer;
120use tower_async_service::Service;
121
122/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
123/// [`Authorization`] header.
124///
125/// See the [module docs](crate::auth::async_require_authorization) for an example.
126///
127/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
128#[derive(Debug, Clone)]
129pub struct AsyncRequireAuthorizationLayer<T> {
130    auth: T,
131}
132
133impl<T> AsyncRequireAuthorizationLayer<T> {
134    /// Authorize requests using a custom scheme.
135    pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
136        Self { auth }
137    }
138}
139
140impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
141where
142    T: Clone,
143{
144    type Service = AsyncRequireAuthorization<S, T>;
145
146    fn layer(&self, inner: S) -> Self::Service {
147        AsyncRequireAuthorization::new(inner, self.auth.clone())
148    }
149}
150
151/// Middleware that authorizes all requests using the [`Authorization`] header.
152///
153/// See the [module docs](crate::auth::async_require_authorization) for an example.
154///
155/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
156#[derive(Clone, Debug)]
157pub struct AsyncRequireAuthorization<S, T> {
158    inner: S,
159    auth: T,
160}
161
162impl<S, T> AsyncRequireAuthorization<S, T> {
163    define_inner_service_accessors!();
164}
165
166impl<S, T> AsyncRequireAuthorization<S, T> {
167    /// Authorize requests using a custom scheme.
168    ///
169    /// The `Authorization` header is required to have the value provided.
170    pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
171        Self { inner, auth }
172    }
173
174    /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`]
175    /// middleware.
176    ///
177    /// [`Layer`]: tower_async_layer::Layer
178    pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> {
179        AsyncRequireAuthorizationLayer::new(auth)
180    }
181}
182
183impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth>
184where
185    Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>,
186    S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone,
187{
188    type Response = Response<ResBody>;
189    type Error = S::Error;
190
191    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
192        let req = match self.auth.authorize(req).await {
193            Ok(req) => req,
194            Err(res) => return Ok(res),
195        };
196        self.inner.call(req).await
197    }
198}
199
200/// Trait for authorizing requests.
201pub trait AsyncAuthorizeRequest<B> {
202    /// The type of request body returned by `authorize`.
203    ///
204    /// Set this to `B` unless you need to change the request body type.
205    type RequestBody;
206
207    /// The body type used for responses to unauthorized requests.
208    type ResponseBody;
209
210    /// Authorize the request.
211    ///
212    /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
213    fn authorize(
214        &self,
215        request: Request<B>,
216    ) -> impl std::future::Future<
217        Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>,
218    >;
219}
220
221impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F
222where
223    F: Fn(Request<B>) -> Fut,
224    Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>,
225{
226    type RequestBody = ReqBody;
227    type ResponseBody = ResBody;
228
229    async fn authorize(
230        &self,
231        request: Request<B>,
232    ) -> Result<Request<Self::RequestBody>, Response<Self::ResponseBody>> {
233        self(request).await
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    #[allow(unused_imports)]
240    use super::*;
241
242    use crate::test_helpers::Body;
243
244    use http::{header, StatusCode};
245    use tower_async::{BoxError, ServiceBuilder};
246
247    #[derive(Clone, Copy)]
248    struct MyAuth;
249
250    impl<B> AsyncAuthorizeRequest<B> for MyAuth
251    where
252        B: Send + 'static,
253    {
254        type RequestBody = B;
255        type ResponseBody = Body;
256
257        async fn authorize(
258            &self,
259            mut request: Request<B>,
260        ) -> Result<Request<Self::RequestBody>, Response<Self::ResponseBody>> {
261            let authorized = request
262                .headers()
263                .get(header::AUTHORIZATION)
264                .and_then(|it: &http::HeaderValue| it.to_str().ok())
265                .and_then(|it| it.strip_prefix("Bearer "))
266                .map(|it| it == "69420")
267                .unwrap_or(false);
268
269            if authorized {
270                let user_id = UserId("6969".to_owned());
271                request.extensions_mut().insert(user_id);
272                Ok(request)
273            } else {
274                Err(Response::builder()
275                    .status(StatusCode::UNAUTHORIZED)
276                    .body(Body::empty())
277                    .unwrap())
278            }
279        }
280    }
281
282    #[derive(Debug, Clone)]
283    struct UserId(String);
284
285    #[tokio::test]
286    async fn require_async_auth_works() {
287        let service = ServiceBuilder::new()
288            .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
289            .service_fn(echo);
290
291        let request = Request::get("/")
292            .header(header::AUTHORIZATION, "Bearer 69420")
293            .body(Body::empty())
294            .unwrap();
295
296        let res = service.call(request).await.unwrap();
297
298        assert_eq!(res.status(), StatusCode::OK);
299    }
300
301    #[tokio::test]
302    async fn require_async_auth_401() {
303        let service = ServiceBuilder::new()
304            .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
305            .service_fn(echo);
306
307        let request = Request::get("/")
308            .header(header::AUTHORIZATION, "Bearer deez")
309            .body(Body::empty())
310            .unwrap();
311
312        let res = service.call(request).await.unwrap();
313
314        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
315    }
316
317    async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
318        Ok(Response::new(req.into_body()))
319    }
320}