tower_async_http/compression/
predicate.rs1use http::{header, Extensions, HeaderMap, StatusCode, Version};
10use http_body::Body;
11use std::{fmt, sync::Arc};
12
13pub trait Predicate: Clone {
15 fn should_compress<B>(&self, response: &http::Response<B>) -> bool
17 where
18 B: Body;
19
20 fn and<Other>(self, other: Other) -> And<Self, Other>
24 where
25 Self: Sized,
26 Other: Predicate,
27 {
28 And {
29 lhs: self,
30 rhs: other,
31 }
32 }
33}
34
35impl<F> Predicate for F
36where
37 F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone,
38{
39 fn should_compress<B>(&self, response: &http::Response<B>) -> bool
40 where
41 B: Body,
42 {
43 let status = response.status();
44 let version = response.version();
45 let headers = response.headers();
46 let extensions = response.extensions();
47 self(status, version, headers, extensions)
48 }
49}
50
51impl<T> Predicate for Option<T>
52where
53 T: Predicate,
54{
55 fn should_compress<B>(&self, response: &http::Response<B>) -> bool
56 where
57 B: Body,
58 {
59 self.as_ref()
60 .map(|inner| inner.should_compress(response))
61 .unwrap_or(true)
62 }
63}
64
65#[derive(Debug, Clone, Default, Copy)]
69pub struct And<Lhs, Rhs> {
70 lhs: Lhs,
71 rhs: Rhs,
72}
73
74impl<Lhs, Rhs> Predicate for And<Lhs, Rhs>
75where
76 Lhs: Predicate,
77 Rhs: Predicate,
78{
79 fn should_compress<B>(&self, response: &http::Response<B>) -> bool
80 where
81 B: Body,
82 {
83 self.lhs.should_compress(response) && self.rhs.should_compress(response)
84 }
85}
86
87#[derive(Clone)]
116pub struct DefaultPredicate(And<And<SizeAbove, NotForContentType>, NotForContentType>);
117
118impl DefaultPredicate {
119 pub fn new() -> Self {
121 let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE)
122 .and(NotForContentType::GRPC)
123 .and(NotForContentType::IMAGES);
124 Self(inner)
125 }
126}
127
128impl Default for DefaultPredicate {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl Predicate for DefaultPredicate {
135 fn should_compress<B>(&self, response: &http::Response<B>) -> bool
136 where
137 B: Body,
138 {
139 self.0.should_compress(response)
140 }
141}
142
143#[derive(Clone, Copy, Debug)]
145pub struct SizeAbove(u16);
146
147impl SizeAbove {
148 pub(crate) const DEFAULT_MIN_SIZE: u16 = 32;
149
150 pub const fn new(min_size_bytes: u16) -> Self {
156 Self(min_size_bytes)
157 }
158}
159
160impl Default for SizeAbove {
161 fn default() -> Self {
162 Self(Self::DEFAULT_MIN_SIZE)
163 }
164}
165
166impl Predicate for SizeAbove {
167 fn should_compress<B>(&self, response: &http::Response<B>) -> bool
168 where
169 B: Body,
170 {
171 let content_size = response.body().size_hint().exact().or_else(|| {
172 response
173 .headers()
174 .get(header::CONTENT_LENGTH)
175 .and_then(|h| h.to_str().ok())
176 .and_then(|val| val.parse().ok())
177 });
178
179 match content_size {
180 Some(size) => size >= (self.0 as u64),
181 _ => true,
182 }
183 }
184}
185
186#[derive(Clone, Debug)]
188pub struct NotForContentType {
189 content_type: Str,
190 exception: Option<Str>,
191}
192
193impl NotForContentType {
194 pub const GRPC: Self = Self::const_new("application/grpc");
196
197 pub const IMAGES: Self = Self {
199 content_type: Str::Static("image/"),
200 exception: Some(Str::Static("image/svg+xml")),
201 };
202
203 pub fn new(content_type: &str) -> Self {
205 Self {
206 content_type: Str::Shared(content_type.into()),
207 exception: None,
208 }
209 }
210
211 pub const fn const_new(content_type: &'static str) -> Self {
213 Self {
214 content_type: Str::Static(content_type),
215 exception: None,
216 }
217 }
218}
219
220impl Predicate for NotForContentType {
221 fn should_compress<B>(&self, response: &http::Response<B>) -> bool
222 where
223 B: Body,
224 {
225 if let Some(except) = &self.exception {
226 if content_type(response) == except.as_str() {
227 return true;
228 }
229 }
230
231 !content_type(response).starts_with(self.content_type.as_str())
232 }
233}
234
235#[derive(Clone)]
236enum Str {
237 Static(&'static str),
238 Shared(Arc<str>),
239}
240
241impl Str {
242 fn as_str(&self) -> &str {
243 match self {
244 Str::Static(s) => s,
245 Str::Shared(s) => s,
246 }
247 }
248}
249
250impl fmt::Debug for Str {
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 match self {
253 Self::Static(inner) => inner.fmt(f),
254 Self::Shared(inner) => inner.fmt(f),
255 }
256 }
257}
258
259fn content_type<B>(response: &http::Response<B>) -> &str {
260 response
261 .headers()
262 .get(header::CONTENT_TYPE)
263 .and_then(|h| h.to_str().ok())
264 .unwrap_or_default()
265}