1use bytes::{Buf, BufMut, Bytes, BytesMut};
38use thiserror::Error;
39
40use crate::CodecKind;
41
42pub const FRAME_MAGIC: &[u8; 4] = b"S4F2";
47pub const PADDING_MAGIC: &[u8; 4] = b"S4P1";
55pub const FRAME_HEADER_BYTES: usize = 4 + 4 + 8 + 8 + 4;
57pub const PADDING_HEADER_BYTES: usize = 4 + 8; pub const S3_MULTIPART_MIN_PART_BYTES: usize = 5 * 1024 * 1024;
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct FrameHeader {
64 pub codec: CodecKind,
65 pub original_size: u64,
66 pub compressed_size: u64,
67 pub crc32c: u32,
68}
69
70#[derive(Debug, Error)]
74#[non_exhaustive]
75pub enum FrameError {
76 #[error("frame too short: need at least {FRAME_HEADER_BYTES} bytes, have {0}")]
77 TooShort(usize),
78 #[error("bad frame magic: expected {expected:?}, got {got:?}")]
79 BadMagic { expected: [u8; 4], got: [u8; 4] },
80 #[error("frame compressed_size {compressed_size} exceeds remaining buffer {remaining}")]
81 PayloadTruncated {
82 compressed_size: u64,
83 remaining: usize,
84 },
85 #[error("unknown codec id {0} in frame header (decoder out of date?)")]
86 UnknownCodec(u32),
87 #[error("frame payload size {0} exceeds usize on this target")]
95 PayloadTooLarge(u64),
96}
97
98pub fn write_frame(dst: &mut BytesMut, header: FrameHeader, payload: &[u8]) {
100 debug_assert_eq!(payload.len() as u64, header.compressed_size);
101 dst.reserve(FRAME_HEADER_BYTES + payload.len());
102 dst.put_slice(FRAME_MAGIC);
103 dst.put_u32_le(header.codec.id());
104 dst.put_u64_le(header.original_size);
105 dst.put_u64_le(header.compressed_size);
106 dst.put_u32_le(header.crc32c);
107 dst.put_slice(payload);
108}
109
110pub fn pad_to_minimum(dst: &mut BytesMut, min_total: usize) {
128 if dst.len() >= min_total {
129 return;
130 }
131 let need = min_total - dst.len();
133 let payload_len = need.saturating_sub(PADDING_HEADER_BYTES);
134 if payload_len > 0 {
139 dst.reserve(PADDING_HEADER_BYTES + payload_len);
140 }
141 dst.put_slice(PADDING_MAGIC);
142 dst.put_u64_le(payload_len as u64);
143 dst.put_bytes(0, payload_len);
145}
146
147pub fn read_frame(mut input: Bytes) -> Result<(FrameHeader, Bytes, Bytes), FrameError> {
149 if input.len() < FRAME_HEADER_BYTES {
150 return Err(FrameError::TooShort(input.len()));
151 }
152 let mut magic = [0u8; 4];
153 magic.copy_from_slice(&input[..4]);
154 if &magic != FRAME_MAGIC {
155 return Err(FrameError::BadMagic {
156 expected: *FRAME_MAGIC,
157 got: magic,
158 });
159 }
160 input.advance(4);
161 let codec_id = input.get_u32_le();
162 let codec = CodecKind::from_id(codec_id).ok_or(FrameError::UnknownCodec(codec_id))?;
163 let original_size = input.get_u64_le();
164 let compressed_size = input.get_u64_le();
165 let crc32c = input.get_u32_le();
166 let compressed_size_usize = usize::try_from(compressed_size)
172 .map_err(|_| FrameError::PayloadTooLarge(compressed_size))?;
173 if compressed_size_usize > input.len() {
174 return Err(FrameError::PayloadTruncated {
175 compressed_size,
176 remaining: input.len(),
177 });
178 }
179 let payload = input.split_to(compressed_size_usize);
180 Ok((
181 FrameHeader {
182 codec,
183 original_size,
184 compressed_size,
185 crc32c,
186 },
187 payload,
188 input,
189 ))
190}
191
192pub struct FrameIter {
201 rest: Bytes,
202 fused: bool,
203}
204
205impl FrameIter {
206 pub fn new(input: Bytes) -> Self {
207 Self {
208 rest: input,
209 fused: false,
210 }
211 }
212}
213
214impl Iterator for FrameIter {
215 type Item = Result<(FrameHeader, Bytes), FrameError>;
216 fn next(&mut self) -> Option<Self::Item> {
217 if self.fused {
218 return None;
219 }
220 loop {
221 if self.rest.is_empty() {
222 return None;
223 }
224 if self.rest.len() < 4 {
225 self.fused = true;
226 return Some(Err(FrameError::TooShort(self.rest.len())));
227 }
228 let mut magic = [0u8; 4];
229 magic.copy_from_slice(&self.rest[..4]);
230 if &magic == PADDING_MAGIC {
231 if self.rest.len() < PADDING_HEADER_BYTES {
233 self.fused = true;
234 return Some(Err(FrameError::TooShort(self.rest.len())));
235 }
236 self.rest.advance(4);
237 let pad_len = self.rest.get_u64_le();
238 let pad_len_usize = match usize::try_from(pad_len) {
242 Ok(n) => n,
243 Err(_) => {
244 self.fused = true;
245 return Some(Err(FrameError::PayloadTooLarge(pad_len)));
246 }
247 };
248 if pad_len_usize > self.rest.len() {
249 self.fused = true;
250 return Some(Err(FrameError::PayloadTruncated {
251 compressed_size: pad_len,
252 remaining: self.rest.len(),
253 }));
254 }
255 self.rest.advance(pad_len_usize);
256 continue;
257 }
258 return match read_frame(std::mem::take(&mut self.rest)) {
260 Ok((hdr, payload, remainder)) => {
261 self.rest = remainder;
262 Some(Ok((hdr, payload)))
263 }
264 Err(e) => {
265 self.fused = true;
266 Some(Err(e))
267 }
268 };
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn frame_roundtrip_single() {
279 let payload = Bytes::from_static(b"hello frame payload");
280 let header = FrameHeader {
281 codec: CodecKind::CpuZstd,
282 original_size: 999,
283 compressed_size: payload.len() as u64,
284 crc32c: 0xdead_beef,
285 };
286 let mut buf = BytesMut::new();
287 write_frame(&mut buf, header, &payload);
288 assert_eq!(buf.len(), FRAME_HEADER_BYTES + payload.len());
289 let bytes = buf.freeze();
290 let (got_header, got_payload, rest) = read_frame(bytes).unwrap();
291 assert_eq!(got_header, header);
292 assert_eq!(got_payload, payload);
293 assert!(rest.is_empty());
294 }
295
296 #[test]
297 fn frame_iter_walks_all_frames_with_mixed_codecs() {
298 let codecs = [
300 CodecKind::Passthrough,
301 CodecKind::CpuZstd,
302 CodecKind::NvcompZstd,
303 CodecKind::NvcompBitcomp,
304 CodecKind::DietGpuAns,
305 ];
306 let mut buf = BytesMut::new();
307 for (i, codec) in codecs.iter().enumerate() {
308 let payload = vec![i as u8; (i + 1) * 4];
309 let h = FrameHeader {
310 codec: *codec,
311 original_size: 100 + i as u64,
312 compressed_size: payload.len() as u64,
313 crc32c: i as u32,
314 };
315 write_frame(&mut buf, h, &payload);
316 }
317 let total = FrameIter::new(buf.freeze())
318 .collect::<Result<Vec<_>, _>>()
319 .unwrap();
320 assert_eq!(total.len(), 5);
321 for (i, (h, payload)) in total.iter().enumerate() {
322 assert_eq!(h.codec, codecs[i], "codec must be preserved per frame");
323 assert_eq!(h.original_size, 100 + i as u64);
324 assert_eq!(h.crc32c, i as u32);
325 assert_eq!(payload.len(), (i + 1) * 4);
326 }
327 }
328
329 #[test]
330 fn frame_bad_magic_rejected() {
331 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
332 buf.put_slice(b"BAD!");
333 buf.put_u32_le(0); buf.put_u64_le(0);
335 buf.put_u64_le(0);
336 buf.put_u32_le(0);
337 let err = read_frame(buf.freeze()).unwrap_err();
338 assert!(matches!(err, FrameError::BadMagic { .. }));
339 }
340
341 #[test]
342 fn frame_truncated_rejected() {
343 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
345 buf.put_slice(FRAME_MAGIC);
346 buf.put_u32_le(CodecKind::CpuZstd.id());
347 buf.put_u64_le(100);
348 buf.put_u64_le(100);
349 buf.put_u32_le(0);
350 let err = read_frame(buf.freeze()).unwrap_err();
351 assert!(matches!(err, FrameError::PayloadTruncated { .. }));
352 }
353
354 #[test]
355 fn frame_unknown_codec_rejected() {
356 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
357 buf.put_slice(FRAME_MAGIC);
358 buf.put_u32_le(99); buf.put_u64_le(0);
360 buf.put_u64_le(0);
361 buf.put_u32_le(0);
362 let err = read_frame(buf.freeze()).unwrap_err();
363 assert!(matches!(err, FrameError::UnknownCodec(99)));
364 }
365
366 #[test]
367 fn frame_too_short_for_header_rejected() {
368 let buf = Bytes::from_static(b"shortdata");
369 let err = read_frame(buf).unwrap_err();
370 assert!(matches!(err, FrameError::TooShort(_)));
371 }
372
373 #[test]
374 fn padding_skipped_by_iter() {
375 let mut buf = BytesMut::new();
376 let p1 = Bytes::from_static(b"first frame");
378 write_frame(
379 &mut buf,
380 FrameHeader {
381 codec: CodecKind::CpuZstd,
382 original_size: 11,
383 compressed_size: p1.len() as u64,
384 crc32c: 0,
385 },
386 &p1,
387 );
388 pad_to_minimum(&mut buf, 1024);
390 assert!(buf.len() >= 1024);
391 let p2 = Bytes::from_static(b"second frame");
393 write_frame(
394 &mut buf,
395 FrameHeader {
396 codec: CodecKind::CpuZstd,
397 original_size: 12,
398 compressed_size: p2.len() as u64,
399 crc32c: 0,
400 },
401 &p2,
402 );
403
404 let frames: Vec<_> = FrameIter::new(buf.freeze())
405 .collect::<Result<_, _>>()
406 .unwrap();
407 assert_eq!(
408 frames.len(),
409 2,
410 "padding must be skipped, only data yielded"
411 );
412 assert_eq!(frames[0].1, p1);
413 assert_eq!(frames[1].1, p2);
414 }
415
416 #[test]
417 fn pad_to_minimum_is_noop_when_already_above() {
418 let mut buf = BytesMut::new();
419 buf.extend_from_slice(&[0u8; 1024]);
420 pad_to_minimum(&mut buf, 100);
421 assert_eq!(buf.len(), 1024);
422 }
423
424 #[test]
425 fn pad_to_minimum_grows_to_target() {
426 let mut buf = BytesMut::new();
427 write_frame(
428 &mut buf,
429 FrameHeader {
430 codec: CodecKind::Passthrough,
431 original_size: 0,
432 compressed_size: 0,
433 crc32c: 0,
434 },
435 &[],
436 );
437 let before = buf.len();
438 pad_to_minimum(&mut buf, 5_000_000);
439 assert!(buf.len() >= 5_000_000);
440 assert!(buf.len() < 5_000_000 + 64, "no excessive overshoot");
441 assert!(buf.len() > before);
442 }
443}