s2_api/v1/stream/
s2s.rs

1use std::{
2    io::{Read, Write},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use enum_ordinalize::Ordinalize;
9use flate2::{Compression, read::GzDecoder, write::GzEncoder};
10use futures::Stream;
11
12/*
13  REGULAR MESSAGE:
14  ┌─────────────┬────────────┬─────────────────────────────┐
15  │   LENGTH    │   FLAGS    │        PAYLOAD DATA         │
16  │  (3 bytes)  │  (1 byte)  │     (variable length)       │
17  ├─────────────┼────────────┼─────────────────────────────┤
18  │ 0x00 00 XX  │ 0 CA XXXXX │  Compressed proto message   │
19  └─────────────┴────────────┴─────────────────────────────┘
20
21  TERMINAL MESSAGE:
22  ┌─────────────┬────────────┬─────────────┬───────────────┐
23  │   LENGTH    │   FLAGS    │ STATUS CODE │   JSON BODY   │
24  │  (3 bytes)  │  (1 byte)  │  (2 bytes)  │  (variable)   │
25  ├─────────────┼────────────┼─────────────┼───────────────┤
26  │ 0x00 00 XX  │ 1 CA XXXXX │   HTTP Code │   JSON data   │
27  └─────────────┴────────────┴─────────────┴───────────────┘
28
29  LENGTH = size of (FLAGS + PAYLOAD), does NOT include length header itself
30  Implemented limit: 2 MiB (smaller than 24-bit protocol maximum)
31*/
32
33const LENGTH_PREFIX_SIZE: usize = 3;
34const STATUS_CODE_SIZE: usize = 2;
35const COMPRESSION_THRESHOLD_BYTES: usize = 1024; // 1 KiB
36const MAX_FRAME_BYTES: usize = 2 * 1024 * 1024; // 2 MiB
37
38/*
39Flag byte layout:
40  ┌───┬───┬───┬───┬───┬───┬───┬───┐
41  │ 7 │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │  Bit positions
42  ├───┼───┴───┼───┴───┴───┴───┴───┤
43  │ T │  C C  │   Reserved (0s)   │  Purpose
44  └───┴───────┴───────────────────┘
45
46  T = Terminal flag (1 bit)
47  C = Compression (2 bits, encodes 0-3)
48*/
49
50const FLAG_TOTAL_SIZE: usize = 1;
51const FLAG_TERMINAL: u8 = 0b1000_0000;
52const FLAG_COMPRESSION_MASK: u8 = 0b0110_0000;
53const FLAG_COMPRESSION_SHIFT: u8 = 5;
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Ordinalize)]
56#[repr(u8)]
57pub enum CompressionAlgorithm {
58    None = 0,
59    Zstd = 1,
60    Gzip = 2,
61}
62
63impl CompressionAlgorithm {
64    pub fn from_accept_encoding(headers: &http::HeaderMap) -> Self {
65        let mut gzip = false;
66        for header_value in headers.get_all(http::header::ACCEPT_ENCODING) {
67            if let Ok(value) = header_value.to_str() {
68                for encoding in value.split(',') {
69                    let encoding = encoding.trim().split(';').next().unwrap_or("").trim();
70                    if encoding.eq_ignore_ascii_case("zstd") {
71                        return Self::Zstd;
72                    } else if encoding.eq_ignore_ascii_case("gzip") {
73                        gzip = true;
74                    }
75                }
76            }
77        }
78        if gzip { Self::Gzip } else { Self::None }
79    }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct CompressedData {
84    compression: CompressionAlgorithm,
85    payload: Bytes,
86}
87
88impl CompressedData {
89    pub fn for_proto(
90        compression: CompressionAlgorithm,
91        proto: &impl prost::Message,
92    ) -> std::io::Result<Self> {
93        Self::compress(compression, proto.encode_to_vec())
94    }
95
96    fn compress(compression: CompressionAlgorithm, data: Vec<u8>) -> std::io::Result<Self> {
97        if compression == CompressionAlgorithm::None || data.len() < COMPRESSION_THRESHOLD_BYTES {
98            return Ok(Self {
99                compression: CompressionAlgorithm::None,
100                payload: data.into(),
101            });
102        }
103        let mut buf = Vec::with_capacity(data.len());
104        match compression {
105            CompressionAlgorithm::Gzip => {
106                let mut encoder = GzEncoder::new(buf, Compression::default());
107                encoder.write_all(data.as_slice())?;
108                buf = encoder.finish()?;
109            }
110            CompressionAlgorithm::Zstd => {
111                zstd::stream::copy_encode(data.as_slice(), &mut buf, 0)?;
112            }
113            CompressionAlgorithm::None => unreachable!("handled above"),
114        };
115        let payload = Bytes::from(buf.into_boxed_slice());
116        Ok(Self {
117            compression,
118            payload,
119        })
120    }
121
122    fn decompressed(self) -> std::io::Result<Bytes> {
123        match self.compression {
124            CompressionAlgorithm::None => Ok(self.payload),
125            CompressionAlgorithm::Gzip => {
126                let mut decoder = GzDecoder::new(&self.payload[..]);
127                let mut buf = Vec::with_capacity(self.payload.len() * 2);
128                decoder.read_to_end(&mut buf)?;
129                Ok(Bytes::from(buf.into_boxed_slice()))
130            }
131            CompressionAlgorithm::Zstd => {
132                let mut buf = Vec::with_capacity(self.payload.len() * 2);
133                zstd::stream::copy_decode(&self.payload[..], &mut buf)?;
134                Ok(Bytes::from(buf.into_boxed_slice()))
135            }
136        }
137    }
138
139    pub fn try_into_proto<P: prost::Message + Default>(self) -> std::io::Result<P> {
140        let payload = self.decompressed()?;
141        P::decode(payload.as_ref())
142            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
143    }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub struct TerminalMessage {
148    pub status: u16,
149    pub body: String,
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
153pub enum SessionMessage {
154    Regular(CompressedData),
155    Terminal(TerminalMessage),
156}
157
158impl From<CompressedData> for SessionMessage {
159    fn from(data: CompressedData) -> Self {
160        Self::Regular(data)
161    }
162}
163
164impl From<TerminalMessage> for SessionMessage {
165    fn from(msg: TerminalMessage) -> Self {
166        Self::Terminal(msg)
167    }
168}
169
170impl SessionMessage {
171    pub fn regular(
172        compression: CompressionAlgorithm,
173        proto: &impl prost::Message,
174    ) -> std::io::Result<Self> {
175        Ok(Self::Regular(CompressedData::for_proto(
176            compression,
177            proto,
178        )?))
179    }
180
181    pub fn encode(&self) -> Bytes {
182        let encoded_size = FLAG_TOTAL_SIZE + self.payload_size();
183        assert!(
184            encoded_size <= MAX_FRAME_BYTES,
185            "payload exceeds encoder limit"
186        );
187        let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + encoded_size);
188        buf.put_uint(encoded_size as u64, 3);
189        match self {
190            Self::Regular(msg) => {
191                let flag =
192                    (msg.compression.ordinal() << FLAG_COMPRESSION_SHIFT) & FLAG_COMPRESSION_MASK;
193                buf.put_u8(flag);
194                buf.extend_from_slice(&msg.payload);
195            }
196            Self::Terminal(msg) => {
197                buf.put_u8(FLAG_TERMINAL);
198                buf.put_u16(msg.status);
199                buf.extend_from_slice(msg.body.as_bytes());
200            }
201        }
202        buf.freeze()
203    }
204
205    fn decode_message(mut buf: Bytes) -> std::io::Result<Self> {
206        if buf.is_empty() {
207            return Err(std::io::Error::new(
208                std::io::ErrorKind::UnexpectedEof,
209                "empty frame payload",
210            ));
211        }
212        let flag = buf.get_u8();
213
214        let is_terminal = (flag & FLAG_TERMINAL) != 0;
215        if is_terminal {
216            if buf.len() < STATUS_CODE_SIZE {
217                return Err(std::io::Error::new(
218                    std::io::ErrorKind::InvalidData,
219                    "terminal message missing status code",
220                ));
221            }
222            let status = buf.get_u16();
223            let body = String::from_utf8(buf.into()).map_err(|_| {
224                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid utf-8")
225            })?;
226            return Ok(TerminalMessage { status, body }.into());
227        }
228
229        let compression_bits = (flag & FLAG_COMPRESSION_MASK) >> FLAG_COMPRESSION_SHIFT;
230        let Some(compression) = CompressionAlgorithm::from_ordinal(compression_bits) else {
231            return Err(std::io::Error::new(
232                std::io::ErrorKind::InvalidData,
233                "unknown compression algorithm",
234            ));
235        };
236
237        Ok(CompressedData {
238            compression,
239            payload: buf,
240        }
241        .into())
242    }
243
244    fn payload_size(&self) -> usize {
245        match self {
246            Self::Regular(msg) => msg.payload.len(),
247            Self::Terminal(msg) => STATUS_CODE_SIZE + msg.body.len(),
248        }
249    }
250}
251
252pub struct FramedMessageStream<S> {
253    inner: S,
254    compression: CompressionAlgorithm,
255    terminated: bool,
256}
257
258impl<S> FramedMessageStream<S> {
259    pub fn new(compression: CompressionAlgorithm, inner: S) -> Self {
260        Self {
261            inner,
262            compression,
263            terminated: false,
264        }
265    }
266}
267
268impl<S, P, E> Stream for FramedMessageStream<S>
269where
270    S: Stream<Item = Result<P, E>> + Unpin,
271    P: prost::Message,
272    E: Into<TerminalMessage>,
273{
274    type Item = std::io::Result<Bytes>;
275
276    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277        if self.terminated {
278            return Poll::Ready(None);
279        }
280
281        match Pin::new(&mut self.inner).poll_next(cx) {
282            Poll::Ready(Some(Ok(item))) => {
283                let bytes =
284                    SessionMessage::regular(self.compression, &item).map(|msg| msg.encode());
285                Poll::Ready(Some(bytes))
286            }
287            Poll::Ready(Some(Err(e))) => {
288                self.terminated = true;
289                let bytes = SessionMessage::Terminal(e.into()).encode();
290                Poll::Ready(Some(Ok(bytes)))
291            }
292            Poll::Ready(None) => {
293                self.terminated = true;
294                Poll::Ready(None)
295            }
296            Poll::Pending => Poll::Pending,
297        }
298    }
299}
300
301pub struct FrameDecoder;
302
303impl tokio_util::codec::Decoder for FrameDecoder {
304    type Item = SessionMessage;
305    type Error = std::io::Error;
306
307    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
308        if src.len() < LENGTH_PREFIX_SIZE {
309            return Ok(None);
310        }
311
312        let length = ((src[0] as usize) << 16) | ((src[1] as usize) << 8) | (src[2] as usize);
313
314        if length > MAX_FRAME_BYTES {
315            return Err(std::io::Error::new(
316                std::io::ErrorKind::InvalidInput,
317                "frame exceeds decode limit",
318            ));
319        }
320
321        let total_size = LENGTH_PREFIX_SIZE + length;
322        if src.len() < total_size {
323            return Ok(None);
324        }
325
326        src.advance(LENGTH_PREFIX_SIZE);
327        let frame_bytes = src.split_to(length).freeze();
328        Ok(Some(SessionMessage::decode_message(frame_bytes)?))
329    }
330}
331
332#[cfg(test)]
333mod test {
334    use std::{
335        io,
336        pin::Pin,
337        task::{Context, Poll},
338    };
339
340    use bytes::BytesMut;
341    use futures::StreamExt;
342    use http::HeaderValue;
343    use proptest::{collection::vec, prelude::*};
344    use prost::Message;
345    use tokio_util::codec::Decoder;
346
347    use super::*;
348
349    #[derive(Clone, PartialEq, prost::Message)]
350    struct TestProto {
351        #[prost(bytes, tag = "1")]
352        payload: Vec<u8>,
353    }
354
355    impl TestProto {
356        fn new(payload: Vec<u8>) -> Self {
357            Self { payload }
358        }
359    }
360
361    #[derive(Debug, Clone)]
362    struct TestError {
363        status: u16,
364        body: &'static str,
365    }
366
367    impl From<TestError> for TerminalMessage {
368        fn from(val: TestError) -> Self {
369            TerminalMessage {
370                status: val.status,
371                body: val.body.to_string(),
372            }
373        }
374    }
375
376    fn decode_once(bytes: &Bytes) -> io::Result<SessionMessage> {
377        let mut decoder = FrameDecoder;
378        let mut buf = BytesMut::from(bytes.as_ref());
379        decoder
380            .decode(&mut buf)?
381            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "frame incomplete"))
382    }
383
384    fn compression_strategy() -> impl proptest::strategy::Strategy<Value = CompressionAlgorithm> {
385        prop_oneof![
386            Just(CompressionAlgorithm::None),
387            Just(CompressionAlgorithm::Gzip),
388            Just(CompressionAlgorithm::Zstd),
389        ]
390    }
391
392    fn chunk_bytes(data: &Bytes, pattern: &[usize]) -> Vec<Bytes> {
393        let mut chunks = Vec::new();
394        let mut offset = 0;
395        for &hint in pattern {
396            if offset >= data.len() {
397                break;
398            }
399            let remaining = data.len() - offset;
400            let take = (hint % remaining).saturating_add(1).min(remaining);
401            chunks.push(data.slice(offset..offset + take));
402            offset += take;
403        }
404        if offset < data.len() {
405            chunks.push(data.slice(offset..));
406        }
407        if chunks.is_empty() {
408            chunks.push(data.clone());
409        }
410        chunks
411    }
412
413    proptest! {
414        #[test]
415        fn regular_session_message_round_trips_proptest(
416            algo in compression_strategy(),
417            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4)
418        ) {
419            let proto = TestProto::new(payload.clone());
420            let msg = SessionMessage::regular(algo, &proto).unwrap();
421            let encoded = msg.encode();
422            let decoded = decode_once(&encoded).unwrap();
423
424            prop_assert!(matches!(decoded, SessionMessage::Regular(_)));
425            let SessionMessage::Regular(data) = decoded else { unreachable!() };
426
427            let expected_compression = if algo == CompressionAlgorithm::None || proto.encoded_len() < COMPRESSION_THRESHOLD_BYTES {
428                CompressionAlgorithm::None
429            } else {
430                algo
431            };
432            let actual_compression = data.compression;
433
434            let restored = data.try_into_proto::<TestProto>().unwrap();
435            prop_assert_eq!(restored.payload, payload);
436            prop_assert_eq!(actual_compression, expected_compression);
437        }
438
439        #[test]
440        fn frame_decoder_handles_chunked_frames(
441            algo in compression_strategy(),
442            payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4),
443            chunk_pattern in vec(0usize..=16, 0..=16)
444        ) {
445            let proto = TestProto::new(payload);
446            let msg = SessionMessage::regular(algo, &proto).unwrap();
447            let encoded = msg.encode();
448            let expected = decode_once(&encoded).unwrap();
449
450            let chunks = chunk_bytes(&encoded, &chunk_pattern);
451            prop_assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), encoded.len());
452
453            let mut decoder = FrameDecoder;
454            let mut buf = BytesMut::new();
455            let mut decoded = None;
456
457            for (idx, chunk) in chunks.iter().enumerate() {
458                buf.extend_from_slice(chunk.as_ref());
459                let result = decoder.decode(&mut buf).expect("decode invocation failed");
460                if idx < chunks.len() - 1 {
461                    prop_assert!(result.is_none());
462                } else {
463                    let message = result.expect("final chunk should produce frame");
464                    prop_assert!(buf.is_empty());
465                    decoded = Some(message);
466                }
467            }
468
469            let decoded = decoded.expect("decoder never emitted frame");
470            prop_assert_eq!(decoded, expected);
471        }
472    }
473
474    #[test]
475    fn from_accept_encoding_prefers_zstd() {
476        let mut headers = http::HeaderMap::new();
477        headers.insert(
478            http::header::ACCEPT_ENCODING,
479            HeaderValue::from_static("gzip, zstd, br"),
480        );
481
482        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
483        assert_eq!(algo, CompressionAlgorithm::Zstd);
484    }
485
486    #[test]
487    fn from_accept_encoding_falls_back_to_gzip() {
488        let mut headers = http::HeaderMap::new();
489        headers.insert(
490            http::header::ACCEPT_ENCODING,
491            HeaderValue::from_static("gzip;q=0.8, deflate"),
492        );
493
494        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
495        assert_eq!(algo, CompressionAlgorithm::Gzip);
496    }
497
498    #[test]
499    fn from_accept_encoding_defaults_to_none() {
500        let headers = http::HeaderMap::new();
501        let algo = CompressionAlgorithm::from_accept_encoding(&headers);
502        assert_eq!(algo, CompressionAlgorithm::None);
503    }
504
505    #[test]
506    fn regular_session_message_round_trips() {
507        let proto = TestProto::new(vec![1, 2, 3, 4]);
508        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
509        let encoded = msg.encode();
510        let decoded = decode_once(&encoded).unwrap();
511
512        match decoded {
513            SessionMessage::Regular(data) => {
514                assert_eq!(data.compression, CompressionAlgorithm::None);
515                let restored = data.try_into_proto::<TestProto>().unwrap();
516                assert_eq!(restored, proto);
517            }
518            SessionMessage::Terminal(_) => panic!("expected regular message"),
519        }
520    }
521
522    #[test]
523    fn terminal_session_message_round_trips() {
524        let terminal = TerminalMessage {
525            status: 418,
526            body: "short-circuit".to_string(),
527        };
528        let msg = SessionMessage::from(terminal.clone());
529        let encoded = msg.encode();
530        let decoded = decode_once(&encoded).unwrap();
531
532        match decoded {
533            SessionMessage::Regular(_) => panic!("expected terminal message"),
534            SessionMessage::Terminal(decoded_terminal) => {
535                assert_eq!(decoded_terminal, terminal);
536            }
537        }
538    }
539
540    #[test]
541    fn frame_decoder_waits_for_complete_frame() {
542        let proto = TestProto::new(vec![9, 9, 9]);
543        let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
544        let encoded = msg.encode();
545        let mut decoder = FrameDecoder;
546
547        let split_idx = encoded.len() - 1;
548        let mut buf = BytesMut::from(&encoded[..split_idx]);
549        assert!(decoder.decode(&mut buf).unwrap().is_none());
550        buf.extend_from_slice(&encoded[split_idx..]);
551        let decoded = decoder.decode(&mut buf).unwrap().unwrap();
552
553        match decoded {
554            SessionMessage::Regular(data) => {
555                let restored = data.try_into_proto::<TestProto>().unwrap();
556                assert_eq!(restored, proto);
557            }
558            SessionMessage::Terminal(_) => panic!("expected regular message"),
559        }
560        assert!(buf.is_empty());
561    }
562
563    #[test]
564    fn frame_decoder_rejects_frames_exceeding_decode_limit() {
565        let length = MAX_FRAME_BYTES + 1;
566        let prefix = [
567            ((length >> 16) & 0xFF) as u8,
568            ((length >> 8) & 0xFF) as u8,
569            (length & 0xFF) as u8,
570        ];
571        let mut buf = BytesMut::from(prefix.as_slice());
572        let mut decoder = FrameDecoder;
573        let err = decoder.decode(&mut buf).unwrap_err();
574        assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
575    }
576
577    #[test]
578    #[should_panic(expected = "encoder limit")]
579    fn session_message_encode_rejects_frames_over_limit() {
580        let data = CompressedData {
581            compression: CompressionAlgorithm::None,
582            payload: Bytes::from(vec![0u8; MAX_FRAME_BYTES]),
583        };
584        let msg = SessionMessage::from(data);
585        let _ = msg.encode();
586    }
587
588    #[test]
589    fn frame_decoder_rejects_unknown_compression() {
590        let mut raw = vec![0, 0, 1];
591        raw.push(0x60);
592        let mut decoder = FrameDecoder;
593        let mut buf = BytesMut::from(raw.as_slice());
594        let err = decoder.decode(&mut buf).unwrap_err();
595        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
596    }
597
598    #[test]
599    fn frame_decoder_rejects_terminal_without_status() {
600        let mut raw = vec![0, 0, 1];
601        raw.push(FLAG_TERMINAL);
602        let mut decoder = FrameDecoder;
603        let mut buf = BytesMut::from(raw.as_slice());
604        let err = decoder.decode(&mut buf).unwrap_err();
605        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
606    }
607
608    #[test]
609    fn frame_decoder_handles_empty_payload() {
610        let raw = vec![0, 0, 0];
611        let mut decoder = FrameDecoder;
612        let mut buf = BytesMut::from(raw.as_slice());
613        let err = decoder.decode(&mut buf).unwrap_err();
614        assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
615    }
616
617    #[test]
618    fn compressed_data_round_trip_gzip() {
619        let payload = vec![42; 1_200_000];
620        let proto = TestProto::new(payload.clone());
621        let msg = SessionMessage::regular(CompressionAlgorithm::Gzip, &proto).unwrap();
622        let encoded = msg.encode();
623        let decoded = decode_once(&encoded).unwrap();
624
625        match decoded {
626            SessionMessage::Regular(data) => {
627                assert_eq!(data.compression, CompressionAlgorithm::Gzip);
628                assert!(data.payload.len() < proto.encode_to_vec().len());
629                let restored = data.try_into_proto::<TestProto>().unwrap();
630                assert_eq!(restored.payload, payload);
631            }
632            SessionMessage::Terminal(_) => panic!("expected regular message"),
633        }
634    }
635
636    #[test]
637    fn compressed_data_round_trip_zstd() {
638        let payload = vec![7; 1_100_000];
639        let proto = TestProto::new(payload.clone());
640        let msg = SessionMessage::regular(CompressionAlgorithm::Zstd, &proto).unwrap();
641        let encoded = msg.encode();
642        let decoded = decode_once(&encoded).unwrap();
643
644        match decoded {
645            SessionMessage::Regular(data) => {
646                assert_eq!(data.compression, CompressionAlgorithm::Zstd);
647                assert!(data.payload.len() < proto.encode_to_vec().len());
648                let restored = data.try_into_proto::<TestProto>().unwrap();
649                assert_eq!(restored.payload, payload);
650            }
651            SessionMessage::Terminal(_) => panic!("expected regular message"),
652        }
653    }
654
655    #[test]
656    fn framed_message_stream_yields_terminal_on_error() {
657        let proto = TestProto::new(vec![1, 2, 3]);
658        let items = vec![
659            Ok(proto.clone()),
660            Err(TestError {
661                status: 500,
662                body: "boom",
663            }),
664            Ok(proto.clone()),
665        ];
666
667        let stream = futures::stream::iter(items);
668        let framed = FramedMessageStream::new(CompressionAlgorithm::None, stream);
669        let outputs = futures::executor::block_on(async {
670            framed.collect::<Vec<std::io::Result<Bytes>>>().await
671        });
672
673        assert_eq!(outputs.len(), 2);
674
675        let first = outputs[0].as_ref().expect("first frame ok");
676        match decode_once(first).unwrap() {
677            SessionMessage::Regular(data) => {
678                let restored = data.try_into_proto::<TestProto>().unwrap();
679                assert_eq!(restored, proto);
680            }
681            SessionMessage::Terminal(_) => panic!("expected regular message"),
682        }
683
684        let second = outputs[1].as_ref().expect("second frame ok");
685        match decode_once(second).unwrap() {
686            SessionMessage::Regular(_) => panic!("expected terminal message"),
687            SessionMessage::Terminal(term) => {
688                assert_eq!(term.status, 500);
689                assert_eq!(term.body, "boom");
690            }
691        }
692    }
693
694    #[test]
695    fn framed_message_stream_stops_after_termination() {
696        let mut stream = FramedMessageStream::new(
697            CompressionAlgorithm::None,
698            futures::stream::iter(vec![
699                Ok(TestProto::new(vec![0])),
700                Err(TestError {
701                    status: 400,
702                    body: "bad",
703                }),
704            ]),
705        );
706
707        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
708
709        match Pin::new(&mut stream).poll_next(&mut cx) {
710            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
711                SessionMessage::Regular(_) => {}
712                SessionMessage::Terminal(_) => panic!("expected regular message"),
713            },
714            other => panic!("unexpected poll result: {other:?}"),
715        }
716
717        match Pin::new(&mut stream).poll_next(&mut cx) {
718            Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
719                SessionMessage::Terminal(term) => {
720                    assert_eq!(term.status, 400);
721                    assert_eq!(term.body, "bad");
722                }
723                SessionMessage::Regular(_) => panic!("expected terminal message"),
724            },
725            other => panic!("unexpected poll result: {other:?}"),
726        }
727
728        match Pin::new(&mut stream).poll_next(&mut cx) {
729            Poll::Ready(None) => {}
730            other => panic!("expected stream to terminate, got {other:?}"),
731        }
732    }
733}