s2_api/v1/stream/
s2s.rs

1use std::{
2    io::{Read, Write},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use enum_ordinalize::Ordinalize;
9use flate2::{Compression, read::GzDecoder, write::GzEncoder};
10use futures::Stream;
11use zstd::zstd_safe::WriteBuf;
12
13/*
14  REGULAR MESSAGE:
15  ┌─────────────┬────────────┬─────────────────────────────┐
16  │   LENGTH    │   FLAGS    │        PAYLOAD DATA         │
17  │  (3 bytes)  │  (1 byte)  │     (variable length)       │
18  ├─────────────┼────────────┼─────────────────────────────┤
19  │ 0x00 00 XX  │ 0 CA XXXXX │  Compressed proto message   │
20  └─────────────┴────────────┴─────────────────────────────┘
21
22  TERMINAL MESSAGE:
23  ┌─────────────┬────────────┬─────────────┬───────────────┐
24  │   LENGTH    │   FLAGS    │ STATUS CODE │   JSON BODY   │
25  │  (3 bytes)  │  (1 byte)  │  (2 bytes)  │  (variable)   │
26  ├─────────────┼────────────┼─────────────┼───────────────┤
27  │ 0x00 00 XX  │ 1 CA XXXXX │   HTTP Code │   JSON data   │
28  └─────────────┴────────────┴─────────────┴───────────────┘
29
30  LENGTH = size of (FLAGS + PAYLOAD), does NOT include length header itself
31  Implemented limit: 2 MiB (smaller than 24-bit protocol maximum)
32*/
33
34const LENGTH_PREFIX_SIZE: usize = 3;
35const STATUS_CODE_SIZE: usize = 2;
36const COMPRESSION_THRESHOLD_BYTES: usize = 1024; // 1 KiB
37const MAX_FRAME_BYTES: usize = 2 * 1024 * 1024; // 2 MiB
38
39/*
40Flag byte layout:
41  ┌───┬───┬───┬───┬───┬───┬───┬───┐
42  │ 7 │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │  Bit positions
43  ├───┼───┴───┼───┴───┴───┴───┴───┤
44  │ T │  C C  │   Reserved (0s)   │  Purpose
45  └───┴───────┴───────────────────┘
46
47  T = Terminal flag (1 bit)
48  C = Compression (2 bits, encodes 0-3)
49*/
50
51const FLAG_TOTAL_SIZE: usize = 1;
52const FLAG_TERMINAL: u8 = 0b1000_0000;
53const FLAG_COMPRESSION_MASK: u8 = 0b0110_0000;
54const FLAG_COMPRESSION_SHIFT: u8 = 5;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Ordinalize)]
57#[repr(u8)]
58pub enum CompressionAlgorithm {
59    None = 0,
60    Zstd = 1,
61    Gzip = 2,
62}
63
64impl CompressionAlgorithm {
65    pub fn from_accept_encoding(headers: &http::HeaderMap) -> Self {
66        let mut gzip = false;
67        for header_value in headers.get_all(http::header::ACCEPT_ENCODING) {
68            if let Ok(value) = header_value.to_str() {
69                for encoding in value.split(',') {
70                    let encoding = encoding.trim().split(';').next().unwrap_or("").trim();
71                    if encoding.eq_ignore_ascii_case("zstd") {
72                        return Self::Zstd;
73                    } else if encoding.eq_ignore_ascii_case("gzip") {
74                        gzip = true;
75                    }
76                }
77            }
78        }
79        if gzip { Self::Gzip } else { Self::None }
80    }
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub struct CompressedData {
85    compression: CompressionAlgorithm,
86    payload: Bytes,
87}
88
89impl CompressedData {
90    pub fn for_proto(
91        compression: CompressionAlgorithm,
92        proto: &impl prost::Message,
93    ) -> std::io::Result<Self> {
94        Self::compress(compression, proto.encode_to_vec())
95    }
96
97    fn compress(compression: CompressionAlgorithm, data: Vec<u8>) -> std::io::Result<Self> {
98        if compression == CompressionAlgorithm::None || data.len() < COMPRESSION_THRESHOLD_BYTES {
99            return Ok(Self {
100                compression: CompressionAlgorithm::None,
101                payload: data.into(),
102            });
103        }
104        let mut buf = Vec::with_capacity(data.len());
105        match compression {
106            CompressionAlgorithm::Gzip => {
107                let mut encoder = GzEncoder::new(buf, Compression::default());
108                encoder.write_all(data.as_slice())?;
109                buf = encoder.finish()?;
110            }
111            CompressionAlgorithm::Zstd => {
112                zstd::stream::copy_encode(data.as_slice(), &mut buf, 0)?;
113            }
114            CompressionAlgorithm::None => unreachable!("handled above"),
115        };
116        let payload = Bytes::from(buf.into_boxed_slice());
117        Ok(Self {
118            compression,
119            payload,
120        })
121    }
122
123    fn decompressed(self) -> std::io::Result<Bytes> {
124        match self.compression {
125            CompressionAlgorithm::None => Ok(self.payload),
126            CompressionAlgorithm::Gzip => {
127                let mut decoder = GzDecoder::new(self.payload.as_slice());
128                let mut buf = Vec::with_capacity(self.payload.len() * 2);
129                decoder.read_to_end(&mut buf)?;
130                Ok(Bytes::from(buf.into_boxed_slice()))
131            }
132            CompressionAlgorithm::Zstd => {
133                let mut buf = Vec::with_capacity(self.payload.len() * 2);
134                zstd::stream::copy_decode(self.payload.as_slice(), &mut buf)?;
135                Ok(Bytes::from(buf.into_boxed_slice()))
136            }
137        }
138    }
139
140    pub fn try_into_proto<P: prost::Message + Default>(self) -> std::io::Result<P> {
141        let payload = self.decompressed()?;
142        P::decode(payload.as_ref())
143            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
144    }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
148pub struct TerminalMessage {
149    pub status: u16,
150    pub body: String,
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum SessionMessage {
155    Regular(CompressedData),
156    Terminal(TerminalMessage),
157}
158
159impl From<CompressedData> for SessionMessage {
160    fn from(data: CompressedData) -> Self {
161        Self::Regular(data)
162    }
163}
164
165impl From<TerminalMessage> for SessionMessage {
166    fn from(msg: TerminalMessage) -> Self {
167        Self::Terminal(msg)
168    }
169}
170
171impl SessionMessage {
172    pub fn regular(
173        compression: CompressionAlgorithm,
174        proto: &impl prost::Message,
175    ) -> std::io::Result<Self> {
176        Ok(Self::Regular(CompressedData::for_proto(
177            compression,
178            proto,
179        )?))
180    }
181
182    pub fn encode(&self) -> Bytes {
183        let encoded_size = FLAG_TOTAL_SIZE + self.payload_size();
184        assert!(
185            encoded_size <= MAX_FRAME_BYTES,
186            "payload exceeds encoder limit"
187        );
188        let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + encoded_size);
189        buf.put_uint(encoded_size as u64, 3);
190        match self {
191            Self::Regular(msg) => {
192                let flag =
193                    (msg.compression.ordinal() << FLAG_COMPRESSION_SHIFT) & FLAG_COMPRESSION_MASK;
194                buf.put_u8(flag);
195                buf.extend_from_slice(&msg.payload);
196            }
197            Self::Terminal(msg) => {
198                buf.put_u8(FLAG_TERMINAL);
199                buf.put_u16(msg.status);
200                buf.extend_from_slice(msg.body.as_bytes());
201            }
202        }
203        buf.freeze()
204    }
205
206    fn decode_message(mut buf: Bytes) -> std::io::Result<Self> {
207        if buf.is_empty() {
208            return Err(std::io::Error::new(
209                std::io::ErrorKind::UnexpectedEof,
210                "empty frame payload",
211            ));
212        }
213        let flag = buf.get_u8();
214
215        let is_terminal = (flag & FLAG_TERMINAL) != 0;
216        if is_terminal {
217            if buf.len() < STATUS_CODE_SIZE {
218                return Err(std::io::Error::new(
219                    std::io::ErrorKind::InvalidData,
220                    "terminal message missing status code",
221                ));
222            }
223            let status = buf.get_u16();
224            let body = String::from_utf8(buf.into()).map_err(|_| {
225                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid utf-8")
226            })?;
227            return Ok(TerminalMessage { status, body }.into());
228        }
229
230        let compression_bits = (flag & FLAG_COMPRESSION_MASK) >> FLAG_COMPRESSION_SHIFT;
231        let Some(compression) = CompressionAlgorithm::from_ordinal(compression_bits) else {
232            return Err(std::io::Error::new(
233                std::io::ErrorKind::InvalidData,
234                "unknown compression algorithm",
235            ));
236        };
237
238        Ok(CompressedData {
239            compression,
240            payload: buf,
241        }
242        .into())
243    }
244
245    fn payload_size(&self) -> usize {
246        match self {
247            Self::Regular(msg) => msg.payload.len(),
248            Self::Terminal(msg) => STATUS_CODE_SIZE + msg.body.len(),
249        }
250    }
251}
252
253pub struct FramedMessageStream<S> {
254    inner: S,
255    compression: CompressionAlgorithm,
256    terminated: bool,
257}
258
259impl<S> FramedMessageStream<S> {
260    pub fn new(compression: CompressionAlgorithm, inner: S) -> Self {
261        Self {
262            inner,
263            compression,
264            terminated: false,
265        }
266    }
267}
268
269impl<S, P, E> Stream for FramedMessageStream<S>
270where
271    S: Stream<Item = Result<P, E>> + Unpin,
272    P: prost::Message,
273    E: Into<TerminalMessage>,
274{
275    type Item = std::io::Result<Bytes>;
276
277    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
278        if self.terminated {
279            return Poll::Ready(None);
280        }
281
282        match Pin::new(&mut self.inner).poll_next(cx) {
283            Poll::Ready(Some(Ok(item))) => {
284                let bytes =
285                    SessionMessage::regular(self.compression, &item).map(|msg| msg.encode());
286                Poll::Ready(Some(bytes))
287            }
288            Poll::Ready(Some(Err(e))) => {
289                self.terminated = true;
290                let bytes = SessionMessage::Terminal(e.into()).encode();
291                Poll::Ready(Some(Ok(bytes)))
292            }
293            Poll::Ready(None) => {
294                self.terminated = true;
295                Poll::Ready(None)
296            }
297            Poll::Pending => Poll::Pending,
298        }
299    }
300}
301
302pub struct FrameDecoder;
303
304impl tokio_util::codec::Decoder for FrameDecoder {
305    type Item = SessionMessage;
306    type Error = std::io::Error;
307
308    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
309        if src.len() < LENGTH_PREFIX_SIZE {
310            return Ok(None);
311        }
312
313        let length = ((src[0] as usize) << 16) | ((src[1] as usize) << 8) | (src[2] as usize);
314
315        if length > MAX_FRAME_BYTES {
316            return Err(std::io::Error::new(
317                std::io::ErrorKind::InvalidInput,
318                "frame exceeds decode limit",
319            ));
320        }
321
322        let total_size = LENGTH_PREFIX_SIZE + length;
323        if src.len() < total_size {
324            return Ok(None);
325        }
326
327        src.advance(LENGTH_PREFIX_SIZE);
328        let frame_bytes = src.split_to(length).freeze();
329        Ok(Some(SessionMessage::decode_message(frame_bytes)?))
330    }
331}
332
333#[cfg(test)]
334mod test {
335    use std::{
336        io,
337        pin::Pin,
338        task::{Context, Poll},
339    };
340
341    use bytes::BytesMut;
342    use futures::StreamExt;
343    use http::HeaderValue;
344    use proptest::{collection::vec, prelude::*};
345    use prost::Message;
346    use tokio_util::codec::Decoder;
347
348    use super::*;
349
350    #[derive(Clone, PartialEq, prost::Message)]
351    struct TestProto {
352        #[prost(bytes, tag = "1")]
353        payload: Vec<u8>,
354    }
355
356    impl TestProto {
357        fn new(payload: Vec<u8>) -> Self {
358            Self { payload }
359        }
360    }
361
362    #[derive(Debug, Clone)]
363    struct TestError {
364        status: u16,
365        body: &'static str,
366    }
367
368    impl From<TestError> for TerminalMessage {
369        fn from(val: TestError) -> Self {
370            TerminalMessage {
371                status: val.status,
372                body: val.body.to_string(),
373            }
374        }
375    }
376
377    fn decode_once(bytes: &Bytes) -> io::Result<SessionMessage> {
378        let mut decoder = FrameDecoder;
379        let mut buf = BytesMut::from(bytes.as_ref());
380        decoder
381            .decode(&mut buf)?
382            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "frame incomplete"))
383    }
384
385    fn compression_strategy() -> impl proptest::strategy::Strategy<Value = CompressionAlgorithm> {
386        prop_oneof![
387            Just(CompressionAlgorithm::None),
388            Just(CompressionAlgorithm::Gzip),
389            Just(CompressionAlgorithm::Zstd),
390        ]
391    }
392
393    fn chunk_bytes(data: &Bytes, pattern: &[usize]) -> Vec<Bytes> {
394        let mut chunks = Vec::new();
395        let mut offset = 0;
396        for &hint in pattern {
397            if offset >= data.len() {
398                break;
399            }
400            let remaining = data.len() - offset;
401            let take = (hint % remaining).saturating_add(1).min(remaining);
402            chunks.push(data.slice(offset..offset + take));
403            offset += take;
404        }
405        if offset < data.len() {
406            chunks.push(data.slice(offset..));
407        }
408        if chunks.is_empty() {
409            chunks.push(data.clone());
410        }
411        chunks
412    }
413
414    proptest! {
415        #[test]
416        fn regular_session_message_round_trips_proptest(
417            algo in compression_strategy(),
418            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4)
419        ) {
420            let proto = TestProto::new(payload.clone());
421            let msg = SessionMessage::regular(algo, &proto).unwrap();
422            let encoded = msg.encode();
423            let decoded = decode_once(&encoded).unwrap();
424
425            prop_assert!(matches!(decoded, SessionMessage::Regular(_)));
426            let SessionMessage::Regular(data) = decoded else { unreachable!() };
427
428            let expected_compression = if algo == CompressionAlgorithm::None || proto.encoded_len() < COMPRESSION_THRESHOLD_BYTES {
429                CompressionAlgorithm::None
430            } else {
431                algo
432            };
433            let actual_compression = data.compression;
434
435            let restored = data.try_into_proto::<TestProto>().unwrap();
436            prop_assert_eq!(restored.payload, payload);
437            prop_assert_eq!(actual_compression, expected_compression);
438        }
439
440        #[test]
441        fn frame_decoder_handles_chunked_frames(
442            algo in compression_strategy(),
443            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4),
444            chunk_pattern in vec(0usize..=16, 0..=16)
445        ) {
446            let proto = TestProto::new(payload);
447            let msg = SessionMessage::regular(algo, &proto).unwrap();
448            let encoded = msg.encode();
449            let expected = decode_once(&encoded).unwrap();
450
451            let chunks = chunk_bytes(&encoded, &chunk_pattern);
452            prop_assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), encoded.len());
453
454            let mut decoder = FrameDecoder;
455            let mut buf = BytesMut::new();
456            let mut decoded = None;
457
458            for (idx, chunk) in chunks.iter().enumerate() {
459                buf.extend_from_slice(chunk.as_ref());
460                let result = decoder.decode(&mut buf).expect("decode invocation failed");
461                if idx < chunks.len() - 1 {
462                    prop_assert!(result.is_none());
463                } else {
464                    let message = result.expect("final chunk should produce frame");
465                    prop_assert!(buf.is_empty());
466                    decoded = Some(message);
467                }
468            }
469
470            let decoded = decoded.expect("decoder never emitted frame");
471            prop_assert_eq!(decoded, expected);
472        }
473    }
474
475    #[test]
476    fn from_accept_encoding_prefers_zstd() {
477        let mut headers = http::HeaderMap::new();
478        headers.insert(
479            http::header::ACCEPT_ENCODING,
480            HeaderValue::from_static("gzip, zstd, br"),
481        );
482
483        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
484        assert_eq!(algo, CompressionAlgorithm::Zstd);
485    }
486
487    #[test]
488    fn from_accept_encoding_falls_back_to_gzip() {
489        let mut headers = http::HeaderMap::new();
490        headers.insert(
491            http::header::ACCEPT_ENCODING,
492            HeaderValue::from_static("gzip;q=0.8, deflate"),
493        );
494
495        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
496        assert_eq!(algo, CompressionAlgorithm::Gzip);
497    }
498
499    #[test]
500    fn from_accept_encoding_defaults_to_none() {
501        let headers = http::HeaderMap::new();
502        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
503        assert_eq!(algo, CompressionAlgorithm::None);
504    }
505
506    #[test]
507    fn regular_session_message_round_trips() {
508        let proto = TestProto::new(vec![1, 2, 3, 4]);
509        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
510        let encoded = msg.encode();
511        let decoded = decode_once(&encoded).unwrap();
512
513        match decoded {
514            SessionMessage::Regular(data) => {
515                assert_eq!(data.compression, CompressionAlgorithm::None);
516                let restored = data.try_into_proto::<TestProto>().unwrap();
517                assert_eq!(restored, proto);
518            }
519            SessionMessage::Terminal(_) => panic!("expected regular message"),
520        }
521    }
522
523    #[test]
524    fn terminal_session_message_round_trips() {
525        let terminal = TerminalMessage {
526            status: 418,
527            body: "short-circuit".to_string(),
528        };
529        let msg = SessionMessage::from(terminal.clone());
530        let encoded = msg.encode();
531        let decoded = decode_once(&encoded).unwrap();
532
533        match decoded {
534            SessionMessage::Regular(_) => panic!("expected terminal message"),
535            SessionMessage::Terminal(decoded_terminal) => {
536                assert_eq!(decoded_terminal, terminal);
537            }
538        }
539    }
540
541    #[test]
542    fn frame_decoder_waits_for_complete_frame() {
543        let proto = TestProto::new(vec![9, 9, 9]);
544        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
545        let encoded = msg.encode();
546        let mut decoder = FrameDecoder;
547
548        let split_idx = encoded.len() - 1;
549        let mut buf = BytesMut::from(&encoded[..split_idx]);
550        assert!(decoder.decode(&mut buf).unwrap().is_none());
551        buf.extend_from_slice(&encoded[split_idx..]);
552        let decoded = decoder.decode(&mut buf).unwrap().unwrap();
553
554        match decoded {
555            SessionMessage::Regular(data) => {
556                let restored = data.try_into_proto::<TestProto>().unwrap();
557                assert_eq!(restored, proto);
558            }
559            SessionMessage::Terminal(_) => panic!("expected regular message"),
560        }
561        assert!(buf.is_empty());
562    }
563
564    #[test]
565    fn frame_decoder_rejects_frames_exceeding_decode_limit() {
566        let length = MAX_FRAME_BYTES + 1;
567        let prefix = [
568            ((length >> 16) & 0xFF) as u8,
569            ((length >> 8) & 0xFF) as u8,
570            (length & 0xFF) as u8,
571        ];
572        let mut buf = BytesMut::from(prefix.as_slice());
573        let mut decoder = FrameDecoder;
574        let err = decoder.decode(&mut buf).unwrap_err();
575        assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
576    }
577
578    #[test]
579    #[should_panic(expected = "encoder limit")]
580    fn session_message_encode_rejects_frames_over_limit() {
581        let data = CompressedData {
582            compression: CompressionAlgorithm::None,
583            payload: Bytes::from(vec![0u8; MAX_FRAME_BYTES]),
584        };
585        let msg = SessionMessage::from(data);
586        let _ = msg.encode();
587    }
588
589    #[test]
590    fn frame_decoder_rejects_unknown_compression() {
591        let mut raw = vec![0, 0, 1];
592        raw.push(0x60);
593        let mut decoder = FrameDecoder;
594        let mut buf = BytesMut::from(raw.as_slice());
595        let err = decoder.decode(&mut buf).unwrap_err();
596        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
597    }
598
599    #[test]
600    fn frame_decoder_rejects_terminal_without_status() {
601        let mut raw = vec![0, 0, 1];
602        raw.push(FLAG_TERMINAL);
603        let mut decoder = FrameDecoder;
604        let mut buf = BytesMut::from(raw.as_slice());
605        let err = decoder.decode(&mut buf).unwrap_err();
606        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
607    }
608
609    #[test]
610    fn frame_decoder_handles_empty_payload() {
611        let raw = vec![0, 0, 0];
612        let mut decoder = FrameDecoder;
613        let mut buf = BytesMut::from(raw.as_slice());
614        let err = decoder.decode(&mut buf).unwrap_err();
615        assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
616    }
617
618    #[test]
619    fn compressed_data_round_trip_gzip() {
620        let payload = vec![42; 1_200_000];
621        let proto = TestProto::new(payload.clone());
622        let msg = SessionMessage::regular(CompressionAlgorithm::Gzip, &proto).unwrap();
623        let encoded = msg.encode();
624        let decoded = decode_once(&encoded).unwrap();
625
626        match decoded {
627            SessionMessage::Regular(data) => {
628                assert_eq!(data.compression, CompressionAlgorithm::Gzip);
629                assert!(data.payload.len() < proto.encode_to_vec().len());
630                let restored = data.try_into_proto::<TestProto>().unwrap();
631                assert_eq!(restored.payload, payload);
632            }
633            SessionMessage::Terminal(_) => panic!("expected regular message"),
634        }
635    }
636
637    #[test]
638    fn compressed_data_round_trip_zstd() {
639        let payload = vec![7; 1_100_000];
640        let proto = TestProto::new(payload.clone());
641        let msg = SessionMessage::regular(CompressionAlgorithm::Zstd, &proto).unwrap();
642        let encoded = msg.encode();
643        let decoded = decode_once(&encoded).unwrap();
644
645        match decoded {
646            SessionMessage::Regular(data) => {
647                assert_eq!(data.compression, CompressionAlgorithm::Zstd);
648                assert!(data.payload.len() < proto.encode_to_vec().len());
649                let restored = data.try_into_proto::<TestProto>().unwrap();
650                assert_eq!(restored.payload, payload);
651            }
652            SessionMessage::Terminal(_) => panic!("expected regular message"),
653        }
654    }
655
656    #[test]
657    fn framed_message_stream_yields_terminal_on_error() {
658        let proto = TestProto::new(vec![1, 2, 3]);
659        let items = vec![
660            Ok(proto.clone()),
661            Err(TestError {
662                status: 500,
663                body: "boom",
664            }),
665            Ok(proto.clone()),
666        ];
667
668        let stream = futures::stream::iter(items);
669        let framed = FramedMessageStream::new(CompressionAlgorithm::None, stream);
670        let outputs = futures::executor::block_on(async {
671            framed.collect::<Vec<std::io::Result<Bytes>>>().await
672        });
673
674        assert_eq!(outputs.len(), 2);
675
676        let first = outputs[0].as_ref().expect("first frame ok");
677        match decode_once(first).unwrap() {
678            SessionMessage::Regular(data) => {
679                let restored = data.try_into_proto::<TestProto>().unwrap();
680                assert_eq!(restored, proto);
681            }
682            SessionMessage::Terminal(_) => panic!("expected regular message"),
683        }
684
685        let second = outputs[1].as_ref().expect("second frame ok");
686        match decode_once(second).unwrap() {
687            SessionMessage::Regular(_) => panic!("expected terminal message"),
688            SessionMessage::Terminal(term) => {
689                assert_eq!(term.status, 500);
690                assert_eq!(term.body, "boom");
691            }
692        }
693    }
694
695    #[test]
696    fn framed_message_stream_stops_after_termination() {
697        let mut stream = FramedMessageStream::new(
698            CompressionAlgorithm::None,
699            futures::stream::iter(vec![
700                Ok(TestProto::new(vec![0])),
701                Err(TestError {
702                    status: 400,
703                    body: "bad",
704                }),
705            ]),
706        );
707
708        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
709
710        match Pin::new(&mut stream).poll_next(&mut cx) {
711            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
712                SessionMessage::Regular(_) => {}
713                SessionMessage::Terminal(_) => panic!("expected regular message"),
714            },
715            other => panic!("unexpected poll result: {other:?}"),
716        }
717
718        match Pin::new(&mut stream).poll_next(&mut cx) {
719            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
720                SessionMessage::Terminal(term) => {
721                    assert_eq!(term.status, 400);
722                    assert_eq!(term.body, "bad");
723                }
724                SessionMessage::Regular(_) => panic!("expected terminal message"),
725            },
726            other => panic!("unexpected poll result: {other:?}"),
727        }
728
729        match Pin::new(&mut stream).poll_next(&mut cx) {
730            Poll::Ready(None) => {}
731            other => panic!("expected stream to terminate, got {other:?}"),
732        }
733    }
734}