Skip to main content

suture/
encoding.rs

1//! HTTP content-encoding handling for the proxy: a small `Encoding` model plus
2//! streaming decode (and, in a later task, encode) so repair operates on plaintext.
3
4use bytes::Bytes;
5use futures_util::Stream;
6use std::io;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10
11/// A boxed byte stream with io errors — the common currency of the codec layer.
12pub type ByteStream = Pin<Box<dyn Stream<Item = io::Result<Bytes>> + Send>>;
13
14/// An HTTP content coding.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Encoding {
17    Identity,
18    Gzip,
19    Brotli,
20    Deflate,
21    /// A coding we do not handle — caller must NOT attempt to decode/repair.
22    Unknown,
23}
24
25impl Encoding {
26    /// Parse a single `Content-Encoding` / `Accept-Encoding` token (case-insensitive).
27    pub fn from_token(token: &str) -> Self {
28        match token.trim().to_ascii_lowercase().as_str() {
29            "" | "identity" => Encoding::Identity,
30            "gzip" | "x-gzip" => Encoding::Gzip,
31            "br" => Encoding::Brotli,
32            "deflate" => Encoding::Deflate,
33            _ => Encoding::Unknown,
34        }
35    }
36
37    /// The header value to advertise for this coding, if any.
38    pub fn header_value(self) -> Option<&'static str> {
39        match self {
40            Encoding::Gzip => Some("gzip"),
41            Encoding::Brotli => Some("br"),
42            Encoding::Deflate => Some("deflate"),
43            Encoding::Identity | Encoding::Unknown => None,
44        }
45    }
46}
47
48/// Wrap a content-encoded byte stream so it yields the DECODED plaintext bytes.
49/// `Identity`/`Unknown` pass through unchanged (the caller decides not to repair
50/// an `Unknown`-coded body).
51pub fn decode_stream<S>(s: S, enc: Encoding) -> ByteStream
52where
53    S: Stream<Item = io::Result<Bytes>> + Send + 'static,
54{
55    use async_compression::tokio::bufread;
56    use tokio_util::io::{ReaderStream, StreamReader};
57    match enc {
58        Encoding::Identity | Encoding::Unknown => Box::pin(s),
59        Encoding::Gzip => Box::pin(ReaderStream::new(bufread::GzipDecoder::new(
60            StreamReader::new(s),
61        ))),
62        Encoding::Brotli => Box::pin(ReaderStream::new(bufread::BrotliDecoder::new(
63            StreamReader::new(s),
64        ))),
65        Encoding::Deflate => Box::pin(ReaderStream::new(bufread::ZlibDecoder::new(
66            StreamReader::new(s),
67        ))),
68    }
69}
70
71/// An in-memory `AsyncWrite` sink whose buffer we drain after each flush. The
72/// buffer is shared (Arc) so the encode loop can take its contents independent of
73/// the encoder's concrete type.
74struct SharedSink(Arc<Mutex<Vec<u8>>>);
75
76impl tokio::io::AsyncWrite for SharedSink {
77    fn poll_write(
78        self: Pin<&mut Self>,
79        _: &mut Context<'_>,
80        buf: &[u8],
81    ) -> Poll<io::Result<usize>> {
82        self.0.lock().unwrap().extend_from_slice(buf);
83        Poll::Ready(Ok(buf.len()))
84    }
85    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
86        Poll::Ready(Ok(()))
87    }
88    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
89        Poll::Ready(Ok(()))
90    }
91}
92
93/// Re-encode a plaintext byte stream with `enc`, flushing after every input chunk so
94/// output is emitted promptly (never buffered until end of stream). `Identity`/`Unknown`
95/// pass through unchanged. Uses fast compression levels (high quality costs ms per call
96/// and buys almost nothing on tiny flushed payloads).
97pub fn encode_stream<S>(s: S, enc: Encoding) -> ByteStream
98where
99    S: futures_util::Stream<Item = io::Result<Bytes>> + Send + 'static,
100{
101    use async_compression::tokio::write::{BrotliEncoder, GzipEncoder, ZlibEncoder};
102    use async_compression::Level;
103    match enc {
104        Encoding::Identity | Encoding::Unknown => Box::pin(s),
105        Encoding::Gzip => {
106            let buf = Arc::new(Mutex::new(Vec::new()));
107            encode_with(
108                GzipEncoder::with_quality(SharedSink(buf.clone()), Level::Default),
109                s,
110                buf,
111            )
112        }
113        Encoding::Brotli => {
114            let buf = Arc::new(Mutex::new(Vec::new()));
115            encode_with(
116                BrotliEncoder::with_quality(SharedSink(buf.clone()), Level::Fastest),
117                s,
118                buf,
119            )
120        }
121        Encoding::Deflate => {
122            let buf = Arc::new(Mutex::new(Vec::new()));
123            encode_with(
124                ZlibEncoder::with_quality(SharedSink(buf.clone()), Level::Default),
125                s,
126                buf,
127            )
128        }
129    }
130}
131
132fn encode_with<E, S>(encoder: E, input: S, buf: Arc<Mutex<Vec<u8>>>) -> ByteStream
133where
134    E: tokio::io::AsyncWrite + Unpin + Send + 'static,
135    S: futures_util::Stream<Item = io::Result<Bytes>> + Send + 'static,
136{
137    use futures_util::StreamExt;
138    use tokio::io::AsyncWriteExt;
139    let input = Box::pin(input);
140    Box::pin(futures_util::stream::unfold(
141        (input, encoder, buf, false),
142        |(mut input, mut encoder, buf, done)| async move {
143            if done {
144                return None;
145            }
146            match input.next().await {
147                Some(Ok(chunk)) => {
148                    if let Err(e) = encoder.write_all(&chunk).await {
149                        return Some((Err(e), (input, encoder, buf, true)));
150                    }
151                    if let Err(e) = encoder.flush().await {
152                        return Some((Err(e), (input, encoder, buf, true)));
153                    }
154                    let out = std::mem::take(&mut *buf.lock().unwrap());
155                    Some((Ok(Bytes::from(out)), (input, encoder, buf, false)))
156                }
157                Some(Err(e)) => Some((Err(e), (input, encoder, buf, true))),
158                None => {
159                    if let Err(e) = encoder.shutdown().await {
160                        return Some((Err(e), (input, encoder, buf, true)));
161                    }
162                    let out = std::mem::take(&mut *buf.lock().unwrap());
163                    if out.is_empty() {
164                        None
165                    } else {
166                        Some((Ok(Bytes::from(out)), (input, encoder, buf, true)))
167                    }
168                }
169            }
170        },
171    ))
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use bytes::Bytes;
178    use flate2::write::GzEncoder;
179    use flate2::Compression;
180    use futures::StreamExt;
181    use std::io::Write;
182
183    fn gzip(data: &[u8]) -> Vec<u8> {
184        let mut e = GzEncoder::new(Vec::new(), Compression::default());
185        e.write_all(data).unwrap();
186        e.finish().unwrap()
187    }
188
189    async fn collect(s: impl futures::Stream<Item = std::io::Result<Bytes>>) -> Vec<u8> {
190        futures::pin_mut!(s);
191        let mut out = Vec::new();
192        while let Some(item) = s.next().await {
193            out.extend_from_slice(&item.unwrap());
194        }
195        out
196    }
197
198    #[test]
199    fn parses_encoding_tokens() {
200        assert_eq!(Encoding::from_token(""), Encoding::Identity);
201        assert_eq!(Encoding::from_token("identity"), Encoding::Identity);
202        assert_eq!(Encoding::from_token("gzip"), Encoding::Gzip);
203        assert_eq!(Encoding::from_token("GZIP"), Encoding::Gzip);
204        assert_eq!(Encoding::from_token("br"), Encoding::Brotli);
205        assert_eq!(Encoding::from_token("deflate"), Encoding::Deflate);
206        assert_eq!(Encoding::from_token("weird"), Encoding::Unknown);
207    }
208
209    #[tokio::test]
210    async fn decodes_gzip_stream() {
211        let plain = b"data: {\"a\":1}\n\ndata: [DONE]\n\n";
212        let comp = gzip(plain);
213        // feed the compressed bytes in two chunks to exercise streaming
214        let mid = comp.len() / 2;
215        let chunks = vec![
216            Ok(Bytes::copy_from_slice(&comp[..mid])),
217            Ok(Bytes::copy_from_slice(&comp[mid..])),
218        ];
219        let input = futures::stream::iter(chunks);
220        let decoded = collect(decode_stream(Box::pin(input), Encoding::Gzip)).await;
221        assert_eq!(decoded, plain);
222    }
223
224    #[tokio::test]
225    async fn identity_and_unknown_passthrough() {
226        let bytes = b"raw bytes not compressed";
227        for enc in [Encoding::Identity, Encoding::Unknown] {
228            let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(bytes))]);
229            let out = collect(decode_stream(Box::pin(input), enc)).await;
230            assert_eq!(out, bytes);
231        }
232    }
233
234    #[tokio::test]
235    async fn gzip_round_trips() {
236        let plain = b"data: {\"city\":\"Paris\"}\n\ndata: [DONE]\n\n";
237        let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(plain))]);
238        let encoded = collect(encode_stream(Box::pin(input), Encoding::Gzip)).await;
239        // decode it back with the existing decoder
240        let dec_in = futures::stream::iter(vec![Ok(Bytes::from(encoded))]);
241        let decoded = collect(decode_stream(Box::pin(dec_in), Encoding::Gzip)).await;
242        assert_eq!(decoded, plain);
243    }
244
245    #[tokio::test]
246    async fn brotli_round_trips() {
247        let plain = b"hello brotli streaming world";
248        let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(plain))]);
249        let encoded = collect(encode_stream(Box::pin(input), Encoding::Brotli)).await;
250        let dec_in = futures::stream::iter(vec![Ok(Bytes::from(encoded))]);
251        let decoded = collect(decode_stream(Box::pin(dec_in), Encoding::Brotli)).await;
252        assert_eq!(decoded, plain);
253    }
254
255    #[tokio::test]
256    async fn identity_encode_passthrough() {
257        let plain = b"unchanged";
258        let input = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(plain))]);
259        let out = collect(encode_stream(Box::pin(input), Encoding::Identity)).await;
260        assert_eq!(out, plain);
261    }
262
263    /// PROMPTNESS: per-chunk flush must emit decodable output per input chunk, not
264    /// buffer everything until stream end. We feed N distinct chunks one at a time and
265    /// assert the encoder yields at least N NON-EMPTY output chunks (one flushed block
266    /// per input). A buffer-until-end encoder would yield ~1 non-empty chunk (only at
267    /// shutdown), failing this.
268    #[tokio::test]
269    async fn encoder_flushes_per_chunk() {
270        let parts: Vec<&[u8]> = vec![b"first ", b"second ", b"third"];
271        let input = futures::stream::iter(
272            parts
273                .iter()
274                .map(|p| Ok(Bytes::copy_from_slice(p)))
275                .collect::<Vec<_>>(),
276        );
277        let encoded = encode_stream(Box::pin(input), Encoding::Gzip);
278        futures::pin_mut!(encoded);
279
280        let mut nonempty = 0usize;
281        let mut all = Vec::new();
282        while let Some(item) = encoded.next().await {
283            let b = item.unwrap();
284            if !b.is_empty() {
285                nonempty += 1;
286            }
287            all.extend_from_slice(&b);
288        }
289        assert!(
290            nonempty >= parts.len(),
291            "expected >= {} non-empty flushed output chunks (per-chunk flush), got {}",
292            parts.len(),
293            nonempty
294        );
295        // and it still round-trips
296        let dec_in = futures::stream::iter(vec![Ok(Bytes::from(all))]);
297        let decoded = collect(decode_stream(Box::pin(dec_in), Encoding::Gzip)).await;
298        assert_eq!(decoded, b"first second third");
299    }
300}