1use http::{header, Request, Response, StatusCode};
116use http_body::Body;
117use mime::{Mime, MimeIter};
118use std::{fmt, marker::PhantomData, sync::Arc};
119use tower_async_layer::Layer;
120use tower_async_service::Service;
121
122#[derive(Debug, Clone)]
126pub struct ValidateRequestHeaderLayer<T> {
127 validate: T,
128}
129
130impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
131 pub fn accept(value: &str) -> Self
153 where
154 ResBody: Body + Default,
155 {
156 Self::custom(AcceptHeader::new(value))
157 }
158}
159
160impl<T> ValidateRequestHeaderLayer<T> {
161 pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
163 Self { validate }
164 }
165}
166
167impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
168where
169 T: Clone,
170{
171 type Service = ValidateRequestHeader<S, T>;
172
173 fn layer(&self, inner: S) -> Self::Service {
174 ValidateRequestHeader::new(inner, self.validate.clone())
175 }
176}
177
178#[derive(Clone, Debug)]
182pub struct ValidateRequestHeader<S, T> {
183 inner: S,
184 validate: T,
185}
186
187impl<S, T> ValidateRequestHeader<S, T> {
188 fn new(inner: S, validate: T) -> Self {
189 Self::custom(inner, validate)
190 }
191
192 define_inner_service_accessors!();
193}
194
195impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
196 pub fn accept(inner: S, value: &str) -> Self
205 where
206 ResBody: Body + Default,
207 {
208 Self::custom(inner, AcceptHeader::new(value))
209 }
210}
211
212impl<S, T> ValidateRequestHeader<S, T> {
213 pub fn custom(inner: S, validate: T) -> ValidateRequestHeader<S, T> {
215 Self { inner, validate }
216 }
217}
218
219impl<ReqBody, ResBody, S, V> Service<Request<ReqBody>> for ValidateRequestHeader<S, V>
220where
221 V: ValidateRequest<ReqBody, ResponseBody = ResBody>,
222 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
223{
224 type Response = Response<ResBody>;
225 type Error = S::Error;
226
227 async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
228 match self.validate.validate(&mut req) {
229 Ok(_) => self.inner.call(req).await,
230 Err(res) => Ok(res),
231 }
232 }
233}
234
235pub trait ValidateRequest<B> {
237 type ResponseBody;
239
240 fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>>;
244}
245
246impl<B, F, ResBody> ValidateRequest<B> for F
247where
248 F: Fn(&mut Request<B>) -> Result<(), Response<ResBody>>,
249{
250 type ResponseBody = ResBody;
251
252 fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
253 self(request)
254 }
255}
256
257pub struct AcceptHeader<ResBody> {
259 header_value: Arc<Mime>,
260 _ty: PhantomData<fn() -> ResBody>,
261}
262
263impl<ResBody> AcceptHeader<ResBody> {
264 fn new(header_value: &str) -> Self
270 where
271 ResBody: Body + Default,
272 {
273 Self {
274 header_value: Arc::new(
275 header_value
276 .parse::<Mime>()
277 .expect("value is not a valid header value"),
278 ),
279 _ty: PhantomData,
280 }
281 }
282}
283
284impl<ResBody> Clone for AcceptHeader<ResBody> {
285 fn clone(&self) -> Self {
286 Self {
287 header_value: self.header_value.clone(),
288 _ty: PhantomData,
289 }
290 }
291}
292
293impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
294 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295 f.debug_struct("AcceptHeader")
296 .field("header_value", &self.header_value)
297 .finish()
298 }
299}
300
301impl<B, ResBody> ValidateRequest<B> for AcceptHeader<ResBody>
302where
303 ResBody: Body + Default,
304{
305 type ResponseBody = ResBody;
306
307 fn validate(&self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
308 if !req.headers().contains_key(header::ACCEPT) {
309 return Ok(());
310 }
311 if req
312 .headers()
313 .get_all(header::ACCEPT)
314 .into_iter()
315 .filter_map(|header| header.to_str().ok())
316 .any(|h| {
317 MimeIter::new(h)
318 .map(|mim| {
319 if let Ok(mim) = mim {
320 let typ = self.header_value.type_();
321 let subtype = self.header_value.subtype();
322 match (mim.type_(), mim.subtype()) {
323 (t, s) if t == typ && s == subtype => true,
324 (t, mime::STAR) if t == typ => true,
325 (mime::STAR, mime::STAR) => true,
326 _ => false,
327 }
328 } else {
329 false
330 }
331 })
332 .reduce(|acc, mim| acc || mim)
333 .unwrap_or(false)
334 })
335 {
336 return Ok(());
337 }
338 let mut res = Response::new(ResBody::default());
339 *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
340 Err(res)
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 #[allow(unused_imports)]
347 use super::*;
348
349 use crate::test_helpers::Body;
350
351 use http::{header, StatusCode};
352 use tower_async::{BoxError, ServiceBuilder};
353
354 #[tokio::test]
355 async fn valid_accept_header() {
356 let service = ServiceBuilder::new()
357 .layer(ValidateRequestHeaderLayer::accept("application/json"))
358 .service_fn(echo);
359
360 let request = Request::get("/")
361 .header(header::ACCEPT, "application/json")
362 .body(Body::empty())
363 .unwrap();
364
365 let res = service.call(request).await.unwrap();
366
367 assert_eq!(res.status(), StatusCode::OK);
368 }
369
370 #[tokio::test]
371 async fn valid_accept_header_accept_all_json() {
372 let service = ServiceBuilder::new()
373 .layer(ValidateRequestHeaderLayer::accept("application/json"))
374 .service_fn(echo);
375
376 let request = Request::get("/")
377 .header(header::ACCEPT, "application/*")
378 .body(Body::empty())
379 .unwrap();
380
381 let res = service.call(request).await.unwrap();
382
383 assert_eq!(res.status(), StatusCode::OK);
384 }
385
386 #[tokio::test]
387 async fn valid_accept_header_accept_all() {
388 let service = ServiceBuilder::new()
389 .layer(ValidateRequestHeaderLayer::accept("application/json"))
390 .service_fn(echo);
391
392 let request = Request::get("/")
393 .header(header::ACCEPT, "*/*")
394 .body(Body::empty())
395 .unwrap();
396
397 let res = service.call(request).await.unwrap();
398
399 assert_eq!(res.status(), StatusCode::OK);
400 }
401
402 #[tokio::test]
403 async fn invalid_accept_header() {
404 let service = ServiceBuilder::new()
405 .layer(ValidateRequestHeaderLayer::accept("application/json"))
406 .service_fn(echo);
407
408 let request = Request::get("/")
409 .header(header::ACCEPT, "invalid")
410 .body(Body::empty())
411 .unwrap();
412
413 let res = service.call(request).await.unwrap();
414
415 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
416 }
417 #[tokio::test]
418 async fn not_accepted_accept_header_subtype() {
419 let service = ServiceBuilder::new()
420 .layer(ValidateRequestHeaderLayer::accept("application/json"))
421 .service_fn(echo);
422
423 let request = Request::get("/")
424 .header(header::ACCEPT, "application/strings")
425 .body(Body::empty())
426 .unwrap();
427
428 let res = service.call(request).await.unwrap();
429
430 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
431 }
432
433 #[tokio::test]
434 async fn not_accepted_accept_header() {
435 let service = ServiceBuilder::new()
436 .layer(ValidateRequestHeaderLayer::accept("application/json"))
437 .service_fn(echo);
438
439 let request = Request::get("/")
440 .header(header::ACCEPT, "text/strings")
441 .body(Body::empty())
442 .unwrap();
443
444 let res = service.call(request).await.unwrap();
445
446 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
447 }
448
449 #[tokio::test]
450 async fn accepted_multiple_header_value() {
451 let service = ServiceBuilder::new()
452 .layer(ValidateRequestHeaderLayer::accept("application/json"))
453 .service_fn(echo);
454
455 let request = Request::get("/")
456 .header(header::ACCEPT, "text/strings")
457 .header(header::ACCEPT, "invalid, application/json")
458 .body(Body::empty())
459 .unwrap();
460
461 let res = service.call(request).await.unwrap();
462
463 assert_eq!(res.status(), StatusCode::OK);
464 }
465
466 #[tokio::test]
467 async fn accepted_inner_header_value() {
468 let service = ServiceBuilder::new()
469 .layer(ValidateRequestHeaderLayer::accept("application/json"))
470 .service_fn(echo);
471
472 let request = Request::get("/")
473 .header(header::ACCEPT, "text/strings, invalid, application/json")
474 .body(Body::empty())
475 .unwrap();
476
477 let res = service.call(request).await.unwrap();
478
479 assert_eq!(res.status(), StatusCode::OK);
480 }
481
482 #[tokio::test]
483 async fn accepted_header_with_quotes_valid() {
484 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
485 let service = ServiceBuilder::new()
486 .layer(ValidateRequestHeaderLayer::accept("application/xml"))
487 .service_fn(echo);
488
489 let request = Request::get("/")
490 .header(header::ACCEPT, value)
491 .body(Body::empty())
492 .unwrap();
493
494 let res = service.call(request).await.unwrap();
495
496 assert_eq!(res.status(), StatusCode::OK);
497 }
498
499 #[tokio::test]
500 async fn accepted_header_with_quotes_invalid() {
501 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
502 let service = ServiceBuilder::new()
503 .layer(ValidateRequestHeaderLayer::accept("text/html"))
504 .service_fn(echo);
505
506 let request = Request::get("/")
507 .header(header::ACCEPT, value)
508 .body(Body::empty())
509 .unwrap();
510
511 let res = service.call(request).await.unwrap();
512
513 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
514 }
515
516 async fn echo<B>(req: Request<B>) -> Result<Response<B>, BoxError> {
517 Ok(Response::new(req.into_body()))
518 }
519}