tower_async_http/auth/
require_authorization.rs

1//! Authorize requests using [`ValidateRequest`].
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_async_http::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use http_body_util::Full;
11//! use bytes::Bytes;
12//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
13//!
14//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
15//!     Ok(Response::new(Full::default()))
16//! }
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), BoxError> {
20//! let mut service = ServiceBuilder::new()
21//!     // Require the `Authorization` header to be `Bearer passwordlol`
22//!     .layer(ValidateRequestHeaderLayer::bearer("passwordlol"))
23//!     .service_fn(handle);
24//!
25//! // Requests with the correct token are allowed through
26//! let request = Request::builder()
27//!     .header(AUTHORIZATION, "Bearer passwordlol")
28//!     .body(Full::<Bytes>::default())
29//!     .unwrap();
30//!
31//! let response = service
32//!
33//!     .call(request)
34//!     .await?;
35//!
36//! assert_eq!(StatusCode::OK, response.status());
37//!
38//! // Requests with an invalid token get a `401 Unauthorized` response
39//! let request = Request::builder()
40//!     .body(Full::<Bytes>::default())
41//!     .unwrap();
42//!
43//! let response = service
44//!     .call(request)
45//!     .await?;
46//!
47//! assert_eq!(StatusCode::UNAUTHORIZED, response.status());
48//! # Ok(())
49//! # }
50//! ```
51//!
52//! Custom validation can be made by implementing [`ValidateRequest`].
53
54use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
55use base64::Engine as _;
56use http::{
57    header::{self, HeaderValue},
58    Request, Response, StatusCode,
59};
60use http_body::Body;
61use std::{fmt, marker::PhantomData};
62
63const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
64
65impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
66    /// Authorize requests using a username and password pair.
67    ///
68    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
69    /// `base64_encode("{username}:{password}")`.
70    ///
71    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
72    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
73    pub fn basic(inner: S, username: &str, value: &str) -> Self
74    where
75        ResBody: Body + Default,
76    {
77        Self::custom(inner, Basic::new(username, value))
78    }
79}
80
81impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
82    /// Authorize requests using a username and password pair.
83    ///
84    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
85    /// `base64_encode("{username}:{password}")`.
86    ///
87    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
88    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
89    pub fn basic(username: &str, password: &str) -> Self
90    where
91        ResBody: Body + Default,
92    {
93        Self::custom(Basic::new(username, password))
94    }
95}
96
97impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
98    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
99    ///
100    /// The `Authorization` header is required to be `Bearer {token}`.
101    ///
102    /// # Panics
103    ///
104    /// Panics if the token is not a valid [`HeaderValue`].
105    pub fn bearer(inner: S, token: &str) -> Self
106    where
107        ResBody: Body + Default,
108    {
109        Self::custom(inner, Bearer::new(token))
110    }
111}
112
113impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
114    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
115    ///
116    /// The `Authorization` header is required to be `Bearer {token}`.
117    ///
118    /// # Panics
119    ///
120    /// Panics if the token is not a valid [`HeaderValue`].
121    pub fn bearer(token: &str) -> Self
122    where
123        ResBody: Body + Default,
124    {
125        Self::custom(Bearer::new(token))
126    }
127}
128
129/// Type that performs "bearer token" authorization.
130///
131/// See [`ValidateRequestHeader::bearer`] for more details.
132pub struct Bearer<ResBody> {
133    header_value: HeaderValue,
134    _ty: PhantomData<fn() -> ResBody>,
135}
136
137impl<ResBody> Bearer<ResBody> {
138    fn new(token: &str) -> Self
139    where
140        ResBody: Body + Default,
141    {
142        Self {
143            header_value: format!("Bearer {}", token)
144                .parse()
145                .expect("token is not a valid header value"),
146            _ty: PhantomData,
147        }
148    }
149}
150
151impl<ResBody> Clone for Bearer<ResBody> {
152    fn clone(&self) -> Self {
153        Self {
154            header_value: self.header_value.clone(),
155            _ty: PhantomData,
156        }
157    }
158}
159
160impl<ResBody> fmt::Debug for Bearer<ResBody> {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        f.debug_struct("Bearer")
163            .field("header_value", &self.header_value)
164            .finish()
165    }
166}
167
168impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
169where
170    ResBody: Body + Default,
171{
172    type ResponseBody = ResBody;
173
174    fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
175        match request.headers().get(header::AUTHORIZATION) {
176            Some(actual) if actual == self.header_value => Ok(()),
177            _ => {
178                let mut res = Response::new(ResBody::default());
179                *res.status_mut() = StatusCode::UNAUTHORIZED;
180                Err(res)
181            }
182        }
183    }
184}
185
186/// Type that performs basic authorization.
187///
188/// See [`ValidateRequestHeader::basic`] for more details.
189pub struct Basic<ResBody> {
190    header_value: HeaderValue,
191    _ty: PhantomData<fn() -> ResBody>,
192}
193
194impl<ResBody> Basic<ResBody> {
195    fn new(username: &str, password: &str) -> Self
196    where
197        ResBody: Body + Default,
198    {
199        let encoded = BASE64.encode(format!("{}:{}", username, password));
200        let header_value = format!("Basic {}", encoded).parse().unwrap();
201        Self {
202            header_value,
203            _ty: PhantomData,
204        }
205    }
206}
207
208impl<ResBody> Clone for Basic<ResBody> {
209    fn clone(&self) -> Self {
210        Self {
211            header_value: self.header_value.clone(),
212            _ty: PhantomData,
213        }
214    }
215}
216
217impl<ResBody> fmt::Debug for Basic<ResBody> {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        f.debug_struct("Basic")
220            .field("header_value", &self.header_value)
221            .finish()
222    }
223}
224
225impl<B, ResBody> ValidateRequest<B> for Basic<ResBody>
226where
227    ResBody: Body + Default,
228{
229    type ResponseBody = ResBody;
230
231    fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
232        match request.headers().get(header::AUTHORIZATION) {
233            Some(actual) if actual == self.header_value => Ok(()),
234            _ => {
235                let mut res = Response::new(ResBody::default());
236                *res.status_mut() = StatusCode::UNAUTHORIZED;
237                res.headers_mut()
238                    .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
239                Err(res)
240            }
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    #[allow(unused_imports)]
248    use super::*;
249
250    use crate::test_helpers::Body;
251    use crate::validate_request::ValidateRequestHeaderLayer;
252
253    use http::header;
254    use tower_async::{BoxError, ServiceBuilder};
255    use tower_async_service::Service;
256
257    #[tokio::test]
258    async fn valid_basic_token() {
259        let service = ServiceBuilder::new()
260            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
261            .service_fn(echo);
262
263        let request = Request::get("/")
264            .header(
265                header::AUTHORIZATION,
266                format!("Basic {}", BASE64.encode("foo:bar")),
267            )
268            .body(Body::empty())
269            .unwrap();
270
271        let res = service.call(request).await.unwrap();
272
273        assert_eq!(res.status(), StatusCode::OK);
274    }
275
276    #[tokio::test]
277    async fn invalid_basic_token() {
278        let service = ServiceBuilder::new()
279            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
280            .service_fn(echo);
281
282        let request = Request::get("/")
283            .header(
284                header::AUTHORIZATION,
285                format!("Basic {}", BASE64.encode("wrong:credentials")),
286            )
287            .body(Body::empty())
288            .unwrap();
289
290        let res = service.call(request).await.unwrap();
291
292        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
293
294        let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
295        assert_eq!(www_authenticate, "Basic");
296    }
297
298    #[tokio::test]
299    async fn valid_bearer_token() {
300        let service = ServiceBuilder::new()
301            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
302            .service_fn(echo);
303
304        let request = Request::get("/")
305            .header(header::AUTHORIZATION, "Bearer foobar")
306            .body(Body::empty())
307            .unwrap();
308
309        let res = service.call(request).await.unwrap();
310
311        assert_eq!(res.status(), StatusCode::OK);
312    }
313
314    #[tokio::test]
315    async fn basic_auth_is_case_sensitive_in_prefix() {
316        let service = ServiceBuilder::new()
317            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
318            .service_fn(echo);
319
320        let request = Request::get("/")
321            .header(
322                header::AUTHORIZATION,
323                format!("basic {}", BASE64.encode("foo:bar")),
324            )
325            .body(Body::empty())
326            .unwrap();
327
328        let res = service.call(request).await.unwrap();
329
330        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
331    }
332
333    #[tokio::test]
334    async fn basic_auth_is_case_sensitive_in_value() {
335        let service = ServiceBuilder::new()
336            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
337            .service_fn(echo);
338
339        let request = Request::get("/")
340            .header(
341                header::AUTHORIZATION,
342                format!("Basic {}", BASE64.encode("Foo:bar")),
343            )
344            .body(Body::empty())
345            .unwrap();
346
347        let res = service.call(request).await.unwrap();
348
349        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
350    }
351
352    #[tokio::test]
353    async fn invalid_bearer_token() {
354        let service = ServiceBuilder::new()
355            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
356            .service_fn(echo);
357
358        let request = Request::get("/")
359            .header(header::AUTHORIZATION, "Bearer wat")
360            .body(Body::empty())
361            .unwrap();
362
363        let res = service.call(request).await.unwrap();
364
365        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
366    }
367
368    #[tokio::test]
369    async fn bearer_token_is_case_sensitive_in_prefix() {
370        let service = ServiceBuilder::new()
371            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
372            .service_fn(echo);
373
374        let request = Request::get("/")
375            .header(header::AUTHORIZATION, "bearer foobar")
376            .body(Body::empty())
377            .unwrap();
378
379        let res = service.call(request).await.unwrap();
380
381        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
382    }
383
384    #[tokio::test]
385    async fn bearer_token_is_case_sensitive_in_token() {
386        let service = ServiceBuilder::new()
387            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
388            .service_fn(echo);
389
390        let request = Request::get("/")
391            .header(header::AUTHORIZATION, "Bearer Foobar")
392            .body(Body::empty())
393            .unwrap();
394
395        let res = service.call(request).await.unwrap();
396
397        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
398    }
399
400    async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
401        Ok(Response::new(req.into_body()))
402    }
403}