sqry_daemon_protocol/
framing.rs1use std::io;
24
25use serde::{Serialize, de::DeserializeOwned};
26use thiserror::Error;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28
29pub const MAX_FRAME_BYTES: usize = 64 * 1024 * 1024;
35
36#[derive(Debug, Error)]
38pub enum FrameError {
39 #[error("frame io: {0}")]
42 Io(#[from] io::Error),
43
44 #[error("frame json: {0}")]
47 Json(#[from] serde_json::Error),
48}
49
50pub async fn write_frame<W>(w: &mut W, body: &[u8]) -> io::Result<()>
53where
54 W: AsyncWrite + Unpin,
55{
56 if body.len() > MAX_FRAME_BYTES {
57 return Err(io::Error::new(
58 io::ErrorKind::InvalidInput,
59 format!(
60 "frame body length {} exceeds MAX_FRAME_BYTES ({MAX_FRAME_BYTES})",
61 body.len()
62 ),
63 ));
64 }
65 let len = u32::try_from(body.len()).expect("length bounded above by MAX_FRAME_BYTES");
66 w.write_all(&len.to_le_bytes()).await?;
67 w.write_all(body).await?;
68 w.flush().await?;
69 Ok(())
70}
71
72pub async fn read_frame<R>(r: &mut R) -> io::Result<Option<Vec<u8>>>
79where
80 R: AsyncRead + Unpin,
81{
82 let mut len_buf = [0u8; 4];
83 let mut filled = 0usize;
84 while filled < 4 {
85 match r.read(&mut len_buf[filled..]).await? {
86 0 if filled == 0 => return Ok(None),
87 0 => {
88 return Err(io::Error::new(
89 io::ErrorKind::UnexpectedEof,
90 format!("truncated frame: got {filled}/4 length bytes before EOF"),
91 ));
92 }
93 n => filled += n,
94 }
95 }
96 let len = u32::from_le_bytes(len_buf) as usize;
97 if len > MAX_FRAME_BYTES {
98 return Err(io::Error::new(
99 io::ErrorKind::InvalidData,
100 format!("frame len {len} exceeds MAX_FRAME_BYTES ({MAX_FRAME_BYTES})"),
101 ));
102 }
103 let mut body = vec![0u8; len];
104 match r.read_exact(&mut body).await {
105 Ok(_) => Ok(Some(body)),
106 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Err(io::Error::new(
107 io::ErrorKind::UnexpectedEof,
108 format!("truncated frame body: expected {len} bytes"),
109 )),
110 Err(e) => Err(e),
111 }
112}
113
114pub async fn write_frame_json<W, T>(w: &mut W, value: &T) -> Result<(), FrameError>
116where
117 W: AsyncWrite + Unpin,
118 T: Serialize + ?Sized,
119{
120 let body = serde_json::to_vec(value)?;
121 write_frame(w, &body).await?;
122 Ok(())
123}
124
125pub async fn read_frame_json<R, T>(r: &mut R) -> Result<Option<T>, FrameError>
128where
129 R: AsyncRead + Unpin,
130 T: DeserializeOwned,
131{
132 let Some(bytes) = read_frame(r).await? else {
133 return Ok(None);
134 };
135 let value = serde_json::from_slice(&bytes)?;
136 Ok(Some(value))
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use tokio::io::duplex;
143
144 #[tokio::test]
145 async fn round_trip_small_frame() {
146 let (mut a, mut b) = duplex(1024);
147 let payload = br#"{"hello":"world"}"#;
148 write_frame(&mut a, payload).await.expect("write");
149 drop(a);
150 let got = read_frame(&mut b).await.expect("read").expect("some");
151 assert_eq!(got, payload);
152 }
153
154 #[tokio::test]
155 async fn rejects_oversize_frame_on_write() {
156 let (mut a, _b) = duplex(1024);
157 let body = vec![0u8; MAX_FRAME_BYTES + 1];
158 let err = write_frame(&mut a, &body)
159 .await
160 .expect_err("oversize must fail");
161 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
162 }
163
164 #[tokio::test]
165 async fn clean_eof_at_frame_boundary_returns_ok_none() {
166 let (a, mut b) = duplex(64);
167 drop(a); let got = read_frame(&mut b).await.expect("no error on clean EOF");
169 assert!(got.is_none());
170 }
171
172 #[tokio::test]
173 async fn truncated_prefix_is_error() {
174 let (mut a, mut b) = duplex(64);
175 a.write_all(&[0x01, 0x00]).await.unwrap(); drop(a);
177 let err = read_frame(&mut b).await.expect_err("truncated prefix");
178 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
179 assert!(
180 err.to_string().contains("got 2/4 length bytes"),
181 "got unexpected message: {err}"
182 );
183 }
184
185 #[tokio::test]
186 async fn truncated_body_is_error() {
187 let (mut a, mut b) = duplex(64);
188 let len = 16u32.to_le_bytes();
190 a.write_all(&len).await.unwrap();
191 a.write_all(b"short").await.unwrap();
192 drop(a);
193 let err = read_frame(&mut b).await.expect_err("truncated body");
194 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
195 assert!(
196 err.to_string().contains("truncated frame body"),
197 "got unexpected message: {err}"
198 );
199 }
200
201 #[tokio::test]
202 async fn oversize_read_is_rejected() {
203 let (mut a, mut b) = duplex(64);
204 let bad_len = (MAX_FRAME_BYTES as u32 + 1).to_le_bytes();
205 a.write_all(&bad_len).await.unwrap();
206 let err = read_frame(&mut b).await.expect_err("oversize claim");
207 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
208 }
209}