tower_async_http/compression/
predicate.rs

1//! Predicates for disabling compression of responses.
2//!
3//! Predicates are applied with [`Compression::compress_when`] or
4//! [`CompressionLayer::compress_when`].
5//!
6//! [`Compression::compress_when`]: super::Compression::compress_when
7//! [`CompressionLayer::compress_when`]: super::CompressionLayer::compress_when
8
9use http::{header, Extensions, HeaderMap, StatusCode, Version};
10use http_body::Body;
11use std::{fmt, sync::Arc};
12
13/// Predicate used to determine if a response should be compressed or not.
14pub trait Predicate: Clone {
15    /// Should this response be compressed or not?
16    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
17    where
18        B: Body;
19
20    /// Combine two predicates into one.
21    ///
22    /// The resulting predicate enables compression if both inner predicates do.
23    fn and<Other>(self, other: Other) -> And<Self, Other>
24    where
25        Self: Sized,
26        Other: Predicate,
27    {
28        And {
29            lhs: self,
30            rhs: other,
31        }
32    }
33}
34
35impl<F> Predicate for F
36where
37    F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone,
38{
39    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
40    where
41        B: Body,
42    {
43        let status = response.status();
44        let version = response.version();
45        let headers = response.headers();
46        let extensions = response.extensions();
47        self(status, version, headers, extensions)
48    }
49}
50
51impl<T> Predicate for Option<T>
52where
53    T: Predicate,
54{
55    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
56    where
57        B: Body,
58    {
59        self.as_ref()
60            .map(|inner| inner.should_compress(response))
61            .unwrap_or(true)
62    }
63}
64
65/// Two predicates combined into one.
66///
67/// Created with [`Predicate::and`]
68#[derive(Debug, Clone, Default, Copy)]
69pub struct And<Lhs, Rhs> {
70    lhs: Lhs,
71    rhs: Rhs,
72}
73
74impl<Lhs, Rhs> Predicate for And<Lhs, Rhs>
75where
76    Lhs: Predicate,
77    Rhs: Predicate,
78{
79    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
80    where
81        B: Body,
82    {
83        self.lhs.should_compress(response) && self.rhs.should_compress(response)
84    }
85}
86
87/// The default predicate used by [`Compression`] and [`CompressionLayer`].
88///
89/// This will compress responses unless:
90///
91/// - They're gRPC, which has its own protocol specific compression scheme.
92/// - It's an image as determined by the `content-type` starting with `image/`.
93/// - The response is less than 32 bytes.
94///
95/// # Configuring the defaults
96///
97/// `DefaultPredicate` doesn't support any configuration. Instead you can build your own predicate
98/// by combining types in this module:
99///
100/// ```rust
101/// use tower_async_http::compression::predicate::{SizeAbove, NotForContentType, Predicate};
102///
103/// // slightly large min size than the default 32
104/// let predicate = SizeAbove::new(256)
105///     // still don't compress gRPC
106///     .and(NotForContentType::GRPC)
107///     // still don't compress images
108///     .and(NotForContentType::IMAGES)
109///     // also don't compress JSON
110///     .and(NotForContentType::const_new("application/json"));
111/// ```
112///
113/// [`Compression`]: super::Compression
114/// [`CompressionLayer`]: super::CompressionLayer
115#[derive(Clone)]
116pub struct DefaultPredicate(And<And<SizeAbove, NotForContentType>, NotForContentType>);
117
118impl DefaultPredicate {
119    /// Create a new `DefaultPredicate`.
120    pub fn new() -> Self {
121        let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE)
122            .and(NotForContentType::GRPC)
123            .and(NotForContentType::IMAGES);
124        Self(inner)
125    }
126}
127
128impl Default for DefaultPredicate {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl Predicate for DefaultPredicate {
135    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
136    where
137        B: Body,
138    {
139        self.0.should_compress(response)
140    }
141}
142
143/// [`Predicate`] that will only allow compression of responses above a certain size.
144#[derive(Clone, Copy, Debug)]
145pub struct SizeAbove(u16);
146
147impl SizeAbove {
148    pub(crate) const DEFAULT_MIN_SIZE: u16 = 32;
149
150    /// Create a new `SizeAbove` predicate that will only compress responses larger than
151    /// `min_size_bytes`.
152    ///
153    /// The response will be compressed if the exact size cannot be determined through either the
154    /// `content-length` header or [`Body::size_hint`].
155    pub const fn new(min_size_bytes: u16) -> Self {
156        Self(min_size_bytes)
157    }
158}
159
160impl Default for SizeAbove {
161    fn default() -> Self {
162        Self(Self::DEFAULT_MIN_SIZE)
163    }
164}
165
166impl Predicate for SizeAbove {
167    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
168    where
169        B: Body,
170    {
171        let content_size = response.body().size_hint().exact().or_else(|| {
172            response
173                .headers()
174                .get(header::CONTENT_LENGTH)
175                .and_then(|h| h.to_str().ok())
176                .and_then(|val| val.parse().ok())
177        });
178
179        match content_size {
180            Some(size) => size >= (self.0 as u64),
181            _ => true,
182        }
183    }
184}
185
186/// Predicate that wont allow responses with a specific `content-type` to be compressed.
187#[derive(Clone, Debug)]
188pub struct NotForContentType {
189    content_type: Str,
190    exception: Option<Str>,
191}
192
193impl NotForContentType {
194    /// Predicate that wont compress gRPC responses.
195    pub const GRPC: Self = Self::const_new("application/grpc");
196
197    /// Predicate that wont compress images.
198    pub const IMAGES: Self = Self {
199        content_type: Str::Static("image/"),
200        exception: Some(Str::Static("image/svg+xml")),
201    };
202
203    /// Create a new `NotForContentType`.
204    pub fn new(content_type: &str) -> Self {
205        Self {
206            content_type: Str::Shared(content_type.into()),
207            exception: None,
208        }
209    }
210
211    /// Create a new `NotForContentType` from a static string.
212    pub const fn const_new(content_type: &'static str) -> Self {
213        Self {
214            content_type: Str::Static(content_type),
215            exception: None,
216        }
217    }
218}
219
220impl Predicate for NotForContentType {
221    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
222    where
223        B: Body,
224    {
225        if let Some(except) = &self.exception {
226            if content_type(response) == except.as_str() {
227                return true;
228            }
229        }
230
231        !content_type(response).starts_with(self.content_type.as_str())
232    }
233}
234
235#[derive(Clone)]
236enum Str {
237    Static(&'static str),
238    Shared(Arc<str>),
239}
240
241impl Str {
242    fn as_str(&self) -> &str {
243        match self {
244            Str::Static(s) => s,
245            Str::Shared(s) => s,
246        }
247    }
248}
249
250impl fmt::Debug for Str {
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        match self {
253            Self::Static(inner) => inner.fmt(f),
254            Self::Shared(inner) => inner.fmt(f),
255        }
256    }
257}
258
259fn content_type<B>(response: &http::Response<B>) -> &str {
260    response
261        .headers()
262        .get(header::CONTENT_TYPE)
263        .and_then(|h| h.to_str().ok())
264        .unwrap_or_default()
265}