tower_http/compression/
future.rs

1#![allow(unused_imports)]
2
3use super::{body::BodyInner, CompressionBody};
4use crate::compression::predicate::Predicate;
5use crate::compression::CompressionLevel;
6use crate::compression_utils::WrapBody;
7use crate::content_encoding::Encoding;
8use http::{header, HeaderMap, HeaderValue, Response};
9use http_body::Body;
10use pin_project_lite::pin_project;
11use std::{
12    future::Future,
13    pin::Pin,
14    task::{ready, Context, Poll},
15};
16
17pin_project! {
18    /// Response future of [`Compression`].
19    ///
20    /// [`Compression`]: super::Compression
21    #[derive(Debug)]
22    pub struct ResponseFuture<F, P> {
23        #[pin]
24        pub(crate) inner: F,
25        pub(crate) encoding: Encoding,
26        pub(crate) predicate: P,
27        pub(crate) quality: CompressionLevel,
28    }
29}
30
31impl<F, B, E, P> Future for ResponseFuture<F, P>
32where
33    F: Future<Output = Result<Response<B>, E>>,
34    B: Body,
35    P: Predicate,
36{
37    type Output = Result<Response<CompressionBody<B>>, E>;
38
39    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        let res = ready!(self.as_mut().project().inner.poll(cx)?);
41
42        // never recompress responses that are already compressed
43        let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
44            // never compress responses that are ranges
45            && !res.headers().contains_key(header::CONTENT_RANGE)
46            && self.predicate.should_compress(&res);
47
48        let (mut parts, body) = res.into_parts();
49
50        if should_compress
51            && !parts.headers.get_all(header::VARY).iter().any(|value| {
52                contains_ignore_ascii_case(
53                    value.as_bytes(),
54                    header::ACCEPT_ENCODING.as_str().as_bytes(),
55                )
56            })
57        {
58            parts
59                .headers
60                .append(header::VARY, header::ACCEPT_ENCODING.into());
61        }
62
63        let body = match (should_compress, self.encoding) {
64            // if compression is _not_ supported or the client doesn't accept it
65            (false, _) | (_, Encoding::Identity) => {
66                return Poll::Ready(Ok(Response::from_parts(
67                    parts,
68                    CompressionBody::new(BodyInner::identity(body)),
69                )))
70            }
71
72            #[cfg(feature = "compression-gzip")]
73            (_, Encoding::Gzip) => {
74                CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
75            }
76            #[cfg(feature = "compression-deflate")]
77            (_, Encoding::Deflate) => {
78                CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
79            }
80            #[cfg(feature = "compression-br")]
81            (_, Encoding::Brotli) => {
82                CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
83            }
84            #[cfg(feature = "compression-zstd")]
85            (_, Encoding::Zstd) => {
86                CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
87            }
88            #[cfg(feature = "fs")]
89            #[allow(unreachable_patterns)]
90            (true, _) => {
91                // This should never happen because the `AcceptEncoding` struct which is used to determine
92                // `self.encoding` will only enable the different compression algorithms if the
93                // corresponding crate feature has been enabled. This means
94                // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
95                // features enabled.
96                //
97                // The match arm is still required though because the `fs` feature uses the
98                // Encoding struct independently and requires no compression logic to be enabled.
99                // This means a combination of an individual compression feature and `fs` will fail
100                // to compile without this branch even though it will never be reached.
101                //
102                // To safeguard against refactors that changes this relationship or other bugs the
103                // server will return an uncompressed response instead of panicking since that could
104                // become a ddos attack vector.
105                return Poll::Ready(Ok(Response::from_parts(
106                    parts,
107                    CompressionBody::new(BodyInner::identity(body)),
108                )));
109            }
110        };
111
112        parts.headers.remove(header::ACCEPT_RANGES);
113        parts.headers.remove(header::CONTENT_LENGTH);
114
115        parts
116            .headers
117            .insert(header::CONTENT_ENCODING, self.encoding.into_header_value());
118
119        let res = Response::from_parts(parts, body);
120        Poll::Ready(Ok(res))
121    }
122}
123
124fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
125    while needle.len() <= haystack.len() {
126        if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
127            return true;
128        }
129        haystack = &haystack[1..];
130    }
131
132    false
133}