tower_async_http/compression/service.rs
1use super::body::BodyInner;
2use super::{CompressionBody, CompressionLayer};
3use crate::compression::predicate::{DefaultPredicate, Predicate};
4use crate::compression::CompressionLevel;
5use crate::compression_utils::WrapBody;
6use crate::{compression_utils::AcceptEncoding, content_encoding::Encoding};
7use http::{header, Request, Response};
8use http_body::Body;
9use tower_async_service::Service;
10
11/// Compress response bodies of the underlying service.
12///
13/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
14/// `Content-Encoding` header to responses.
15///
16/// See the [module docs](crate::compression) for more details.
17#[derive(Clone, Copy)]
18pub struct Compression<S, P = DefaultPredicate> {
19 pub(crate) inner: S,
20 pub(crate) accept: AcceptEncoding,
21 pub(crate) predicate: P,
22 pub(crate) quality: CompressionLevel,
23}
24
25impl<S> Compression<S, DefaultPredicate> {
26 /// Creates a new `Compression` wrapping the `service`.
27 pub fn new(service: S) -> Compression<S, DefaultPredicate> {
28 Self {
29 inner: service,
30 accept: AcceptEncoding::default(),
31 predicate: DefaultPredicate::default(),
32 quality: CompressionLevel::default(),
33 }
34 }
35}
36
37impl<S, P> Compression<S, P> {
38 define_inner_service_accessors!();
39
40 /// Returns a new [`Layer`] that wraps services with a `Compression` middleware.
41 ///
42 /// [`Layer`]: tower_async_layer::Layer
43 pub fn layer() -> CompressionLayer {
44 CompressionLayer::new()
45 }
46
47 /// Sets whether to enable the gzip encoding.
48 #[cfg(feature = "compression-gzip")]
49 pub fn gzip(mut self, enable: bool) -> Self {
50 self.accept.set_gzip(enable);
51 self
52 }
53
54 /// Sets whether to enable the Deflate encoding.
55 #[cfg(feature = "compression-deflate")]
56 pub fn deflate(mut self, enable: bool) -> Self {
57 self.accept.set_deflate(enable);
58 self
59 }
60
61 /// Sets whether to enable the Brotli encoding.
62 #[cfg(feature = "compression-br")]
63 pub fn br(mut self, enable: bool) -> Self {
64 self.accept.set_br(enable);
65 self
66 }
67
68 /// Sets whether to enable the Zstd encoding.
69 #[cfg(feature = "compression-zstd")]
70 pub fn zstd(mut self, enable: bool) -> Self {
71 self.accept.set_zstd(enable);
72 self
73 }
74
75 /// Sets the compression quality.
76 pub fn quality(mut self, quality: CompressionLevel) -> Self {
77 self.quality = quality;
78 self
79 }
80
81 /// Disables the gzip encoding.
82 ///
83 /// This method is available even if the `gzip` crate feature is disabled.
84 pub fn no_gzip(mut self) -> Self {
85 self.accept.set_gzip(false);
86 self
87 }
88
89 /// Disables the Deflate encoding.
90 ///
91 /// This method is available even if the `deflate` crate feature is disabled.
92 pub fn no_deflate(mut self) -> Self {
93 self.accept.set_deflate(false);
94 self
95 }
96
97 /// Disables the Brotli encoding.
98 ///
99 /// This method is available even if the `br` crate feature is disabled.
100 pub fn no_br(mut self) -> Self {
101 self.accept.set_br(false);
102 self
103 }
104
105 /// Disables the Zstd encoding.
106 ///
107 /// This method is available even if the `zstd` crate feature is disabled.
108 pub fn no_zstd(mut self) -> Self {
109 self.accept.set_zstd(false);
110 self
111 }
112
113 /// Replace the current compression predicate.
114 ///
115 /// Predicates are used to determine whether a response should be compressed or not.
116 ///
117 /// The default predicate is [`DefaultPredicate`]. See its documentation for more
118 /// details on which responses it wont compress.
119 ///
120 /// # Changing the compression predicate
121 ///
122 /// ```
123 /// use tower_async_http::compression::{
124 /// Compression,
125 /// predicate::{Predicate, NotForContentType, DefaultPredicate},
126 /// };
127 /// use tower_async::util::service_fn;
128 ///
129 /// // Placeholder service_fn
130 /// let service = service_fn(|_: ()| async {
131 /// Ok::<_, std::io::Error>(http::Response::new(()))
132 /// });
133 ///
134 /// // build our custom compression predicate
135 /// // its recommended to still include `DefaultPredicate` as part of
136 /// // custom predicates
137 /// let predicate = DefaultPredicate::new()
138 /// // don't compress responses who's `content-type` starts with `application/json`
139 /// .and(NotForContentType::new("application/json"));
140 ///
141 /// let service = Compression::new(service).compress_when(predicate);
142 /// ```
143 ///
144 /// See [`predicate`](super::predicate) for more utilities for building compression predicates.
145 ///
146 /// Responses that are already compressed (ie have a `content-encoding` header) will _never_ be
147 /// recompressed, regardless what they predicate says.
148 pub fn compress_when<C>(self, predicate: C) -> Compression<S, C>
149 where
150 C: Predicate,
151 {
152 Compression {
153 inner: self.inner,
154 accept: self.accept,
155 predicate,
156 quality: self.quality,
157 }
158 }
159}
160
161impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for Compression<S, P>
162where
163 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
164 ResBody: Body,
165 P: Predicate,
166{
167 type Response = Response<CompressionBody<ResBody>>;
168 type Error = S::Error;
169
170 #[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
171 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
172 let encoding = Encoding::from_headers(req.headers(), self.accept);
173
174 let res = self.inner.call(req).await?;
175
176 // never recompress responses that are already compressed
177 let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
178 && self.predicate.should_compress(&res);
179
180 let (mut parts, body) = res.into_parts();
181
182 let body = match (should_compress, encoding) {
183 // if compression is _not_ support or the client doesn't accept it
184 (false, _) | (_, Encoding::Identity) => {
185 return Ok(Response::from_parts(
186 parts,
187 CompressionBody::new(BodyInner::identity(body)),
188 ))
189 }
190
191 #[cfg(feature = "compression-gzip")]
192 (_, Encoding::Gzip) => {
193 CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
194 }
195 #[cfg(feature = "compression-deflate")]
196 (_, Encoding::Deflate) => {
197 CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
198 }
199 #[cfg(feature = "compression-br")]
200 (_, Encoding::Brotli) => {
201 CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
202 }
203 #[cfg(feature = "compression-zstd")]
204 (_, Encoding::Zstd) => {
205 CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
206 }
207 #[cfg(feature = "fs")]
208 (true, _) => {
209 // This should never happen because the `AcceptEncoding` struct which is used to determine
210 // `self.encoding` will only enable the different compression algorithms if the
211 // corresponding crate feature has been enabled. This means
212 // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
213 // features enabled.
214 //
215 // The match arm is still required though because the `fs` feature uses the
216 // Encoding struct independently and requires no compression logic to be enabled.
217 // This means a combination of an individual compression feature and `fs` will fail
218 // to compile without this branch even though it will never be reached.
219 //
220 // To safeguard against refactors that changes this relationship or other bugs the
221 // server will return an uncompressed response instead of panicking since that could
222 // become a ddos attack vector.
223 return Ok(Response::from_parts(
224 parts,
225 CompressionBody::new(BodyInner::identity(body)),
226 ));
227 }
228 };
229
230 parts.headers.remove(header::CONTENT_LENGTH);
231
232 parts
233 .headers
234 .insert(header::CONTENT_ENCODING, encoding.into_header_value());
235
236 let res = Response::from_parts(parts, body);
237 Ok(res)
238 }
239}