1use std::{
2 io::{Read, Write},
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use enum_ordinalize::Ordinalize;
9use flate2::{Compression, read::GzDecoder, write::GzEncoder};
10use futures::Stream;
11
12const LENGTH_PREFIX_SIZE: usize = 3;
34const STATUS_CODE_SIZE: usize = 2;
35const COMPRESSION_THRESHOLD_BYTES: usize = 1024; const MAX_FRAME_BYTES: usize = 2 * 1024 * 1024; const FLAG_TOTAL_SIZE: usize = 1;
51const MAX_FRAME_PAYLOAD_BYTES: usize = MAX_FRAME_BYTES - FLAG_TOTAL_SIZE;
53const MAX_DECOMPRESSED_PAYLOAD_BYTES: usize = MAX_FRAME_PAYLOAD_BYTES;
54const FLAG_TERMINAL: u8 = 0b1000_0000;
55const FLAG_COMPRESSION_MASK: u8 = 0b0110_0000;
56const FLAG_COMPRESSION_SHIFT: u8 = 5;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Ordinalize)]
59#[repr(u8)]
60pub enum CompressionAlgorithm {
61 None = 0,
62 Zstd = 1,
63 Gzip = 2,
64}
65
66impl CompressionAlgorithm {
67 pub fn from_accept_encoding(headers: &http::HeaderMap) -> Self {
68 let mut gzip = false;
69 for header_value in headers.get_all(http::header::ACCEPT_ENCODING) {
70 if let Ok(value) = header_value.to_str() {
71 for encoding in value.split(',') {
72 let encoding = encoding.trim().split(';').next().unwrap_or("").trim();
73 if encoding.eq_ignore_ascii_case("zstd") {
74 return Self::Zstd;
75 } else if encoding.eq_ignore_ascii_case("gzip") {
76 gzip = true;
77 }
78 }
79 }
80 }
81 if gzip { Self::Gzip } else { Self::None }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct CompressedData {
87 compression: CompressionAlgorithm,
88 payload: Bytes,
89}
90
91impl CompressedData {
92 pub fn for_proto(
93 compression: CompressionAlgorithm,
94 proto: &impl prost::Message,
95 ) -> std::io::Result<Self> {
96 Self::compress(compression, proto.encode_to_vec())
97 }
98
99 fn compress(compression: CompressionAlgorithm, data: Vec<u8>) -> std::io::Result<Self> {
100 if data.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
101 return Err(std::io::Error::new(
102 std::io::ErrorKind::InvalidInput,
103 "payload exceeds decompressed limit",
104 ));
105 }
106
107 if compression == CompressionAlgorithm::None || data.len() < COMPRESSION_THRESHOLD_BYTES {
108 return Ok(Self {
109 compression: CompressionAlgorithm::None,
110 payload: data.into(),
111 });
112 }
113 let mut buf = Vec::with_capacity(data.len());
114 match compression {
115 CompressionAlgorithm::Gzip => {
116 let mut encoder = GzEncoder::new(buf, Compression::default());
117 encoder.write_all(data.as_slice())?;
118 buf = encoder.finish()?;
119 }
120 CompressionAlgorithm::Zstd => {
121 zstd::stream::copy_encode(data.as_slice(), &mut buf, 0)?;
122 }
123 CompressionAlgorithm::None => unreachable!("handled above"),
124 };
125 let payload = Bytes::from(buf.into_boxed_slice());
126 if payload.len() > MAX_FRAME_PAYLOAD_BYTES {
127 return Err(std::io::Error::new(
128 std::io::ErrorKind::InvalidInput,
129 "compressed payload exceeds frame limit",
130 ));
131 }
132 Ok(Self {
133 compression,
134 payload,
135 })
136 }
137
138 fn decompressed(self) -> std::io::Result<Bytes> {
139 let initial_capacity = self
140 .payload
141 .len()
142 .saturating_mul(2)
143 .clamp(COMPRESSION_THRESHOLD_BYTES, MAX_DECOMPRESSED_PAYLOAD_BYTES);
144
145 fn read_to_end_limited(
147 mut reader: impl Read,
148 initial_capacity: usize,
149 ) -> std::io::Result<Bytes> {
150 let mut limited = reader
151 .by_ref()
152 .take((MAX_DECOMPRESSED_PAYLOAD_BYTES + 1) as u64);
153 let mut buf = Vec::with_capacity(initial_capacity);
154 limited.read_to_end(&mut buf)?;
155 if buf.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
156 return Err(std::io::Error::new(
157 std::io::ErrorKind::InvalidData,
158 "decompressed payload exceeds limit",
159 ));
160 }
161 Ok(Bytes::from(buf.into_boxed_slice()))
162 }
163
164 match self.compression {
165 CompressionAlgorithm::None => {
166 if self.payload.len() > MAX_DECOMPRESSED_PAYLOAD_BYTES {
167 return Err(std::io::Error::new(
168 std::io::ErrorKind::InvalidData,
169 "decompressed payload exceeds limit",
170 ));
171 }
172 Ok(self.payload)
173 }
174 CompressionAlgorithm::Gzip => {
175 let mut decoder = GzDecoder::new(&self.payload[..]);
176 read_to_end_limited(&mut decoder, initial_capacity)
177 }
178 CompressionAlgorithm::Zstd => {
179 let mut decoder = zstd::stream::Decoder::new(&self.payload[..])?;
180 read_to_end_limited(&mut decoder, initial_capacity)
181 }
182 }
183 }
184
185 pub fn try_into_proto<P: prost::Message + Default>(self) -> std::io::Result<P> {
186 let payload = self.decompressed()?;
187 P::decode(payload.as_ref())
188 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct TerminalMessage {
194 pub status: u16,
195 pub body: String,
196}
197
198#[derive(Debug, Clone, PartialEq, Eq)]
199pub enum SessionMessage {
200 Regular(CompressedData),
201 Terminal(TerminalMessage),
202}
203
204impl From<CompressedData> for SessionMessage {
205 fn from(data: CompressedData) -> Self {
206 Self::Regular(data)
207 }
208}
209
210impl From<TerminalMessage> for SessionMessage {
211 fn from(msg: TerminalMessage) -> Self {
212 Self::Terminal(msg)
213 }
214}
215
216impl SessionMessage {
217 pub fn regular(
218 compression: CompressionAlgorithm,
219 proto: &impl prost::Message,
220 ) -> std::io::Result<Self> {
221 Ok(Self::Regular(CompressedData::for_proto(
222 compression,
223 proto,
224 )?))
225 }
226
227 pub fn encode(&self) -> Bytes {
228 let encoded_size = FLAG_TOTAL_SIZE + self.payload_size();
229 assert!(
230 encoded_size <= MAX_FRAME_BYTES,
231 "payload exceeds encoder limit"
232 );
233 let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + encoded_size);
234 buf.put_uint(encoded_size as u64, 3);
235 match self {
236 Self::Regular(msg) => {
237 let flag =
238 (msg.compression.ordinal() << FLAG_COMPRESSION_SHIFT) & FLAG_COMPRESSION_MASK;
239 buf.put_u8(flag);
240 buf.extend_from_slice(&msg.payload);
241 }
242 Self::Terminal(msg) => {
243 buf.put_u8(FLAG_TERMINAL);
244 buf.put_u16(msg.status);
245 buf.extend_from_slice(msg.body.as_bytes());
246 }
247 }
248 buf.freeze()
249 }
250
251 fn decode_message(mut buf: Bytes) -> std::io::Result<Self> {
252 if buf.is_empty() {
253 return Err(std::io::Error::new(
254 std::io::ErrorKind::UnexpectedEof,
255 "empty frame payload",
256 ));
257 }
258 let flag = buf.get_u8();
259
260 let is_terminal = (flag & FLAG_TERMINAL) != 0;
261 if is_terminal {
262 if buf.len() < STATUS_CODE_SIZE {
263 return Err(std::io::Error::new(
264 std::io::ErrorKind::InvalidData,
265 "terminal message missing status code",
266 ));
267 }
268 let status = buf.get_u16();
269 let body = String::from_utf8(buf.into()).map_err(|_| {
270 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid utf-8")
271 })?;
272 return Ok(TerminalMessage { status, body }.into());
273 }
274
275 let compression_bits = (flag & FLAG_COMPRESSION_MASK) >> FLAG_COMPRESSION_SHIFT;
276 let Some(compression) = CompressionAlgorithm::from_ordinal(compression_bits) else {
277 return Err(std::io::Error::new(
278 std::io::ErrorKind::InvalidData,
279 "unknown compression algorithm",
280 ));
281 };
282
283 Ok(CompressedData {
284 compression,
285 payload: buf,
286 }
287 .into())
288 }
289
290 fn payload_size(&self) -> usize {
291 match self {
292 Self::Regular(msg) => msg.payload.len(),
293 Self::Terminal(msg) => STATUS_CODE_SIZE + msg.body.len(),
294 }
295 }
296}
297
298pub struct FramedMessageStream<S> {
299 inner: S,
300 compression: CompressionAlgorithm,
301 terminated: bool,
302}
303
304impl<S> FramedMessageStream<S> {
305 pub fn new(compression: CompressionAlgorithm, inner: S) -> Self {
306 Self {
307 inner,
308 compression,
309 terminated: false,
310 }
311 }
312}
313
314impl<S, P, E> Stream for FramedMessageStream<S>
315where
316 S: Stream<Item = Result<P, E>> + Unpin,
317 P: prost::Message,
318 E: Into<TerminalMessage>,
319{
320 type Item = std::io::Result<Bytes>;
321
322 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
323 if self.terminated {
324 return Poll::Ready(None);
325 }
326
327 match Pin::new(&mut self.inner).poll_next(cx) {
328 Poll::Ready(Some(Ok(item))) => {
329 let bytes =
330 SessionMessage::regular(self.compression, &item).map(|msg| msg.encode());
331 Poll::Ready(Some(bytes))
332 }
333 Poll::Ready(Some(Err(e))) => {
334 self.terminated = true;
335 let bytes = SessionMessage::Terminal(e.into()).encode();
336 Poll::Ready(Some(Ok(bytes)))
337 }
338 Poll::Ready(None) => {
339 self.terminated = true;
340 Poll::Ready(None)
341 }
342 Poll::Pending => Poll::Pending,
343 }
344 }
345}
346
347pub struct FrameDecoder;
348
349impl tokio_util::codec::Decoder for FrameDecoder {
350 type Item = SessionMessage;
351 type Error = std::io::Error;
352
353 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
354 if src.len() < LENGTH_PREFIX_SIZE {
355 return Ok(None);
356 }
357
358 let length = ((src[0] as usize) << 16) | ((src[1] as usize) << 8) | (src[2] as usize);
359
360 if length > MAX_FRAME_BYTES {
361 return Err(std::io::Error::new(
362 std::io::ErrorKind::InvalidInput,
363 "frame exceeds decode limit",
364 ));
365 }
366
367 let total_size = LENGTH_PREFIX_SIZE + length;
368 if src.len() < total_size {
369 return Ok(None);
370 }
371
372 src.advance(LENGTH_PREFIX_SIZE);
373 let frame_bytes = src.split_to(length).freeze();
374 Ok(Some(SessionMessage::decode_message(frame_bytes)?))
375 }
376}
377
378#[cfg(test)]
379mod test {
380 use std::{
381 io,
382 pin::Pin,
383 task::{Context, Poll},
384 };
385
386 use bytes::BytesMut;
387 use futures::StreamExt;
388 use http::HeaderValue;
389 use proptest::{collection::vec, prelude::*};
390 use prost::Message;
391 use tokio_util::codec::Decoder;
392
393 use super::*;
394
395 #[derive(Clone, PartialEq, prost::Message)]
396 struct TestProto {
397 #[prost(bytes, tag = "1")]
398 payload: Vec<u8>,
399 }
400
401 impl TestProto {
402 fn new(payload: Vec<u8>) -> Self {
403 Self { payload }
404 }
405 }
406
407 #[derive(Debug, Clone)]
408 struct TestError {
409 status: u16,
410 body: &'static str,
411 }
412
413 impl From<TestError> for TerminalMessage {
414 fn from(val: TestError) -> Self {
415 TerminalMessage {
416 status: val.status,
417 body: val.body.to_string(),
418 }
419 }
420 }
421
422 fn decode_once(bytes: &Bytes) -> io::Result<SessionMessage> {
423 let mut decoder = FrameDecoder;
424 let mut buf = BytesMut::from(bytes.as_ref());
425 decoder
426 .decode(&mut buf)?
427 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "frame incomplete"))
428 }
429
430 fn compression_strategy() -> impl proptest::strategy::Strategy<Value = CompressionAlgorithm> {
431 prop_oneof![
432 Just(CompressionAlgorithm::None),
433 Just(CompressionAlgorithm::Gzip),
434 Just(CompressionAlgorithm::Zstd),
435 ]
436 }
437
438 fn chunk_bytes(data: &Bytes, pattern: &[usize]) -> Vec<Bytes> {
439 let mut chunks = Vec::new();
440 let mut offset = 0;
441 for &hint in pattern {
442 if offset >= data.len() {
443 break;
444 }
445 let remaining = data.len() - offset;
446 let take = (hint % remaining).saturating_add(1).min(remaining);
447 chunks.push(data.slice(offset..offset + take));
448 offset += take;
449 }
450 if offset < data.len() {
451 chunks.push(data.slice(offset..));
452 }
453 if chunks.is_empty() {
454 chunks.push(data.clone());
455 }
456 chunks
457 }
458
459 proptest! {
460 #[test]
461 fn regular_session_message_round_trips_proptest(
462 algo in compression_strategy(),
463 payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4)
464 ) {
465 let proto = TestProto::new(payload.clone());
466 let msg = SessionMessage::regular(algo, &proto).unwrap();
467 let encoded = msg.encode();
468 let decoded = decode_once(&encoded).unwrap();
469
470 prop_assert!(matches!(decoded, SessionMessage::Regular(_)));
471 let SessionMessage::Regular(data) = decoded else { unreachable!() };
472
473 let expected_compression = if algo == CompressionAlgorithm::None || proto.encoded_len() < COMPRESSION_THRESHOLD_BYTES {
474 CompressionAlgorithm::None
475 } else {
476 algo
477 };
478 let actual_compression = data.compression;
479
480 let restored = data.try_into_proto::<TestProto>().unwrap();
481 prop_assert_eq!(restored.payload, payload);
482 prop_assert_eq!(actual_compression, expected_compression);
483 }
484
485 #[test]
486 fn frame_decoder_handles_chunked_frames(
487 algo in compression_strategy(),
488 payload in vec(any::<u8>(), 0..=COMPRESSION_THRESHOLD_BYTES * 4),
489 chunk_pattern in vec(0usize..=16, 0..=16)
490 ) {
491 let proto = TestProto::new(payload);
492 let msg = SessionMessage::regular(algo, &proto).unwrap();
493 let encoded = msg.encode();
494 let expected = decode_once(&encoded).unwrap();
495
496 let chunks = chunk_bytes(&encoded, &chunk_pattern);
497 prop_assert_eq!(chunks.iter().map(|c| c.len()).sum::<usize>(), encoded.len());
498
499 let mut decoder = FrameDecoder;
500 let mut buf = BytesMut::new();
501 let mut decoded = None;
502
503 for (idx, chunk) in chunks.iter().enumerate() {
504 buf.extend_from_slice(chunk.as_ref());
505 let result = decoder.decode(&mut buf).expect("decode invocation failed");
506 if idx < chunks.len() - 1 {
507 prop_assert!(result.is_none());
508 } else {
509 let message = result.expect("final chunk should produce frame");
510 prop_assert!(buf.is_empty());
511 decoded = Some(message);
512 }
513 }
514
515 let decoded = decoded.expect("decoder never emitted frame");
516 prop_assert_eq!(decoded, expected);
517 }
518 }
519
520 #[test]
521 fn from_accept_encoding_prefers_zstd() {
522 let mut headers = http::HeaderMap::new();
523 headers.insert(
524 http::header::ACCEPT_ENCODING,
525 HeaderValue::from_static("gzip, zstd, br"),
526 );
527
528 let algo = CompressionAlgorithm::from_accept_encoding(&headers);
529 assert_eq!(algo, CompressionAlgorithm::Zstd);
530 }
531
532 #[test]
533 fn from_accept_encoding_falls_back_to_gzip() {
534 let mut headers = http::HeaderMap::new();
535 headers.insert(
536 http::header::ACCEPT_ENCODING,
537 HeaderValue::from_static("gzip;q=0.8, deflate"),
538 );
539
540 let algo = CompressionAlgorithm::from_accept_encoding(&headers);
541 assert_eq!(algo, CompressionAlgorithm::Gzip);
542 }
543
544 #[test]
545 fn from_accept_encoding_defaults_to_none() {
546 let headers = http::HeaderMap::new();
547 let algo = CompressionAlgorithm::from_accept_encoding(&headers);
548 assert_eq!(algo, CompressionAlgorithm::None);
549 }
550
551 #[test]
552 fn regular_session_message_round_trips() {
553 let proto = TestProto::new(vec![1, 2, 3, 4]);
554 let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
555 let encoded = msg.encode();
556 let decoded = decode_once(&encoded).unwrap();
557
558 match decoded {
559 SessionMessage::Regular(data) => {
560 assert_eq!(data.compression, CompressionAlgorithm::None);
561 let restored = data.try_into_proto::<TestProto>().unwrap();
562 assert_eq!(restored, proto);
563 }
564 SessionMessage::Terminal(_) => panic!("expected regular message"),
565 }
566 }
567
568 #[test]
569 fn terminal_session_message_round_trips() {
570 let terminal = TerminalMessage {
571 status: 418,
572 body: "short-circuit".to_string(),
573 };
574 let msg = SessionMessage::from(terminal.clone());
575 let encoded = msg.encode();
576 let decoded = decode_once(&encoded).unwrap();
577
578 match decoded {
579 SessionMessage::Regular(_) => panic!("expected terminal message"),
580 SessionMessage::Terminal(decoded_terminal) => {
581 assert_eq!(decoded_terminal, terminal);
582 }
583 }
584 }
585
586 #[test]
587 fn frame_decoder_waits_for_complete_frame() {
588 let proto = TestProto::new(vec![9, 9, 9]);
589 let msg = SessionMessage::regular(CompressionAlgorithm::None, &proto).unwrap();
590 let encoded = msg.encode();
591 let mut decoder = FrameDecoder;
592
593 let split_idx = encoded.len() - 1;
594 let mut buf = BytesMut::from(&encoded[..split_idx]);
595 assert!(decoder.decode(&mut buf).unwrap().is_none());
596 buf.extend_from_slice(&encoded[split_idx..]);
597 let decoded = decoder.decode(&mut buf).unwrap().unwrap();
598
599 match decoded {
600 SessionMessage::Regular(data) => {
601 let restored = data.try_into_proto::<TestProto>().unwrap();
602 assert_eq!(restored, proto);
603 }
604 SessionMessage::Terminal(_) => panic!("expected regular message"),
605 }
606 assert!(buf.is_empty());
607 }
608
609 #[test]
610 fn frame_decoder_rejects_frames_exceeding_decode_limit() {
611 let length = MAX_FRAME_BYTES + 1;
612 let prefix = [
613 ((length >> 16) & 0xFF) as u8,
614 ((length >> 8) & 0xFF) as u8,
615 (length & 0xFF) as u8,
616 ];
617 let mut buf = BytesMut::from(prefix.as_slice());
618 let mut decoder = FrameDecoder;
619 let err = decoder.decode(&mut buf).unwrap_err();
620 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
621 }
622
623 #[test]
624 #[should_panic(expected = "encoder limit")]
625 fn session_message_encode_rejects_frames_over_limit() {
626 let data = CompressedData {
627 compression: CompressionAlgorithm::None,
628 payload: Bytes::from(vec![0u8; MAX_FRAME_BYTES]),
629 };
630 let msg = SessionMessage::from(data);
631 let _ = msg.encode();
632 }
633
634 #[test]
635 fn frame_decoder_rejects_unknown_compression() {
636 let mut raw = vec![0, 0, 1];
637 raw.push(0x60);
638 let mut decoder = FrameDecoder;
639 let mut buf = BytesMut::from(raw.as_slice());
640 let err = decoder.decode(&mut buf).unwrap_err();
641 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
642 }
643
644 #[test]
645 fn frame_decoder_rejects_terminal_without_status() {
646 let mut raw = vec![0, 0, 1];
647 raw.push(FLAG_TERMINAL);
648 let mut decoder = FrameDecoder;
649 let mut buf = BytesMut::from(raw.as_slice());
650 let err = decoder.decode(&mut buf).unwrap_err();
651 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
652 }
653
654 #[test]
655 fn frame_decoder_handles_empty_payload() {
656 let raw = vec![0, 0, 0];
657 let mut decoder = FrameDecoder;
658 let mut buf = BytesMut::from(raw.as_slice());
659 let err = decoder.decode(&mut buf).unwrap_err();
660 assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
661 }
662
663 #[test]
664 fn compressed_data_round_trip_gzip() {
665 let payload = vec![42; 1_200_000];
666 let proto = TestProto::new(payload.clone());
667 let msg = SessionMessage::regular(CompressionAlgorithm::Gzip, &proto).unwrap();
668 let encoded = msg.encode();
669 let decoded = decode_once(&encoded).unwrap();
670
671 match decoded {
672 SessionMessage::Regular(data) => {
673 assert_eq!(data.compression, CompressionAlgorithm::Gzip);
674 assert!(data.payload.len() < proto.encode_to_vec().len());
675 let restored = data.try_into_proto::<TestProto>().unwrap();
676 assert_eq!(restored.payload, payload);
677 }
678 SessionMessage::Terminal(_) => panic!("expected regular message"),
679 }
680 }
681
682 #[test]
683 fn compressed_data_round_trip_zstd() {
684 let payload = vec![7; 1_100_000];
685 let proto = TestProto::new(payload.clone());
686 let msg = SessionMessage::regular(CompressionAlgorithm::Zstd, &proto).unwrap();
687 let encoded = msg.encode();
688 let decoded = decode_once(&encoded).unwrap();
689
690 match decoded {
691 SessionMessage::Regular(data) => {
692 assert_eq!(data.compression, CompressionAlgorithm::Zstd);
693 assert!(data.payload.len() < proto.encode_to_vec().len());
694 let restored = data.try_into_proto::<TestProto>().unwrap();
695 assert_eq!(restored.payload, payload);
696 }
697 SessionMessage::Terminal(_) => panic!("expected regular message"),
698 }
699 }
700
701 #[test]
702 fn decompression_rejects_payloads_exceeding_limit() {
703 let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
704 let proto = TestProto::new(payload);
705 let encoded = proto.encode_to_vec();
706
707 for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
708 let compressed = match algo {
709 CompressionAlgorithm::Gzip => {
710 let mut out = Vec::new();
711 let mut encoder = GzEncoder::new(&mut out, Compression::default());
712 encoder.write_all(encoded.as_slice()).unwrap();
713 encoder.finish().unwrap();
714 out
715 }
716 CompressionAlgorithm::Zstd => {
717 let mut out = Vec::new();
718 zstd::stream::copy_encode(encoded.as_slice(), &mut out, 0).unwrap();
719 out
720 }
721 CompressionAlgorithm::None => unreachable!("explicitly excluded in test"),
722 };
723
724 let data = CompressedData {
725 compression: algo,
726 payload: Bytes::from(compressed),
727 };
728 assert!(data.payload.len() <= MAX_FRAME_PAYLOAD_BYTES);
729
730 let err = data.try_into_proto::<TestProto>().expect_err("should fail");
731 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
732 assert!(
733 err.to_string()
734 .contains("decompressed payload exceeds limit")
735 );
736 }
737 }
738
739 #[test]
740 fn compress_rejects_payloads_exceeding_decompressed_limit() {
741 let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES + 1];
742 let proto = TestProto::new(payload);
743
744 let err = CompressedData::compress(CompressionAlgorithm::Gzip, proto.encode_to_vec())
745 .expect_err("should fail");
746 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
747 assert!(
748 err.to_string()
749 .contains("payload exceeds decompressed limit")
750 );
751 }
752
753 #[test]
754 fn compress_allows_payload_at_exact_limit_without_encode_panic() {
755 let payload = vec![0; MAX_DECOMPRESSED_PAYLOAD_BYTES];
756 let data = CompressedData::compress(CompressionAlgorithm::None, payload).unwrap();
757 let encoded = SessionMessage::from(data).encode();
758 assert_eq!(encoded.len(), LENGTH_PREFIX_SIZE + MAX_FRAME_BYTES);
759 }
760
761 #[test]
762 fn compress_rejects_incompressible_payload_that_exceeds_frame_limit_after_compression() {
763 let mut payload = vec![0u8; MAX_DECOMPRESSED_PAYLOAD_BYTES];
764 let mut x = 0x1234_5678u32;
765 for byte in &mut payload {
766 x ^= x << 13;
767 x ^= x >> 17;
768 x ^= x << 5;
769 *byte = (x & 0xFF) as u8;
770 }
771
772 for algo in [CompressionAlgorithm::Gzip, CompressionAlgorithm::Zstd] {
773 let err = CompressedData::compress(algo, payload.clone()).expect_err("should fail");
774 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
775 assert!(
776 err.to_string()
777 .contains("compressed payload exceeds frame limit")
778 );
779 }
780 }
781
782 #[test]
783 fn framed_message_stream_yields_terminal_on_error() {
784 let proto = TestProto::new(vec![1, 2, 3]);
785 let items = vec![
786 Ok(proto.clone()),
787 Err(TestError {
788 status: 500,
789 body: "boom",
790 }),
791 Ok(proto.clone()),
792 ];
793
794 let stream = futures::stream::iter(items);
795 let framed = FramedMessageStream::new(CompressionAlgorithm::None, stream);
796 let outputs = futures::executor::block_on(async {
797 framed.collect::<Vec<std::io::Result<Bytes>>>().await
798 });
799
800 assert_eq!(outputs.len(), 2);
801
802 let first = outputs[0].as_ref().expect("first frame ok");
803 match decode_once(first).unwrap() {
804 SessionMessage::Regular(data) => {
805 let restored = data.try_into_proto::<TestProto>().unwrap();
806 assert_eq!(restored, proto);
807 }
808 SessionMessage::Terminal(_) => panic!("expected regular message"),
809 }
810
811 let second = outputs[1].as_ref().expect("second frame ok");
812 match decode_once(second).unwrap() {
813 SessionMessage::Regular(_) => panic!("expected terminal message"),
814 SessionMessage::Terminal(term) => {
815 assert_eq!(term.status, 500);
816 assert_eq!(term.body, "boom");
817 }
818 }
819 }
820
821 #[test]
822 fn framed_message_stream_stops_after_termination() {
823 let mut stream = FramedMessageStream::new(
824 CompressionAlgorithm::None,
825 futures::stream::iter(vec![
826 Ok(TestProto::new(vec![0])),
827 Err(TestError {
828 status: 400,
829 body: "bad",
830 }),
831 ]),
832 );
833
834 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
835
836 match Pin::new(&mut stream).poll_next(&mut cx) {
837 Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
838 SessionMessage::Regular(_) => {}
839 SessionMessage::Terminal(_) => panic!("expected regular message"),
840 },
841 other => panic!("unexpected poll result: {other:?}"),
842 }
843
844 match Pin::new(&mut stream).poll_next(&mut cx) {
845 Poll::Ready(Some(Ok(bytes))) => match decode_once(&bytes).unwrap() {
846 SessionMessage::Terminal(term) => {
847 assert_eq!(term.status, 400);
848 assert_eq!(term.body, "bad");
849 }
850 SessionMessage::Regular(_) => panic!("expected terminal message"),
851 },
852 other => panic!("unexpected poll result: {other:?}"),
853 }
854
855 match Pin::new(&mut stream).poll_next(&mut cx) {
856 Poll::Ready(None) => {}
857 other => panic!("expected stream to terminate, got {other:?}"),
858 }
859 }
860}