tower_http/compression/
predicate.rs1use http::{header, Extensions, HeaderMap, StatusCode, Version};
10use http_body::Body;
11use std::{fmt, sync::Arc};
12
13pub trait Predicate: Clone {
15    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
17    where
18        B: Body;
19
20    fn and<Other>(self, other: Other) -> And<Self, Other>
24    where
25        Self: Sized,
26        Other: Predicate,
27    {
28        And {
29            lhs: self,
30            rhs: other,
31        }
32    }
33}
34
35impl<F> Predicate for F
36where
37    F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone,
38{
39    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
40    where
41        B: Body,
42    {
43        let status = response.status();
44        let version = response.version();
45        let headers = response.headers();
46        let extensions = response.extensions();
47        self(status, version, headers, extensions)
48    }
49}
50
51impl<T> Predicate for Option<T>
52where
53    T: Predicate,
54{
55    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
56    where
57        B: Body,
58    {
59        self.as_ref()
60            .map(|inner| inner.should_compress(response))
61            .unwrap_or(true)
62    }
63}
64
65#[derive(Debug, Clone, Default, Copy)]
69pub struct And<Lhs, Rhs> {
70    lhs: Lhs,
71    rhs: Rhs,
72}
73
74impl<Lhs, Rhs> Predicate for And<Lhs, Rhs>
75where
76    Lhs: Predicate,
77    Rhs: Predicate,
78{
79    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
80    where
81        B: Body,
82    {
83        self.lhs.should_compress(response) && self.rhs.should_compress(response)
84    }
85}
86
87#[derive(Clone)]
117pub struct DefaultPredicate(
118    And<And<And<SizeAbove, NotForContentType>, NotForContentType>, NotForContentType>,
119);
120
121impl DefaultPredicate {
122    pub fn new() -> Self {
124        let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE)
125            .and(NotForContentType::GRPC)
126            .and(NotForContentType::IMAGES)
127            .and(NotForContentType::SSE);
128        Self(inner)
129    }
130}
131
132impl Default for DefaultPredicate {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl Predicate for DefaultPredicate {
139    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
140    where
141        B: Body,
142    {
143        self.0.should_compress(response)
144    }
145}
146
147#[derive(Clone, Copy, Debug)]
149pub struct SizeAbove(u16);
150
151impl SizeAbove {
152    pub(crate) const DEFAULT_MIN_SIZE: u16 = 32;
153
154    pub const fn new(min_size_bytes: u16) -> Self {
160        Self(min_size_bytes)
161    }
162}
163
164impl Default for SizeAbove {
165    fn default() -> Self {
166        Self(Self::DEFAULT_MIN_SIZE)
167    }
168}
169
170impl Predicate for SizeAbove {
171    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
172    where
173        B: Body,
174    {
175        let content_size = response.body().size_hint().exact().or_else(|| {
176            response
177                .headers()
178                .get(header::CONTENT_LENGTH)
179                .and_then(|h| h.to_str().ok())
180                .and_then(|val| val.parse().ok())
181        });
182
183        match content_size {
184            Some(size) => size >= (self.0 as u64),
185            _ => true,
186        }
187    }
188}
189
190#[derive(Clone, Debug)]
192pub struct NotForContentType {
193    content_type: Str,
194    exception: Option<Str>,
195}
196
197impl NotForContentType {
198    pub const GRPC: Self = Self::const_new("application/grpc");
200
201    pub const IMAGES: Self = Self {
203        content_type: Str::Static("image/"),
204        exception: Some(Str::Static("image/svg+xml")),
205    };
206
207    pub const SSE: Self = Self::const_new("text/event-stream");
209
210    pub fn new(content_type: &str) -> Self {
212        Self {
213            content_type: Str::Shared(content_type.into()),
214            exception: None,
215        }
216    }
217
218    pub const fn const_new(content_type: &'static str) -> Self {
220        Self {
221            content_type: Str::Static(content_type),
222            exception: None,
223        }
224    }
225}
226
227impl Predicate for NotForContentType {
228    fn should_compress<B>(&self, response: &http::Response<B>) -> bool
229    where
230        B: Body,
231    {
232        if let Some(except) = &self.exception {
233            if content_type(response) == except.as_str() {
234                return true;
235            }
236        }
237
238        !content_type(response).starts_with(self.content_type.as_str())
239    }
240}
241
242#[derive(Clone)]
243enum Str {
244    Static(&'static str),
245    Shared(Arc<str>),
246}
247
248impl Str {
249    fn as_str(&self) -> &str {
250        match self {
251            Str::Static(s) => s,
252            Str::Shared(s) => s,
253        }
254    }
255}
256
257impl fmt::Debug for Str {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        match self {
260            Self::Static(inner) => inner.fmt(f),
261            Self::Shared(inner) => inner.fmt(f),
262        }
263    }
264}
265
266fn content_type<B>(response: &http::Response<B>) -> &str {
267    response
268        .headers()
269        .get(header::CONTENT_TYPE)
270        .and_then(|h| h.to_str().ok())
271        .unwrap_or_default()
272}