1use bytes::{Buf, BufMut, Bytes, BytesMut};
38use thiserror::Error;
39
40use crate::CodecKind;
41
42pub const FRAME_MAGIC: &[u8; 4] = b"S4F2";
47pub const PADDING_MAGIC: &[u8; 4] = b"S4P1";
55pub const FRAME_HEADER_BYTES: usize = 4 + 4 + 8 + 8 + 4;
57pub const PADDING_HEADER_BYTES: usize = 4 + 8; pub const S3_MULTIPART_MIN_PART_BYTES: usize = 5 * 1024 * 1024;
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct FrameHeader {
64 pub codec: CodecKind,
65 pub original_size: u64,
66 pub compressed_size: u64,
67 pub crc32c: u32,
68}
69
70#[derive(Debug, Error)]
71pub enum FrameError {
72 #[error("frame too short: need at least {FRAME_HEADER_BYTES} bytes, have {0}")]
73 TooShort(usize),
74 #[error("bad frame magic: expected {expected:?}, got {got:?}")]
75 BadMagic { expected: [u8; 4], got: [u8; 4] },
76 #[error("frame compressed_size {compressed_size} exceeds remaining buffer {remaining}")]
77 PayloadTruncated {
78 compressed_size: u64,
79 remaining: usize,
80 },
81 #[error("unknown codec id {0} in frame header (decoder out of date?)")]
82 UnknownCodec(u32),
83 #[error("frame payload size {0} exceeds usize on this target")]
91 PayloadTooLarge(u64),
92}
93
94pub fn write_frame(dst: &mut BytesMut, header: FrameHeader, payload: &[u8]) {
96 debug_assert_eq!(payload.len() as u64, header.compressed_size);
97 dst.reserve(FRAME_HEADER_BYTES + payload.len());
98 dst.put_slice(FRAME_MAGIC);
99 dst.put_u32_le(header.codec.id());
100 dst.put_u64_le(header.original_size);
101 dst.put_u64_le(header.compressed_size);
102 dst.put_u32_le(header.crc32c);
103 dst.put_slice(payload);
104}
105
106pub fn pad_to_minimum(dst: &mut BytesMut, min_total: usize) {
124 if dst.len() >= min_total {
125 return;
126 }
127 let need = min_total - dst.len();
129 let payload_len = need.saturating_sub(PADDING_HEADER_BYTES);
130 if payload_len > 0 {
135 dst.reserve(PADDING_HEADER_BYTES + payload_len);
136 }
137 dst.put_slice(PADDING_MAGIC);
138 dst.put_u64_le(payload_len as u64);
139 dst.put_bytes(0, payload_len);
141}
142
143pub fn read_frame(mut input: Bytes) -> Result<(FrameHeader, Bytes, Bytes), FrameError> {
145 if input.len() < FRAME_HEADER_BYTES {
146 return Err(FrameError::TooShort(input.len()));
147 }
148 let mut magic = [0u8; 4];
149 magic.copy_from_slice(&input[..4]);
150 if &magic != FRAME_MAGIC {
151 return Err(FrameError::BadMagic {
152 expected: *FRAME_MAGIC,
153 got: magic,
154 });
155 }
156 input.advance(4);
157 let codec_id = input.get_u32_le();
158 let codec = CodecKind::from_id(codec_id).ok_or(FrameError::UnknownCodec(codec_id))?;
159 let original_size = input.get_u64_le();
160 let compressed_size = input.get_u64_le();
161 let crc32c = input.get_u32_le();
162 let compressed_size_usize = usize::try_from(compressed_size)
168 .map_err(|_| FrameError::PayloadTooLarge(compressed_size))?;
169 if compressed_size_usize > input.len() {
170 return Err(FrameError::PayloadTruncated {
171 compressed_size,
172 remaining: input.len(),
173 });
174 }
175 let payload = input.split_to(compressed_size_usize);
176 Ok((
177 FrameHeader {
178 codec,
179 original_size,
180 compressed_size,
181 crc32c,
182 },
183 payload,
184 input,
185 ))
186}
187
188pub struct FrameIter {
197 rest: Bytes,
198 fused: bool,
199}
200
201impl FrameIter {
202 pub fn new(input: Bytes) -> Self {
203 Self {
204 rest: input,
205 fused: false,
206 }
207 }
208}
209
210impl Iterator for FrameIter {
211 type Item = Result<(FrameHeader, Bytes), FrameError>;
212 fn next(&mut self) -> Option<Self::Item> {
213 if self.fused {
214 return None;
215 }
216 loop {
217 if self.rest.is_empty() {
218 return None;
219 }
220 if self.rest.len() < 4 {
221 self.fused = true;
222 return Some(Err(FrameError::TooShort(self.rest.len())));
223 }
224 let mut magic = [0u8; 4];
225 magic.copy_from_slice(&self.rest[..4]);
226 if &magic == PADDING_MAGIC {
227 if self.rest.len() < PADDING_HEADER_BYTES {
229 self.fused = true;
230 return Some(Err(FrameError::TooShort(self.rest.len())));
231 }
232 self.rest.advance(4);
233 let pad_len = self.rest.get_u64_le();
234 let pad_len_usize = match usize::try_from(pad_len) {
238 Ok(n) => n,
239 Err(_) => {
240 self.fused = true;
241 return Some(Err(FrameError::PayloadTooLarge(pad_len)));
242 }
243 };
244 if pad_len_usize > self.rest.len() {
245 self.fused = true;
246 return Some(Err(FrameError::PayloadTruncated {
247 compressed_size: pad_len,
248 remaining: self.rest.len(),
249 }));
250 }
251 self.rest.advance(pad_len_usize);
252 continue;
253 }
254 return match read_frame(std::mem::take(&mut self.rest)) {
256 Ok((hdr, payload, remainder)) => {
257 self.rest = remainder;
258 Some(Ok((hdr, payload)))
259 }
260 Err(e) => {
261 self.fused = true;
262 Some(Err(e))
263 }
264 };
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn frame_roundtrip_single() {
275 let payload = Bytes::from_static(b"hello frame payload");
276 let header = FrameHeader {
277 codec: CodecKind::CpuZstd,
278 original_size: 999,
279 compressed_size: payload.len() as u64,
280 crc32c: 0xdead_beef,
281 };
282 let mut buf = BytesMut::new();
283 write_frame(&mut buf, header, &payload);
284 assert_eq!(buf.len(), FRAME_HEADER_BYTES + payload.len());
285 let bytes = buf.freeze();
286 let (got_header, got_payload, rest) = read_frame(bytes).unwrap();
287 assert_eq!(got_header, header);
288 assert_eq!(got_payload, payload);
289 assert!(rest.is_empty());
290 }
291
292 #[test]
293 fn frame_iter_walks_all_frames_with_mixed_codecs() {
294 let codecs = [
296 CodecKind::Passthrough,
297 CodecKind::CpuZstd,
298 CodecKind::NvcompZstd,
299 CodecKind::NvcompBitcomp,
300 CodecKind::DietGpuAns,
301 ];
302 let mut buf = BytesMut::new();
303 for (i, codec) in codecs.iter().enumerate() {
304 let payload = vec![i as u8; (i + 1) * 4];
305 let h = FrameHeader {
306 codec: *codec,
307 original_size: 100 + i as u64,
308 compressed_size: payload.len() as u64,
309 crc32c: i as u32,
310 };
311 write_frame(&mut buf, h, &payload);
312 }
313 let total = FrameIter::new(buf.freeze())
314 .collect::<Result<Vec<_>, _>>()
315 .unwrap();
316 assert_eq!(total.len(), 5);
317 for (i, (h, payload)) in total.iter().enumerate() {
318 assert_eq!(h.codec, codecs[i], "codec must be preserved per frame");
319 assert_eq!(h.original_size, 100 + i as u64);
320 assert_eq!(h.crc32c, i as u32);
321 assert_eq!(payload.len(), (i + 1) * 4);
322 }
323 }
324
325 #[test]
326 fn frame_bad_magic_rejected() {
327 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
328 buf.put_slice(b"BAD!");
329 buf.put_u32_le(0); buf.put_u64_le(0);
331 buf.put_u64_le(0);
332 buf.put_u32_le(0);
333 let err = read_frame(buf.freeze()).unwrap_err();
334 assert!(matches!(err, FrameError::BadMagic { .. }));
335 }
336
337 #[test]
338 fn frame_truncated_rejected() {
339 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
341 buf.put_slice(FRAME_MAGIC);
342 buf.put_u32_le(CodecKind::CpuZstd.id());
343 buf.put_u64_le(100);
344 buf.put_u64_le(100);
345 buf.put_u32_le(0);
346 let err = read_frame(buf.freeze()).unwrap_err();
347 assert!(matches!(err, FrameError::PayloadTruncated { .. }));
348 }
349
350 #[test]
351 fn frame_unknown_codec_rejected() {
352 let mut buf = BytesMut::with_capacity(FRAME_HEADER_BYTES);
353 buf.put_slice(FRAME_MAGIC);
354 buf.put_u32_le(99); buf.put_u64_le(0);
356 buf.put_u64_le(0);
357 buf.put_u32_le(0);
358 let err = read_frame(buf.freeze()).unwrap_err();
359 assert!(matches!(err, FrameError::UnknownCodec(99)));
360 }
361
362 #[test]
363 fn frame_too_short_for_header_rejected() {
364 let buf = Bytes::from_static(b"shortdata");
365 let err = read_frame(buf).unwrap_err();
366 assert!(matches!(err, FrameError::TooShort(_)));
367 }
368
369 #[test]
370 fn padding_skipped_by_iter() {
371 let mut buf = BytesMut::new();
372 let p1 = Bytes::from_static(b"first frame");
374 write_frame(
375 &mut buf,
376 FrameHeader {
377 codec: CodecKind::CpuZstd,
378 original_size: 11,
379 compressed_size: p1.len() as u64,
380 crc32c: 0,
381 },
382 &p1,
383 );
384 pad_to_minimum(&mut buf, 1024);
386 assert!(buf.len() >= 1024);
387 let p2 = Bytes::from_static(b"second frame");
389 write_frame(
390 &mut buf,
391 FrameHeader {
392 codec: CodecKind::CpuZstd,
393 original_size: 12,
394 compressed_size: p2.len() as u64,
395 crc32c: 0,
396 },
397 &p2,
398 );
399
400 let frames: Vec<_> = FrameIter::new(buf.freeze())
401 .collect::<Result<_, _>>()
402 .unwrap();
403 assert_eq!(
404 frames.len(),
405 2,
406 "padding must be skipped, only data yielded"
407 );
408 assert_eq!(frames[0].1, p1);
409 assert_eq!(frames[1].1, p2);
410 }
411
412 #[test]
413 fn pad_to_minimum_is_noop_when_already_above() {
414 let mut buf = BytesMut::new();
415 buf.extend_from_slice(&[0u8; 1024]);
416 pad_to_minimum(&mut buf, 100);
417 assert_eq!(buf.len(), 1024);
418 }
419
420 #[test]
421 fn pad_to_minimum_grows_to_target() {
422 let mut buf = BytesMut::new();
423 write_frame(
424 &mut buf,
425 FrameHeader {
426 codec: CodecKind::Passthrough,
427 original_size: 0,
428 compressed_size: 0,
429 crc32c: 0,
430 },
431 &[],
432 );
433 let before = buf.len();
434 pad_to_minimum(&mut buf, 5_000_000);
435 assert!(buf.len() >= 5_000_000);
436 assert!(buf.len() < 5_000_000 + 64, "no excessive overshoot");
437 assert!(buf.len() > before);
438 }
439}