1use bytes::Bytes;
5use futures_util::Stream;
6use std::io;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10
11pub type ByteStream = Pin<Box<dyn Stream<Item = io::Result<Bytes>> + Send>>;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Encoding {
17 Identity,
18 Gzip,
19 Brotli,
20 Deflate,
21 Unknown,
23}
24
25impl Encoding {
26 pub fn from_token(token: &str) -> Self {
28 match token.trim().to_ascii_lowercase().as_str() {
29 "" | "identity" => Encoding::Identity,
30 "gzip" | "x-gzip" => Encoding::Gzip,
31 "br" => Encoding::Brotli,
32 "deflate" => Encoding::Deflate,
33 _ => Encoding::Unknown,
34 }
35 }
36
37 pub fn header_value(self) -> Option<&'static str> {
39 match self {
40 Encoding::Gzip => Some("gzip"),
41 Encoding::Brotli => Some("br"),
42 Encoding::Deflate => Some("deflate"),
43 Encoding::Identity | Encoding::Unknown => None,
44 }
45 }
46}
47
48pub fn decode_stream<S>(s: S, enc: Encoding) -> ByteStream
52where
53 S: Stream<Item = io::Result<Bytes>> + Send + 'static,
54{
55 use async_compression::tokio::bufread;
56 use tokio_util::io::{ReaderStream, StreamReader};
57 match enc {
58 Encoding::Identity | Encoding::Unknown => Box::pin(s),
59 Encoding::Gzip => Box::pin(ReaderStream::new(bufread::GzipDecoder::new(
60 StreamReader::new(s),
61 ))),
62 Encoding::Brotli => Box::pin(ReaderStream::new(bufread::BrotliDecoder::new(
63 StreamReader::new(s),
64 ))),
65 Encoding::Deflate => Box::pin(ReaderStream::new(bufread::ZlibDecoder::new(
66 StreamReader::new(s),
67 ))),
68 }
69}
70
71struct SharedSink(Arc<Mutex<Vec<u8>>>);
75
76impl tokio::io::AsyncWrite for SharedSink {
77 fn poll_write(
78 self: Pin<&mut Self>,
79 _: &mut Context<'_>,
80 buf: &[u8],
81 ) -> Poll<io::Result<usize>> {
82 self.0.lock().unwrap().extend_from_slice(buf);
83 Poll::Ready(Ok(buf.len()))
84 }
85 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
86 Poll::Ready(Ok(()))
87 }
88 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
89 Poll::Ready(Ok(()))
90 }
91}
92
93pub fn encode_stream<S>(s: S, enc: Encoding) -> ByteStream
98where
99 S: futures_util::Stream<Item = io::Result<Bytes>> + Send + 'static,
100{
101 use async_compression::tokio::write::{BrotliEncoder, GzipEncoder, ZlibEncoder};
102 use async_compression::Level;
103 match enc {
104 Encoding::Identity | Encoding::Unknown => Box::pin(s),
105 Encoding::Gzip => {
106 let buf = Arc::new(Mutex::new(Vec::new()));
107 encode_with(
108 GzipEncoder::with_quality(SharedSink(buf.clone()), Level::Default),
109 s,
110 buf,
111 )
112 }
113 Encoding::Brotli => {
114 let buf = Arc::new(Mutex::new(Vec::new()));
115 encode_with(
116 BrotliEncoder::with_quality(SharedSink(buf.clone()), Level::Fastest),
117 s,
118 buf,
119 )
120 }
121 Encoding::Deflate => {
122 let buf = Arc::new(Mutex::new(Vec::new()));
123 encode_with(
124 ZlibEncoder::with_quality(SharedSink(buf.clone()), Level::Default),
125 s,
126 buf,
127 )
128 }
129 }
130}
131
132fn encode_with<E, S>(encoder: E, input: S, buf: Arc<Mutex<Vec<u8>>>) -> ByteStream
133where
134 E: tokio::io::AsyncWrite + Unpin + Send + 'static,
135 S: futures_util::Stream<Item = io::Result<Bytes>> + Send + 'static,
136{
137 use futures_util::StreamExt;
138 use tokio::io::AsyncWriteExt;
139 let input = Box::pin(input);
140 Box::pin(futures_util::stream::unfold(
141 (input, encoder, buf, false),
142 |(mut input, mut encoder, buf, done)| async move {
143 if done {
144 return None;
145 }
146 match input.next().await {
147 Some(Ok(chunk)) => {
148 if let Err(e) = encoder.write_all(&chunk).await {
149 return Some((Err(e), (input, encoder, buf, true)));
150 }
151 if let Err(e) = encoder.flush().await {
152 return Some((Err(e), (input, encoder, buf, true)));
153 }
154 let out = std::mem::take(&mut *buf.lock().unwrap());
155 Some((Ok(Bytes::from(out)), (input, encoder, buf, false)))
156 }
157 Some(Err(e)) => Some((Err(e), (input, encoder, buf, true))),
158 None => {
159 if let Err(e) = encoder.shutdown().await {
160 return Some((Err(e), (input, encoder, buf, true)));
161 }
162 let out = std::mem::take(&mut *buf.lock().unwrap());
163 if out.is_empty() {
164 None
165 } else {
166 Some((Ok(Bytes::from(out)), (input, encoder, buf, true)))
167 }
168 }
169 }
170 },
171 ))
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use bytes::Bytes;
178 use flate2::write::GzEncoder;
179 use flate2::Compression;
180 use futures::StreamExt;
181 use std::io::Write;
182
183 fn gzip(data: &[u8]) -> Vec<u8> {
184 let mut e = GzEncoder::new(Vec::new(), Compression::default());
185 e.write_all(data).unwrap();
186 e.finish().unwrap()
187 }
188
189 async fn collect(s: impl futures::Stream<Item = std::io::Result<Bytes>>) -> Vec<u8> {
190 futures::pin_mut!(s);
191 let mut out = Vec::new();
192 while let Some(item) = s.next().await {
193 out.extend_from_slice(&item.unwrap());
194 }
195 out
196 }
197
198 #[test]
199 fn parses_encoding_tokens() {
200 assert_eq!(Encoding::from_token(""), Encoding::Identity);
201 assert_eq!(Encoding::from_token("identity"), Encoding::Identity);
202 assert_eq!(Encoding::from_token("gzip"), Encoding::Gzip);
203 assert_eq!(Encoding::from_token("GZIP"), Encoding::Gzip);
204 assert_eq!(Encoding::from_token("br"), Encoding::Brotli);
205 assert_eq!(Encoding::from_token("deflate"), Encoding::Deflate);
206 assert_eq!(Encoding::from_token("weird"), Encoding::Unknown);
207 }
208
209 #[tokio::test]
210 async fn decodes_gzip_stream() {
211 let plain = b"data: {\"a\":1}\n\ndata: [DONE]\n\n";
212 let comp = gzip(plain);
213 let mid = comp.len() / 2;
215 let chunks = vec![
216 Ok(Bytes::copy_from_slice(&comp[..mid])),
217 Ok(Bytes::copy_from_slice(&comp[mid..])),
218 ];
219 let input = futures::stream::iter(chunks);
220 let decoded = collect(decode_stream(Box::pin(input), Encoding::Gzip)).await;
221 assert_eq!(decoded, plain);
222 }
223
224 #[tokio::test]
225 async fn identity_and_unknown_passthrough() {
226 let bytes = b"raw bytes not compressed";
227 for enc in [Encoding::Identity, Encoding::Unknown] {
228 let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(bytes))]);
229 let out = collect(decode_stream(Box::pin(input), enc)).await;
230 assert_eq!(out, bytes);
231 }
232 }
233
234 #[tokio::test]
235 async fn gzip_round_trips() {
236 let plain = b"data: {\"city\":\"Paris\"}\n\ndata: [DONE]\n\n";
237 let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(plain))]);
238 let encoded = collect(encode_stream(Box::pin(input), Encoding::Gzip)).await;
239 let dec_in = futures::stream::iter(vec![Ok(Bytes::from(encoded))]);
241 let decoded = collect(decode_stream(Box::pin(dec_in), Encoding::Gzip)).await;
242 assert_eq!(decoded, plain);
243 }
244
245 #[tokio::test]
246 async fn brotli_round_trips() {
247 let plain = b"hello brotli streaming world";
248 let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(plain))]);
249 let encoded = collect(encode_stream(Box::pin(input), Encoding::Brotli)).await;
250 let dec_in = futures::stream::iter(vec![Ok(Bytes::from(encoded))]);
251 let decoded = collect(decode_stream(Box::pin(dec_in), Encoding::Brotli)).await;
252 assert_eq!(decoded, plain);
253 }
254
255 #[tokio::test]
256 async fn identity_encode_passthrough() {
257 let plain = b"unchanged";
258 let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(plain))]);
259 let out = collect(encode_stream(Box::pin(input), Encoding::Identity)).await;
260 assert_eq!(out, plain);
261 }
262
263 #[tokio::test]
269 async fn encoder_flushes_per_chunk() {
270 let parts: Vec<&[u8]> = vec![b"first ", b"second ", b"third"];
271 let input = futures::stream::iter(
272 parts
273 .iter()
274 .map(|p| Ok(Bytes::copy_from_slice(p)))
275 .collect::<Vec<_>>(),
276 );
277 let encoded = encode_stream(Box::pin(input), Encoding::Gzip);
278 futures::pin_mut!(encoded);
279
280 let mut nonempty = 0usize;
281 let mut all = Vec::new();
282 while let Some(item) = encoded.next().await {
283 let b = item.unwrap();
284 if !b.is_empty() {
285 nonempty += 1;
286 }
287 all.extend_from_slice(&b);
288 }
289 assert!(
290 nonempty >= parts.len(),
291 "expected >= {} non-empty flushed output chunks (per-chunk flush), got {}",
292 parts.len(),
293 nonempty
294 );
295 let dec_in = futures::stream::iter(vec![Ok(Bytes::from(all))]);
297 let decoded = collect(decode_stream(Box::pin(dec_in), Encoding::Gzip)).await;
298 assert_eq!(decoded, b"first second third");
299 }
300}