zamsync_network/protocol/
frame_buf.rs1use super::frame::MAX_FRAME_SIZE;
2use std::io::{ErrorKind, Read};
3use zamsync_core::{ZamError, ZamResult};
4use zstd;
5
6pub struct FrameBuffer {
19 buf: Vec<u8>,
20}
21
22impl Default for FrameBuffer {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl FrameBuffer {
29 pub fn new() -> Self {
30 Self { buf: Vec::new() }
31 }
32
33 pub fn try_read_frame(&mut self, stream: &mut impl Read) -> ZamResult<Option<Vec<u8>>> {
44 if let Some(frame) = self.try_consume_frame()? {
48 return Ok(Some(frame));
49 }
50
51 let mut tmp = [0u8; 8192];
53 let mut got_new_bytes = false;
54 loop {
55 match stream.read(&mut tmp) {
56 Ok(0) => {
57 if !got_new_bytes {
63 return Err(ZamError::Io(std::io::Error::new(
64 ErrorKind::UnexpectedEof,
65 "connection closed by peer",
66 )));
67 }
68 break;
69 }
70 Ok(n) => {
71 self.buf.extend_from_slice(&tmp[..n]);
72 got_new_bytes = true;
73 }
74 Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
75 break;
77 }
78 Err(e) => return Err(ZamError::Io(e)),
79 }
80 }
81
82 self.try_consume_frame()
84 }
85
86 fn try_consume_frame(&mut self) -> ZamResult<Option<Vec<u8>>> {
89 if self.buf.len() < 4 {
92 return Ok(None);
93 }
94 let total_len =
95 u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]]) as usize;
96
97 if total_len == 0 {
98 self.buf.drain(..4);
100 return Ok(Some(vec![]));
101 }
102
103 if total_len as u64 > MAX_FRAME_SIZE as u64 {
104 return Err(ZamError::Protocol(format!(
105 "received frame too large: {} bytes (max {})",
106 total_len, MAX_FRAME_SIZE
107 )));
108 }
109
110 let frame_end = 4 + total_len;
111 if self.buf.len() < frame_end {
112 return Ok(None);
114 }
115
116 let flag = self.buf[4];
118 let body = self.buf[5..frame_end].to_vec();
119 self.buf.drain(..frame_end);
120
121 const FLAG_RAW: u8 = 0x00;
122 const FLAG_ZSTD: u8 = 0x01;
123
124 let payload = match flag {
125 FLAG_RAW => body,
126 FLAG_ZSTD => zstd::decode_all(body.as_slice())
127 .map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}")))?,
128 other => {
129 return Err(ZamError::Protocol(format!(
130 "unknown frame flag: 0x{other:02x}"
131 )))
132 }
133 };
134
135 Ok(Some(payload))
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::protocol::frame::write_frame;
143 use std::io::Cursor;
144
145 fn make_frame(payload: &[u8]) -> Vec<u8> {
146 let mut buf = Vec::new();
147 write_frame(&mut buf, payload).unwrap();
148 buf
149 }
150
151 #[test]
152 fn test_complete_frame_at_once() {
153 let payload = b"hello from bhutan";
154 let wire = make_frame(payload);
155 let mut fb = FrameBuffer::new();
156 let result = fb.try_read_frame(&mut Cursor::new(&wire)).unwrap();
158 assert_eq!(result, Some(payload.to_vec()));
159 assert!(fb.buf.is_empty());
161 }
162
163 #[test]
164 fn test_two_frames_back_to_back() {
165 let wire1 = make_frame(b"frame-one");
166 let wire2 = make_frame(b"frame-two");
167 let mut combined = wire1.clone();
168 combined.extend_from_slice(&wire2);
169
170 let mut fb = FrameBuffer::new();
171 let r1 = fb.try_read_frame(&mut Cursor::new(&combined)).unwrap();
172 assert_eq!(r1, Some(b"frame-one".to_vec()));
173 let r2 = fb.try_read_frame(&mut Cursor::new(&[])).unwrap();
176 assert_eq!(r2, Some(b"frame-two".to_vec()));
177 }
178
179 #[test]
180 fn test_partial_header_returns_none() {
181 let wire = make_frame(b"some data");
182 let partial = &wire[..2]; let mut fb = FrameBuffer::new();
184 let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
185 assert!(result.is_none());
186 assert_eq!(fb.buf.len(), 2);
187 }
188
189 #[test]
190 fn test_partial_body_returns_none() {
191 let wire = make_frame(b"some longer payload that has many bytes");
192 let partial = &wire[..wire.len() - 5]; let mut fb = FrameBuffer::new();
194 let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
195 assert!(result.is_none());
196 assert_eq!(fb.buf.len(), partial.len());
198 }
199
200 #[test]
201 fn test_split_delivery_reassembles_frame() {
202 let payload = b"patient-record-from-rural-bhutan";
203 let wire = make_frame(payload);
204 let mid = wire.len() / 2;
205
206 let mut fb = FrameBuffer::new();
207 let r1 = fb.try_read_frame(&mut Cursor::new(&wire[..mid])).unwrap();
209 assert!(r1.is_none());
210 let r2 = fb.try_read_frame(&mut Cursor::new(&wire[mid..])).unwrap();
212 assert_eq!(r2, Some(payload.to_vec()));
213 }
214
215 #[test]
216 fn test_empty_reader_on_empty_buffer_is_eof() {
217 let mut fb = FrameBuffer::new();
218 let result = fb.try_read_frame(&mut Cursor::new(&[]));
221 assert!(matches!(result, Err(ZamError::Io(_))));
222 }
223}