tower_http/compression/
future.rs1#![allow(unused_imports)]
2
3use super::{body::BodyInner, CompressionBody};
4use crate::compression::predicate::Predicate;
5use crate::compression::CompressionLevel;
6use crate::compression_utils::WrapBody;
7use crate::content_encoding::Encoding;
8use http::{header, HeaderMap, HeaderValue, Response};
9use http_body::Body;
10use pin_project_lite::pin_project;
11use std::{
12 future::Future,
13 pin::Pin,
14 task::{ready, Context, Poll},
15};
16
17pin_project! {
18 #[derive(Debug)]
22 pub struct ResponseFuture<F, P> {
23 #[pin]
24 pub(crate) inner: F,
25 pub(crate) encoding: Option<Encoding>,
26 pub(crate) predicate: P,
27 pub(crate) quality: CompressionLevel,
28 }
29}
30
31impl<F, B, E, P> Future for ResponseFuture<F, P>
32where
33 F: Future<Output = Result<Response<B>, E>>,
34 B: Body,
35 P: Predicate,
36{
37 type Output = Result<Response<CompressionBody<B>>, E>;
38
39 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40 let res = ready!(self.as_mut().project().inner.poll(cx)?);
41
42 let encoding = match self.encoding {
43 Some(enc) => enc,
44 None => {
45 let mut res = res;
51 *res.status_mut() = http::StatusCode::NOT_ACCEPTABLE;
52 if !res.headers().get_all(header::VARY).iter().any(|value| {
53 contains_ignore_ascii_case(
54 value.as_bytes(),
55 header::ACCEPT_ENCODING.as_str().as_bytes(),
56 )
57 }) {
58 res.headers_mut()
59 .append(header::VARY, header::ACCEPT_ENCODING.into());
60 }
61 let (parts, body) = res.into_parts();
62 return Poll::Ready(Ok(Response::from_parts(
63 parts,
64 CompressionBody::new(BodyInner::identity(body)),
65 )));
66 }
67 };
68
69 let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
71 && !res.headers().contains_key(header::CONTENT_RANGE)
73 && self.predicate.should_compress(&res);
74
75 let (mut parts, body) = res.into_parts();
76
77 if should_compress
78 && !parts.headers.get_all(header::VARY).iter().any(|value| {
79 contains_ignore_ascii_case(
80 value.as_bytes(),
81 header::ACCEPT_ENCODING.as_str().as_bytes(),
82 )
83 })
84 {
85 parts
86 .headers
87 .append(header::VARY, header::ACCEPT_ENCODING.into());
88 }
89
90 let body = match (should_compress, encoding) {
91 (false, _) | (_, Encoding::Identity) => {
93 return Poll::Ready(Ok(Response::from_parts(
94 parts,
95 CompressionBody::new(BodyInner::identity(body)),
96 )))
97 }
98
99 #[cfg(feature = "compression-gzip")]
100 (_, Encoding::Gzip) => {
101 CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
102 }
103 #[cfg(feature = "compression-deflate")]
104 (_, Encoding::Deflate) => {
105 CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
106 }
107 #[cfg(feature = "compression-br")]
108 (_, Encoding::Brotli) => {
109 CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
110 }
111 #[cfg(feature = "compression-zstd")]
112 (_, Encoding::Zstd) => {
113 CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
114 }
115 #[cfg(feature = "fs")]
116 #[allow(unreachable_patterns)]
117 (true, _) => {
118 return Poll::Ready(Ok(Response::from_parts(
133 parts,
134 CompressionBody::new(BodyInner::identity(body)),
135 )));
136 }
137 };
138
139 parts.headers.remove(header::ACCEPT_RANGES);
140 parts.headers.remove(header::CONTENT_LENGTH);
141
142 parts
143 .headers
144 .insert(header::CONTENT_ENCODING, encoding.into_header_value());
145
146 let res = Response::from_parts(parts, body);
147 Poll::Ready(Ok(res))
148 }
149}
150
151fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
152 while needle.len() <= haystack.len() {
153 if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
154 return true;
155 }
156 haystack = &haystack[1..];
157 }
158
159 false
160}