zamsync_network/protocol/
frame.rs1use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
2use std::io::{Read, Write};
3use zamsync_core::{ZamError, ZamResult};
4
5pub const MAX_FRAME_SIZE: u32 = 64 * 1024 * 1024;
7
8const COMPRESS_THRESHOLD: usize = 64;
10
11const FLAG_RAW: u8 = 0x00;
12const FLAG_ZSTD: u8 = 0x01;
13
14pub fn write_frame(writer: &mut impl Write, payload: &[u8]) -> ZamResult<()> {
19 if payload.len() as u64 >= MAX_FRAME_SIZE as u64 {
20 return Err(ZamError::Protocol(format!(
21 "frame payload too large: {} bytes (max {})",
22 payload.len(),
23 MAX_FRAME_SIZE - 1
24 )));
25 }
26
27 let (flag, body): (u8, Vec<u8>) = if payload.len() >= COMPRESS_THRESHOLD {
28 let compressed = zstd::encode_all(payload, 3)
29 .map_err(|e| ZamError::Protocol(format!("zstd compress: {e}")))?;
30 if compressed.len() < payload.len() {
31 (FLAG_ZSTD, compressed)
32 } else {
33 (FLAG_RAW, payload.to_vec())
34 }
35 } else {
36 (FLAG_RAW, payload.to_vec())
37 };
38
39 let total_len = 1u32 + body.len() as u32;
40 writer.write_u32::<BigEndian>(total_len)?;
41 writer.write_u8(flag)?;
42 writer.write_all(&body)?;
43 Ok(())
44}
45
46pub fn read_frame(reader: &mut impl Read) -> ZamResult<Vec<u8>> {
47 let total_len = reader.read_u32::<BigEndian>()?;
48 if total_len as u64 > MAX_FRAME_SIZE as u64 {
49 return Err(ZamError::Protocol(format!(
50 "received frame too large: {} bytes (max {})",
51 total_len, MAX_FRAME_SIZE
52 )));
53 }
54
55 if total_len == 0 {
56 return Ok(vec![]);
57 }
58
59 let flag = reader.read_u8()?;
60 let body_len = (total_len - 1) as usize;
61 let mut body = vec![0u8; body_len];
62 reader.read_exact(&mut body)?;
63
64 match flag {
65 FLAG_RAW => Ok(body),
66 FLAG_ZSTD => zstd::decode_all(body.as_slice())
67 .map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}"))),
68 other => Err(ZamError::Protocol(format!(
69 "unknown frame flag: 0x{other:02x}"
70 ))),
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use std::io::Cursor;
78
79 #[test]
80 fn test_frame_roundtrip_small() {
81 let payload = b"hello world"; let mut buf = Vec::new();
83 write_frame(&mut buf, payload).unwrap();
84 let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
85 assert_eq!(decoded, payload);
86 }
87
88 #[test]
89 fn test_frame_roundtrip_empty() {
90 let mut buf = Vec::new();
91 write_frame(&mut buf, &[]).unwrap();
92 let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
93 assert!(decoded.is_empty());
94 }
95
96 #[test]
97 fn test_frame_compression_roundtrip() {
98 let payload: Vec<u8> = (0..512).map(|i| b"abcdefghij"[i % 10]).collect();
100 let mut buf = Vec::new();
101 write_frame(&mut buf, &payload).unwrap();
102
103 assert!(
105 buf.len() < payload.len(),
106 "compressed frame ({} bytes) should be smaller than raw payload ({} bytes)",
107 buf.len(),
108 payload.len()
109 );
110
111 let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
112 assert_eq!(decoded, payload);
113 }
114
115 #[test]
116 fn test_frame_compression_flag_raw() {
117 let payload = b"hi";
119 let mut buf = Vec::new();
120 write_frame(&mut buf, payload).unwrap();
121 assert_eq!(buf[4], FLAG_RAW);
123 }
124
125 #[test]
126 fn test_frame_compression_flag_zstd() {
127 let payload: Vec<u8> = vec![b'x'; 1024];
129 let mut buf = Vec::new();
130 write_frame(&mut buf, &payload).unwrap();
131 assert_eq!(buf[4], FLAG_ZSTD);
132 }
133
134 #[test]
135 fn test_write_frame_rejects_payload_at_max_size() {
136 let huge = vec![0u8; MAX_FRAME_SIZE as usize];
140 let mut buf = Vec::new();
141 let result = write_frame(&mut buf, &huge);
142 assert!(
143 result.is_err(),
144 "payload at MAX_FRAME_SIZE must be rejected"
145 );
146 assert!(buf.is_empty(), "no bytes must be written on rejection");
147 }
148
149 #[test]
150 fn test_try_consume_frame_rejects_oversized_length_field() {
151 use super::super::frame_buf::FrameBuffer;
152 use std::io::Cursor;
153
154 let oversized_len = (MAX_FRAME_SIZE as u64 + 1) as u32;
157 let mut wire = Vec::new();
158 wire.extend_from_slice(&oversized_len.to_be_bytes()); wire.push(0x00); let mut fb = FrameBuffer::new();
163 let result = fb.try_read_frame(&mut Cursor::new(&wire));
164 assert!(result.is_err(), "oversized length field must be rejected");
165 }
166}