Skip to main content

tower_http/compression/
future.rs

1#![allow(unused_imports)]
2
3use super::{body::BodyInner, CompressionBody};
4use crate::compression::predicate::Predicate;
5use crate::compression::CompressionLevel;
6use crate::compression_utils::WrapBody;
7use crate::content_encoding::Encoding;
8use http::{header, HeaderMap, HeaderValue, Response};
9use http_body::Body;
10use pin_project_lite::pin_project;
11use std::{
12    future::Future,
13    pin::Pin,
14    task::{ready, Context, Poll},
15};
16
17pin_project! {
18    /// Response future of [`Compression`].
19    ///
20    /// [`Compression`]: super::Compression
21    #[derive(Debug)]
22    pub struct ResponseFuture<F, P> {
23        #[pin]
24        pub(crate) inner: F,
25        pub(crate) encoding: Option<Encoding>,
26        pub(crate) predicate: P,
27        pub(crate) quality: CompressionLevel,
28    }
29}
30
31impl<F, B, E, P> Future for ResponseFuture<F, P>
32where
33    F: Future<Output = Result<Response<B>, E>>,
34    B: Body,
35    P: Predicate,
36{
37    type Output = Result<Response<CompressionBody<B>>, E>;
38
39    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        let res = ready!(self.as_mut().project().inner.poll(cx)?);
41
42        let encoding = match self.encoding {
43            Some(enc) => enc,
44            None => {
45                // RFC 9110 ยง12.5.3: the server SHOULD respond with 406 Not Acceptable
46                // when no encoding is satisfiable. This middleware chooses to enforce it.
47                //
48                // Note: the inner service has already been called, so its response body and
49                // headers are passed through. Only the status code is overwritten.
50                let mut res = res;
51                *res.status_mut() = http::StatusCode::NOT_ACCEPTABLE;
52                if !res.headers().get_all(header::VARY).iter().any(|value| {
53                    contains_ignore_ascii_case(
54                        value.as_bytes(),
55                        header::ACCEPT_ENCODING.as_str().as_bytes(),
56                    )
57                }) {
58                    res.headers_mut()
59                        .append(header::VARY, header::ACCEPT_ENCODING.into());
60                }
61                let (parts, body) = res.into_parts();
62                return Poll::Ready(Ok(Response::from_parts(
63                    parts,
64                    CompressionBody::new(BodyInner::identity(body)),
65                )));
66            }
67        };
68
69        // never recompress responses that are already compressed
70        let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
71            // never compress responses that are ranges
72            && !res.headers().contains_key(header::CONTENT_RANGE)
73            && self.predicate.should_compress(&res);
74
75        let (mut parts, body) = res.into_parts();
76
77        if should_compress
78            && !parts.headers.get_all(header::VARY).iter().any(|value| {
79                contains_ignore_ascii_case(
80                    value.as_bytes(),
81                    header::ACCEPT_ENCODING.as_str().as_bytes(),
82                )
83            })
84        {
85            parts
86                .headers
87                .append(header::VARY, header::ACCEPT_ENCODING.into());
88        }
89
90        let body = match (should_compress, encoding) {
91            // if compression is _not_ supported or the client doesn't accept it
92            (false, _) | (_, Encoding::Identity) => {
93                return Poll::Ready(Ok(Response::from_parts(
94                    parts,
95                    CompressionBody::new(BodyInner::identity(body)),
96                )))
97            }
98
99            #[cfg(feature = "compression-gzip")]
100            (_, Encoding::Gzip) => {
101                CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
102            }
103            #[cfg(feature = "compression-deflate")]
104            (_, Encoding::Deflate) => {
105                CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
106            }
107            #[cfg(feature = "compression-br")]
108            (_, Encoding::Brotli) => {
109                CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
110            }
111            #[cfg(feature = "compression-zstd")]
112            (_, Encoding::Zstd) => {
113                CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
114            }
115            #[cfg(feature = "fs")]
116            #[allow(unreachable_patterns)]
117            (true, _) => {
118                // This should never happen because the `AcceptEncoding` struct which is used to determine
119                // `self.encoding` will only enable the different compression algorithms if the
120                // corresponding crate feature has been enabled. This means
121                // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
122                // features enabled.
123                //
124                // The match arm is still required though because the `fs` feature uses the
125                // Encoding struct independently and requires no compression logic to be enabled.
126                // This means a combination of an individual compression feature and `fs` will fail
127                // to compile without this branch even though it will never be reached.
128                //
129                // To safeguard against refactors that changes this relationship or other bugs the
130                // server will return an uncompressed response instead of panicking since that could
131                // become a ddos attack vector.
132                return Poll::Ready(Ok(Response::from_parts(
133                    parts,
134                    CompressionBody::new(BodyInner::identity(body)),
135                )));
136            }
137        };
138
139        parts.headers.remove(header::ACCEPT_RANGES);
140        parts.headers.remove(header::CONTENT_LENGTH);
141
142        parts
143            .headers
144            .insert(header::CONTENT_ENCODING, encoding.into_header_value());
145
146        let res = Response::from_parts(parts, body);
147        Poll::Ready(Ok(res))
148    }
149}
150
151fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
152    while needle.len() <= haystack.len() {
153        if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
154            return true;
155        }
156        haystack = &haystack[1..];
157    }
158
159    false
160}