Skip to main content

zamsync_network/protocol/
frame_buf.rs

1use super::frame::MAX_FRAME_SIZE;
2use std::io::{ErrorKind, Read};
3use zamsync_core::{ZamError, ZamResult};
4use zstd;
5
6/// Per-connection receive buffer.
7///
8/// The 50ms `read_timeout` on TCP sockets is used to poll multiple peers without
9/// blocking forever, but it means a `read_exact` inside `read_frame` can be
10/// interrupted mid-frame on very slow links (e.g. 3 KB/s). When that happens
11/// the partial bytes that were already pulled from the kernel buffer are lost,
12/// which shifts every subsequent frame by some number of bytes and breaks the
13/// length-prefix framing entirely.
14///
15/// `FrameBuffer` fixes this by accumulating all received bytes in `buf` and
16/// only returning a complete frame once enough bytes are present. Partial reads
17/// due to timeout just leave bytes in the buffer for the next poll cycle.
18pub struct FrameBuffer {
19    buf: Vec<u8>,
20}
21
22impl Default for FrameBuffer {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl FrameBuffer {
29    pub fn new() -> Self {
30        Self { buf: Vec::new() }
31    }
32
33    /// Try to return one complete decoded frame.
34    ///
35    /// Reads as many bytes as the stream offers (stopping on `WouldBlock` or
36    /// `TimedOut`), then checks whether the accumulated buffer contains a full
37    /// length-prefixed frame.
38    ///
39    /// Returns:
40    /// * `Ok(Some(payload))` -- a complete, decompressed frame is ready.
41    /// * `Ok(None)` -- not enough bytes yet; call again after the next read opportunity.
42    /// * `Err(_)` -- a real I/O or protocol error occurred.
43    pub fn try_read_frame(&mut self, stream: &mut impl Read) -> ZamResult<Option<Vec<u8>>> {
44        // Fast path: if the buffer already holds a complete frame, return it
45        // without touching the stream at all. This handles the case where a
46        // previous read delivered two frames in one syscall.
47        if let Some(frame) = self.try_consume_frame()? {
48            return Ok(Some(frame));
49        }
50
51        // Drain whatever bytes the stream has right now into our buffer.
52        let mut tmp = [0u8; 8192];
53        let mut got_new_bytes = false;
54        loop {
55            match stream.read(&mut tmp) {
56                Ok(0) => {
57                    // The peer sent us EOF (connection closed cleanly).
58                    // Only treat it as a real EOF if we received no new bytes
59                    // in this call. If we did receive some bytes and *then*
60                    // got Ok(0), the OS drained the kernel buffer -- we process
61                    // what we have and let the next poll see another Ok(0).
62                    if !got_new_bytes {
63                        return Err(ZamError::Io(std::io::Error::new(
64                            ErrorKind::UnexpectedEof,
65                            "connection closed by peer",
66                        )));
67                    }
68                    break;
69                }
70                Ok(n) => {
71                    self.buf.extend_from_slice(&tmp[..n]);
72                    got_new_bytes = true;
73                }
74                Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
75                    // Nothing more to read right now -- stop draining.
76                    break;
77                }
78                Err(e) => return Err(ZamError::Io(e)),
79            }
80        }
81
82        // Check the buffer again after draining.
83        self.try_consume_frame()
84    }
85
86    /// Attempt to extract and decode one complete frame from `self.buf`.
87    /// Returns `Ok(None)` if fewer than `4 + total_len` bytes are present.
88    fn try_consume_frame(&mut self) -> ZamResult<Option<Vec<u8>>> {
89        // Check whether we have accumulated a full frame.
90        // Wire format: [4 bytes big-endian u32 = total_len] [1 byte flag] [body]
91        if self.buf.len() < 4 {
92            return Ok(None);
93        }
94        let total_len =
95            u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]]) as usize;
96
97        if total_len == 0 {
98            // Empty frame -- consume the 4-byte header and return an empty payload.
99            self.buf.drain(..4);
100            return Ok(Some(vec![]));
101        }
102
103        if total_len as u64 > MAX_FRAME_SIZE as u64 {
104            return Err(ZamError::Protocol(format!(
105                "received frame too large: {} bytes (max {})",
106                total_len, MAX_FRAME_SIZE
107            )));
108        }
109
110        let frame_end = 4 + total_len;
111        if self.buf.len() < frame_end {
112            // Not enough bytes yet -- wait for more.
113            return Ok(None);
114        }
115
116        // We have a complete frame; decode it.
117        let flag = self.buf[4];
118        let body = self.buf[5..frame_end].to_vec();
119        self.buf.drain(..frame_end);
120
121        const FLAG_RAW: u8 = 0x00;
122        const FLAG_ZSTD: u8 = 0x01;
123
124        let payload = match flag {
125            FLAG_RAW => body,
126            FLAG_ZSTD => zstd::decode_all(body.as_slice())
127                .map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}")))?,
128            other => {
129                return Err(ZamError::Protocol(format!(
130                    "unknown frame flag: 0x{other:02x}"
131                )))
132            }
133        };
134
135        Ok(Some(payload))
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::protocol::frame::write_frame;
143    use std::io::Cursor;
144
145    fn make_frame(payload: &[u8]) -> Vec<u8> {
146        let mut buf = Vec::new();
147        write_frame(&mut buf, payload).unwrap();
148        buf
149    }
150
151    #[test]
152    fn test_complete_frame_at_once() {
153        let payload = b"hello from bhutan";
154        let wire = make_frame(payload);
155        let mut fb = FrameBuffer::new();
156        // Simulate one big read delivering everything.
157        let result = fb.try_read_frame(&mut Cursor::new(&wire)).unwrap();
158        assert_eq!(result, Some(payload.to_vec()));
159        // Buffer should be empty now.
160        assert!(fb.buf.is_empty());
161    }
162
163    #[test]
164    fn test_two_frames_back_to_back() {
165        let wire1 = make_frame(b"frame-one");
166        let wire2 = make_frame(b"frame-two");
167        let mut combined = wire1.clone();
168        combined.extend_from_slice(&wire2);
169
170        let mut fb = FrameBuffer::new();
171        let r1 = fb.try_read_frame(&mut Cursor::new(&combined)).unwrap();
172        assert_eq!(r1, Some(b"frame-one".to_vec()));
173        // The second frame's bytes are still in fb.buf -- calling with an
174        // empty reader should return the second frame from buffered data.
175        let r2 = fb.try_read_frame(&mut Cursor::new(&[])).unwrap();
176        assert_eq!(r2, Some(b"frame-two".to_vec()));
177    }
178
179    #[test]
180    fn test_partial_header_returns_none() {
181        let wire = make_frame(b"some data");
182        let partial = &wire[..2]; // only 2 of the 4 header bytes
183        let mut fb = FrameBuffer::new();
184        let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
185        assert!(result.is_none());
186        assert_eq!(fb.buf.len(), 2);
187    }
188
189    #[test]
190    fn test_partial_body_returns_none() {
191        let wire = make_frame(b"some longer payload that has many bytes");
192        let partial = &wire[..wire.len() - 5]; // missing last 5 bytes
193        let mut fb = FrameBuffer::new();
194        let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
195        assert!(result.is_none());
196        // Bytes are preserved in the buffer.
197        assert_eq!(fb.buf.len(), partial.len());
198    }
199
200    #[test]
201    fn test_split_delivery_reassembles_frame() {
202        let payload = b"patient-record-from-rural-bhutan";
203        let wire = make_frame(payload);
204        let mid = wire.len() / 2;
205
206        let mut fb = FrameBuffer::new();
207        // First half -- should return None.
208        let r1 = fb.try_read_frame(&mut Cursor::new(&wire[..mid])).unwrap();
209        assert!(r1.is_none());
210        // Second half -- should now return the full frame.
211        let r2 = fb.try_read_frame(&mut Cursor::new(&wire[mid..])).unwrap();
212        assert_eq!(r2, Some(payload.to_vec()));
213    }
214
215    #[test]
216    fn test_empty_reader_on_empty_buffer_is_eof() {
217        let mut fb = FrameBuffer::new();
218        // Empty buffer + reader returning Ok(0) = peer closed the connection.
219        // This is how the sync session's graceful-close loop detects the end.
220        let result = fb.try_read_frame(&mut Cursor::new(&[]));
221        assert!(matches!(result, Err(ZamError::Io(_))));
222    }
223}