Skip to main content

zamsync_network/protocol/
frame.rs

1use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
2use std::io::{Read, Write};
3use zamsync_core::{ZamError, ZamResult};
4
5/// Maximum decompressed payload size accepted from a peer.
6pub const MAX_FRAME_SIZE: u32 = 64 * 1024 * 1024;
7
8/// Payloads below this size are sent uncompressed -- zstd overhead exceeds savings.
9const COMPRESS_THRESHOLD: usize = 64;
10
11const FLAG_RAW: u8 = 0x00;
12const FLAG_ZSTD: u8 = 0x01;
13
14/// Wire format:
15///   [4 bytes] uint32 big-endian  -- total byte count that follows (flag + body)
16///   [1 byte]  compression flag   -- 0x00 raw, 0x01 zstd
17///   [N bytes] body               -- raw payload or zstd-compressed payload
18///
19/// Returns the number of bytes written to `writer` (length prefix + flag + body).
20pub fn write_frame(writer: &mut impl Write, payload: &[u8]) -> ZamResult<usize> {
21    if payload.len() as u64 >= MAX_FRAME_SIZE as u64 {
22        return Err(ZamError::Protocol(format!(
23            "frame payload too large: {} bytes (max {})",
24            payload.len(),
25            MAX_FRAME_SIZE - 1
26        )));
27    }
28
29    let (flag, body): (u8, Vec<u8>) = if payload.len() >= COMPRESS_THRESHOLD {
30        let compressed = zstd::encode_all(payload, 3)
31            .map_err(|e| ZamError::Protocol(format!("zstd compress: {e}")))?;
32        if compressed.len() < payload.len() {
33            (FLAG_ZSTD, compressed)
34        } else {
35            (FLAG_RAW, payload.to_vec())
36        }
37    } else {
38        (FLAG_RAW, payload.to_vec())
39    };
40
41    let total_len = 1u32 + body.len() as u32;
42    writer.write_u32::<BigEndian>(total_len)?;
43    writer.write_u8(flag)?;
44    writer.write_all(&body)?;
45    // 4-byte length prefix + 1-byte flag + body
46    Ok(4 + 1 + body.len())
47}
48
49pub fn read_frame(reader: &mut impl Read) -> ZamResult<Vec<u8>> {
50    let total_len = reader.read_u32::<BigEndian>()?;
51    if total_len as u64 > MAX_FRAME_SIZE as u64 {
52        return Err(ZamError::Protocol(format!(
53            "received frame too large: {} bytes (max {})",
54            total_len, MAX_FRAME_SIZE
55        )));
56    }
57
58    if total_len == 0 {
59        return Ok(vec![]);
60    }
61
62    let flag = reader.read_u8()?;
63    let body_len = (total_len - 1) as usize;
64    let mut body = vec![0u8; body_len];
65    reader.read_exact(&mut body)?;
66
67    match flag {
68        FLAG_RAW => Ok(body),
69        FLAG_ZSTD => zstd::decode_all(body.as_slice())
70            .map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}"))),
71        other => Err(ZamError::Protocol(format!(
72            "unknown frame flag: 0x{other:02x}"
73        ))),
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use std::io::Cursor;
81
82    #[test]
83    fn test_frame_roundtrip_small() {
84        let payload = b"hello world"; // < COMPRESS_THRESHOLD, sent raw
85        let mut buf = Vec::new();
86        write_frame(&mut buf, payload).unwrap();
87        let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
88        assert_eq!(decoded, payload);
89    }
90
91    #[test]
92    fn test_frame_roundtrip_empty() {
93        let mut buf = Vec::new();
94        write_frame(&mut buf, &[]).unwrap();
95        let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
96        assert!(decoded.is_empty());
97    }
98
99    #[test]
100    fn test_frame_compression_roundtrip() {
101        // JSON-like payload that compresses well
102        let payload: Vec<u8> = (0..512).map(|i| b"abcdefghij"[i % 10]).collect();
103        let mut buf = Vec::new();
104        write_frame(&mut buf, &payload).unwrap();
105
106        // Wire bytes must be smaller than raw payload + overhead
107        assert!(
108            buf.len() < payload.len(),
109            "compressed frame ({} bytes) should be smaller than raw payload ({} bytes)",
110            buf.len(),
111            payload.len()
112        );
113
114        let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
115        assert_eq!(decoded, payload);
116    }
117
118    #[test]
119    fn test_frame_compression_flag_raw() {
120        // Small payload -- flag byte must be FLAG_RAW
121        let payload = b"hi";
122        let mut buf = Vec::new();
123        write_frame(&mut buf, payload).unwrap();
124        // bytes 4..5 is the flag
125        assert_eq!(buf[4], FLAG_RAW);
126    }
127
128    #[test]
129    fn test_frame_compression_flag_zstd() {
130        // Large repetitive payload -- flag byte must be FLAG_ZSTD
131        let payload: Vec<u8> = vec![b'x'; 1024];
132        let mut buf = Vec::new();
133        write_frame(&mut buf, &payload).unwrap();
134        assert_eq!(buf[4], FLAG_ZSTD);
135    }
136
137    #[test]
138    fn test_write_frame_rejects_payload_at_max_size() {
139        // Payload exactly at MAX_FRAME_SIZE must be rejected (the check is >=).
140        // write_frame checks length before any allocation or I/O, so this returns
141        // immediately even though we allocate a large Vec here.
142        let huge = vec![0u8; MAX_FRAME_SIZE as usize];
143        let mut buf = Vec::new();
144        let result = write_frame(&mut buf, &huge);
145        assert!(
146            result.is_err(),
147            "payload at MAX_FRAME_SIZE must be rejected"
148        );
149        assert!(buf.is_empty(), "no bytes must be written on rejection");
150    }
151
152    #[test]
153    fn test_try_consume_frame_rejects_oversized_length_field() {
154        use super::super::frame_buf::FrameBuffer;
155        use std::io::Cursor;
156
157        // Craft a wire frame whose 4-byte length field claims MAX_FRAME_SIZE + 1.
158        // The FrameBuffer must return an error, not try to allocate that much memory.
159        let oversized_len = (MAX_FRAME_SIZE as u64 + 1) as u32;
160        let mut wire = Vec::new();
161        wire.extend_from_slice(&oversized_len.to_be_bytes()); // length field
162        wire.push(0x00); // flag byte (won't be reached)
163                         // No actual payload bytes -- the error fires before the payload is read.
164
165        let mut fb = FrameBuffer::new();
166        let result = fb.try_read_frame(&mut Cursor::new(&wire));
167        assert!(result.is_err(), "oversized length field must be rejected");
168    }
169}