1use std::{
2 io::{Read, Write},
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use enum_ordinalize::Ordinalize;
9use flate2::{Compression, read::GzDecoder, write::GzEncoder};
10use futures::Stream;
11
12const LENGTH_PREFIX_SIZE: usize = 3;
34const STATUS_CODE_SIZE: usize = 2;
35const COMPRESSION_THRESHOLD_BYTES: usize = 1024; const MAX_FRAME_BYTES: usize = 2 * 1024 * 1024; const FLAG_TOTAL_SIZE: usize = 1;
51const MAX_FRAME_PAYLOAD_BYTES: usize = MAX_FRAME_BYTES - FLAG_TOTAL_SIZE;
53const MAX_DECOMPRESSED_PAYLOAD_BYTES: usize = MAX_FRAME_PAYLOAD_BYTES;
54const FLAG_TERMINAL: u8 = 0b1000_0000;
55const FLAG_COMPRESSION_MASK: u8 = 0b0110_0000;
56const FLAG_COMPRESSION_SHIFT: u8 = 5;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Ordinalize)]
59#[repr(u8)]
60pub enum CompressionAlgorithm {
61 None = 0,
62 Zstd = 1,
63 Gzip = 2,
64}
65
66impl CompressionAlgorithm {
67 pub fn from_accept_encoding(headers: &http::HeaderMap) -> Self {
68 let mut gzip = false;
69 for header_value in headers.get_all(http::header::ACCEPT_ENCODING) {
70 if let Ok(value) = header_value.to_str() {
71 for encoding in value.split(',') {
72 let encoding = encoding.trim().split(';').next().unwrap_or("").trim();
73 if encoding.eq_ignore_ascii_case("zstd") {
74 return Self::Zstd;
75 } else if encoding.eq_ignore_ascii_case("gzip") {
76 gzip = true;
77 }
78 }
79 }
80 }
81 if gzip { Self::Gzip } else { Self::None }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct CompressedData {
87 compression: CompressionAlgorithm,
88 payload: Bytes,
89}
90
91impl CompressedData {
92 pub fn for_proto(
93 compression: CompressionAlgorithm,
94 proto: &impl prost::Message,
95 ) -> std::io::Result<Self> {
96 Self::compress(compression, proto.encode_to_vec())
97 }
98
99 fn compress(compression: CompressionAlgorithm, data: Vec<u8>) -> std::io::Result<Self> {
100 if data.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
101 return Err(std::io::Error::new(
102 std::io::ErrorKind::InvalidInput,
103 "payload exceeds decompressed limit",
104 ));
105 }
106
107 if compression == CompressionAlgorithm::None || data.len() < COMPRESSION_THRESHOLD_BYTES {
108 return Ok(Self {
109 compression: CompressionAlgorithm::None,
110 payload: data.into(),
111 });
112 }
113 let mut buf = Vec::with_capacity(data.len());
114 match compression {
115 CompressionAlgorithm::Gzip => {
116 let mut encoder = GzEncoder::new(buf, Compression::default());
117 encoder.write_all(data.as_slice())?;
118 buf = encoder.finish()?;
119 }
120 CompressionAlgorithm::Zstd => {
121 zstd::stream::copy_encode(data.as_slice(), &mut buf, 0)?;
122 }
123 CompressionAlgorithm::None => unreachable!("handled above"),
124 };
125 let payload = Bytes::from(buf.into_boxed_slice());
126 if payload.len() > MAX_FRAME_PAYLOAD_BYTES {
127 return Err(std::io::Error::new(
128 std::io::ErrorKind::InvalidInput,
129 "compressed payload exceeds frame limit",
130 ));
131 }
132 Ok(Self {
133 compression,
134 payload,
135 })
136 }
137
138 fn decompressed(self) -> std::io::Result<Bytes> {
139 let initial_capacity = self
140 .payload
141 .len()
142 .saturating_mul(2)
143 .clamp(COMPRESSION_THRESHOLD_BYTES, MAX_DECOMPRESSED_PAYLOAD_BYTES);
144
145 fn read_to_end_limited(
147 mut reader: impl Read,
148 initial_capacity: usize,
149 ) -> std::io::Result<Bytes> {
150 let mut limited = reader
151 .by_ref()
152 .take((MAX_DECOMPRESSED_PAYLOAD_BYTES + 1) as u64);
153 let mut buf = Vec::with_capacity(initial_capacity);
154 limited.read_to_end(&mut buf)?;
155 if buf.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
156 return Err(std::io::Error::new(
157 std::io::ErrorKind::InvalidData,
158 "decompressed payload exceeds limit",
159 ));
160 }
161 Ok(Bytes::from(buf.into_boxed_slice()))
162 }
163
164 match self.compression {
165 CompressionAlgorithm::None => {
166 if self.payload.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
167 return Err(std::io::Error::new(
168 std::io::ErrorKind::InvalidData,
169 "decompressed payload exceeds limit",
170 ));
171 }
172 Ok(self.payload)
173 }
174 CompressionAlgorithm::Gzip => {
175 let mut decoder = GzDecoder::new(&self.payload[..]);
176 read_to_end_limited(&mut decoder, initial_capacity)
177 }
178 CompressionAlgorithm::Zstd => {
179 let mut decoder = zstd::stream::Decoder::new(&self.payload[..])?;
180 read_to_end_limited(&mut decoder, initial_capacity)
181 }
182 }
183 }
184
185 pub fn try_into_proto<P: prost::Message + Default>(self) -> std::io::Result<P> {
186 let payload = self.decompressed()?;
187 P::decode(payload.as_ref())
188 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct TerminalMessage {
194 pub status: u16,
195 pub body: String,
196}
197
198#[derive(Debug, Clone, PartialEq, Eq)]
199pub enum SessionMessage {
200 Regular(CompressedData),
201 Terminal(TerminalMessage),
202}
203
204impl From<CompressedData> for SessionMessage {
205 fn from(data: CompressedData) -> Self {
206 Self::Regular(data)
207 }
208}
209
210impl From<TerminalMessage> for SessionMessage {
211 fn from(msg: TerminalMessage) -> Self {
212 Self::Terminal(msg)
213 }
214}
215
216impl SessionMessage {
217 pub fn regular(
218 compression: CompressionAlgorithm,
219 proto: &impl prost::Message,
220 ) -> std::io::Result<Self> {
221 Ok(Self::Regular(CompressedData::for_proto(
222 compression,
223 proto,
224 )?))
225 }
226
227 pub fn encode(&self) -> Bytes {
228 let encoded_size = FLAG_TOTAL_SIZE + self.payload_size();
229 assert!(
230 encoded_size <= MAX_FRAME_BYTES,
231 "payload exceeds encoder limit"
232 );
233 let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + encoded_size);
234 buf.put_uint(encoded_size as u64, 3);
235 match self {
236 Self::Regular(msg) => {
237 let flag =
238 (msg.compression.ordinal() << FLAG_COMPRESSION_SHIFT) & FLAG_COMPRESSION_MASK;
239 buf.put_u8(flag);
240 buf.extend_from_slice(&msg.payload);
241 }
242 Self::Terminal(msg) => {
243 buf.put_u8(FLAG_TERMINAL);
244 buf.put_u16(msg.status);
245 buf.extend_from_slice(msg.body.as_bytes());
246 }
247 }
248 buf.freeze()
249 }
250
251 fn decode_message(mut buf: Bytes) -> std::io::Result<Self> {
252 if buf.is_empty() {
253 return Err(std::io::Error::new(
254 std::io::ErrorKind::UnexpectedEof,
255 "empty frame payload",
256 ));
257 }
258 let flag = buf.get_u8();
259
260 let is_terminal = (flag & FLAG_TERMINAL) != 0;
261 if is_terminal {
262 if buf.len() < STATUS_CODE_SIZE {
263 return Err(std::io::Error::new(
264 std::io::ErrorKind::InvalidData,
265 "terminal message missing status code",
266 ));
267 }
268 let status = buf.get_u16();
269 let body = String::from_utf8(buf.into()).map_err(|_| {
270 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid utf-8")
271 })?;
272 return Ok(TerminalMessage { status, body }.into());
273 }
274
275 let compression_bits = (flag & FLAG_COMPRESSION_MASK) >> FLAG_COMPRESSION_SHIFT;
276 let Some(compression) = CompressionAlgorithm::from_ordinal(compression_bits) else {
277 return Err(std::io::Error::new(
278 std::io::ErrorKind::InvalidData,
279 "unknown compression algorithm",
280 ));
281 };
282
283 Ok(CompressedData {
284 compression,
285 payload: buf,
286 }
287 .into())
288 }
289
290 fn payload_size(&self) -> usize {
291 match self {
292 Self::Regular(msg) => msg.payload.len(),
293 Self::Terminal(msg) => STATUS_CODE_SIZE + msg.body.len(),
294 }
295 }
296}
297
298pub struct FramedMessageStream<S> {
299 inner: S,
300 compression: CompressionAlgorithm,
301 terminated: bool,
302}
303
304impl<S> FramedMessageStream<S> {
305 pub fn new(compression: CompressionAlgorithm, inner: S) -> Self {
306 Self {
307 inner,
308 compression,
309 terminated: false,
310 }
311 }
312}
313
314impl<S, P, E> Stream for FramedMessageStream<S>
315where
316 S: Stream<Item = Result<P, E>> + Unpin,
317 P: prost::Message,
318 E: Into<TerminalMessage>,
319{
320 type Item = std::io::Result<Bytes>;
321
322 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
323 if self.terminated {
324 return Poll::Ready(None);
325 }
326
327 match Pin::new(&mut self.inner).poll_next(cx) {
328 Poll::Ready(Some(Ok(item))) => match SessionMessage::regular(self.compression, &item) {
329 Ok(msg) => Poll::Ready(Some(Ok(msg.encode()))),
330 Err(err) => {
331 self.terminated = true;
332 Poll::Ready(Some(Err(err)))
333 }
334 },
335 Poll::Ready(Some(Err(e))) => {
336 self.terminated = true;
337 let bytes = SessionMessage::Terminal(e.into()).encode();
338 Poll::Ready(Some(Ok(bytes)))
339 }
340 Poll::Ready(None) => {
341 self.terminated = true;
342 Poll::Ready(None)
343 }
344 Poll::Pending => Poll::Pending,
345 }
346 }
347}
348
349pub struct FrameDecoder;
350
351impl tokio_util::codec::Decoder for FrameDecoder {
352 type Item = SessionMessage;
353 type Error = std::io::Error;
354
355 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
356 if src.len() < LENGTH_PREFIX_SIZE {
357 return Ok(None);
358 }
359
360 let length = ((src[0] as usize) << 16) | ((src[1] as usize) << 8) | (src[2] as usize);
361
362 if length > MAX_FRAME_BYTES {
363 return Err(std::io::Error::new(
364 std::io::ErrorKind::InvalidInput,
365 "frame exceeds decode limit",
366 ));
367 }
368
369 let total_size = LENGTH_PREFIX_SIZE + length;
370 if src.len() < total_size {
371 return Ok(None);
372 }
373
374 src.advance(LENGTH_PREFIX_SIZE);
375 let frame_bytes = src.split_to(length).freeze();
376 Ok(Some(SessionMessage::decode_message(frame_bytes)?))
377 }
378}
379
380#[cfg(test)]
381mod test {
382 use std::{
383 io,
384 pin::Pin,
385 task::{Context, Poll},
386 };
387
388 use bytes::BytesMut;
389 use futures::StreamExt;
390 use http::HeaderValue;
391 use proptest::{collection::vec, prelude::*};
392 use prost::Message;
393 use tokio_util::codec::Decoder;
394
395 use super::*;
396
397 #[derive(Clone, PartialEq, prost::Message)]
398 struct TestProto {
399 #[prost(bytes, tag = "1")]
400 payload: Vec<u8>,
401 }
402
403 impl TestProto {
404 fn new(payload: Vec<u8>) -> Self {
405 Self { payload }
406 }
407 }
408
409 #[derive(Debug, Clone)]
410 struct TestError {
411 status: u16,
412 body: &'static str,
413 }
414
415 impl From<TestError> for TerminalMessage {
416 fn from(val: TestError) -> Self {
417 TerminalMessage {
418 status: val.status,
419 body: val.body.to_string(),
420 }
421 }
422 }
423
424 fn decode_once(bytes: &Bytes) -> io::Result<SessionMessage> {
425 let mut decoder = FrameDecoder;
426 let mut buf = BytesMut::from(bytes.as_ref());
427 decoder
428 .decode(&mut buf)?
429 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "frame incomplete"))
430 }
431
432 fn compression_strategy() -> impl proptest::strategy::Strategy<Value = CompressionAlgorithm> {
433 prop_oneof![
434 Just(CompressionAlgorithm::None),
435 Just(CompressionAlgorithm::Gzip),
436 Just(CompressionAlgorithm::Zstd),
437 ]
438 }
439
440 fn chunk_bytes(data: &Bytes, pattern: &[usize]) -> Vec<Bytes> {
441 let mut chunks = Vec::new();
442 let mut offset = 0;
443 for &hint in pattern {
444 if offset >= data.len() {
445 break;
446 }
447 let remaining = data.len() - offset;
448 let take = (hint % remaining).saturating_add(1).min(remaining);
449 chunks.push(data.slice(offset..offset + take));
450 offset += take;
451 }
452 if offset < data.len() {
453 chunks.push(data.slice(offset..));
454 }
455 if chunks.is_empty() {
456 chunks.push(data.clone());
457 }
458 chunks
459 }
460
461 proptest! {
462 #[test]
463 fn regular_session_message_round_trips_proptest(
464 algo in compression_strategy(),
465 payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4)
466 ) {
467 let proto = TestProto::new(payload.clone());
468 let msg = SessionMessage::regular(algo, &proto).unwrap();
469 let encoded = msg.encode();
470 let decoded = decode_once(&encoded).unwrap();
471
472 prop_assert!(matches!(decoded, SessionMessage::Regular(_)));
473 let SessionMessage::Regular(data) = decoded else { unreachable!() };
474
475 let expected_compression = if algo == CompressionAlgorithm::None || proto.encoded_len() < COMPRESSION_THRESHOLD_BYTES {
476 CompressionAlgorithm::None
477 } else {
478 algo
479 };
480 let actual_compression = data.compression;
481
482 let restored = data.try_into_proto::<TestProto>().unwrap();
483 prop_assert_eq!(restored.payload, payload);
484 prop_assert_eq!(actual_compression, expected_compression);
485 }
486
487 #[test]
488 fn frame_decoder_handles_chunked_frames(
489 algo in compression_strategy(),
490 payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4),
491 chunk_pattern in vec(0usize..=16, 0..=16)
492 ) {
493 let proto = TestProto::new(payload);
494 let msg = SessionMessage::regular(algo, &proto).unwrap();
495 let encoded = msg.encode();
496 let expected = decode_once(&encoded).unwrap();
497
498 let chunks = chunk_bytes(&encoded, &chunk_pattern);
499 prop_assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), encoded.len());
500
501 let mut decoder = FrameDecoder;
502 let mut buf = BytesMut::new();
503 let mut decoded = None;
504
505 for (idx, chunk) in chunks.iter().enumerate() {
506 buf.extend_from_slice(chunk.as_ref());
507 let result = decoder.decode(&mut buf).expect("decode invocation failed");
508 if idx < chunks.len() - 1 {
509 prop_assert!(result.is_none());
510 } else {
511 let message = result.expect("final chunk should produce frame");
512 prop_assert!(buf.is_empty());
513 decoded = Some(message);
514 }
515 }
516
517 let decoded = decoded.expect("decoder never emitted frame");
518 prop_assert_eq!(decoded, expected);
519 }
520 }
521
522 #[test]
523 fn from_accept_encoding_prefers_zstd() {
524 let mut headers = http::HeaderMap::new();
525 headers.insert(
526 http::header::ACCEPT_ENCODING,
527 HeaderValue::from_static("gzip, zstd, br"),
528 );
529
530 let algo = CompressionAlgorithm::from_accept_encoding(&headers);
531 assert_eq!(algo, CompressionAlgorithm::Zstd);
532 }
533
534 #[test]
535 fn from_accept_encoding_falls_back_to_gzip() {
536 let mut headers = http::HeaderMap::new();
537 headers.insert(
538 http::header::ACCEPT_ENCODING,
539 HeaderValue::from_static("gzip;q=0.8, deflate"),
540 );
541
542 let algo = CompressionAlgorithm::from_accept_encoding(&headers);
543 assert_eq!(algo, CompressionAlgorithm::Gzip);
544 }
545
546 #[test]
547 fn from_accept_encoding_defaults_to_none() {
548 let headers = http::HeaderMap::new();
549 let algo = CompressionAlgorithm::from_accept_encoding(&headers);
550 assert_eq!(algo, CompressionAlgorithm::None);
551 }
552
553 #[test]
554 fn regular_session_message_round_trips() {
555 let proto = TestProto::new(vec![1, 2, 3, 4]);
556 let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
557 let encoded = msg.encode();
558 let decoded = decode_once(&encoded).unwrap();
559
560 match decoded {
561 SessionMessage::Regular(data) => {
562 assert_eq!(data.compression, CompressionAlgorithm::None);
563 let restored = data.try_into_proto::<TestProto>().unwrap();
564 assert_eq!(restored, proto);
565 }
566 SessionMessage::Terminal(_) => panic!("expected regular message"),
567 }
568 }
569
570 #[test]
571 fn terminal_session_message_round_trips() {
572 let terminal = TerminalMessage {
573 status: 418,
574 body: "short-circuit".to_string(),
575 };
576 let msg = SessionMessage::from(terminal.clone());
577 let encoded = msg.encode();
578 let decoded = decode_once(&encoded).unwrap();
579
580 match decoded {
581 SessionMessage::Regular(_) => panic!("expected terminal message"),
582 SessionMessage::Terminal(decoded_terminal) => {
583 assert_eq!(decoded_terminal, terminal);
584 }
585 }
586 }
587
588 #[test]
589 fn frame_decoder_waits_for_complete_frame() {
590 let proto = TestProto::new(vec![9, 9, 9]);
591 let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
592 let encoded = msg.encode();
593 let mut decoder = FrameDecoder;
594
595 let split_idx = encoded.len() - 1;
596 let mut buf = BytesMut::from(&encoded[..split_idx]);
597 assert!(decoder.decode(&mut buf).unwrap().is_none());
598 buf.extend_from_slice(&encoded[split_idx..]);
599 let decoded = decoder.decode(&mut buf).unwrap().unwrap();
600
601 match decoded {
602 SessionMessage::Regular(data) => {
603 let restored = data.try_into_proto::<TestProto>().unwrap();
604 assert_eq!(restored, proto);
605 }
606 SessionMessage::Terminal(_) => panic!("expected regular message"),
607 }
608 assert!(buf.is_empty());
609 }
610
611 #[test]
612 fn frame_decoder_rejects_frames_exceeding_decode_limit() {
613 let length = MAX_FRAME_BYTES + 1;
614 let prefix = [
615 ((length >> 16) & 0xFF) as u8,
616 ((length >> 8) & 0xFF) as u8,
617 (length & 0xFF) as u8,
618 ];
619 let mut buf = BytesMut::from(prefix.as_slice());
620 let mut decoder = FrameDecoder;
621 let err = decoder.decode(&mut buf).unwrap_err();
622 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
623 }
624
625 #[test]
626 #[should_panic(expected = "encoder limit")]
627 fn session_message_encode_rejects_frames_over_limit() {
628 let data = CompressedData {
629 compression: CompressionAlgorithm::None,
630 payload: Bytes::from(vec![0u8; MAX_FRAME_BYTES]),
631 };
632 let msg = SessionMessage::from(data);
633 let _ = msg.encode();
634 }
635
636 #[test]
637 fn frame_decoder_rejects_unknown_compression() {
638 let mut raw = vec![0, 0, 1];
639 raw.push(0x60);
640 let mut decoder = FrameDecoder;
641 let mut buf = BytesMut::from(raw.as_slice());
642 let err = decoder.decode(&mut buf).unwrap_err();
643 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
644 }
645
646 #[test]
647 fn frame_decoder_rejects_terminal_without_status() {
648 let mut raw = vec![0, 0, 1];
649 raw.push(FLAG_TERMINAL);
650 let mut decoder = FrameDecoder;
651 let mut buf = BytesMut::from(raw.as_slice());
652 let err = decoder.decode(&mut buf).unwrap_err();
653 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
654 }
655
656 #[test]
657 fn frame_decoder_handles_empty_payload() {
658 let raw = vec![0, 0, 0];
659 let mut decoder = FrameDecoder;
660 let mut buf = BytesMut::from(raw.as_slice());
661 let err = decoder.decode(&mut buf).unwrap_err();
662 assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
663 }
664
665 #[test]
666 fn compressed_data_round_trip_gzip() {
667 let payload = vec![42; 1_200_000];
668 let proto = TestProto::new(payload.clone());
669 let msg = SessionMessage::regular(CompressionAlgorithm::Gzip, &proto).unwrap();
670 let encoded = msg.encode();
671 let decoded = decode_once(&encoded).unwrap();
672
673 match decoded {
674 SessionMessage::Regular(data) => {
675 assert_eq!(data.compression, CompressionAlgorithm::Gzip);
676 assert!(data.payload.len() < proto.encode_to_vec().len());
677 let restored = data.try_into_proto::<TestProto>().unwrap();
678 assert_eq!(restored.payload, payload);
679 }
680 SessionMessage::Terminal(_) => panic!("expected regular message"),
681 }
682 }
683
684 #[test]
685 fn compressed_data_round_trip_zstd() {
686 let payload = vec![7; 1_100_000];
687 let proto = TestProto::new(payload.clone());
688 let msg = SessionMessage::regular(CompressionAlgorithm::Zstd, &proto).unwrap();
689 let encoded = msg.encode();
690 let decoded = decode_once(&encoded).unwrap();
691
692 match decoded {
693 SessionMessage::Regular(data) => {
694 assert_eq!(data.compression, CompressionAlgorithm::Zstd);
695 assert!(data.payload.len() < proto.encode_to_vec().len());
696 let restored = data.try_into_proto::<TestProto>().unwrap();
697 assert_eq!(restored.payload, payload);
698 }
699 SessionMessage::Terminal(_) => panic!("expected regular message"),
700 }
701 }
702
703 #[test]
704 fn decompression_rejects_payloads_exceeding_limit() {
705 let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
706 let proto = TestProto::new(payload);
707 let encoded = proto.encode_to_vec();
708
709 for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
710 let compressed = match algo {
711 CompressionAlgorithm::Gzip => {
712 let mut out = Vec::new();
713 let mut encoder = GzEncoder::new(&mut out, Compression::default());
714 encoder.write_all(encoded.as_slice()).unwrap();
715 encoder.finish().unwrap();
716 out
717 }
718 CompressionAlgorithm::Zstd => {
719 let mut out = Vec::new();
720 zstd::stream::copy_encode(encoded.as_slice(), &mut out, 0).unwrap();
721 out
722 }
723 CompressionAlgorithm::None => unreachable!("explicitly excluded in test"),
724 };
725
726 let data = CompressedData {
727 compression: algo,
728 payload: Bytes::from(compressed),
729 };
730 assert!(data.payload.len() <= MAX_FRAME_PAYLOAD_BYTES);
731
732 let err = data.try_into_proto::<TestProto>().expect_err("should fail");
733 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
734 assert!(
735 err.to_string()
736 .contains("decompressed payload exceeds limit")
737 );
738 }
739 }
740
741 #[test]
742 fn compress_rejects_payloads_exceeding_decompressed_limit() {
743 let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
744 let proto = TestProto::new(payload);
745
746 let err = CompressedData::compress(CompressionAlgorithm::Gzip, proto.encode_to_vec())
747 .expect_err("should fail");
748 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
749 assert!(
750 err.to_string()
751 .contains("payload exceeds decompressed limit")
752 );
753 }
754
755 #[test]
756 fn compress_allows_payload_at_exact_limit_without_encode_panic() {
757 let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES];
758 let data = CompressedData::compress(CompressionAlgorithm::None, payload).unwrap();
759 let encoded = SessionMessage::from(data).encode();
760 assert_eq!(encoded.len(), LENGTH_PREFIX_SIZE + MAX_FRAME_BYTES);
761 }
762
763 #[test]
764 fn compress_rejects_incompressible_payload_that_exceeds_frame_limit_after_compression() {
765 let mut payload = vec![0u8; MAX_DECOMPRESSED_PAYLOAD_BYTES];
766 let mut x = 0x1234_5678u32;
767 for byte in &mut payload {
768 x ^= x << 13;
769 x ^= x >> 17;
770 x ^= x << 5;
771 *byte = (x & 0xFF) as u8;
772 }
773
774 for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
775 let err = CompressedData::compress(algo, payload.clone()).expect_err("should fail");
776 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
777 assert!(
778 err.to_string()
779 .contains("compressed payload exceeds frame limit")
780 );
781 }
782 }
783
784 #[test]
785 fn framed_message_stream_yields_terminal_on_error() {
786 let proto = TestProto::new(vec![1, 2, 3]);
787 let items = vec![
788 Ok(proto.clone()),
789 Err(TestError {
790 status: 500,
791 body: "boom",
792 }),
793 Ok(proto.clone()),
794 ];
795
796 let stream = futures::stream::iter(items);
797 let framed = FramedMessageStream::new(CompressionAlgorithm::None, stream);
798 let outputs = futures::executor::block_on(async {
799 framed.collect::<Vec<std::io::Result<Bytes>>>().await
800 });
801
802 assert_eq!(outputs.len(), 2);
803
804 let first = outputs[0].as_ref().expect("first frame ok");
805 match decode_once(first).unwrap() {
806 SessionMessage::Regular(data) => {
807 let restored = data.try_into_proto::<TestProto>().unwrap();
808 assert_eq!(restored, proto);
809 }
810 SessionMessage::Terminal(_) => panic!("expected regular message"),
811 }
812
813 let second = outputs[1].as_ref().expect("second frame ok");
814 match decode_once(second).unwrap() {
815 SessionMessage::Regular(_) => panic!("expected terminal message"),
816 SessionMessage::Terminal(term) => {
817 assert_eq!(term.status, 500);
818 assert_eq!(term.body, "boom");
819 }
820 }
821 }
822
823 #[test]
824 fn framed_message_stream_stops_after_termination() {
825 let mut stream = FramedMessageStream::new(
826 CompressionAlgorithm::None,
827 futures::stream::iter(vec![
828 Ok(TestProto::new(vec![0])),
829 Err(TestError {
830 status: 400,
831 body: "bad",
832 }),
833 ]),
834 );
835
836 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
837
838 match Pin::new(&mut stream).poll_next(&mut cx) {
839 Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
840 SessionMessage::Regular(_) => {}
841 SessionMessage::Terminal(_) => panic!("expected regular message"),
842 },
843 other => panic!("unexpected poll result: {other:?}"),
844 }
845
846 match Pin::new(&mut stream).poll_next(&mut cx) {
847 Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
848 SessionMessage::Terminal(term) => {
849 assert_eq!(term.status, 400);
850 assert_eq!(term.body, "bad");
851 }
852 SessionMessage::Regular(_) => panic!("expected terminal message"),
853 },
854 other => panic!("unexpected poll result: {other:?}"),
855 }
856
857 match Pin::new(&mut stream).poll_next(&mut cx) {
858 Poll::Ready(None) => {}
859 other => panic!("expected stream to terminate, got {other:?}"),
860 }
861 }
862
863 #[test]
864 fn framed_message_stream_terminates_after_encoding_error() {
865 let oversized = MAX_DECOMPRESSED_PAYLOAD_BYTES + 1;
866 let items: Vec<Result<TestProto, TestError>> = vec![
867 Ok(TestProto::new(vec![0u8; oversized])),
868 Ok(TestProto::new(vec![1u8; oversized])),
869 ];
870 let mut stream =
871 FramedMessageStream::new(CompressionAlgorithm::None, futures::stream::iter(items));
872
873 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
874
875 match Pin::new(&mut stream).poll_next(&mut cx) {
876 Poll::Ready(Some(Err(err))) => {
877 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
878 assert!(
879 err.to_string()
880 .contains("payload exceeds decompressed limit")
881 );
882 }
883 other => panic!("expected encoding error, got {other:?}"),
884 }
885
886 match Pin::new(&mut stream).poll_next(&mut cx) {
887 Poll::Ready(None) => {}
888 other => panic!("expected stream to terminate after encoding error, got {other:?}"),
889 }
890 }
891}