1use bytes::{Buf, BufMut, Bytes, BytesMut};
38use thiserror::Error;
39
40use crate::CodecKind;
41
42pub const FRAME_MAGIC: &[u8; 4] = b"S4F2";
47pub const PADDING_MAGIC: &[u8; 4] = b"S4P1";
55pub const FRAME_HEADER_BYTES: usize = 4 + 4 + 8 + 8 + 4;
57pub const PADDING_HEADER_BYTES: usize = 4 + 8; pub const S3_MULTIPART_MIN_PART_BYTES: usize = 5 * 1024 * 1024;
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct FrameHeader {
64 pub codec: CodecKind,
65 pub original_size: u64,
66 pub compressed_size: u64,
67 pub crc32c: u32,
68}
69
70#[derive(Debug, Error)]
71pub enum FrameError {
72 #[error("frame too short: need at least {FRAME_HEADER_BYTES} bytes, have {0}")]
73 TooShort(usize),
74 #[error("bad frame magic: expected {expected:?}, got {got:?}")]
75 BadMagic { expected: [u8; 4], got: [u8; 4] },
76 #[error("frame compressed_size {compressed_size} exceeds remaining buffer {remaining}")]
77 PayloadTruncated {
78 compressed_size: u64,
79 remaining: usize,
80 },
81 #[error("unknown codec id {0} in frame header (decoder out of date?)")]
82 UnknownCodec(u32),
83}
84
85pub fn write_frame(dst: &mut BytesMut, header: FrameHeader, payload: &[u8]) {
87 debug_assert_eq!(payload.len() as u64, header.compressed_size);
88 dst.reserve(FRAME_HEADER_BYTES + payload.len());
89 dst.put_slice(FRAME_MAGIC);
90 dst.put_u32_le(header.codec.id());
91 dst.put_u64_le(header.original_size);
92 dst.put_u64_le(header.compressed_size);
93 dst.put_u32_le(header.crc32c);
94 dst.put_slice(payload);
95}
96
97pub fn pad_to_minimum(dst: &mut BytesMut, min_total: usize) {
103 if dst.len() >= min_total {
104 return;
105 }
106 let need = min_total - dst.len();
108 let payload_len = need.saturating_sub(PADDING_HEADER_BYTES);
109 dst.reserve(PADDING_HEADER_BYTES + payload_len);
110 dst.put_slice(PADDING_MAGIC);
111 dst.put_u64_le(payload_len as u64);
112 dst.put_bytes(0, payload_len);
114}
115
116pub fn read_frame(mut input: Bytes) -> Result<(FrameHeader, Bytes, Bytes), FrameError> {
118 if input.len() < FRAME_HEADER_BYTES {
119 return Err(FrameError::TooShort(input.len()));
120 }
121 let mut magic = [0u8; 4];
122 magic.copy_from_slice(&input[..4]);
123 if &magic != FRAME_MAGIC {
124 return Err(FrameError::BadMagic {
125 expected: *FRAME_MAGIC,
126 got: magic,
127 });
128 }
129 input.advance(4);
130 let codec_id = input.get_u32_le();
131 let codec = CodecKind::from_id(codec_id).ok_or(FrameError::UnknownCodec(codec_id))?;
132 let original_size = input.get_u64_le();
133 let compressed_size = input.get_u64_le();
134 let crc32c = input.get_u32_le();
135 if (compressed_size as usize) > input.len() {
136 return Err(FrameError::PayloadTruncated {
137 compressed_size,
138 remaining: input.len(),
139 });
140 }
141 let payload = input.split_to(compressed_size as usize);
142 Ok((
143 FrameHeader {
144 codec,
145 original_size,
146 compressed_size,
147 crc32c,
148 },
149 payload,
150 input,
151 ))
152}
153
154pub struct FrameIter {
163 rest: Bytes,
164 fused: bool,
165}
166
167impl FrameIter {
168 pub fn new(input: Bytes) -> Self {
169 Self {
170 rest: input,
171 fused: false,
172 }
173 }
174}
175
176impl Iterator for FrameIter {
177 type Item = Result<(FrameHeader, Bytes), FrameError>;
178 fn next(&mut self) -> Option<Self::Item> {
179 if self.fused {
180 return None;
181 }
182 loop {
183 if self.rest.is_empty() {
184 return None;
185 }
186 if self.rest.len() < 4 {
187 self.fused = true;
188 return Some(Err(FrameError::TooShort(self.rest.len())));
189 }
190 let mut magic = [0u8; 4];
191 magic.copy_from_slice(&self.rest[..4]);
192 if &magic == PADDING_MAGIC {
193 if self.rest.len() < PADDING_HEADER_BYTES {
195 self.fused = true;
196 return Some(Err(FrameError::TooShort(self.rest.len())));
197 }
198 self.rest.advance(4);
199 let pad_len = self.rest.get_u64_le();
200 if (pad_len as usize) > self.rest.len() {
201 self.fused = true;
202 return Some(Err(FrameError::PayloadTruncated {
203 compressed_size: pad_len,
204 remaining: self.rest.len(),
205 }));
206 }
207 self.rest.advance(pad_len as usize);
208 continue;
209 }
210 return match read_frame(std::mem::take(&mut self.rest)) {
212 Ok((hdr, payload, remainder)) => {
213 self.rest = remainder;
214 Some(Ok((hdr, payload)))
215 }
216 Err(e) => {
217 self.fused = true;
218 Some(Err(e))
219 }
220 };
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn frame_roundtrip_single() {
231 let payload = Bytes::from_static(b"hello frame payload");
232 let header = FrameHeader {
233 codec: CodecKind::CpuZstd,
234 original_size: 999,
235 compressed_size: payload.len() as u64,
236 crc32c: 0xdead_beef,
237 };
238 let mut buf = BytesMut::new();
239 write_frame(&mut buf, header, &payload);
240 assert_eq!(buf.len(), FRAME_HEADER_BYTES + payload.len());
241 let bytes = buf.freeze();
242 let (got_header, got_payload, rest) = read_frame(bytes).unwrap();
243 assert_eq!(got_header, header);
244 assert_eq!(got_payload, payload);
245 assert!(rest.is_empty());
246 }
247
248 #[test]
249 fn frame_iter_walks_all_frames_with_mixed_codecs() {
250 let codecs = [
252 CodecKind::Passthrough,
253 CodecKind::CpuZstd,
254 CodecKind::NvcompZstd,
255 CodecKind::NvcompBitcomp,
256 CodecKind::DietGpuAns,
257 ];
258 let mut buf = BytesMut::new();
259 for (i, codec) in codecs.iter().enumerate() {
260 let payload = vec![i as u8; (i + 1) * 4];
261 let h = FrameHeader {
262 codec: *codec,
263 original_size: 100 + i as u64,
264 compressed_size: payload.len() as u64,
265 crc32c: i as u32,
266 };
267 write_frame(&mut buf, h, &payload);
268 }
269 let total = FrameIter::new(buf.freeze())
270 .collect::<Result<Vec<_>, _>>()
271 .unwrap();
272 assert_eq!(total.len(), 5);
273 for (i, (h, payload)) in total.iter().enumerate() {
274 assert_eq!(h.codec, codecs[i], "codec must be preserved per frame");
275 assert_eq!(h.original_size, 100 + i as u64);
276 assert_eq!(h.crc32c, i as u32);
277 assert_eq!(payload.len(), (i + 1) * 4);
278 }
279 }
280
281 #[test]
282 fn frame_bad_magic_rejected() {
283 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
284 buf.put_slice(b"BAD!");
285 buf.put_u32_le(0); buf.put_u64_le(0);
287 buf.put_u64_le(0);
288 buf.put_u32_le(0);
289 let err = read_frame(buf.freeze()).unwrap_err();
290 assert!(matches!(err, FrameError::BadMagic { .. }));
291 }
292
293 #[test]
294 fn frame_truncated_rejected() {
295 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
297 buf.put_slice(FRAME_MAGIC);
298 buf.put_u32_le(CodecKind::CpuZstd.id());
299 buf.put_u64_le(100);
300 buf.put_u64_le(100);
301 buf.put_u32_le(0);
302 let err = read_frame(buf.freeze()).unwrap_err();
303 assert!(matches!(err, FrameError::PayloadTruncated { .. }));
304 }
305
306 #[test]
307 fn frame_unknown_codec_rejected() {
308 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
309 buf.put_slice(FRAME_MAGIC);
310 buf.put_u32_le(99); buf.put_u64_le(0);
312 buf.put_u64_le(0);
313 buf.put_u32_le(0);
314 let err = read_frame(buf.freeze()).unwrap_err();
315 assert!(matches!(err, FrameError::UnknownCodec(99)));
316 }
317
318 #[test]
319 fn frame_too_short_for_header_rejected() {
320 let buf = Bytes::from_static(b"shortdata");
321 let err = read_frame(buf).unwrap_err();
322 assert!(matches!(err, FrameError::TooShort(_)));
323 }
324
325 #[test]
326 fn padding_skipped_by_iter() {
327 let mut buf = BytesMut::new();
328 let p1 = Bytes::from_static(b"first frame");
330 write_frame(
331 &mut buf,
332 FrameHeader {
333 codec: CodecKind::CpuZstd,
334 original_size: 11,
335 compressed_size: p1.len() as u64,
336 crc32c: 0,
337 },
338 &p1,
339 );
340 pad_to_minimum(&mut buf, 1024);
342 assert!(buf.len() >= 1024);
343 let p2 = Bytes::from_static(b"second frame");
345 write_frame(
346 &mut buf,
347 FrameHeader {
348 codec: CodecKind::CpuZstd,
349 original_size: 12,
350 compressed_size: p2.len() as u64,
351 crc32c: 0,
352 },
353 &p2,
354 );
355
356 let frames: Vec<_> = FrameIter::new(buf.freeze())
357 .collect::<Result<_, _>>()
358 .unwrap();
359 assert_eq!(
360 frames.len(),
361 2,
362 "padding must be skipped, only data yielded"
363 );
364 assert_eq!(frames[0].1, p1);
365 assert_eq!(frames[1].1, p2);
366 }
367
368 #[test]
369 fn pad_to_minimum_is_noop_when_already_above() {
370 let mut buf = BytesMut::new();
371 buf.extend_from_slice(&[0u8; 1024]);
372 pad_to_minimum(&mut buf, 100);
373 assert_eq!(buf.len(), 1024);
374 }
375
376 #[test]
377 fn pad_to_minimum_grows_to_target() {
378 let mut buf = BytesMut::new();
379 write_frame(
380 &mut buf,
381 FrameHeader {
382 codec: CodecKind::Passthrough,
383 original_size: 0,
384 compressed_size: 0,
385 crc32c: 0,
386 },
387 &[],
388 );
389 let before = buf.len();
390 pad_to_minimum(&mut buf, 5_000_000);
391 assert!(buf.len() >= 5_000_000);
392 assert!(buf.len() < 5_000_000 + 64, "no excessive overshoot");
393 assert!(buf.len() > before);
394 }
395}