1use crate::{content_encoding::SupportedEncodings, BoxError};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use http::HeaderValue;
7use http_body::{Body, Frame};
8use pin_project_lite::pin_project;
9use std::{
10 io,
11 pin::Pin,
12 task::{ready, Context, Poll},
13};
14use tokio::io::AsyncRead;
15use tokio_util::io::StreamReader;
16
17#[derive(Debug, Clone, Copy)]
18pub(crate) struct AcceptEncoding {
19 pub(crate) gzip: bool,
20 pub(crate) deflate: bool,
21 pub(crate) br: bool,
22 pub(crate) zstd: bool,
23}
24
25impl AcceptEncoding {
26 #[allow(dead_code)]
27 pub(crate) fn to_header_value(self) -> Option<HeaderValue> {
28 let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) {
29 (true, true, true, false) => "gzip,deflate,br",
30 (true, true, false, false) => "gzip,deflate",
31 (true, false, true, false) => "gzip,br",
32 (true, false, false, false) => "gzip",
33 (false, true, true, false) => "deflate,br",
34 (false, true, false, false) => "deflate",
35 (false, false, true, false) => "br",
36 (true, true, true, true) => "zstd,gzip,deflate,br",
37 (true, true, false, true) => "zstd,gzip,deflate",
38 (true, false, true, true) => "zstd,gzip,br",
39 (true, false, false, true) => "zstd,gzip",
40 (false, true, true, true) => "zstd,deflate,br",
41 (false, true, false, true) => "zstd,deflate",
42 (false, false, true, true) => "zstd,br",
43 (false, false, false, true) => "zstd",
44 (false, false, false, false) => return None,
45 };
46 Some(HeaderValue::from_static(accept))
47 }
48
49 #[allow(dead_code)]
50 pub(crate) fn set_gzip(&mut self, enable: bool) {
51 self.gzip = enable;
52 }
53
54 #[allow(dead_code)]
55 pub(crate) fn set_deflate(&mut self, enable: bool) {
56 self.deflate = enable;
57 }
58
59 #[allow(dead_code)]
60 pub(crate) fn set_br(&mut self, enable: bool) {
61 self.br = enable;
62 }
63
64 #[allow(dead_code)]
65 pub(crate) fn set_zstd(&mut self, enable: bool) {
66 self.zstd = enable;
67 }
68}
69
70impl SupportedEncodings for AcceptEncoding {
71 #[allow(dead_code)]
72 fn gzip(&self) -> bool {
73 #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))]
74 return self.gzip;
75
76 #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))]
77 return false;
78 }
79
80 #[allow(dead_code)]
81 fn deflate(&self) -> bool {
82 #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))]
83 return self.deflate;
84
85 #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))]
86 return false;
87 }
88
89 #[allow(dead_code)]
90 fn br(&self) -> bool {
91 #[cfg(any(feature = "decompression-br", feature = "compression-br"))]
92 return self.br;
93
94 #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))]
95 return false;
96 }
97
98 #[allow(dead_code)]
99 fn zstd(&self) -> bool {
100 #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))]
101 return self.zstd;
102
103 #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))]
104 return false;
105 }
106}
107
108impl Default for AcceptEncoding {
109 fn default() -> Self {
110 AcceptEncoding {
111 gzip: true,
112 deflate: true,
113 br: true,
114 zstd: true,
115 }
116 }
117}
118
119pub(crate) type AsyncReadBody<B> =
121 StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
122
123pub(crate) trait DecorateAsyncRead {
125 type Input: AsyncRead;
126 type Output: AsyncRead;
127
128 fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
130
131 fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
135}
136
137pin_project! {
138 pub(crate) struct WrapBody<M: DecorateAsyncRead> {
140 #[pin]
141 pub read: M::Output,
144 buf: BytesMut,
147 read_all_data: bool,
148 }
149}
150
151impl<M: DecorateAsyncRead> WrapBody<M> {
152 const INTERNAL_BUF_CAPACITY: usize = 4096;
153}
154
155impl<M: DecorateAsyncRead> WrapBody<M> {
156 #[allow(dead_code)]
157 pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
158 where
159 B: Body,
160 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
161 {
162 let stream = BodyIntoStream::new(body);
164
165 let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
168
169 let read = StreamReader::new(stream);
171
172 let read = M::apply(read, quality);
174
175 Self {
176 read,
177 buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
178 read_all_data: false,
179 }
180 }
181}
182
183impl<B, M> Body for WrapBody<M>
184where
185 B: Body,
186 B::Error: Into<BoxError>,
187 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
188{
189 type Data = Bytes;
190 type Error = BoxError;
191
192 fn poll_frame(
193 self: Pin<&mut Self>,
194 cx: &mut Context<'_>,
195 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
196 let mut this = self.project();
197
198 if !*this.read_all_data {
199 if this.buf.capacity() == 0 {
200 this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
201 }
202
203 let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
204
205 match ready!(result) {
206 Ok(0) => {
207 *this.read_all_data = true;
208 }
209 Ok(_) => {
210 let chunk = this.buf.split().freeze();
211 return Poll::Ready(Some(Ok(Frame::data(chunk))));
212 }
213 Err(err) => {
214 let body_error: Option<B::Error> = M::get_pin_mut(this.read.as_mut())
215 .get_pin_mut()
216 .project()
217 .error
218 .take();
219
220 let read_some_data = M::get_pin_mut(this.read.as_mut())
221 .get_pin_mut()
222 .project()
223 .read_some_data;
224
225 if let Some(body_error) = body_error {
226 return Poll::Ready(Some(Err(body_error.into())));
227 } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
228 unreachable!()
231 } else if *read_some_data {
232 return Poll::Ready(Some(Err(err.into())));
233 }
234 }
235 }
236 }
237
238 let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
240 match ready!(body.poll_frame(cx)) {
241 Some(Ok(frame)) if frame.is_trailers() => Poll::Ready(Some(Ok(
242 frame.map_data(|mut data| data.copy_to_bytes(data.remaining()))
243 ))),
244 Some(Ok(frame)) => {
245 if let Ok(bytes) = frame.into_data() {
246 if bytes.has_remaining() {
247 return Poll::Ready(Some(Err(
248 "there are extra bytes after body has been decompressed".into(),
249 )));
250 }
251 }
252 Poll::Ready(None)
253 }
254 Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
255 None => Poll::Ready(None),
256 }
257 }
258}
259
260pin_project! {
261 pub(crate) struct BodyIntoStream<B>
262 where
263 B: Body,
264 {
265 #[pin]
266 body: B,
267 yielded_all_data: bool,
268 non_data_frame: Option<Frame<B::Data>>,
269 }
270}
271
272#[allow(dead_code)]
273impl<B> BodyIntoStream<B>
274where
275 B: Body,
276{
277 pub(crate) fn new(body: B) -> Self {
278 Self {
279 body,
280 yielded_all_data: false,
281 non_data_frame: None,
282 }
283 }
284
285 pub(crate) fn get_ref(&self) -> &B {
287 &self.body
288 }
289
290 pub(crate) fn get_mut(&mut self) -> &mut B {
292 &mut self.body
293 }
294
295 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
297 self.project().body
298 }
299
300 pub(crate) fn into_inner(self) -> B {
302 self.body
303 }
304}
305
306impl<B> Stream for BodyIntoStream<B>
307where
308 B: Body,
309{
310 type Item = Result<B::Data, B::Error>;
311
312 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
313 loop {
314 let this = self.as_mut().project();
315
316 if *this.yielded_all_data {
317 return Poll::Ready(None);
318 }
319
320 match std::task::ready!(this.body.poll_frame(cx)) {
321 Some(Ok(frame)) => match frame.into_data() {
322 Ok(data) => return Poll::Ready(Some(Ok(data))),
323 Err(frame) => {
324 *this.yielded_all_data = true;
325 *this.non_data_frame = Some(frame);
326 }
327 },
328 Some(Err(err)) => return Poll::Ready(Some(Err(err))),
329 None => {
330 *this.yielded_all_data = true;
331 }
332 }
333 }
334 }
335}
336
337impl<B> Body for BodyIntoStream<B>
338where
339 B: Body,
340{
341 type Data = B::Data;
342 type Error = B::Error;
343
344 fn poll_frame(
345 mut self: Pin<&mut Self>,
346 cx: &mut Context<'_>,
347 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
348 if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
351 return Poll::Ready(Some(frame.map(Frame::data)));
352 }
353
354 let this = self.project();
355
356 if let Some(frame) = this.non_data_frame.take() {
358 return Poll::Ready(Some(Ok(frame)));
359 }
360
361 this.body.poll_frame(cx)
364 }
365
366 #[inline]
367 fn size_hint(&self) -> http_body::SizeHint {
368 self.body.size_hint()
369 }
370}
371
372pin_project! {
373 pub(crate) struct StreamErrorIntoIoError<S, E> {
374 #[pin]
375 inner: S,
376 error: Option<E>,
377 read_some_data: bool
378 }
379}
380
381impl<S, E> StreamErrorIntoIoError<S, E> {
382 pub(crate) fn new(inner: S) -> Self {
383 Self {
384 inner,
385 error: None,
386 read_some_data: false,
387 }
388 }
389
390 pub(crate) fn get_ref(&self) -> &S {
392 &self.inner
393 }
394
395 pub(crate) fn get_mut(&mut self) -> &mut S {
397 &mut self.inner
398 }
399
400 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
402 self.project().inner
403 }
404
405 pub(crate) fn into_inner(self) -> S {
407 self.inner
408 }
409}
410
411impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
412where
413 S: Stream<Item = Result<T, E>>,
414{
415 type Item = Result<T, io::Error>;
416
417 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418 let this = self.project();
419 match ready!(this.inner.poll_next(cx)) {
420 None => Poll::Ready(None),
421 Some(Ok(value)) => {
422 *this.read_some_data = true;
423 Poll::Ready(Some(Ok(value)))
424 }
425 Some(Err(err)) => {
426 *this.error = Some(err);
427 Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
428 }
429 }
430 }
431}
432
433pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
434
435#[non_exhaustive]
437#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
438pub enum CompressionLevel {
439 Fastest,
441 Best,
443 #[default]
446 Default,
447 Precise(i32),
455}
456
457#[cfg(any(
458 feature = "compression-br",
459 feature = "compression-gzip",
460 feature = "compression-deflate",
461 feature = "compression-zstd"
462))]
463use async_compression::Level as AsyncCompressionLevel;
464
465#[cfg(any(
466 feature = "compression-br",
467 feature = "compression-gzip",
468 feature = "compression-deflate",
469 feature = "compression-zstd"
470))]
471impl CompressionLevel {
472 pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
473 match self {
474 CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
475 CompressionLevel::Best => AsyncCompressionLevel::Best,
476 CompressionLevel::Default => AsyncCompressionLevel::Default,
477 CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality),
478 }
479 }
480}