Skip to main content

s2_api/v1/stream/
s2s.rs

1use std::{
2    io::{Read, Write},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use enum_ordinalize::Ordinalize;
9use flate2::{Compression, read::GzDecoder, write::GzEncoder};
10use futures::Stream;
11
12/*
13  REGULAR MESSAGE:
14  ┌─────────────┬────────────┬─────────────────────────────┐
15  │   LENGTH    │   FLAGS    │        PAYLOAD DATA         │
16  │  (3 bytes)  │  (1 byte)  │     (variable length)       │
17  ├─────────────┼────────────┼─────────────────────────────┤
18  │ 0x00 00 XX  │ 0 CA XXXXX │  Compressed proto message   │
19  └─────────────┴────────────┴─────────────────────────────┘
20
21  TERMINAL MESSAGE:
22  ┌─────────────┬────────────┬─────────────┬───────────────┐
23  │   LENGTH    │   FLAGS    │ STATUS CODE │   JSON BODY   │
24  │  (3 bytes)  │  (1 byte)  │  (2 bytes)  │  (variable)   │
25  ├─────────────┼────────────┼─────────────┼───────────────┤
26  │ 0x00 00 XX  │ 1 CA XXXXX │   HTTP Code │   JSON data   │
27  └─────────────┴────────────┴─────────────┴───────────────┘
28
29  LENGTH = size of (FLAGS + PAYLOAD), does NOT include length header itself
30  Implemented limit: 2 MiB (smaller than 24-bit protocol maximum)
31*/
32
33const LENGTH_PREFIX_SIZE: usize = 3;
34const STATUS_CODE_SIZE: usize = 2;
35const COMPRESSION_THRESHOLD_BYTES: usize = 1024; // 1 KiB
36const MAX_FRAME_BYTES: usize = 2 * 1024 * 1024; // 2 MiB
37
38/*
39Flag byte layout:
40  ┌───┬───┬───┬───┬───┬───┬───┬───┐
41  │ 7 │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │  Bit positions
42  ├───┼───┴───┼───┴───┴───┴───┴───┤
43  │ T │  C C  │   Reserved (0s)   │  Purpose
44  └───┴───────┴───────────────────┘
45
46  T = Terminal flag (1 bit)
47  C = Compression (2 bits, encodes 0-3)
48*/
49
50const FLAG_TOTAL_SIZE: usize = 1;
51// The frame length budget includes one flag byte, so payload bytes are capped at budget - flag.
52const MAX_FRAME_PAYLOAD_BYTES: usize = MAX_FRAME_BYTES - FLAG_TOTAL_SIZE;
53const MAX_DECOMPRESSED_PAYLOAD_BYTES: usize = MAX_FRAME_PAYLOAD_BYTES;
54const FLAG_TERMINAL: u8 = 0b1000_0000;
55const FLAG_COMPRESSION_MASK: u8 = 0b0110_0000;
56const FLAG_COMPRESSION_SHIFT: u8 = 5;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Ordinalize)]
59#[repr(u8)]
60pub enum CompressionAlgorithm {
61    None = 0,
62    Zstd = 1,
63    Gzip = 2,
64}
65
66impl CompressionAlgorithm {
67    pub fn from_accept_encoding(headers: &http::HeaderMap) -> Self {
68        let mut gzip = false;
69        for header_value in headers.get_all(http::header::ACCEPT_ENCODING) {
70            if let Ok(value) = header_value.to_str() {
71                for encoding in value.split(',') {
72                    let encoding = encoding.trim().split(';').next().unwrap_or("").trim();
73                    if encoding.eq_ignore_ascii_case("zstd") {
74                        return Self::Zstd;
75                    } else if encoding.eq_ignore_ascii_case("gzip") {
76                        gzip = true;
77                    }
78                }
79            }
80        }
81        if gzip { Self::Gzip } else { Self::None }
82    }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct CompressedData {
87    compression: CompressionAlgorithm,
88    payload: Bytes,
89}
90
91impl CompressedData {
92    pub fn for_proto(
93        compression: CompressionAlgorithm,
94        proto: &impl prost::Message,
95    ) -> std::io::Result<Self> {
96        Self::compress(compression, proto.encode_to_vec())
97    }
98
99    fn compress(compression: CompressionAlgorithm, data: Vec<u8>) -> std::io::Result<Self> {
100        if data.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
101            return Err(std::io::Error::new(
102                std::io::ErrorKind::InvalidInput,
103                "payload exceeds decompressed limit",
104            ));
105        }
106
107        if compression == CompressionAlgorithm::None || data.len() < COMPRESSION_THRESHOLD_BYTES {
108            return Ok(Self {
109                compression: CompressionAlgorithm::None,
110                payload: data.into(),
111            });
112        }
113        let mut buf = Vec::with_capacity(data.len());
114        match compression {
115            CompressionAlgorithm::Gzip => {
116                let mut encoder = GzEncoder::new(buf, Compression::default());
117                encoder.write_all(data.as_slice())?;
118                buf = encoder.finish()?;
119            }
120            CompressionAlgorithm::Zstd => {
121                zstd::stream::copy_encode(data.as_slice(), &mut buf, 0)?;
122            }
123            CompressionAlgorithm::None => unreachable!("handled above"),
124        };
125        let payload = Bytes::from(buf.into_boxed_slice());
126        if payload.len() > MAX_FRAME_PAYLOAD_BYTES {
127            return Err(std::io::Error::new(
128                std::io::ErrorKind::InvalidInput,
129                "compressed payload exceeds frame limit",
130            ));
131        }
132        Ok(Self {
133            compression,
134            payload,
135        })
136    }
137
138    fn decompressed(self) -> std::io::Result<Bytes> {
139        let initial_capacity = self
140            .payload
141            .len()
142            .saturating_mul(2)
143            .clamp(COMPRESSION_THRESHOLD_BYTES, MAX_DECOMPRESSED_PAYLOAD_BYTES);
144
145        // Decode at most `MAX_DECOMPRESSED_PAYLOAD_BYTES + 1` bytes
146        fn read_to_end_limited(
147            mut reader: impl Read,
148            initial_capacity: usize,
149        ) -> std::io::Result<Bytes> {
150            let mut limited = reader
151                .by_ref()
152                .take((MAX_DECOMPRESSED_PAYLOAD_BYTES + 1) as u64);
153            let mut buf = Vec::with_capacity(initial_capacity);
154            limited.read_to_end(&mut buf)?;
155            if buf.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
156                return Err(std::io::Error::new(
157                    std::io::ErrorKind::InvalidData,
158                    "decompressed payload exceeds limit",
159                ));
160            }
161            Ok(Bytes::from(buf.into_boxed_slice()))
162        }
163
164        match self.compression {
165            CompressionAlgorithm::None => {
166                if self.payload.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
167                    return Err(std::io::Error::new(
168                        std::io::ErrorKind::InvalidData,
169                        "decompressed payload exceeds limit",
170                    ));
171                }
172                Ok(self.payload)
173            }
174            CompressionAlgorithm::Gzip => {
175                let mut decoder = GzDecoder::new(&self.payload[..]);
176                read_to_end_limited(&mut decoder, initial_capacity)
177            }
178            CompressionAlgorithm::Zstd => {
179                let mut decoder = zstd::stream::Decoder::new(&self.payload[..])?;
180                read_to_end_limited(&mut decoder, initial_capacity)
181            }
182        }
183    }
184
185    pub fn try_into_proto<P: prost::Message + Default>(self) -> std::io::Result<P> {
186        let payload = self.decompressed()?;
187        P::decode(payload.as_ref())
188            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
189    }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct TerminalMessage {
194    pub status: u16,
195    pub body: String,
196}
197
198#[derive(Debug, Clone, PartialEq, Eq)]
199pub enum SessionMessage {
200    Regular(CompressedData),
201    Terminal(TerminalMessage),
202}
203
204impl From<CompressedData> for SessionMessage {
205    fn from(data: CompressedData) -> Self {
206        Self::Regular(data)
207    }
208}
209
210impl From<TerminalMessage> for SessionMessage {
211    fn from(msg: TerminalMessage) -> Self {
212        Self::Terminal(msg)
213    }
214}
215
216impl SessionMessage {
217    pub fn regular(
218        compression: CompressionAlgorithm,
219        proto: &impl prost::Message,
220    ) -> std::io::Result<Self> {
221        Ok(Self::Regular(CompressedData::for_proto(
222            compression,
223            proto,
224        )?))
225    }
226
227    pub fn encode(&self) -> Bytes {
228        let encoded_size = FLAG_TOTAL_SIZE + self.payload_size();
229        assert!(
230            encoded_size <= MAX_FRAME_BYTES,
231            "payload exceeds encoder limit"
232        );
233        let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + encoded_size);
234        buf.put_uint(encoded_size as u64, 3);
235        match self {
236            Self::Regular(msg) => {
237                let flag =
238                    (msg.compression.ordinal() << FLAG_COMPRESSION_SHIFT) & FLAG_COMPRESSION_MASK;
239                buf.put_u8(flag);
240                buf.extend_from_slice(&msg.payload);
241            }
242            Self::Terminal(msg) => {
243                buf.put_u8(FLAG_TERMINAL);
244                buf.put_u16(msg.status);
245                buf.extend_from_slice(msg.body.as_bytes());
246            }
247        }
248        buf.freeze()
249    }
250
251    fn decode_message(mut buf: Bytes) -> std::io::Result<Self> {
252        if buf.is_empty() {
253            return Err(std::io::Error::new(
254                std::io::ErrorKind::UnexpectedEof,
255                "empty frame payload",
256            ));
257        }
258        let flag = buf.get_u8();
259
260        let is_terminal = (flag & FLAG_TERMINAL) != 0;
261        if is_terminal {
262            if buf.len() < STATUS_CODE_SIZE {
263                return Err(std::io::Error::new(
264                    std::io::ErrorKind::InvalidData,
265                    "terminal message missing status code",
266                ));
267            }
268            let status = buf.get_u16();
269            let body = String::from_utf8(buf.into()).map_err(|_| {
270                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid utf-8")
271            })?;
272            return Ok(TerminalMessage { status, body }.into());
273        }
274
275        let compression_bits = (flag & FLAG_COMPRESSION_MASK) >> FLAG_COMPRESSION_SHIFT;
276        let Some(compression) = CompressionAlgorithm::from_ordinal(compression_bits) else {
277            return Err(std::io::Error::new(
278                std::io::ErrorKind::InvalidData,
279                "unknown compression algorithm",
280            ));
281        };
282
283        Ok(CompressedData {
284            compression,
285            payload: buf,
286        }
287        .into())
288    }
289
290    fn payload_size(&self) -> usize {
291        match self {
292            Self::Regular(msg) => msg.payload.len(),
293            Self::Terminal(msg) => STATUS_CODE_SIZE + msg.body.len(),
294        }
295    }
296}
297
298pub struct FramedMessageStream<S> {
299    inner: S,
300    compression: CompressionAlgorithm,
301    terminated: bool,
302}
303
304impl<S> FramedMessageStream<S> {
305    pub fn new(compression: CompressionAlgorithm, inner: S) -> Self {
306        Self {
307            inner,
308            compression,
309            terminated: false,
310        }
311    }
312}
313
314impl<S, P, E> Stream for FramedMessageStream<S>
315where
316    S: Stream<Item = Result<P, E>> + Unpin,
317    P: prost::Message,
318    E: Into<TerminalMessage>,
319{
320    type Item = std::io::Result<Bytes>;
321
322    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
323        if self.terminated {
324            return Poll::Ready(None);
325        }
326
327        match Pin::new(&mut self.inner).poll_next(cx) {
328            Poll::Ready(Some(Ok(item))) => {
329                let bytes =
330                    SessionMessage::regular(self.compression, &item).map(|msg| msg.encode());
331                Poll::Ready(Some(bytes))
332            }
333            Poll::Ready(Some(Err(e))) => {
334                self.terminated = true;
335                let bytes = SessionMessage::Terminal(e.into()).encode();
336                Poll::Ready(Some(Ok(bytes)))
337            }
338            Poll::Ready(None) => {
339                self.terminated = true;
340                Poll::Ready(None)
341            }
342            Poll::Pending => Poll::Pending,
343        }
344    }
345}
346
347pub struct FrameDecoder;
348
349impl tokio_util::codec::Decoder for FrameDecoder {
350    type Item = SessionMessage;
351    type Error = std::io::Error;
352
353    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
354        if src.len() < LENGTH_PREFIX_SIZE {
355            return Ok(None);
356        }
357
358        let length = ((src[0] as usize) << 16) | ((src[1] as usize) << 8) | (src[2] as usize);
359
360        if length > MAX_FRAME_BYTES {
361            return Err(std::io::Error::new(
362                std::io::ErrorKind::InvalidInput,
363                "frame exceeds decode limit",
364            ));
365        }
366
367        let total_size = LENGTH_PREFIX_SIZE + length;
368        if src.len() < total_size {
369            return Ok(None);
370        }
371
372        src.advance(LENGTH_PREFIX_SIZE);
373        let frame_bytes = src.split_to(length).freeze();
374        Ok(Some(SessionMessage::decode_message(frame_bytes)?))
375    }
376}
377
378#[cfg(test)]
379mod test {
380    use std::{
381        io,
382        pin::Pin,
383        task::{Context, Poll},
384    };
385
386    use bytes::BytesMut;
387    use futures::StreamExt;
388    use http::HeaderValue;
389    use proptest::{collection::vec, prelude::*};
390    use prost::Message;
391    use tokio_util::codec::Decoder;
392
393    use super::*;
394
395    #[derive(Clone, PartialEq, prost::Message)]
396    struct TestProto {
397        #[prost(bytes, tag = "1")]
398        payload: Vec<u8>,
399    }
400
401    impl TestProto {
402        fn new(payload: Vec<u8>) -> Self {
403            Self { payload }
404        }
405    }
406
407    #[derive(Debug, Clone)]
408    struct TestError {
409        status: u16,
410        body: &'static str,
411    }
412
413    impl From<TestError> for TerminalMessage {
414        fn from(val: TestError) -> Self {
415            TerminalMessage {
416                status: val.status,
417                body: val.body.to_string(),
418            }
419        }
420    }
421
422    fn decode_once(bytes: &Bytes) -> io::Result<SessionMessage> {
423        let mut decoder = FrameDecoder;
424        let mut buf = BytesMut::from(bytes.as_ref());
425        decoder
426            .decode(&mut buf)?
427            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "frame incomplete"))
428    }
429
430    fn compression_strategy() -> impl proptest::strategy::Strategy<Value = CompressionAlgorithm> {
431        prop_oneof![
432            Just(CompressionAlgorithm::None),
433            Just(CompressionAlgorithm::Gzip),
434            Just(CompressionAlgorithm::Zstd),
435        ]
436    }
437
438    fn chunk_bytes(data: &Bytes, pattern: &[usize]) -> Vec<Bytes> {
439        let mut chunks = Vec::new();
440        let mut offset = 0;
441        for &hint in pattern {
442            if offset >= data.len() {
443                break;
444            }
445            let remaining = data.len() - offset;
446            let take = (hint % remaining).saturating_add(1).min(remaining);
447            chunks.push(data.slice(offset..offset + take));
448            offset += take;
449        }
450        if offset < data.len() {
451            chunks.push(data.slice(offset..));
452        }
453        if chunks.is_empty() {
454            chunks.push(data.clone());
455        }
456        chunks
457    }
458
459    proptest! {
460        #[test]
461        fn regular_session_message_round_trips_proptest(
462            algo in compression_strategy(),
463            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4)
464        ) {
465            let proto = TestProto::new(payload.clone());
466            let msg = SessionMessage::regular(algo, &proto).unwrap();
467            let encoded = msg.encode();
468            let decoded = decode_once(&encoded).unwrap();
469
470            prop_assert!(matches!(decoded, SessionMessage::Regular(_)));
471            let SessionMessage::Regular(data) = decoded else { unreachable!() };
472
473            let expected_compression = if algo == CompressionAlgorithm::None || proto.encoded_len() < COMPRESSION_THRESHOLD_BYTES {
474                CompressionAlgorithm::None
475            } else {
476                algo
477            };
478            let actual_compression = data.compression;
479
480            let restored = data.try_into_proto::<TestProto>().unwrap();
481            prop_assert_eq!(restored.payload, payload);
482            prop_assert_eq!(actual_compression, expected_compression);
483        }
484
485        #[test]
486        fn frame_decoder_handles_chunked_frames(
487            algo in compression_strategy(),
488            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4),
489            chunk_pattern in vec(0usize..=16, 0..=16)
490        ) {
491            let proto = TestProto::new(payload);
492            let msg = SessionMessage::regular(algo, &proto).unwrap();
493            let encoded = msg.encode();
494            let expected = decode_once(&encoded).unwrap();
495
496            let chunks = chunk_bytes(&encoded, &chunk_pattern);
497            prop_assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), encoded.len());
498
499            let mut decoder = FrameDecoder;
500            let mut buf = BytesMut::new();
501            let mut decoded = None;
502
503            for (idx, chunk) in chunks.iter().enumerate() {
504                buf.extend_from_slice(chunk.as_ref());
505                let result = decoder.decode(&mut buf).expect("decode invocation failed");
506                if idx < chunks.len() - 1 {
507                    prop_assert!(result.is_none());
508                } else {
509                    let message = result.expect("final chunk should produce frame");
510                    prop_assert!(buf.is_empty());
511                    decoded = Some(message);
512                }
513            }
514
515            let decoded = decoded.expect("decoder never emitted frame");
516            prop_assert_eq!(decoded, expected);
517        }
518    }
519
520    #[test]
521    fn from_accept_encoding_prefers_zstd() {
522        let mut headers = http::HeaderMap::new();
523        headers.insert(
524            http::header::ACCEPT_ENCODING,
525            HeaderValue::from_static("gzip, zstd, br"),
526        );
527
528        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
529        assert_eq!(algo, CompressionAlgorithm::Zstd);
530    }
531
532    #[test]
533    fn from_accept_encoding_falls_back_to_gzip() {
534        let mut headers = http::HeaderMap::new();
535        headers.insert(
536            http::header::ACCEPT_ENCODING,
537            HeaderValue::from_static("gzip;q=0.8, deflate"),
538        );
539
540        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
541        assert_eq!(algo, CompressionAlgorithm::Gzip);
542    }
543
544    #[test]
545    fn from_accept_encoding_defaults_to_none() {
546        let headers = http::HeaderMap::new();
547        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
548        assert_eq!(algo, CompressionAlgorithm::None);
549    }
550
551    #[test]
552    fn regular_session_message_round_trips() {
553        let proto = TestProto::new(vec![1, 2, 3, 4]);
554        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
555        let encoded = msg.encode();
556        let decoded = decode_once(&encoded).unwrap();
557
558        match decoded {
559            SessionMessage::Regular(data) => {
560                assert_eq!(data.compression, CompressionAlgorithm::None);
561                let restored = data.try_into_proto::<TestProto>().unwrap();
562                assert_eq!(restored, proto);
563            }
564            SessionMessage::Terminal(_) => panic!("expected regular message"),
565        }
566    }
567
568    #[test]
569    fn terminal_session_message_round_trips() {
570        let terminal = TerminalMessage {
571            status: 418,
572            body: "short-circuit".to_string(),
573        };
574        let msg = SessionMessage::from(terminal.clone());
575        let encoded = msg.encode();
576        let decoded = decode_once(&encoded).unwrap();
577
578        match decoded {
579            SessionMessage::Regular(_) => panic!("expected terminal message"),
580            SessionMessage::Terminal(decoded_terminal) => {
581                assert_eq!(decoded_terminal, terminal);
582            }
583        }
584    }
585
586    #[test]
587    fn frame_decoder_waits_for_complete_frame() {
588        let proto = TestProto::new(vec![9, 9, 9]);
589        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
590        let encoded = msg.encode();
591        let mut decoder = FrameDecoder;
592
593        let split_idx = encoded.len() - 1;
594        let mut buf = BytesMut::from(&encoded[..split_idx]);
595        assert!(decoder.decode(&mut buf).unwrap().is_none());
596        buf.extend_from_slice(&encoded[split_idx..]);
597        let decoded = decoder.decode(&mut buf).unwrap().unwrap();
598
599        match decoded {
600            SessionMessage::Regular(data) => {
601                let restored = data.try_into_proto::<TestProto>().unwrap();
602                assert_eq!(restored, proto);
603            }
604            SessionMessage::Terminal(_) => panic!("expected regular message"),
605        }
606        assert!(buf.is_empty());
607    }
608
609    #[test]
610    fn frame_decoder_rejects_frames_exceeding_decode_limit() {
611        let length = MAX_FRAME_BYTES + 1;
612        let prefix = [
613            ((length >> 16) & 0xFF) as u8,
614            ((length >> 8) & 0xFF) as u8,
615            (length & 0xFF) as u8,
616        ];
617        let mut buf = BytesMut::from(prefix.as_slice());
618        let mut decoder = FrameDecoder;
619        let err = decoder.decode(&mut buf).unwrap_err();
620        assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
621    }
622
623    #[test]
624    #[should_panic(expected = "encoder limit")]
625    fn session_message_encode_rejects_frames_over_limit() {
626        let data = CompressedData {
627            compression: CompressionAlgorithm::None,
628            payload: Bytes::from(vec![0u8; MAX_FRAME_BYTES]),
629        };
630        let msg = SessionMessage::from(data);
631        let _ = msg.encode();
632    }
633
634    #[test]
635    fn frame_decoder_rejects_unknown_compression() {
636        let mut raw = vec![0, 0, 1];
637        raw.push(0x60);
638        let mut decoder = FrameDecoder;
639        let mut buf = BytesMut::from(raw.as_slice());
640        let err = decoder.decode(&mut buf).unwrap_err();
641        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
642    }
643
644    #[test]
645    fn frame_decoder_rejects_terminal_without_status() {
646        let mut raw = vec![0, 0, 1];
647        raw.push(FLAG_TERMINAL);
648        let mut decoder = FrameDecoder;
649        let mut buf = BytesMut::from(raw.as_slice());
650        let err = decoder.decode(&mut buf).unwrap_err();
651        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
652    }
653
654    #[test]
655    fn frame_decoder_handles_empty_payload() {
656        let raw = vec![0, 0, 0];
657        let mut decoder = FrameDecoder;
658        let mut buf = BytesMut::from(raw.as_slice());
659        let err = decoder.decode(&mut buf).unwrap_err();
660        assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
661    }
662
663    #[test]
664    fn compressed_data_round_trip_gzip() {
665        let payload = vec![42; 1_200_000];
666        let proto = TestProto::new(payload.clone());
667        let msg = SessionMessage::regular(CompressionAlgorithm::Gzip, &proto).unwrap();
668        let encoded = msg.encode();
669        let decoded = decode_once(&encoded).unwrap();
670
671        match decoded {
672            SessionMessage::Regular(data) => {
673                assert_eq!(data.compression, CompressionAlgorithm::Gzip);
674                assert!(data.payload.len() < proto.encode_to_vec().len());
675                let restored = data.try_into_proto::<TestProto>().unwrap();
676                assert_eq!(restored.payload, payload);
677            }
678            SessionMessage::Terminal(_) => panic!("expected regular message"),
679        }
680    }
681
682    #[test]
683    fn compressed_data_round_trip_zstd() {
684        let payload = vec![7; 1_100_000];
685        let proto = TestProto::new(payload.clone());
686        let msg = SessionMessage::regular(CompressionAlgorithm::Zstd, &proto).unwrap();
687        let encoded = msg.encode();
688        let decoded = decode_once(&encoded).unwrap();
689
690        match decoded {
691            SessionMessage::Regular(data) => {
692                assert_eq!(data.compression, CompressionAlgorithm::Zstd);
693                assert!(data.payload.len() < proto.encode_to_vec().len());
694                let restored = data.try_into_proto::<TestProto>().unwrap();
695                assert_eq!(restored.payload, payload);
696            }
697            SessionMessage::Terminal(_) => panic!("expected regular message"),
698        }
699    }
700
701    #[test]
702    fn decompression_rejects_payloads_exceeding_limit() {
703        let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
704        let proto = TestProto::new(payload);
705        let encoded = proto.encode_to_vec();
706
707        for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
708            let compressed = match algo {
709                CompressionAlgorithm::Gzip => {
710                    let mut out = Vec::new();
711                    let mut encoder = GzEncoder::new(&mut out, Compression::default());
712                    encoder.write_all(encoded.as_slice()).unwrap();
713                    encoder.finish().unwrap();
714                    out
715                }
716                CompressionAlgorithm::Zstd => {
717                    let mut out = Vec::new();
718                    zstd::stream::copy_encode(encoded.as_slice(), &mut out, 0).unwrap();
719                    out
720                }
721                CompressionAlgorithm::None => unreachable!("explicitly excluded in test"),
722            };
723
724            let data = CompressedData {
725                compression: algo,
726                payload: Bytes::from(compressed),
727            };
728            assert!(data.payload.len() <= MAX_FRAME_PAYLOAD_BYTES);
729
730            let err = data.try_into_proto::<TestProto>().expect_err("should fail");
731            assert_eq!(err.kind(), io::ErrorKind::InvalidData);
732            assert!(
733                err.to_string()
734                    .contains("decompressed payload exceeds limit")
735            );
736        }
737    }
738
739    #[test]
740    fn compress_rejects_payloads_exceeding_decompressed_limit() {
741        let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
742        let proto = TestProto::new(payload);
743
744        let err = CompressedData::compress(CompressionAlgorithm::Gzip, proto.encode_to_vec())
745            .expect_err("should fail");
746        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
747        assert!(
748            err.to_string()
749                .contains("payload exceeds decompressed limit")
750        );
751    }
752
753    #[test]
754    fn compress_allows_payload_at_exact_limit_without_encode_panic() {
755        let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES];
756        let data = CompressedData::compress(CompressionAlgorithm::None, payload).unwrap();
757        let encoded = SessionMessage::from(data).encode();
758        assert_eq!(encoded.len(), LENGTH_PREFIX_SIZE + MAX_FRAME_BYTES);
759    }
760
761    #[test]
762    fn compress_rejects_incompressible_payload_that_exceeds_frame_limit_after_compression() {
763        let mut payload = vec![0u8; MAX_DECOMPRESSED_PAYLOAD_BYTES];
764        let mut x = 0x1234_5678u32;
765        for byte in &mut payload {
766            x ^= x << 13;
767            x ^= x >> 17;
768            x ^= x << 5;
769            *byte = (x & 0xFF) as u8;
770        }
771
772        for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
773            let err = CompressedData::compress(algo, payload.clone()).expect_err("should fail");
774            assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
775            assert!(
776                err.to_string()
777                    .contains("compressed payload exceeds frame limit")
778            );
779        }
780    }
781
782    #[test]
783    fn framed_message_stream_yields_terminal_on_error() {
784        let proto = TestProto::new(vec![1, 2, 3]);
785        let items = vec![
786            Ok(proto.clone()),
787            Err(TestError {
788                status: 500,
789                body: "boom",
790            }),
791            Ok(proto.clone()),
792        ];
793
794        let stream = futures::stream::iter(items);
795        let framed = FramedMessageStream::new(CompressionAlgorithm::None, stream);
796        let outputs = futures::executor::block_on(async {
797            framed.collect::<Vec<std::io::Result<Bytes>>>().await
798        });
799
800        assert_eq!(outputs.len(), 2);
801
802        let first = outputs[0].as_ref().expect("first frame ok");
803        match decode_once(first).unwrap() {
804            SessionMessage::Regular(data) => {
805                let restored = data.try_into_proto::<TestProto>().unwrap();
806                assert_eq!(restored, proto);
807            }
808            SessionMessage::Terminal(_) => panic!("expected regular message"),
809        }
810
811        let second = outputs[1].as_ref().expect("second frame ok");
812        match decode_once(second).unwrap() {
813            SessionMessage::Regular(_) => panic!("expected terminal message"),
814            SessionMessage::Terminal(term) => {
815                assert_eq!(term.status, 500);
816                assert_eq!(term.body, "boom");
817            }
818        }
819    }
820
821    #[test]
822    fn framed_message_stream_stops_after_termination() {
823        let mut stream = FramedMessageStream::new(
824            CompressionAlgorithm::None,
825            futures::stream::iter(vec![
826                Ok(TestProto::new(vec![0])),
827                Err(TestError {
828                    status: 400,
829                    body: "bad",
830                }),
831            ]),
832        );
833
834        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
835
836        match Pin::new(&mut stream).poll_next(&mut cx) {
837            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
838                SessionMessage::Regular(_) => {}
839                SessionMessage::Terminal(_) => panic!("expected regular message"),
840            },
841            other => panic!("unexpected poll result: {other:?}"),
842        }
843
844        match Pin::new(&mut stream).poll_next(&mut cx) {
845            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
846                SessionMessage::Terminal(term) => {
847                    assert_eq!(term.status, 400);
848                    assert_eq!(term.body, "bad");
849                }
850                SessionMessage::Regular(_) => panic!("expected terminal message"),
851            },
852            other => panic!("unexpected poll result: {other:?}"),
853        }
854
855        match Pin::new(&mut stream).poll_next(&mut cx) {
856            Poll::Ready(None) => {}
857            other => panic!("expected stream to terminate, got {other:?}"),
858        }
859    }
860}