zamsync_network/protocol/
frame_buf.rs1use super::frame::MAX_FRAME_SIZE;
2use std::io::{ErrorKind, Read};
3use zstd;
4use zamsync_core::{ZamError, ZamResult};
5
6pub struct FrameBuffer {
19 buf: Vec<u8>,
20}
21
22impl FrameBuffer {
23 pub fn new() -> Self {
24 Self { buf: Vec::new() }
25 }
26
27 pub fn try_read_frame(&mut self, stream: &mut impl Read) -> ZamResult<Option<Vec<u8>>> {
39 if let Some(frame) = self.try_consume_frame()? {
43 return Ok(Some(frame));
44 }
45
46 let mut tmp = [0u8; 8192];
48 let mut got_new_bytes = false;
49 loop {
50 match stream.read(&mut tmp) {
51 Ok(0) => {
52 if !got_new_bytes {
58 return Err(ZamError::Io(std::io::Error::new(
59 ErrorKind::UnexpectedEof,
60 "connection closed by peer",
61 )));
62 }
63 break;
64 }
65 Ok(n) => {
66 self.buf.extend_from_slice(&tmp[..n]);
67 got_new_bytes = true;
68 }
69 Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
70 break;
72 }
73 Err(e) => return Err(ZamError::Io(e)),
74 }
75 }
76
77 self.try_consume_frame()
79 }
80
81 fn try_consume_frame(&mut self) -> ZamResult<Option<Vec<u8>>> {
84 if self.buf.len() < 4 {
87 return Ok(None);
88 }
89 let total_len =
90 u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]]) as usize;
91
92 if total_len == 0 {
93 self.buf.drain(..4);
95 return Ok(Some(vec![]));
96 }
97
98 if total_len as u64 > MAX_FRAME_SIZE as u64 {
99 return Err(ZamError::Protocol(format!(
100 "received frame too large: {} bytes (max {})",
101 total_len, MAX_FRAME_SIZE
102 )));
103 }
104
105 let frame_end = 4 + total_len;
106 if self.buf.len() < frame_end {
107 return Ok(None);
109 }
110
111 let flag = self.buf[4];
113 let body = self.buf[5..frame_end].to_vec();
114 self.buf.drain(..frame_end);
115
116 const FLAG_RAW: u8 = 0x00;
117 const FLAG_ZSTD: u8 = 0x01;
118
119 let payload = match flag {
120 FLAG_RAW => body,
121 FLAG_ZSTD => zstd::decode_all(body.as_slice())
122 .map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}")))?,
123 other => {
124 return Err(ZamError::Protocol(format!(
125 "unknown frame flag: 0x{other:02x}"
126 )))
127 }
128 };
129
130 Ok(Some(payload))
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::protocol::frame::write_frame;
138 use std::io::Cursor;
139
140 fn make_frame(payload: &[u8]) -> Vec<u8> {
141 let mut buf = Vec::new();
142 write_frame(&mut buf, payload).unwrap();
143 buf
144 }
145
146 #[test]
147 fn test_complete_frame_at_once() {
148 let payload = b"hello from bhutan";
149 let wire = make_frame(payload);
150 let mut fb = FrameBuffer::new();
151 let result = fb.try_read_frame(&mut Cursor::new(&wire)).unwrap();
153 assert_eq!(result, Some(payload.to_vec()));
154 assert!(fb.buf.is_empty());
156 }
157
158 #[test]
159 fn test_two_frames_back_to_back() {
160 let wire1 = make_frame(b"frame-one");
161 let wire2 = make_frame(b"frame-two");
162 let mut combined = wire1.clone();
163 combined.extend_from_slice(&wire2);
164
165 let mut fb = FrameBuffer::new();
166 let r1 = fb.try_read_frame(&mut Cursor::new(&combined)).unwrap();
167 assert_eq!(r1, Some(b"frame-one".to_vec()));
168 let r2 = fb.try_read_frame(&mut Cursor::new(&[])).unwrap();
171 assert_eq!(r2, Some(b"frame-two".to_vec()));
172 }
173
174 #[test]
175 fn test_partial_header_returns_none() {
176 let wire = make_frame(b"some data");
177 let partial = &wire[..2]; let mut fb = FrameBuffer::new();
179 let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
180 assert!(result.is_none());
181 assert_eq!(fb.buf.len(), 2);
182 }
183
184 #[test]
185 fn test_partial_body_returns_none() {
186 let wire = make_frame(b"some longer payload that has many bytes");
187 let partial = &wire[..wire.len() - 5]; let mut fb = FrameBuffer::new();
189 let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
190 assert!(result.is_none());
191 assert_eq!(fb.buf.len(), partial.len());
193 }
194
195 #[test]
196 fn test_split_delivery_reassembles_frame() {
197 let payload = b"patient-record-from-rural-bhutan";
198 let wire = make_frame(payload);
199 let mid = wire.len() / 2;
200
201 let mut fb = FrameBuffer::new();
202 let r1 = fb.try_read_frame(&mut Cursor::new(&wire[..mid])).unwrap();
204 assert!(r1.is_none());
205 let r2 = fb.try_read_frame(&mut Cursor::new(&wire[mid..])).unwrap();
207 assert_eq!(r2, Some(payload.to_vec()));
208 }
209
210 #[test]
211 fn test_empty_reader_on_empty_buffer_is_eof() {
212 let mut fb = FrameBuffer::new();
213 let result = fb.try_read_frame(&mut Cursor::new(&[]));
216 assert!(matches!(result, Err(ZamError::Io(_))));
217 }
218}