Skip to main content

shape_wire/transport/
framing.rs

1//! Transparent zstd compression for wire frames.
2//!
3//! Frame format: `[flags: u8] [body...]`
4//! - `flags & 0x01` = body is zstd compressed
5//! - Payloads < COMPRESSION_THRESHOLD bytes: stored uncompressed
6//! - Payloads >= threshold: compressed with zstd level 3, used only if smaller
7
8use super::TransportError;
9
10/// Minimum payload size to attempt compression.
11pub const COMPRESSION_THRESHOLD: usize = 256;
12
13/// Maximum allowed decompressed size (256 MB) to prevent decompression bombs.
14pub const MAX_DECOMPRESSED_SIZE: usize = 256 * 1024 * 1024;
15
16/// Compression level for zstd (level 3 = good ratio with fast speed).
17const ZSTD_LEVEL: i32 = 3;
18
19const FLAG_COMPRESSED: u8 = 0x01;
20
21/// Encode data into a framed payload: `[flags: u8] [body...]`
22///
23/// If data is >= COMPRESSION_THRESHOLD bytes and compresses smaller,
24/// the body is zstd-compressed and FLAG_COMPRESSED is set.
25pub fn encode_framed(data: &[u8]) -> Vec<u8> {
26    if data.len() < COMPRESSION_THRESHOLD {
27        let mut out = Vec::with_capacity(1 + data.len());
28        out.push(0x00); // no compression
29        out.extend_from_slice(data);
30        return out;
31    }
32
33    match zstd::stream::encode_all(data, ZSTD_LEVEL) {
34        Ok(compressed) if compressed.len() < data.len() => {
35            let mut out = Vec::with_capacity(1 + compressed.len());
36            out.push(FLAG_COMPRESSED);
37            out.extend_from_slice(&compressed);
38            out
39        }
40        _ => {
41            // Compression failed or didn't help — store raw
42            let mut out = Vec::with_capacity(1 + data.len());
43            out.push(0x00);
44            out.extend_from_slice(data);
45            out
46        }
47    }
48}
49
50/// Decode a framed payload: read flags byte, decompress if needed.
51pub fn decode_framed(data: &[u8]) -> Result<Vec<u8>, TransportError> {
52    if data.is_empty() {
53        return Err(TransportError::ReceiveFailed(
54            "empty framed payload".to_string(),
55        ));
56    }
57
58    let flags = data[0];
59    let body = &data[1..];
60
61    if flags & FLAG_COMPRESSED != 0 {
62        let decompressed = zstd::stream::decode_all(body)
63            .map_err(|e| TransportError::ReceiveFailed(format!("zstd decompress: {}", e)))?;
64
65        if decompressed.len() > MAX_DECOMPRESSED_SIZE {
66            return Err(TransportError::PayloadTooLarge {
67                size: decompressed.len(),
68                max: MAX_DECOMPRESSED_SIZE,
69            });
70        }
71
72        Ok(decompressed)
73    } else {
74        Ok(body.to_vec())
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn test_small_payload_no_compression() {
84        let data = b"hello";
85        let framed = encode_framed(data);
86        assert_eq!(framed[0], 0x00, "small payload should not be compressed");
87        assert_eq!(&framed[1..], data);
88        let decoded = decode_framed(&framed).unwrap();
89        assert_eq!(decoded, data);
90    }
91
92    #[test]
93    fn test_large_compressible_payload() {
94        // Highly compressible: repeated pattern
95        let data = vec![0x42u8; 4096];
96        let framed = encode_framed(&data);
97        assert_eq!(
98            framed[0] & FLAG_COMPRESSED,
99            FLAG_COMPRESSED,
100            "large compressible payload should be compressed"
101        );
102        assert!(framed.len() < data.len(), "compressed should be smaller");
103        let decoded = decode_framed(&framed).unwrap();
104        assert_eq!(decoded, data);
105    }
106
107    #[test]
108    fn test_large_incompressible_payload() {
109        // Random-ish data that won't compress well
110        let mut data = Vec::with_capacity(1024);
111        for i in 0..1024u32 {
112            data.extend_from_slice(&i.to_le_bytes());
113        }
114        let framed = encode_framed(&data);
115        let decoded = decode_framed(&framed).unwrap();
116        assert_eq!(decoded, data);
117    }
118
119    #[test]
120    fn test_empty_framed_payload_error() {
121        assert!(decode_framed(&[]).is_err());
122    }
123
124    #[test]
125    fn test_roundtrip_at_threshold_boundary() {
126        let data = vec![0xAA; COMPRESSION_THRESHOLD];
127        let framed = encode_framed(&data);
128        let decoded = decode_framed(&framed).unwrap();
129        assert_eq!(decoded, data);
130    }
131
132    #[test]
133    fn test_just_below_threshold() {
134        let data = vec![0xBB; COMPRESSION_THRESHOLD - 1];
135        let framed = encode_framed(&data);
136        assert_eq!(framed[0], 0x00, "below threshold should not compress");
137        let decoded = decode_framed(&framed).unwrap();
138        assert_eq!(decoded, data);
139    }
140}