zamsync_network/protocol/
frame.rs1use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
2use std::io::{Read, Write};
3use zamsync_core::{ZamError, ZamResult};
4
5pub const MAX_FRAME_SIZE: u32 = 64 * 1024 * 1024;
7
8const COMPRESS_THRESHOLD: usize = 64;
10
11const FLAG_RAW: u8 = 0x00;
12const FLAG_ZSTD: u8 = 0x01;
13
14pub fn write_frame(writer: &mut impl Write, payload: &[u8]) -> ZamResult<usize> {
21 if payload.len() as u64 >= MAX_FRAME_SIZE as u64 {
22 return Err(ZamError::Protocol(format!(
23 "frame payload too large: {} bytes (max {})",
24 payload.len(),
25 MAX_FRAME_SIZE - 1
26 )));
27 }
28
29 let (flag, body): (u8, Vec<u8>) = if payload.len() >= COMPRESS_THRESHOLD {
30 let compressed = zstd::encode_all(payload, 3)
31 .map_err(|e| ZamError::Protocol(format!("zstd compress: {e}")))?;
32 if compressed.len() < payload.len() {
33 (FLAG_ZSTD, compressed)
34 } else {
35 (FLAG_RAW, payload.to_vec())
36 }
37 } else {
38 (FLAG_RAW, payload.to_vec())
39 };
40
41 let total_len = 1u32 + body.len() as u32;
42 writer.write_u32::<BigEndian>(total_len)?;
43 writer.write_u8(flag)?;
44 writer.write_all(&body)?;
45 Ok(4 + 1 + body.len())
47}
48
49pub fn read_frame(reader: &mut impl Read) -> ZamResult<Vec<u8>> {
50 let total_len = reader.read_u32::<BigEndian>()?;
51 if total_len as u64 > MAX_FRAME_SIZE as u64 {
52 return Err(ZamError::Protocol(format!(
53 "received frame too large: {} bytes (max {})",
54 total_len, MAX_FRAME_SIZE
55 )));
56 }
57
58 if total_len == 0 {
59 return Ok(vec![]);
60 }
61
62 let flag = reader.read_u8()?;
63 let body_len = (total_len - 1) as usize;
64 let mut body = vec![0u8; body_len];
65 reader.read_exact(&mut body)?;
66
67 match flag {
68 FLAG_RAW => Ok(body),
69 FLAG_ZSTD => zstd::decode_all(body.as_slice())
70 .map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}"))),
71 other => Err(ZamError::Protocol(format!(
72 "unknown frame flag: 0x{other:02x}"
73 ))),
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use std::io::Cursor;
81
82 #[test]
83 fn test_frame_roundtrip_small() {
84 let payload = b"hello world"; let mut buf = Vec::new();
86 write_frame(&mut buf, payload).unwrap();
87 let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
88 assert_eq!(decoded, payload);
89 }
90
91 #[test]
92 fn test_frame_roundtrip_empty() {
93 let mut buf = Vec::new();
94 write_frame(&mut buf, &[]).unwrap();
95 let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
96 assert!(decoded.is_empty());
97 }
98
99 #[test]
100 fn test_frame_compression_roundtrip() {
101 let payload: Vec<u8> = (0..512).map(|i| b"abcdefghij"[i % 10]).collect();
103 let mut buf = Vec::new();
104 write_frame(&mut buf, &payload).unwrap();
105
106 assert!(
108 buf.len() < payload.len(),
109 "compressed frame ({} bytes) should be smaller than raw payload ({} bytes)",
110 buf.len(),
111 payload.len()
112 );
113
114 let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
115 assert_eq!(decoded, payload);
116 }
117
118 #[test]
119 fn test_frame_compression_flag_raw() {
120 let payload = b"hi";
122 let mut buf = Vec::new();
123 write_frame(&mut buf, payload).unwrap();
124 assert_eq!(buf[4], FLAG_RAW);
126 }
127
128 #[test]
129 fn test_frame_compression_flag_zstd() {
130 let payload: Vec<u8> = vec![b'x'; 1024];
132 let mut buf = Vec::new();
133 write_frame(&mut buf, &payload).unwrap();
134 assert_eq!(buf[4], FLAG_ZSTD);
135 }
136
137 #[test]
138 fn test_write_frame_rejects_payload_at_max_size() {
139 let huge = vec![0u8; MAX_FRAME_SIZE as usize];
143 let mut buf = Vec::new();
144 let result = write_frame(&mut buf, &huge);
145 assert!(
146 result.is_err(),
147 "payload at MAX_FRAME_SIZE must be rejected"
148 );
149 assert!(buf.is_empty(), "no bytes must be written on rejection");
150 }
151
152 #[test]
153 fn test_try_consume_frame_rejects_oversized_length_field() {
154 use super::super::frame_buf::FrameBuffer;
155 use std::io::Cursor;
156
157 let oversized_len = (MAX_FRAME_SIZE as u64 + 1) as u32;
160 let mut wire = Vec::new();
161 wire.extend_from_slice(&oversized_len.to_be_bytes()); wire.push(0x00); let mut fb = FrameBuffer::new();
166 let result = fb.try_read_frame(&mut Cursor::new(&wire));
167 assert!(result.is_err(), "oversized length field must be rejected");
168 }
169}