1#![deprecated(since = "0.6.7", note = "too basic to be useful in real applications")]
2use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
59use base64::Engine as _;
60use http::{
61 header::{self, HeaderValue},
62 Request, Response, StatusCode,
63};
64use std::{fmt, marker::PhantomData};
65
66const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
67
68impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
69 pub fn basic(inner: S, username: &str, value: &str) -> Self
77 where
78 ResBody: Default,
79 {
80 Self::custom(inner, Basic::new(username, value))
81 }
82}
83
84impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
85 pub fn basic(username: &str, password: &str) -> Self
93 where
94 ResBody: Default,
95 {
96 Self::custom(Basic::new(username, password))
97 }
98}
99
100impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
101 pub fn bearer(inner: S, token: &str) -> Self
109 where
110 ResBody: Default,
111 {
112 Self::custom(inner, Bearer::new(token))
113 }
114}
115
116impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
117 pub fn bearer(token: &str) -> Self
125 where
126 ResBody: Default,
127 {
128 Self::custom(Bearer::new(token))
129 }
130}
131
132pub struct Bearer<ResBody> {
136 header_value: HeaderValue,
137 _ty: PhantomData<fn() -> ResBody>,
138}
139
140impl<ResBody> Bearer<ResBody> {
141 fn new(token: &str) -> Self
142 where
143 ResBody: Default,
144 {
145 Self {
146 header_value: format!("Bearer {}", token)
147 .parse()
148 .expect("token is not a valid header value"),
149 _ty: PhantomData,
150 }
151 }
152}
153
154impl<ResBody> Clone for Bearer<ResBody> {
155 fn clone(&self) -> Self {
156 Self {
157 header_value: self.header_value.clone(),
158 _ty: PhantomData,
159 }
160 }
161}
162
163impl<ResBody> fmt::Debug for Bearer<ResBody> {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.debug_struct("Bearer")
166 .field("header_value", &self.header_value)
167 .finish()
168 }
169}
170
171impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
172where
173 ResBody: Default,
174{
175 type ResponseBody = ResBody;
176
177 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
178 match request.headers().get(header::AUTHORIZATION) {
179 Some(actual) if actual == self.header_value => Ok(()),
180 _ => {
181 let mut res = Response::new(ResBody::default());
182 *res.status_mut() = StatusCode::UNAUTHORIZED;
183 Err(res)
184 }
185 }
186 }
187}
188
189pub struct Basic<ResBody> {
193 header_value: HeaderValue,
194 _ty: PhantomData<fn() -> ResBody>,
195}
196
197impl<ResBody> Basic<ResBody> {
198 fn new(username: &str, password: &str) -> Self
199 where
200 ResBody: Default,
201 {
202 let encoded = BASE64.encode(format!("{}:{}", username, password));
203 let header_value = format!("Basic {}", encoded).parse().unwrap();
204 Self {
205 header_value,
206 _ty: PhantomData,
207 }
208 }
209}
210
211impl<ResBody> Clone for Basic<ResBody> {
212 fn clone(&self) -> Self {
213 Self {
214 header_value: self.header_value.clone(),
215 _ty: PhantomData,
216 }
217 }
218}
219
220impl<ResBody> fmt::Debug for Basic<ResBody> {
221 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222 f.debug_struct("Basic")
223 .field("header_value", &self.header_value)
224 .finish()
225 }
226}
227
228impl<B, ResBody> ValidateRequest<B> for Basic<ResBody>
229where
230 ResBody: Default,
231{
232 type ResponseBody = ResBody;
233
234 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
235 match request.headers().get(header::AUTHORIZATION) {
236 Some(actual) if actual == self.header_value => Ok(()),
237 _ => {
238 let mut res = Response::new(ResBody::default());
239 *res.status_mut() = StatusCode::UNAUTHORIZED;
240 res.headers_mut()
241 .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
242 Err(res)
243 }
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use crate::validate_request::ValidateRequestHeaderLayer;
251
252 #[allow(unused_imports)]
253 use super::*;
254 use crate::test_helpers::Body;
255 use http::header;
256 use tower::{BoxError, ServiceBuilder, ServiceExt};
257 use tower_service::Service;
258
259 #[tokio::test]
260 async fn valid_basic_token() {
261 let mut service = ServiceBuilder::new()
262 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
263 .service_fn(echo);
264
265 let request = Request::get("/")
266 .header(
267 header::AUTHORIZATION,
268 format!("Basic {}", BASE64.encode("foo:bar")),
269 )
270 .body(Body::empty())
271 .unwrap();
272
273 let res = service.ready().await.unwrap().call(request).await.unwrap();
274
275 assert_eq!(res.status(), StatusCode::OK);
276 }
277
278 #[tokio::test]
279 async fn invalid_basic_token() {
280 let mut service = ServiceBuilder::new()
281 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
282 .service_fn(echo);
283
284 let request = Request::get("/")
285 .header(
286 header::AUTHORIZATION,
287 format!("Basic {}", BASE64.encode("wrong:credentials")),
288 )
289 .body(Body::empty())
290 .unwrap();
291
292 let res = service.ready().await.unwrap().call(request).await.unwrap();
293
294 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
295
296 let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
297 assert_eq!(www_authenticate, "Basic");
298 }
299
300 #[tokio::test]
301 async fn valid_bearer_token() {
302 let mut service = ServiceBuilder::new()
303 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
304 .service_fn(echo);
305
306 let request = Request::get("/")
307 .header(header::AUTHORIZATION, "Bearer foobar")
308 .body(Body::empty())
309 .unwrap();
310
311 let res = service.ready().await.unwrap().call(request).await.unwrap();
312
313 assert_eq!(res.status(), StatusCode::OK);
314 }
315
316 #[tokio::test]
317 async fn basic_auth_is_case_sensitive_in_prefix() {
318 let mut service = ServiceBuilder::new()
319 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
320 .service_fn(echo);
321
322 let request = Request::get("/")
323 .header(
324 header::AUTHORIZATION,
325 format!("basic {}", BASE64.encode("foo:bar")),
326 )
327 .body(Body::empty())
328 .unwrap();
329
330 let res = service.ready().await.unwrap().call(request).await.unwrap();
331
332 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
333 }
334
335 #[tokio::test]
336 async fn basic_auth_is_case_sensitive_in_value() {
337 let mut service = ServiceBuilder::new()
338 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
339 .service_fn(echo);
340
341 let request = Request::get("/")
342 .header(
343 header::AUTHORIZATION,
344 format!("Basic {}", BASE64.encode("Foo:bar")),
345 )
346 .body(Body::empty())
347 .unwrap();
348
349 let res = service.ready().await.unwrap().call(request).await.unwrap();
350
351 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
352 }
353
354 #[tokio::test]
355 async fn invalid_bearer_token() {
356 let mut service = ServiceBuilder::new()
357 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
358 .service_fn(echo);
359
360 let request = Request::get("/")
361 .header(header::AUTHORIZATION, "Bearer wat")
362 .body(Body::empty())
363 .unwrap();
364
365 let res = service.ready().await.unwrap().call(request).await.unwrap();
366
367 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
368 }
369
370 #[tokio::test]
371 async fn bearer_token_is_case_sensitive_in_prefix() {
372 let mut service = ServiceBuilder::new()
373 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
374 .service_fn(echo);
375
376 let request = Request::get("/")
377 .header(header::AUTHORIZATION, "bearer foobar")
378 .body(Body::empty())
379 .unwrap();
380
381 let res = service.ready().await.unwrap().call(request).await.unwrap();
382
383 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
384 }
385
386 #[tokio::test]
387 async fn bearer_token_is_case_sensitive_in_token() {
388 let mut service = ServiceBuilder::new()
389 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
390 .service_fn(echo);
391
392 let request = Request::get("/")
393 .header(header::AUTHORIZATION, "Bearer Foobar")
394 .body(Body::empty())
395 .unwrap();
396
397 let res = service.ready().await.unwrap().call(request).await.unwrap();
398
399 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
400 }
401
402 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
403 Ok(Response::new(req.into_body()))
404 }
405}