1use crate::{ZlibDecompressionError, ZlibStreamDecompressor};
2use flate2::DecompressError;
3use futures_util::{Stream, StreamExt};
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8pub struct ZlibStream<V: AsRef<[u8]> + Sized, T: Stream<Item = V> + Unpin> {
9 decompressor: ZlibStreamDecompressor,
10 stream: T,
11}
12
13impl<V: AsRef<[u8]> + Sized, T: Stream<Item = V> + Unpin> ZlibStream<V, T> {
14 pub fn new(stream: T) -> Self {
17 Self {
18 decompressor: Default::default(),
19 stream,
20 }
21 }
22
23 pub fn new_with_decompressor(decompressor: ZlibStreamDecompressor, stream: T) -> Self {
26 Self {
27 decompressor,
28 stream,
29 }
30 }
31}
32
33impl<V: AsRef<[u8]> + Sized, T: Stream<Item = V> + Unpin> Stream for ZlibStream<V, T> {
34 type Item = Result<Vec<u8>, DecompressError>;
35
36 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
37 match Pin::new(&mut self.stream.next()).poll(cx) {
38 Poll::Ready(vec) => {
39 if let Some(vec) = vec {
40 #[cfg(feature = "tokio-runtime")]
41 let result = tokio::task::block_in_place(|| self.decompressor.decompress(vec));
42
43 #[cfg(not(feature = "tokio-runtime"))]
44 let result = self.decompressor.decompress(vec);
45
46 match result {
47 Ok(data) => Poll::Ready(Some(Ok(data))),
48 Err(ZlibDecompressionError::NeedMoreData) => {
49 cx.waker().wake_by_ref();
50 Poll::Pending
51 }
52 Err(ZlibDecompressionError::DecompressError(err)) => {
53 Poll::Ready(Some(Err(err)))
54 }
55 }
56 } else {
57 Poll::Ready(None)
58 }
59 }
60 Poll::Pending => Poll::Pending,
61 }
62 }
63}