tower_async_http/auth/async_require_authorization.rs
1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_async_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use http_body_util::Full;
11//! use bytes::Bytes;
12//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
13//! use futures_util::future::BoxFuture;
14//!
15//! #[derive(Clone, Copy)]
16//! struct MyAuth;
17//!
18//! impl<B> AsyncAuthorizeRequest<B> for MyAuth
19//! where
20//! B: Send + Sync + 'static,
21//! {
22//! type RequestBody = B;
23//! type ResponseBody = Full<Bytes>;
24//!
25//! async fn authorize(&self, mut request: Request<B>) -> Result<Request<B>, Response<Self::ResponseBody>> {
26//! if let Some(user_id) = check_auth(&request).await {
27//! // Set `user_id` as a request extension so it can be accessed by other
28//! // services down the stack.
29//! request.extensions_mut().insert(user_id);
30//!
31//! Ok(request)
32//! } else {
33//! let unauthorized_response = Response::builder()
34//! .status(StatusCode::UNAUTHORIZED)
35//! .body(Full::<Bytes>::default())
36//! .unwrap();
37//!
38//! Err(unauthorized_response)
39//! }
40//! }
41//! }
42//!
43//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
44//! // ...
45//! # None
46//! }
47//!
48//! #[derive(Clone, Debug)]
49//! struct UserId(String);
50//!
51//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
52//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
53//! // request was authorized and `UserId` will be present.
54//! let user_id = request
55//! .extensions()
56//! .get::<UserId>()
57//! .expect("UserId will be there if request was authorized");
58//!
59//! println!("request from {:?}", user_id);
60//!
61//! Ok(Response::new(Full::<Bytes>::default()))
62//! }
63//!
64//! # #[tokio::main]
65//! # async fn main() -> Result<(), BoxError> {
66//! let service = ServiceBuilder::new()
67//! // Authorize requests using `MyAuth`
68//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
69//! .service_fn(handle);
70//! # Ok(())
71//! # }
72//! ```
73//!
74//! Or using a closure:
75//!
76//! ```
77//! use tower_async_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
78//! use http::{Request, Response, StatusCode};
79//! use http_body_util::Full;
80//! use bytes::Bytes;
81//! use tower_async::{Service, ServiceExt, ServiceBuilder, BoxError};
82//! use futures_util::future::BoxFuture;
83//!
84//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
85//! // ...
86//! # None
87//! }
88//!
89//! #[derive(Debug)]
90//! struct UserId(String);
91//!
92//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
93//! # todo!();
94//! // ...
95//! }
96//!
97//! # #[tokio::main]
98//! # async fn main() -> Result<(), BoxError> {
99//! let service = ServiceBuilder::new()
100//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request<Full<Bytes>>| async move {
101//! if let Some(user_id) = check_auth(&request).await {
102//! Ok(request)
103//! } else {
104//! let unauthorized_response = Response::builder()
105//! .status(StatusCode::UNAUTHORIZED)
106//! .body(Full::<Bytes>::default())
107//! .unwrap();
108//!
109//! Err(unauthorized_response)
110//! }
111//! }))
112//! .service_fn(handle);
113//! # Ok(())
114//! # }
115//! ```
116
117use http::{Request, Response};
118use std::future::Future;
119use tower_async_layer::Layer;
120use tower_async_service::Service;
121
122/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
123/// [`Authorization`] header.
124///
125/// See the [module docs](crate::auth::async_require_authorization) for an example.
126///
127/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
128#[derive(Debug, Clone)]
129pub struct AsyncRequireAuthorizationLayer<T> {
130 auth: T,
131}
132
133impl<T> AsyncRequireAuthorizationLayer<T> {
134 /// Authorize requests using a custom scheme.
135 pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
136 Self { auth }
137 }
138}
139
140impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
141where
142 T: Clone,
143{
144 type Service = AsyncRequireAuthorization<S, T>;
145
146 fn layer(&self, inner: S) -> Self::Service {
147 AsyncRequireAuthorization::new(inner, self.auth.clone())
148 }
149}
150
151/// Middleware that authorizes all requests using the [`Authorization`] header.
152///
153/// See the [module docs](crate::auth::async_require_authorization) for an example.
154///
155/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
156#[derive(Clone, Debug)]
157pub struct AsyncRequireAuthorization<S, T> {
158 inner: S,
159 auth: T,
160}
161
162impl<S, T> AsyncRequireAuthorization<S, T> {
163 define_inner_service_accessors!();
164}
165
166impl<S, T> AsyncRequireAuthorization<S, T> {
167 /// Authorize requests using a custom scheme.
168 ///
169 /// The `Authorization` header is required to have the value provided.
170 pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
171 Self { inner, auth }
172 }
173
174 /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`]
175 /// middleware.
176 ///
177 /// [`Layer`]: tower_async_layer::Layer
178 pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> {
179 AsyncRequireAuthorizationLayer::new(auth)
180 }
181}
182
183impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth>
184where
185 Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>,
186 S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone,
187{
188 type Response = Response<ResBody>;
189 type Error = S::Error;
190
191 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
192 let req = match self.auth.authorize(req).await {
193 Ok(req) => req,
194 Err(res) => return Ok(res),
195 };
196 self.inner.call(req).await
197 }
198}
199
200/// Trait for authorizing requests.
201pub trait AsyncAuthorizeRequest<B> {
202 /// The type of request body returned by `authorize`.
203 ///
204 /// Set this to `B` unless you need to change the request body type.
205 type RequestBody;
206
207 /// The body type used for responses to unauthorized requests.
208 type ResponseBody;
209
210 /// Authorize the request.
211 ///
212 /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
213 fn authorize(
214 &self,
215 request: Request<B>,
216 ) -> impl std::future::Future<
217 Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>,
218 >;
219}
220
221impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F
222where
223 F: Fn(Request<B>) -> Fut,
224 Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>,
225{
226 type RequestBody = ReqBody;
227 type ResponseBody = ResBody;
228
229 async fn authorize(
230 &self,
231 request: Request<B>,
232 ) -> Result<Request<Self::RequestBody>, Response<Self::ResponseBody>> {
233 self(request).await
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 #[allow(unused_imports)]
240 use super::*;
241
242 use crate::test_helpers::Body;
243
244 use http::{header, StatusCode};
245 use tower_async::{BoxError, ServiceBuilder};
246
247 #[derive(Clone, Copy)]
248 struct MyAuth;
249
250 impl<B> AsyncAuthorizeRequest<B> for MyAuth
251 where
252 B: Send + 'static,
253 {
254 type RequestBody = B;
255 type ResponseBody = Body;
256
257 async fn authorize(
258 &self,
259 mut request: Request<B>,
260 ) -> Result<Request<Self::RequestBody>, Response<Self::ResponseBody>> {
261 let authorized = request
262 .headers()
263 .get(header::AUTHORIZATION)
264 .and_then(|it: &http::HeaderValue| it.to_str().ok())
265 .and_then(|it| it.strip_prefix("Bearer "))
266 .map(|it| it == "69420")
267 .unwrap_or(false);
268
269 if authorized {
270 let user_id = UserId("6969".to_owned());
271 request.extensions_mut().insert(user_id);
272 Ok(request)
273 } else {
274 Err(Response::builder()
275 .status(StatusCode::UNAUTHORIZED)
276 .body(Body::empty())
277 .unwrap())
278 }
279 }
280 }
281
282 #[derive(Debug, Clone)]
283 struct UserId(String);
284
285 #[tokio::test]
286 async fn require_async_auth_works() {
287 let service = ServiceBuilder::new()
288 .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
289 .service_fn(echo);
290
291 let request = Request::get("/")
292 .header(header::AUTHORIZATION, "Bearer 69420")
293 .body(Body::empty())
294 .unwrap();
295
296 let res = service.call(request).await.unwrap();
297
298 assert_eq!(res.status(), StatusCode::OK);
299 }
300
301 #[tokio::test]
302 async fn require_async_auth_401() {
303 let service = ServiceBuilder::new()
304 .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
305 .service_fn(echo);
306
307 let request = Request::get("/")
308 .header(header::AUTHORIZATION, "Bearer deez")
309 .body(Body::empty())
310 .unwrap();
311
312 let res = service.call(request).await.unwrap();
313
314 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
315 }
316
317 async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
318 Ok(Response::new(req.into_body()))
319 }
320}