Skip to main content

zamsync_network/protocol/
frame.rs

1use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
2use std::io::{Read, Write};
3use zamsync_core::{ZamError, ZamResult};
4
5/// Maximum decompressed payload size accepted from a peer.
6pub const MAX_FRAME_SIZE: u32 = 64 * 1024 * 1024;
7
8/// Payloads below this size are sent uncompressed -- zstd overhead exceeds savings.
9const COMPRESS_THRESHOLD: usize = 64;
10
11const FLAG_RAW: u8 = 0x00;
12const FLAG_ZSTD: u8 = 0x01;
13
14/// Wire format:
15///   [4 bytes] uint32 big-endian  -- total byte count that follows (flag + body)
16///   [1 byte]  compression flag   -- 0x00 raw, 0x01 zstd
17///   [N bytes] body               -- raw payload or zstd-compressed payload
18pub fn write_frame(writer: &mut impl Write, payload: &[u8]) -> ZamResult<()> {
19    if payload.len() as u64 >= MAX_FRAME_SIZE as u64 {
20        return Err(ZamError::Protocol(format!(
21            "frame payload too large: {} bytes (max {})",
22            payload.len(),
23            MAX_FRAME_SIZE - 1
24        )));
25    }
26
27    let (flag, body): (u8, Vec<u8>) = if payload.len() >= COMPRESS_THRESHOLD {
28        let compressed = zstd::encode_all(payload, 3)
29            .map_err(|e| ZamError::Protocol(format!("zstd compress: {e}")))?;
30        if compressed.len() < payload.len() {
31            (FLAG_ZSTD, compressed)
32        } else {
33            (FLAG_RAW, payload.to_vec())
34        }
35    } else {
36        (FLAG_RAW, payload.to_vec())
37    };
38
39    let total_len = 1u32 + body.len() as u32;
40    writer.write_u32::<BigEndian>(total_len)?;
41    writer.write_u8(flag)?;
42    writer.write_all(&body)?;
43    Ok(())
44}
45
46pub fn read_frame(reader: &mut impl Read) -> ZamResult<Vec<u8>> {
47    let total_len = reader.read_u32::<BigEndian>()?;
48    if total_len as u64 > MAX_FRAME_SIZE as u64 {
49        return Err(ZamError::Protocol(format!(
50            "received frame too large: {} bytes (max {})",
51            total_len, MAX_FRAME_SIZE
52        )));
53    }
54
55    if total_len == 0 {
56        return Ok(vec![]);
57    }
58
59    let flag = reader.read_u8()?;
60    let body_len = (total_len - 1) as usize;
61    let mut body = vec![0u8; body_len];
62    reader.read_exact(&mut body)?;
63
64    match flag {
65        FLAG_RAW => Ok(body),
66        FLAG_ZSTD => zstd::decode_all(body.as_slice())
67            .map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}"))),
68        other => Err(ZamError::Protocol(format!(
69            "unknown frame flag: 0x{other:02x}"
70        ))),
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use std::io::Cursor;
78
79    #[test]
80    fn test_frame_roundtrip_small() {
81        let payload = b"hello world"; // < COMPRESS_THRESHOLD, sent raw
82        let mut buf = Vec::new();
83        write_frame(&mut buf, payload).unwrap();
84        let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
85        assert_eq!(decoded, payload);
86    }
87
88    #[test]
89    fn test_frame_roundtrip_empty() {
90        let mut buf = Vec::new();
91        write_frame(&mut buf, &[]).unwrap();
92        let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
93        assert!(decoded.is_empty());
94    }
95
96    #[test]
97    fn test_frame_compression_roundtrip() {
98        // JSON-like payload that compresses well
99        let payload: Vec<u8> = (0..512).map(|i| b"abcdefghij"[i % 10]).collect();
100        let mut buf = Vec::new();
101        write_frame(&mut buf, &payload).unwrap();
102
103        // Wire bytes must be smaller than raw payload + overhead
104        assert!(
105            buf.len() < payload.len(),
106            "compressed frame ({} bytes) should be smaller than raw payload ({} bytes)",
107            buf.len(),
108            payload.len()
109        );
110
111        let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
112        assert_eq!(decoded, payload);
113    }
114
115    #[test]
116    fn test_frame_compression_flag_raw() {
117        // Small payload -- flag byte must be FLAG_RAW
118        let payload = b"hi";
119        let mut buf = Vec::new();
120        write_frame(&mut buf, payload).unwrap();
121        // bytes 4..5 is the flag
122        assert_eq!(buf[4], FLAG_RAW);
123    }
124
125    #[test]
126    fn test_frame_compression_flag_zstd() {
127        // Large repetitive payload -- flag byte must be FLAG_ZSTD
128        let payload: Vec<u8> = vec![b'x'; 1024];
129        let mut buf = Vec::new();
130        write_frame(&mut buf, &payload).unwrap();
131        assert_eq!(buf[4], FLAG_ZSTD);
132    }
133
134    #[test]
135    fn test_write_frame_rejects_payload_at_max_size() {
136        // Payload exactly at MAX_FRAME_SIZE must be rejected (the check is >=).
137        // write_frame checks length before any allocation or I/O, so this returns
138        // immediately even though we allocate a large Vec here.
139        let huge = vec![0u8; MAX_FRAME_SIZE as usize];
140        let mut buf = Vec::new();
141        let result = write_frame(&mut buf, &huge);
142        assert!(
143            result.is_err(),
144            "payload at MAX_FRAME_SIZE must be rejected"
145        );
146        assert!(buf.is_empty(), "no bytes must be written on rejection");
147    }
148
149    #[test]
150    fn test_try_consume_frame_rejects_oversized_length_field() {
151        use super::super::frame_buf::FrameBuffer;
152        use std::io::Cursor;
153
154        // Craft a wire frame whose 4-byte length field claims MAX_FRAME_SIZE + 1.
155        // The FrameBuffer must return an error, not try to allocate that much memory.
156        let oversized_len = (MAX_FRAME_SIZE as u64 + 1) as u32;
157        let mut wire = Vec::new();
158        wire.extend_from_slice(&oversized_len.to_be_bytes()); // length field
159        wire.push(0x00); // flag byte (won't be reached)
160                         // No actual payload bytes -- the error fires before the payload is read.
161
162        let mut fb = FrameBuffer::new();
163        let result = fb.try_read_frame(&mut Cursor::new(&wire));
164        assert!(result.is_err(), "oversized length field must be rejected");
165    }
166}