1use crate::{content_encoding::SupportedEncodings, BoxError};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use http::HeaderValue;
7use http_body::{Body, Frame};
8use pin_project_lite::pin_project;
9use std::{
10 io,
11 pin::Pin,
12 task::{ready, Context, Poll},
13};
14use tokio::io::AsyncRead;
15use tokio_util::io::StreamReader;
16
17#[derive(Debug, Clone, Copy)]
18pub(crate) struct AcceptEncoding {
19 pub(crate) gzip: bool,
20 pub(crate) deflate: bool,
21 pub(crate) br: bool,
22 pub(crate) zstd: bool,
23}
24
25impl AcceptEncoding {
26 #[allow(dead_code)]
27 pub(crate) fn to_header_value(self) -> Option<HeaderValue> {
28 let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) {
29 (true, true, true, false) => "gzip,deflate,br",
30 (true, true, false, false) => "gzip,deflate",
31 (true, false, true, false) => "gzip,br",
32 (true, false, false, false) => "gzip",
33 (false, true, true, false) => "deflate,br",
34 (false, true, false, false) => "deflate",
35 (false, false, true, false) => "br",
36 (true, true, true, true) => "zstd,gzip,deflate,br",
37 (true, true, false, true) => "zstd,gzip,deflate",
38 (true, false, true, true) => "zstd,gzip,br",
39 (true, false, false, true) => "zstd,gzip",
40 (false, true, true, true) => "zstd,deflate,br",
41 (false, true, false, true) => "zstd,deflate",
42 (false, false, true, true) => "zstd,br",
43 (false, false, false, true) => "zstd",
44 (false, false, false, false) => return None,
45 };
46 Some(HeaderValue::from_static(accept))
47 }
48
49 #[allow(dead_code)]
50 pub(crate) fn set_gzip(&mut self, enable: bool) {
51 self.gzip = enable;
52 }
53
54 #[allow(dead_code)]
55 pub(crate) fn set_deflate(&mut self, enable: bool) {
56 self.deflate = enable;
57 }
58
59 #[allow(dead_code)]
60 pub(crate) fn set_br(&mut self, enable: bool) {
61 self.br = enable;
62 }
63
64 #[allow(dead_code)]
65 pub(crate) fn set_zstd(&mut self, enable: bool) {
66 self.zstd = enable;
67 }
68}
69
70impl SupportedEncodings for AcceptEncoding {
71 #[allow(dead_code)]
72 fn gzip(&self) -> bool {
73 #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))]
74 return self.gzip;
75
76 #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))]
77 return false;
78 }
79
80 #[allow(dead_code)]
81 fn deflate(&self) -> bool {
82 #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))]
83 return self.deflate;
84
85 #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))]
86 return false;
87 }
88
89 #[allow(dead_code)]
90 fn br(&self) -> bool {
91 #[cfg(any(feature = "decompression-br", feature = "compression-br"))]
92 return self.br;
93
94 #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))]
95 return false;
96 }
97
98 #[allow(dead_code)]
99 fn zstd(&self) -> bool {
100 #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))]
101 return self.zstd;
102
103 #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))]
104 return false;
105 }
106}
107
108impl Default for AcceptEncoding {
109 fn default() -> Self {
110 AcceptEncoding {
111 gzip: true,
112 deflate: true,
113 br: true,
114 zstd: true,
115 }
116 }
117}
118
119pub(crate) type AsyncReadBody<B> =
121 StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
122
123pub(crate) trait DecorateAsyncRead {
125 type Input: AsyncRead;
126 type Output: AsyncRead;
127
128 fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
130
131 fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
135}
136
137pin_project! {
138 pub(crate) struct WrapBody<M: DecorateAsyncRead> {
140 #[pin]
141 pub read: M::Output,
144 buf: BytesMut,
147 read_all_data: bool,
148 }
149}
150
151impl<M: DecorateAsyncRead> WrapBody<M> {
152 const INTERNAL_BUF_CAPACITY: usize = 4096;
153}
154
155impl<M: DecorateAsyncRead> WrapBody<M> {
156 #[allow(dead_code)]
157 pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
158 where
159 B: Body,
160 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
161 {
162 let stream = BodyIntoStream::new(body);
164
165 let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
168
169 let read = StreamReader::new(stream);
171
172 let read = M::apply(read, quality);
174
175 Self {
176 read,
177 buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
178 read_all_data: false,
179 }
180 }
181}
182
183impl<B, M> Body for WrapBody<M>
184where
185 B: Body,
186 B::Error: Into<BoxError>,
187 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
188{
189 type Data = Bytes;
190 type Error = BoxError;
191
192 fn poll_frame(
193 self: Pin<&mut Self>,
194 cx: &mut Context<'_>,
195 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
196 let mut this = self.project();
197
198 if !*this.read_all_data {
199 if this.buf.capacity() == 0 {
200 this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
201 }
202
203 let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
204
205 match ready!(result) {
206 Ok(0) => {
207 *this.read_all_data = true;
208 }
209 Ok(_) => {
210 let chunk = this.buf.split().freeze();
211 return Poll::Ready(Some(Ok(Frame::data(chunk))));
212 }
213 Err(err) => {
214 let body_error: Option<B::Error> = M::get_pin_mut(this.read.as_mut())
215 .get_pin_mut()
216 .project()
217 .error
218 .take();
219
220 let read_some_data = M::get_pin_mut(this.read.as_mut())
221 .get_pin_mut()
222 .project()
223 .read_some_data;
224
225 if let Some(body_error) = body_error {
226 return Poll::Ready(Some(Err(body_error.into())));
227 } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
228 unreachable!()
231 } else if *read_some_data {
232 return Poll::Ready(Some(Err(err.into())));
233 }
234 }
235 }
236 }
237
238 let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
240 body.poll_frame(cx).map(|option| {
241 option.map(|result| {
242 result
243 .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining())))
244 .map_err(|err| err.into())
245 })
246 })
247 }
248}
249
250pin_project! {
251 pub(crate) struct BodyIntoStream<B>
252 where
253 B: Body,
254 {
255 #[pin]
256 body: B,
257 yielded_all_data: bool,
258 non_data_frame: Option<Frame<B::Data>>,
259 }
260}
261
262#[allow(dead_code)]
263impl<B> BodyIntoStream<B>
264where
265 B: Body,
266{
267 pub(crate) fn new(body: B) -> Self {
268 Self {
269 body,
270 yielded_all_data: false,
271 non_data_frame: None,
272 }
273 }
274
275 pub(crate) fn get_ref(&self) -> &B {
277 &self.body
278 }
279
280 pub(crate) fn get_mut(&mut self) -> &mut B {
282 &mut self.body
283 }
284
285 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
287 self.project().body
288 }
289
290 pub(crate) fn into_inner(self) -> B {
292 self.body
293 }
294}
295
296impl<B> Stream for BodyIntoStream<B>
297where
298 B: Body,
299{
300 type Item = Result<B::Data, B::Error>;
301
302 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303 loop {
304 let this = self.as_mut().project();
305
306 if *this.yielded_all_data {
307 return Poll::Ready(None);
308 }
309
310 match std::task::ready!(this.body.poll_frame(cx)) {
311 Some(Ok(frame)) => match frame.into_data() {
312 Ok(data) => return Poll::Ready(Some(Ok(data))),
313 Err(frame) => {
314 *this.yielded_all_data = true;
315 *this.non_data_frame = Some(frame);
316 }
317 },
318 Some(Err(err)) => return Poll::Ready(Some(Err(err))),
319 None => {
320 *this.yielded_all_data = true;
321 }
322 }
323 }
324 }
325}
326
327impl<B> Body for BodyIntoStream<B>
328where
329 B: Body,
330{
331 type Data = B::Data;
332 type Error = B::Error;
333
334 fn poll_frame(
335 mut self: Pin<&mut Self>,
336 cx: &mut Context<'_>,
337 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
338 if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
341 return Poll::Ready(Some(frame.map(Frame::data)));
342 }
343
344 let this = self.project();
345
346 if let Some(frame) = this.non_data_frame.take() {
348 return Poll::Ready(Some(Ok(frame)));
349 }
350
351 this.body.poll_frame(cx)
354 }
355
356 #[inline]
357 fn size_hint(&self) -> http_body::SizeHint {
358 self.body.size_hint()
359 }
360}
361
362pin_project! {
363 pub(crate) struct StreamErrorIntoIoError<S, E> {
364 #[pin]
365 inner: S,
366 error: Option<E>,
367 read_some_data: bool
368 }
369}
370
371impl<S, E> StreamErrorIntoIoError<S, E> {
372 pub(crate) fn new(inner: S) -> Self {
373 Self {
374 inner,
375 error: None,
376 read_some_data: false,
377 }
378 }
379
380 pub(crate) fn get_ref(&self) -> &S {
382 &self.inner
383 }
384
385 pub(crate) fn get_mut(&mut self) -> &mut S {
387 &mut self.inner
388 }
389
390 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
392 self.project().inner
393 }
394
395 pub(crate) fn into_inner(self) -> S {
397 self.inner
398 }
399}
400
401impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
402where
403 S: Stream<Item = Result<T, E>>,
404{
405 type Item = Result<T, io::Error>;
406
407 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
408 let this = self.project();
409 match ready!(this.inner.poll_next(cx)) {
410 None => Poll::Ready(None),
411 Some(Ok(value)) => {
412 *this.read_some_data = true;
413 Poll::Ready(Some(Ok(value)))
414 }
415 Some(Err(err)) => {
416 *this.error = Some(err);
417 Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
418 }
419 }
420 }
421}
422
423pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
424
425#[non_exhaustive]
427#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
428pub enum CompressionLevel {
429 Fastest,
431 Best,
433 #[default]
436 Default,
437 Precise(i32),
445}
446
447#[cfg(any(
448 feature = "compression-br",
449 feature = "compression-gzip",
450 feature = "compression-deflate",
451 feature = "compression-zstd"
452))]
453use async_compression::Level as AsyncCompressionLevel;
454
455#[cfg(any(
456 feature = "compression-br",
457 feature = "compression-gzip",
458 feature = "compression-deflate",
459 feature = "compression-zstd"
460))]
461impl CompressionLevel {
462 pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
463 match self {
464 CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
465 CompressionLevel::Best => AsyncCompressionLevel::Best,
466 CompressionLevel::Default => AsyncCompressionLevel::Default,
467 CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality),
468 }
469 }
470}