Skip to main content

s2_api/v1/stream/
s2s.rs

1use std::{
2    io::{Read, Write},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use enum_ordinalize::Ordinalize;
9use flate2::{Compression, read::GzDecoder, write::GzEncoder};
10use futures::Stream;
11
12/*
13  REGULAR MESSAGE:
14  ┌─────────────┬────────────┬─────────────────────────────┐
15  │   LENGTH    │   FLAGS    │        PAYLOAD DATA         │
16  │  (3 bytes)  │  (1 byte)  │     (variable length)       │
17  ├─────────────┼────────────┼─────────────────────────────┤
18  │ 0x00 00 XX  │ 0 CA XXXXX │  Compressed proto message   │
19  └─────────────┴────────────┴─────────────────────────────┘
20
21  TERMINAL MESSAGE:
22  ┌─────────────┬────────────┬─────────────┬───────────────┐
23  │   LENGTH    │   FLAGS    │ STATUS CODE │   JSON BODY   │
24  │  (3 bytes)  │  (1 byte)  │  (2 bytes)  │  (variable)   │
25  ├─────────────┼────────────┼─────────────┼───────────────┤
26  │ 0x00 00 XX  │ 1 CA XXXXX │   HTTP Code │   JSON data   │
27  └─────────────┴────────────┴─────────────┴───────────────┘
28
29  LENGTH = size of (FLAGS + PAYLOAD), does NOT include length header itself
30  Implemented limit: 2 MiB (smaller than 24-bit protocol maximum)
31*/
32
33const LENGTH_PREFIX_SIZE: usize = 3;
34const STATUS_CODE_SIZE: usize = 2;
35const COMPRESSION_THRESHOLD_BYTES: usize = 1024; // 1 KiB
36const MAX_FRAME_BYTES: usize = 2 * 1024 * 1024; // 2 MiB
37
38/*
39Flag byte layout:
40  ┌───┬───┬───┬───┬───┬───┬───┬───┐
41  │ 7 │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │  Bit positions
42  ├───┼───┴───┼───┴───┴───┴───┴───┤
43  │ T │  C C  │   Reserved (0s)   │  Purpose
44  └───┴───────┴───────────────────┘
45
46  T = Terminal flag (1 bit)
47  C = Compression (2 bits, encodes 0-3)
48*/
49
50const FLAG_TOTAL_SIZE: usize = 1;
51// The frame length budget includes one flag byte, so payload bytes are capped at budget - flag.
52const MAX_FRAME_PAYLOAD_BYTES: usize = MAX_FRAME_BYTES - FLAG_TOTAL_SIZE;
53const MAX_DECOMPRESSED_PAYLOAD_BYTES: usize = MAX_FRAME_PAYLOAD_BYTES;
54const FLAG_TERMINAL: u8 = 0b1000_0000;
55const FLAG_COMPRESSION_MASK: u8 = 0b0110_0000;
56const FLAG_COMPRESSION_SHIFT: u8 = 5;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Ordinalize)]
59#[repr(u8)]
60pub enum CompressionAlgorithm {
61    None = 0,
62    Zstd = 1,
63    Gzip = 2,
64}
65
66impl CompressionAlgorithm {
67    pub fn from_accept_encoding(headers: &http::HeaderMap) -> Self {
68        let mut gzip = false;
69        for header_value in headers.get_all(http::header::ACCEPT_ENCODING) {
70            if let Ok(value) = header_value.to_str() {
71                for encoding in value.split(',') {
72                    let encoding = encoding.trim().split(';').next().unwrap_or("").trim();
73                    if encoding.eq_ignore_ascii_case("zstd") {
74                        return Self::Zstd;
75                    } else if encoding.eq_ignore_ascii_case("gzip") {
76                        gzip = true;
77                    }
78                }
79            }
80        }
81        if gzip { Self::Gzip } else { Self::None }
82    }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct CompressedData {
87    compression: CompressionAlgorithm,
88    payload: Bytes,
89}
90
91impl CompressedData {
92    pub fn for_proto(
93        compression: CompressionAlgorithm,
94        proto: &impl prost::Message,
95    ) -> std::io::Result<Self> {
96        Self::compress(compression, proto.encode_to_vec())
97    }
98
99    fn compress(compression: CompressionAlgorithm, data: Vec<u8>) -> std::io::Result<Self> {
100        if data.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
101            return Err(std::io::Error::new(
102                std::io::ErrorKind::InvalidInput,
103                "payload exceeds decompressed limit",
104            ));
105        }
106
107        if compression == CompressionAlgorithm::None || data.len() < COMPRESSION_THRESHOLD_BYTES {
108            return Ok(Self {
109                compression: CompressionAlgorithm::None,
110                payload: data.into(),
111            });
112        }
113        let mut buf = Vec::with_capacity(data.len());
114        match compression {
115            CompressionAlgorithm::Gzip => {
116                let mut encoder = GzEncoder::new(buf, Compression::default());
117                encoder.write_all(data.as_slice())?;
118                buf = encoder.finish()?;
119            }
120            CompressionAlgorithm::Zstd => {
121                zstd::stream::copy_encode(data.as_slice(), &mut buf, 0)?;
122            }
123            CompressionAlgorithm::None => unreachable!("handled above"),
124        };
125        let payload = Bytes::from(buf.into_boxed_slice());
126        if payload.len() > MAX_FRAME_PAYLOAD_BYTES {
127            return Err(std::io::Error::new(
128                std::io::ErrorKind::InvalidInput,
129                "compressed payload exceeds frame limit",
130            ));
131        }
132        Ok(Self {
133            compression,
134            payload,
135        })
136    }
137
138    fn decompressed(self) -> std::io::Result<Bytes> {
139        let initial_capacity = self
140            .payload
141            .len()
142            .saturating_mul(2)
143            .clamp(COMPRESSION_THRESHOLD_BYTES, MAX_DECOMPRESSED_PAYLOAD_BYTES);
144
145        // Decode at most `MAX_DECOMPRESSED_PAYLOAD_BYTES + 1` bytes
146        fn read_to_end_limited(
147            mut reader: impl Read,
148            initial_capacity: usize,
149        ) -> std::io::Result<Bytes> {
150            let mut limited = reader
151                .by_ref()
152                .take((MAX_DECOMPRESSED_PAYLOAD_BYTES + 1) as u64);
153            let mut buf = Vec::with_capacity(initial_capacity);
154            limited.read_to_end(&mut buf)?;
155            if buf.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
156                return Err(std::io::Error::new(
157                    std::io::ErrorKind::InvalidData,
158                    "decompressed payload exceeds limit",
159                ));
160            }
161            Ok(Bytes::from(buf.into_boxed_slice()))
162        }
163
164        match self.compression {
165            CompressionAlgorithm::None => {
166                if self.payload.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
167                    return Err(std::io::Error::new(
168                        std::io::ErrorKind::InvalidData,
169                        "decompressed payload exceeds limit",
170                    ));
171                }
172                Ok(self.payload)
173            }
174            CompressionAlgorithm::Gzip => {
175                let mut decoder = GzDecoder::new(&self.payload[..]);
176                read_to_end_limited(&mut decoder, initial_capacity)
177            }
178            CompressionAlgorithm::Zstd => {
179                let mut decoder = zstd::stream::Decoder::new(&self.payload[..])?;
180                read_to_end_limited(&mut decoder, initial_capacity)
181            }
182        }
183    }
184
185    pub fn try_into_proto<P: prost::Message + Default>(self) -> std::io::Result<P> {
186        let payload = self.decompressed()?;
187        P::decode(payload.as_ref())
188            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
189    }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct TerminalMessage {
194    pub status: u16,
195    pub body: String,
196}
197
198#[derive(Debug, Clone, PartialEq, Eq)]
199pub enum SessionMessage {
200    Regular(CompressedData),
201    Terminal(TerminalMessage),
202}
203
204impl From<CompressedData> for SessionMessage {
205    fn from(data: CompressedData) -> Self {
206        Self::Regular(data)
207    }
208}
209
210impl From<TerminalMessage> for SessionMessage {
211    fn from(msg: TerminalMessage) -> Self {
212        Self::Terminal(msg)
213    }
214}
215
216impl SessionMessage {
217    pub fn regular(
218        compression: CompressionAlgorithm,
219        proto: &impl prost::Message,
220    ) -> std::io::Result<Self> {
221        Ok(Self::Regular(CompressedData::for_proto(
222            compression,
223            proto,
224        )?))
225    }
226
227    pub fn encode(&self) -> Bytes {
228        let encoded_size = FLAG_TOTAL_SIZE + self.payload_size();
229        assert!(
230            encoded_size <= MAX_FRAME_BYTES,
231            "payload exceeds encoder limit"
232        );
233        let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + encoded_size);
234        buf.put_uint(encoded_size as u64, 3);
235        match self {
236            Self::Regular(msg) => {
237                let flag =
238                    (msg.compression.ordinal() << FLAG_COMPRESSION_SHIFT) & FLAG_COMPRESSION_MASK;
239                buf.put_u8(flag);
240                buf.extend_from_slice(&msg.payload);
241            }
242            Self::Terminal(msg) => {
243                buf.put_u8(FLAG_TERMINAL);
244                buf.put_u16(msg.status);
245                buf.extend_from_slice(msg.body.as_bytes());
246            }
247        }
248        buf.freeze()
249    }
250
251    fn decode_message(mut buf: Bytes) -> std::io::Result<Self> {
252        if buf.is_empty() {
253            return Err(std::io::Error::new(
254                std::io::ErrorKind::UnexpectedEof,
255                "empty frame payload",
256            ));
257        }
258        let flag = buf.get_u8();
259
260        let is_terminal = (flag & FLAG_TERMINAL) != 0;
261        if is_terminal {
262            if buf.len() < STATUS_CODE_SIZE {
263                return Err(std::io::Error::new(
264                    std::io::ErrorKind::InvalidData,
265                    "terminal message missing status code",
266                ));
267            }
268            let status = buf.get_u16();
269            let body = String::from_utf8(buf.into()).map_err(|_| {
270                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid utf-8")
271            })?;
272            return Ok(TerminalMessage { status, body }.into());
273        }
274
275        let compression_bits = (flag & FLAG_COMPRESSION_MASK) >> FLAG_COMPRESSION_SHIFT;
276        let Some(compression) = CompressionAlgorithm::from_ordinal(compression_bits) else {
277            return Err(std::io::Error::new(
278                std::io::ErrorKind::InvalidData,
279                "unknown compression algorithm",
280            ));
281        };
282
283        Ok(CompressedData {
284            compression,
285            payload: buf,
286        }
287        .into())
288    }
289
290    fn payload_size(&self) -> usize {
291        match self {
292            Self::Regular(msg) => msg.payload.len(),
293            Self::Terminal(msg) => STATUS_CODE_SIZE + msg.body.len(),
294        }
295    }
296}
297
298pub struct FramedMessageStream<S> {
299    inner: S,
300    compression: CompressionAlgorithm,
301    terminated: bool,
302}
303
304impl<S> FramedMessageStream<S> {
305    pub fn new(compression: CompressionAlgorithm, inner: S) -> Self {
306        Self {
307            inner,
308            compression,
309            terminated: false,
310        }
311    }
312}
313
314impl<S, P, E> Stream for FramedMessageStream<S>
315where
316    S: Stream<Item = Result<P, E>> + Unpin,
317    P: prost::Message,
318    E: Into<TerminalMessage>,
319{
320    type Item = std::io::Result<Bytes>;
321
322    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
323        if self.terminated {
324            return Poll::Ready(None);
325        }
326
327        match Pin::new(&mut self.inner).poll_next(cx) {
328            Poll::Ready(Some(Ok(item))) => match SessionMessage::regular(self.compression, &item) {
329                Ok(msg) => Poll::Ready(Some(Ok(msg.encode()))),
330                Err(err) => {
331                    self.terminated = true;
332                    Poll::Ready(Some(Err(err)))
333                }
334            },
335            Poll::Ready(Some(Err(e))) => {
336                self.terminated = true;
337                let bytes = SessionMessage::Terminal(e.into()).encode();
338                Poll::Ready(Some(Ok(bytes)))
339            }
340            Poll::Ready(None) => {
341                self.terminated = true;
342                Poll::Ready(None)
343            }
344            Poll::Pending => Poll::Pending,
345        }
346    }
347}
348
349pub struct FrameDecoder;
350
351impl tokio_util::codec::Decoder for FrameDecoder {
352    type Item = SessionMessage;
353    type Error = std::io::Error;
354
355    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
356        if src.len() < LENGTH_PREFIX_SIZE {
357            return Ok(None);
358        }
359
360        let length = ((src[0] as usize) << 16) | ((src[1] as usize) << 8) | (src[2] as usize);
361
362        if length > MAX_FRAME_BYTES {
363            return Err(std::io::Error::new(
364                std::io::ErrorKind::InvalidInput,
365                "frame exceeds decode limit",
366            ));
367        }
368
369        let total_size = LENGTH_PREFIX_SIZE + length;
370        if src.len() < total_size {
371            return Ok(None);
372        }
373
374        src.advance(LENGTH_PREFIX_SIZE);
375        let frame_bytes = src.split_to(length).freeze();
376        Ok(Some(SessionMessage::decode_message(frame_bytes)?))
377    }
378}
379
380#[cfg(test)]
381mod test {
382    use std::{
383        io,
384        pin::Pin,
385        task::{Context, Poll},
386    };
387
388    use bytes::BytesMut;
389    use futures::StreamExt;
390    use http::HeaderValue;
391    use proptest::{collection::vec, prelude::*};
392    use prost::Message;
393    use tokio_util::codec::Decoder;
394
395    use super::*;
396
397    #[derive(Clone, PartialEq, prost::Message)]
398    struct TestProto {
399        #[prost(bytes, tag = "1")]
400        payload: Vec<u8>,
401    }
402
403    impl TestProto {
404        fn new(payload: Vec<u8>) -> Self {
405            Self { payload }
406        }
407    }
408
409    #[derive(Debug, Clone)]
410    struct TestError {
411        status: u16,
412        body: &'static str,
413    }
414
415    impl From<TestError> for TerminalMessage {
416        fn from(val: TestError) -> Self {
417            TerminalMessage {
418                status: val.status,
419                body: val.body.to_string(),
420            }
421        }
422    }
423
424    fn decode_once(bytes: &Bytes) -> io::Result<SessionMessage> {
425        let mut decoder = FrameDecoder;
426        let mut buf = BytesMut::from(bytes.as_ref());
427        decoder
428            .decode(&mut buf)?
429            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "frame incomplete"))
430    }
431
432    fn compression_strategy() -> impl proptest::strategy::Strategy<Value = CompressionAlgorithm> {
433        prop_oneof![
434            Just(CompressionAlgorithm::None),
435            Just(CompressionAlgorithm::Gzip),
436            Just(CompressionAlgorithm::Zstd),
437        ]
438    }
439
440    fn chunk_bytes(data: &Bytes, pattern: &[usize]) -> Vec<Bytes> {
441        let mut chunks = Vec::new();
442        let mut offset = 0;
443        for &hint in pattern {
444            if offset >= data.len() {
445                break;
446            }
447            let remaining = data.len() - offset;
448            let take = (hint % remaining).saturating_add(1).min(remaining);
449            chunks.push(data.slice(offset..offset + take));
450            offset += take;
451        }
452        if offset < data.len() {
453            chunks.push(data.slice(offset..));
454        }
455        if chunks.is_empty() {
456            chunks.push(data.clone());
457        }
458        chunks
459    }
460
461    proptest! {
462        #[test]
463        fn regular_session_message_round_trips_proptest(
464            algo in compression_strategy(),
465            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4)
466        ) {
467            let proto = TestProto::new(payload.clone());
468            let msg = SessionMessage::regular(algo, &proto).unwrap();
469            let encoded = msg.encode();
470            let decoded = decode_once(&encoded).unwrap();
471
472            prop_assert!(matches!(decoded, SessionMessage::Regular(_)));
473            let SessionMessage::Regular(data) = decoded else { unreachable!() };
474
475            let expected_compression = if algo == CompressionAlgorithm::None || proto.encoded_len() < COMPRESSION_THRESHOLD_BYTES {
476                CompressionAlgorithm::None
477            } else {
478                algo
479            };
480            let actual_compression = data.compression;
481
482            let restored = data.try_into_proto::<TestProto>().unwrap();
483            prop_assert_eq!(restored.payload, payload);
484            prop_assert_eq!(actual_compression, expected_compression);
485        }
486
487        #[test]
488        fn frame_decoder_handles_chunked_frames(
489            algo in compression_strategy(),
490            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4),
491            chunk_pattern in vec(0usize..=16, 0..=16)
492        ) {
493            let proto = TestProto::new(payload);
494            let msg = SessionMessage::regular(algo, &proto).unwrap();
495            let encoded = msg.encode();
496            let expected = decode_once(&encoded).unwrap();
497
498            let chunks = chunk_bytes(&encoded, &chunk_pattern);
499            prop_assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), encoded.len());
500
501            let mut decoder = FrameDecoder;
502            let mut buf = BytesMut::new();
503            let mut decoded = None;
504
505            for (idx, chunk) in chunks.iter().enumerate() {
506                buf.extend_from_slice(chunk.as_ref());
507                let result = decoder.decode(&mut buf).expect("decode invocation failed");
508                if idx < chunks.len() - 1 {
509                    prop_assert!(result.is_none());
510                } else {
511                    let message = result.expect("final chunk should produce frame");
512                    prop_assert!(buf.is_empty());
513                    decoded = Some(message);
514                }
515            }
516
517            let decoded = decoded.expect("decoder never emitted frame");
518            prop_assert_eq!(decoded, expected);
519        }
520    }
521
522    #[test]
523    fn from_accept_encoding_prefers_zstd() {
524        let mut headers = http::HeaderMap::new();
525        headers.insert(
526            http::header::ACCEPT_ENCODING,
527            HeaderValue::from_static("gzip, zstd, br"),
528        );
529
530        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
531        assert_eq!(algo, CompressionAlgorithm::Zstd);
532    }
533
534    #[test]
535    fn from_accept_encoding_falls_back_to_gzip() {
536        let mut headers = http::HeaderMap::new();
537        headers.insert(
538            http::header::ACCEPT_ENCODING,
539            HeaderValue::from_static("gzip;q=0.8, deflate"),
540        );
541
542        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
543        assert_eq!(algo, CompressionAlgorithm::Gzip);
544    }
545
546    #[test]
547    fn from_accept_encoding_defaults_to_none() {
548        let headers = http::HeaderMap::new();
549        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
550        assert_eq!(algo, CompressionAlgorithm::None);
551    }
552
553    #[test]
554    fn regular_session_message_round_trips() {
555        let proto = TestProto::new(vec![1, 2, 3, 4]);
556        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
557        let encoded = msg.encode();
558        let decoded = decode_once(&encoded).unwrap();
559
560        match decoded {
561            SessionMessage::Regular(data) => {
562                assert_eq!(data.compression, CompressionAlgorithm::None);
563                let restored = data.try_into_proto::<TestProto>().unwrap();
564                assert_eq!(restored, proto);
565            }
566            SessionMessage::Terminal(_) => panic!("expected regular message"),
567        }
568    }
569
570    #[test]
571    fn terminal_session_message_round_trips() {
572        let terminal = TerminalMessage {
573            status: 418,
574            body: "short-circuit".to_string(),
575        };
576        let msg = SessionMessage::from(terminal.clone());
577        let encoded = msg.encode();
578        let decoded = decode_once(&encoded).unwrap();
579
580        match decoded {
581            SessionMessage::Regular(_) => panic!("expected terminal message"),
582            SessionMessage::Terminal(decoded_terminal) => {
583                assert_eq!(decoded_terminal, terminal);
584            }
585        }
586    }
587
588    #[test]
589    fn frame_decoder_waits_for_complete_frame() {
590        let proto = TestProto::new(vec![9, 9, 9]);
591        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
592        let encoded = msg.encode();
593        let mut decoder = FrameDecoder;
594
595        let split_idx = encoded.len() - 1;
596        let mut buf = BytesMut::from(&encoded[..split_idx]);
597        assert!(decoder.decode(&mut buf).unwrap().is_none());
598        buf.extend_from_slice(&encoded[split_idx..]);
599        let decoded = decoder.decode(&mut buf).unwrap().unwrap();
600
601        match decoded {
602            SessionMessage::Regular(data) => {
603                let restored = data.try_into_proto::<TestProto>().unwrap();
604                assert_eq!(restored, proto);
605            }
606            SessionMessage::Terminal(_) => panic!("expected regular message"),
607        }
608        assert!(buf.is_empty());
609    }
610
611    #[test]
612    fn frame_decoder_rejects_frames_exceeding_decode_limit() {
613        let length = MAX_FRAME_BYTES + 1;
614        let prefix = [
615            ((length >> 16) & 0xFF) as u8,
616            ((length >> 8) & 0xFF) as u8,
617            (length & 0xFF) as u8,
618        ];
619        let mut buf = BytesMut::from(prefix.as_slice());
620        let mut decoder = FrameDecoder;
621        let err = decoder.decode(&mut buf).unwrap_err();
622        assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
623    }
624
625    #[test]
626    #[should_panic(expected = "encoder limit")]
627    fn session_message_encode_rejects_frames_over_limit() {
628        let data = CompressedData {
629            compression: CompressionAlgorithm::None,
630            payload: Bytes::from(vec![0u8; MAX_FRAME_BYTES]),
631        };
632        let msg = SessionMessage::from(data);
633        let _ = msg.encode();
634    }
635
636    #[test]
637    fn frame_decoder_rejects_unknown_compression() {
638        let mut raw = vec![0, 0, 1];
639        raw.push(0x60);
640        let mut decoder = FrameDecoder;
641        let mut buf = BytesMut::from(raw.as_slice());
642        let err = decoder.decode(&mut buf).unwrap_err();
643        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
644    }
645
646    #[test]
647    fn frame_decoder_rejects_terminal_without_status() {
648        let mut raw = vec![0, 0, 1];
649        raw.push(FLAG_TERMINAL);
650        let mut decoder = FrameDecoder;
651        let mut buf = BytesMut::from(raw.as_slice());
652        let err = decoder.decode(&mut buf).unwrap_err();
653        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
654    }
655
656    #[test]
657    fn frame_decoder_handles_empty_payload() {
658        let raw = vec![0, 0, 0];
659        let mut decoder = FrameDecoder;
660        let mut buf = BytesMut::from(raw.as_slice());
661        let err = decoder.decode(&mut buf).unwrap_err();
662        assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
663    }
664
665    #[test]
666    fn compressed_data_round_trip_gzip() {
667        let payload = vec![42; 1_200_000];
668        let proto = TestProto::new(payload.clone());
669        let msg = SessionMessage::regular(CompressionAlgorithm::Gzip, &proto).unwrap();
670        let encoded = msg.encode();
671        let decoded = decode_once(&encoded).unwrap();
672
673        match decoded {
674            SessionMessage::Regular(data) => {
675                assert_eq!(data.compression, CompressionAlgorithm::Gzip);
676                assert!(data.payload.len() < proto.encode_to_vec().len());
677                let restored = data.try_into_proto::<TestProto>().unwrap();
678                assert_eq!(restored.payload, payload);
679            }
680            SessionMessage::Terminal(_) => panic!("expected regular message"),
681        }
682    }
683
684    #[test]
685    fn compressed_data_round_trip_zstd() {
686        let payload = vec![7; 1_100_000];
687        let proto = TestProto::new(payload.clone());
688        let msg = SessionMessage::regular(CompressionAlgorithm::Zstd, &proto).unwrap();
689        let encoded = msg.encode();
690        let decoded = decode_once(&encoded).unwrap();
691
692        match decoded {
693            SessionMessage::Regular(data) => {
694                assert_eq!(data.compression, CompressionAlgorithm::Zstd);
695                assert!(data.payload.len() < proto.encode_to_vec().len());
696                let restored = data.try_into_proto::<TestProto>().unwrap();
697                assert_eq!(restored.payload, payload);
698            }
699            SessionMessage::Terminal(_) => panic!("expected regular message"),
700        }
701    }
702
703    #[test]
704    fn decompression_rejects_payloads_exceeding_limit() {
705        let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
706        let proto = TestProto::new(payload);
707        let encoded = proto.encode_to_vec();
708
709        for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
710            let compressed = match algo {
711                CompressionAlgorithm::Gzip => {
712                    let mut out = Vec::new();
713                    let mut encoder = GzEncoder::new(&mut out, Compression::default());
714                    encoder.write_all(encoded.as_slice()).unwrap();
715                    encoder.finish().unwrap();
716                    out
717                }
718                CompressionAlgorithm::Zstd => {
719                    let mut out = Vec::new();
720                    zstd::stream::copy_encode(encoded.as_slice(), &mut out, 0).unwrap();
721                    out
722                }
723                CompressionAlgorithm::None => unreachable!("explicitly excluded in test"),
724            };
725
726            let data = CompressedData {
727                compression: algo,
728                payload: Bytes::from(compressed),
729            };
730            assert!(data.payload.len() <= MAX_FRAME_PAYLOAD_BYTES);
731
732            let err = data.try_into_proto::<TestProto>().expect_err("should fail");
733            assert_eq!(err.kind(), io::ErrorKind::InvalidData);
734            assert!(
735                err.to_string()
736                    .contains("decompressed payload exceeds limit")
737            );
738        }
739    }
740
741    #[test]
742    fn compress_rejects_payloads_exceeding_decompressed_limit() {
743        let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
744        let proto = TestProto::new(payload);
745
746        let err = CompressedData::compress(CompressionAlgorithm::Gzip, proto.encode_to_vec())
747            .expect_err("should fail");
748        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
749        assert!(
750            err.to_string()
751                .contains("payload exceeds decompressed limit")
752        );
753    }
754
755    #[test]
756    fn compress_allows_payload_at_exact_limit_without_encode_panic() {
757        let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES];
758        let data = CompressedData::compress(CompressionAlgorithm::None, payload).unwrap();
759        let encoded = SessionMessage::from(data).encode();
760        assert_eq!(encoded.len(), LENGTH_PREFIX_SIZE + MAX_FRAME_BYTES);
761    }
762
763    #[test]
764    fn compress_rejects_incompressible_payload_that_exceeds_frame_limit_after_compression() {
765        let mut payload = vec![0u8; MAX_DECOMPRESSED_PAYLOAD_BYTES];
766        let mut x = 0x1234_5678u32;
767        for byte in &mut payload {
768            x ^= x << 13;
769            x ^= x >> 17;
770            x ^= x << 5;
771            *byte = (x & 0xFF) as u8;
772        }
773
774        for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
775            let err = CompressedData::compress(algo, payload.clone()).expect_err("should fail");
776            assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
777            assert!(
778                err.to_string()
779                    .contains("compressed payload exceeds frame limit")
780            );
781        }
782    }
783
784    #[test]
785    fn framed_message_stream_yields_terminal_on_error() {
786        let proto = TestProto::new(vec![1, 2, 3]);
787        let items = vec![
788            Ok(proto.clone()),
789            Err(TestError {
790                status: 500,
791                body: "boom",
792            }),
793            Ok(proto.clone()),
794        ];
795
796        let stream = futures::stream::iter(items);
797        let framed = FramedMessageStream::new(CompressionAlgorithm::None, stream);
798        let outputs = futures::executor::block_on(async {
799            framed.collect::<Vec<std::io::Result<Bytes>>>().await
800        });
801
802        assert_eq!(outputs.len(), 2);
803
804        let first = outputs[0].as_ref().expect("first frame ok");
805        match decode_once(first).unwrap() {
806            SessionMessage::Regular(data) => {
807                let restored = data.try_into_proto::<TestProto>().unwrap();
808                assert_eq!(restored, proto);
809            }
810            SessionMessage::Terminal(_) => panic!("expected regular message"),
811        }
812
813        let second = outputs[1].as_ref().expect("second frame ok");
814        match decode_once(second).unwrap() {
815            SessionMessage::Regular(_) => panic!("expected terminal message"),
816            SessionMessage::Terminal(term) => {
817                assert_eq!(term.status, 500);
818                assert_eq!(term.body, "boom");
819            }
820        }
821    }
822
823    #[test]
824    fn framed_message_stream_stops_after_termination() {
825        let mut stream = FramedMessageStream::new(
826            CompressionAlgorithm::None,
827            futures::stream::iter(vec![
828                Ok(TestProto::new(vec![0])),
829                Err(TestError {
830                    status: 400,
831                    body: "bad",
832                }),
833            ]),
834        );
835
836        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
837
838        match Pin::new(&mut stream).poll_next(&mut cx) {
839            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
840                SessionMessage::Regular(_) => {}
841                SessionMessage::Terminal(_) => panic!("expected regular message"),
842            },
843            other => panic!("unexpected poll result: {other:?}"),
844        }
845
846        match Pin::new(&mut stream).poll_next(&mut cx) {
847            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
848                SessionMessage::Terminal(term) => {
849                    assert_eq!(term.status, 400);
850                    assert_eq!(term.body, "bad");
851                }
852                SessionMessage::Regular(_) => panic!("expected terminal message"),
853            },
854            other => panic!("unexpected poll result: {other:?}"),
855        }
856
857        match Pin::new(&mut stream).poll_next(&mut cx) {
858            Poll::Ready(None) => {}
859            other => panic!("expected stream to terminate, got {other:?}"),
860        }
861    }
862
863    #[test]
864    fn framed_message_stream_terminates_after_encoding_error() {
865        let oversized = MAX_DECOMPRESSED_PAYLOAD_BYTES + 1;
866        let items: Vec<Result<TestProto, TestError>> = vec![
867            Ok(TestProto::new(vec![0u8; oversized])),
868            Ok(TestProto::new(vec![1u8; oversized])),
869        ];
870        let mut stream =
871            FramedMessageStream::new(CompressionAlgorithm::None, futures::stream::iter(items));
872
873        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
874
875        match Pin::new(&mut stream).poll_next(&mut cx) {
876            Poll::Ready(Some(Err(err))) => {
877                assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
878                assert!(
879                    err.to_string()
880                        .contains("payload exceeds decompressed limit")
881                );
882            }
883            other => panic!("expected encoding error, got {other:?}"),
884        }
885
886        match Pin::new(&mut stream).poll_next(&mut cx) {
887            Poll::Ready(None) => {}
888            other => panic!("expected stream to terminate after encoding error, got {other:?}"),
889        }
890    }
891}