tower_http/compression/
future.rs1#![allow(unused_imports)]
2
3use super::{body::BodyInner, CompressionBody};
4use crate::compression::predicate::Predicate;
5use crate::compression::CompressionLevel;
6use crate::compression_utils::WrapBody;
7use crate::content_encoding::Encoding;
8use http::{header, HeaderMap, HeaderValue, Response};
9use http_body::Body;
10use pin_project_lite::pin_project;
11use std::{
12 future::Future,
13 pin::Pin,
14 task::{ready, Context, Poll},
15};
16
17pin_project! {
18 #[derive(Debug)]
22 pub struct ResponseFuture<F, P> {
23 #[pin]
24 pub(crate) inner: F,
25 pub(crate) encoding: Encoding,
26 pub(crate) predicate: P,
27 pub(crate) quality: CompressionLevel,
28 }
29}
30
31impl<F, B, E, P> Future for ResponseFuture<F, P>
32where
33 F: Future<Output = Result<Response<B>, E>>,
34 B: Body,
35 P: Predicate,
36{
37 type Output = Result<Response<CompressionBody<B>>, E>;
38
39 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40 let res = ready!(self.as_mut().project().inner.poll(cx)?);
41
42 let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
44 && !res.headers().contains_key(header::CONTENT_RANGE)
46 && self.predicate.should_compress(&res);
47
48 let (mut parts, body) = res.into_parts();
49
50 if should_compress
51 && !parts.headers.get_all(header::VARY).iter().any(|value| {
52 contains_ignore_ascii_case(
53 value.as_bytes(),
54 header::ACCEPT_ENCODING.as_str().as_bytes(),
55 )
56 })
57 {
58 parts
59 .headers
60 .append(header::VARY, header::ACCEPT_ENCODING.into());
61 }
62
63 let body = match (should_compress, self.encoding) {
64 (false, _) | (_, Encoding::Identity) => {
66 return Poll::Ready(Ok(Response::from_parts(
67 parts,
68 CompressionBody::new(BodyInner::identity(body)),
69 )))
70 }
71
72 #[cfg(feature = "compression-gzip")]
73 (_, Encoding::Gzip) => {
74 CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
75 }
76 #[cfg(feature = "compression-deflate")]
77 (_, Encoding::Deflate) => {
78 CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
79 }
80 #[cfg(feature = "compression-br")]
81 (_, Encoding::Brotli) => {
82 CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
83 }
84 #[cfg(feature = "compression-zstd")]
85 (_, Encoding::Zstd) => {
86 CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
87 }
88 #[cfg(feature = "fs")]
89 #[allow(unreachable_patterns)]
90 (true, _) => {
91 return Poll::Ready(Ok(Response::from_parts(
106 parts,
107 CompressionBody::new(BodyInner::identity(body)),
108 )));
109 }
110 };
111
112 parts.headers.remove(header::ACCEPT_RANGES);
113 parts.headers.remove(header::CONTENT_LENGTH);
114
115 parts
116 .headers
117 .insert(header::CONTENT_ENCODING, self.encoding.into_header_value());
118
119 let res = Response::from_parts(parts, body);
120 Poll::Ready(Ok(res))
121 }
122}
123
124fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
125 while needle.len() <= haystack.len() {
126 if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
127 return true;
128 }
129 haystack = &haystack[1..];
130 }
131
132 false
133}