tower_http/auth/
require_authorization.rs

1#![deprecated(since = "0.6.7", note = "too basic to be useful in real applications")]
2//! Authorize requests using [`ValidateRequest`].
3//!
4//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
5//!
6//! # Example
7//!
8//! ```
9//! use tower_http::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
10//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
11//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
12//! use bytes::Bytes;
13//! use http_body_util::Full;
14//!
15//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
16//!     Ok(Response::new(Full::default()))
17//! }
18//!
19//! # #[tokio::main]
20//! # async fn main() -> Result<(), BoxError> {
21//! let mut service = ServiceBuilder::new()
22//!     // Require the `Authorization` header to be `Bearer passwordlol`
23//!     .layer(ValidateRequestHeaderLayer::bearer("passwordlol"))
24//!     .service_fn(handle);
25//!
26//! // Requests with the correct token are allowed through
27//! let request = Request::builder()
28//!     .header(AUTHORIZATION, "Bearer passwordlol")
29//!     .body(Full::default())
30//!     .unwrap();
31//!
32//! let response = service
33//!     .ready()
34//!     .await?
35//!     .call(request)
36//!     .await?;
37//!
38//! assert_eq!(StatusCode::OK, response.status());
39//!
40//! // Requests with an invalid token get a `401 Unauthorized` response
41//! let request = Request::builder()
42//!     .body(Full::default())
43//!     .unwrap();
44//!
45//! let response = service
46//!     .ready()
47//!     .await?
48//!     .call(request)
49//!     .await?;
50//!
51//! assert_eq!(StatusCode::UNAUTHORIZED, response.status());
52//! # Ok(())
53//! # }
54//! ```
55//!
56//! Custom validation can be made by implementing [`ValidateRequest`].
57
58use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
59use base64::Engine as _;
60use http::{
61    header::{self, HeaderValue},
62    Request, Response, StatusCode,
63};
64use std::{fmt, marker::PhantomData};
65
66const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
67
68impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
69    /// Authorize requests using a username and password pair.
70    ///
71    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
72    /// `base64_encode("{username}:{password}")`.
73    ///
74    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
75    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
76    pub fn basic(inner: S, username: &str, value: &str) -> Self
77    where
78        ResBody: Default,
79    {
80        Self::custom(inner, Basic::new(username, value))
81    }
82}
83
84impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
85    /// Authorize requests using a username and password pair.
86    ///
87    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
88    /// `base64_encode("{username}:{password}")`.
89    ///
90    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
91    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
92    pub fn basic(username: &str, password: &str) -> Self
93    where
94        ResBody: Default,
95    {
96        Self::custom(Basic::new(username, password))
97    }
98}
99
100impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
101    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
102    ///
103    /// The `Authorization` header is required to be `Bearer {token}`.
104    ///
105    /// # Panics
106    ///
107    /// Panics if the token is not a valid [`HeaderValue`].
108    pub fn bearer(inner: S, token: &str) -> Self
109    where
110        ResBody: Default,
111    {
112        Self::custom(inner, Bearer::new(token))
113    }
114}
115
116impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
117    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
118    ///
119    /// The `Authorization` header is required to be `Bearer {token}`.
120    ///
121    /// # Panics
122    ///
123    /// Panics if the token is not a valid [`HeaderValue`].
124    pub fn bearer(token: &str) -> Self
125    where
126        ResBody: Default,
127    {
128        Self::custom(Bearer::new(token))
129    }
130}
131
132/// Type that performs "bearer token" authorization.
133///
134/// See [`ValidateRequestHeader::bearer`] for more details.
135pub struct Bearer<ResBody> {
136    header_value: HeaderValue,
137    _ty: PhantomData<fn() -> ResBody>,
138}
139
140impl<ResBody> Bearer<ResBody> {
141    fn new(token: &str) -> Self
142    where
143        ResBody: Default,
144    {
145        Self {
146            header_value: format!("Bearer {}", token)
147                .parse()
148                .expect("token is not a valid header value"),
149            _ty: PhantomData,
150        }
151    }
152}
153
154impl<ResBody> Clone for Bearer<ResBody> {
155    fn clone(&self) -> Self {
156        Self {
157            header_value: self.header_value.clone(),
158            _ty: PhantomData,
159        }
160    }
161}
162
163impl<ResBody> fmt::Debug for Bearer<ResBody> {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.debug_struct("Bearer")
166            .field("header_value", &self.header_value)
167            .finish()
168    }
169}
170
171impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
172where
173    ResBody: Default,
174{
175    type ResponseBody = ResBody;
176
177    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
178        match request.headers().get(header::AUTHORIZATION) {
179            Some(actual) if actual == self.header_value => Ok(()),
180            _ => {
181                let mut res = Response::new(ResBody::default());
182                *res.status_mut() = StatusCode::UNAUTHORIZED;
183                Err(res)
184            }
185        }
186    }
187}
188
189/// Type that performs basic authorization.
190///
191/// See [`ValidateRequestHeader::basic`] for more details.
192pub struct Basic<ResBody> {
193    header_value: HeaderValue,
194    _ty: PhantomData<fn() -> ResBody>,
195}
196
197impl<ResBody> Basic<ResBody> {
198    fn new(username: &str, password: &str) -> Self
199    where
200        ResBody: Default,
201    {
202        let encoded = BASE64.encode(format!("{}:{}", username, password));
203        let header_value = format!("Basic {}", encoded).parse().unwrap();
204        Self {
205            header_value,
206            _ty: PhantomData,
207        }
208    }
209}
210
211impl<ResBody> Clone for Basic<ResBody> {
212    fn clone(&self) -> Self {
213        Self {
214            header_value: self.header_value.clone(),
215            _ty: PhantomData,
216        }
217    }
218}
219
220impl<ResBody> fmt::Debug for Basic<ResBody> {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        f.debug_struct("Basic")
223            .field("header_value", &self.header_value)
224            .finish()
225    }
226}
227
228impl<B, ResBody> ValidateRequest<B> for Basic<ResBody>
229where
230    ResBody: Default,
231{
232    type ResponseBody = ResBody;
233
234    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
235        match request.headers().get(header::AUTHORIZATION) {
236            Some(actual) if actual == self.header_value => Ok(()),
237            _ => {
238                let mut res = Response::new(ResBody::default());
239                *res.status_mut() = StatusCode::UNAUTHORIZED;
240                res.headers_mut()
241                    .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
242                Err(res)
243            }
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use crate::validate_request::ValidateRequestHeaderLayer;
251
252    #[allow(unused_imports)]
253    use super::*;
254    use crate::test_helpers::Body;
255    use http::header;
256    use tower::{BoxError, ServiceBuilder, ServiceExt};
257    use tower_service::Service;
258
259    #[tokio::test]
260    async fn valid_basic_token() {
261        let mut service = ServiceBuilder::new()
262            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
263            .service_fn(echo);
264
265        let request = Request::get("/")
266            .header(
267                header::AUTHORIZATION,
268                format!("Basic {}", BASE64.encode("foo:bar")),
269            )
270            .body(Body::empty())
271            .unwrap();
272
273        let res = service.ready().await.unwrap().call(request).await.unwrap();
274
275        assert_eq!(res.status(), StatusCode::OK);
276    }
277
278    #[tokio::test]
279    async fn invalid_basic_token() {
280        let mut service = ServiceBuilder::new()
281            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
282            .service_fn(echo);
283
284        let request = Request::get("/")
285            .header(
286                header::AUTHORIZATION,
287                format!("Basic {}", BASE64.encode("wrong:credentials")),
288            )
289            .body(Body::empty())
290            .unwrap();
291
292        let res = service.ready().await.unwrap().call(request).await.unwrap();
293
294        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
295
296        let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
297        assert_eq!(www_authenticate, "Basic");
298    }
299
300    #[tokio::test]
301    async fn valid_bearer_token() {
302        let mut service = ServiceBuilder::new()
303            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
304            .service_fn(echo);
305
306        let request = Request::get("/")
307            .header(header::AUTHORIZATION, "Bearer foobar")
308            .body(Body::empty())
309            .unwrap();
310
311        let res = service.ready().await.unwrap().call(request).await.unwrap();
312
313        assert_eq!(res.status(), StatusCode::OK);
314    }
315
316    #[tokio::test]
317    async fn basic_auth_is_case_sensitive_in_prefix() {
318        let mut service = ServiceBuilder::new()
319            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
320            .service_fn(echo);
321
322        let request = Request::get("/")
323            .header(
324                header::AUTHORIZATION,
325                format!("basic {}", BASE64.encode("foo:bar")),
326            )
327            .body(Body::empty())
328            .unwrap();
329
330        let res = service.ready().await.unwrap().call(request).await.unwrap();
331
332        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
333    }
334
335    #[tokio::test]
336    async fn basic_auth_is_case_sensitive_in_value() {
337        let mut service = ServiceBuilder::new()
338            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
339            .service_fn(echo);
340
341        let request = Request::get("/")
342            .header(
343                header::AUTHORIZATION,
344                format!("Basic {}", BASE64.encode("Foo:bar")),
345            )
346            .body(Body::empty())
347            .unwrap();
348
349        let res = service.ready().await.unwrap().call(request).await.unwrap();
350
351        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
352    }
353
354    #[tokio::test]
355    async fn invalid_bearer_token() {
356        let mut service = ServiceBuilder::new()
357            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
358            .service_fn(echo);
359
360        let request = Request::get("/")
361            .header(header::AUTHORIZATION, "Bearer wat")
362            .body(Body::empty())
363            .unwrap();
364
365        let res = service.ready().await.unwrap().call(request).await.unwrap();
366
367        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
368    }
369
370    #[tokio::test]
371    async fn bearer_token_is_case_sensitive_in_prefix() {
372        let mut service = ServiceBuilder::new()
373            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
374            .service_fn(echo);
375
376        let request = Request::get("/")
377            .header(header::AUTHORIZATION, "bearer foobar")
378            .body(Body::empty())
379            .unwrap();
380
381        let res = service.ready().await.unwrap().call(request).await.unwrap();
382
383        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
384    }
385
386    #[tokio::test]
387    async fn bearer_token_is_case_sensitive_in_token() {
388        let mut service = ServiceBuilder::new()
389            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
390            .service_fn(echo);
391
392        let request = Request::get("/")
393            .header(header::AUTHORIZATION, "Bearer Foobar")
394            .body(Body::empty())
395            .unwrap();
396
397        let res = service.ready().await.unwrap().call(request).await.unwrap();
398
399        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
400    }
401
402    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
403        Ok(Response::new(req.into_body()))
404    }
405}