tower_http/compression/
layer.rs1use super::{Compression, Predicate};
2use crate::compression::predicate::DefaultPredicate;
3use crate::compression::CompressionLevel;
4use crate::compression_utils::AcceptEncoding;
5use tower_layer::Layer;
6
7#[derive(Clone, Debug, Default)]
14pub struct CompressionLayer<P = DefaultPredicate> {
15 accept: AcceptEncoding,
16 predicate: P,
17 quality: CompressionLevel,
18}
19
20impl<S, P> Layer<S> for CompressionLayer<P>
21where
22 P: Predicate,
23{
24 type Service = Compression<S, P>;
25
26 fn layer(&self, inner: S) -> Self::Service {
27 Compression {
28 inner,
29 accept: self.accept,
30 predicate: self.predicate.clone(),
31 quality: self.quality,
32 }
33 }
34}
35
36impl CompressionLayer {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 #[cfg(feature = "compression-gzip")]
44 pub fn gzip(mut self, enable: bool) -> Self {
45 self.accept.set_gzip(enable);
46 self
47 }
48
49 #[cfg(feature = "compression-deflate")]
51 pub fn deflate(mut self, enable: bool) -> Self {
52 self.accept.set_deflate(enable);
53 self
54 }
55
56 #[cfg(feature = "compression-br")]
58 pub fn br(mut self, enable: bool) -> Self {
59 self.accept.set_br(enable);
60 self
61 }
62
63 #[cfg(feature = "compression-zstd")]
65 pub fn zstd(mut self, enable: bool) -> Self {
66 self.accept.set_zstd(enable);
67 self
68 }
69
70 pub fn quality(mut self, quality: CompressionLevel) -> Self {
72 self.quality = quality;
73 self
74 }
75
76 pub fn no_gzip(mut self) -> Self {
80 self.accept.set_gzip(false);
81 self
82 }
83
84 pub fn no_deflate(mut self) -> Self {
88 self.accept.set_deflate(false);
89 self
90 }
91
92 pub fn no_br(mut self) -> Self {
96 self.accept.set_br(false);
97 self
98 }
99
100 pub fn no_zstd(mut self) -> Self {
104 self.accept.set_zstd(false);
105 self
106 }
107
108 pub fn compress_when<C>(self, predicate: C) -> CompressionLayer<C>
112 where
113 C: Predicate,
114 {
115 CompressionLayer {
116 accept: self.accept,
117 predicate,
118 quality: self.quality,
119 }
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::test_helpers::Body;
127 use http::{header::ACCEPT_ENCODING, Request, Response};
128 use http_body_util::BodyExt;
129 use std::convert::Infallible;
130 use tokio::fs::File;
131 use tokio_util::io::ReaderStream;
132 use tower::{Service, ServiceBuilder, ServiceExt};
133
134 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
135 let file = File::open("Cargo.toml").await.expect("file missing");
137 let stream = ReaderStream::new(file);
139 let body = Body::from_stream(stream);
141 Ok(Response::new(body))
143 }
144
145 #[tokio::test]
146 async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> {
147 use std::io::Read;
148
149 fn decode<R: Read>(mut r: R) -> std::io::Result<Vec<u8>> {
150 let mut buf = Vec::new();
151 r.read_to_end(&mut buf)?;
152 Ok(buf)
153 }
154
155 let expected = tokio::fs::read("Cargo.toml").await?;
157
158 let deflate_only_layer = CompressionLayer::new()
161 .quality(CompressionLevel::Best)
162 .no_br()
163 .no_gzip();
164
165 let mut service = ServiceBuilder::new()
166 .layer(deflate_only_layer)
167 .service_fn(handle);
168
169 let request = Request::builder()
170 .header(ACCEPT_ENCODING, "gzip, deflate, br")
171 .body(Body::empty())?;
172
173 let response = service.ready().await?.call(request).await?;
174
175 assert_eq!(response.headers()["content-encoding"], "deflate");
176
177 let deflate_body = response.into_body().collect().await?.to_bytes();
178
179 let decoded = decode(flate2::bufread::ZlibDecoder::new(&deflate_body[..]))?;
182 assert_eq!(decoded, expected);
183
184 let br_only_layer = CompressionLayer::new()
186 .quality(CompressionLevel::Best)
187 .no_gzip()
188 .no_deflate();
189
190 let mut service = ServiceBuilder::new()
191 .layer(br_only_layer)
192 .service_fn(handle);
193
194 let request = Request::builder()
195 .header(ACCEPT_ENCODING, "gzip, deflate, br")
196 .body(Body::empty())?;
197
198 let response = service.ready().await?.call(request).await?;
199
200 assert_eq!(response.headers()["content-encoding"], "br");
201
202 let br_body = response.into_body().collect().await?.to_bytes();
203
204 let decoded = decode(brotli::Decompressor::new(&br_body[..], 4096))?;
206 assert_eq!(decoded, expected);
207
208 Ok(())
209 }
210
211 #[tokio::test]
214 async fn zstd_is_web_safe() -> Result<(), crate::BoxError> {
215 async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
216 Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
217 }
218 let zstd_layer = CompressionLayer::new()
223 .quality(CompressionLevel::Best)
224 .no_br()
225 .no_deflate()
226 .no_gzip();
227
228 let mut service = ServiceBuilder::new().layer(zstd_layer).service_fn(zeroes);
229
230 let request = Request::builder()
231 .header(ACCEPT_ENCODING, "zstd")
232 .body(Body::empty())?;
233
234 let response = service.ready().await?.call(request).await?;
235
236 assert_eq!(response.headers()["content-encoding"], "zstd");
237
238 let body = response.into_body();
239 let bytes = body.collect().await?.to_bytes();
240 let mut dec = zstd::Decoder::new(&*bytes)?;
241 dec.window_log_max(23)?; std::io::copy(&mut dec, &mut std::io::sink())?;
244
245 Ok(())
246 }
247}