1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::multipart::{FrameHeader, write_frame};
35use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
36use std::sync::Arc;
37use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
38use tokio_util::io::{ReaderStream, StreamReader};
39
40pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync + 'static {
47 let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
48 StreamReader::new(mapped)
49}
50
51pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
53 reader: R,
54) -> StreamingBlob {
55 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
56 StreamingBlob::new(StreamWrapper { inner: stream })
57}
58
59pin_project_lite::pin_project! {
60 struct StreamWrapper<S> { #[pin] inner: S }
63}
64
65impl<S> Stream for StreamWrapper<S>
66where
67 S: Stream<Item = Result<Bytes, StdError>>,
68{
69 type Item = Result<Bytes, StdError>;
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 self.project().inner.poll_next(cx)
72 }
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 self.inner.size_hint()
75 }
76}
77
78impl<S> ByteStream for StreamWrapper<S>
79where
80 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
81{
82 fn remaining_length(&self) -> RemainingLength {
83 RemainingLength::unknown()
85 }
86}
87
88pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
98 let read = blob_to_async_read(body);
99 let mut decoder = ZstdDecoder::new(BufReader::new(read));
100 decoder.multiple_members(true);
101 async_read_to_blob(decoder)
102}
103
104pub struct Crc32cVerifyingReader<R> {
125 inner: R,
126 expected_crc: u32,
127 expected_size: u64,
128 rolling_crc: u32,
129 bytes_read: u64,
130 failed: bool,
135}
136
137impl<R> Crc32cVerifyingReader<R> {
138 pub fn new(inner: R, expected_crc: u32, expected_size: u64) -> Self {
139 Self {
140 inner,
141 expected_crc,
142 expected_size,
143 rolling_crc: 0,
144 bytes_read: 0,
145 failed: false,
146 }
147 }
148
149 #[cfg(test)]
152 pub fn rolling_crc(&self) -> u32 {
153 self.rolling_crc
154 }
155
156 #[cfg(test)]
157 pub fn bytes_read(&self) -> u64 {
158 self.bytes_read
159 }
160}
161
162impl<R> AsyncRead for Crc32cVerifyingReader<R>
163where
164 R: AsyncRead + Unpin,
165{
166 fn poll_read(
167 mut self: Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 buf: &mut ReadBuf<'_>,
170 ) -> Poll<io::Result<()>> {
171 if self.failed {
172 return Poll::Ready(Ok(()));
178 }
179 let pre_filled = buf.filled().len();
180 match Pin::new(&mut self.inner).poll_read(cx, buf) {
181 Poll::Pending => Poll::Pending,
182 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
183 Poll::Ready(Ok(())) => {
184 let new_filled = buf.filled().len();
185 if new_filled > pre_filled {
186 let chunk = &buf.filled()[pre_filled..new_filled];
187 self.rolling_crc = crc32c::crc32c_append(self.rolling_crc, chunk);
188 self.bytes_read = self.bytes_read.saturating_add(chunk.len() as u64);
189 Poll::Ready(Ok(()))
190 } else {
191 if self.bytes_read != self.expected_size {
196 self.failed = true;
197 return Poll::Ready(Err(io::Error::new(
198 io::ErrorKind::InvalidData,
199 format!(
200 "S4 streaming GET size mismatch: \
201 expected {} bytes, got {}",
202 self.expected_size, self.bytes_read
203 ),
204 )));
205 }
206 if self.rolling_crc != self.expected_crc {
207 self.failed = true;
208 return Poll::Ready(Err(io::Error::new(
209 io::ErrorKind::InvalidData,
210 format!(
211 "S4 streaming GET crc32c mismatch: \
212 expected {:#010x}, got {:#010x}",
213 self.expected_crc, self.rolling_crc
214 ),
215 )));
216 }
217 Poll::Ready(Ok(()))
218 }
219 }
220 }
221 }
222}
223
224pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
226 matches!(
230 codec,
231 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
232 )
233}
234
235pub fn supports_streaming_compress(codec: CodecKind) -> bool {
236 #[cfg(feature = "nvcomp-gpu")]
237 {
238 matches!(
239 codec,
240 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
241 )
242 }
243 #[cfg(not(feature = "nvcomp-gpu"))]
244 {
245 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
246 }
247}
248
249pub async fn streaming_compress_cpu_zstd(
259 body: StreamingBlob,
260 level: i32,
261) -> Result<(Bytes, ChunkManifest), CodecError> {
262 let mut read = blob_to_async_read(body);
263 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
264 let mut crc: u32 = 0;
265 let mut total_in: u64 = 0;
266 let mut in_buf = vec![0u8; 64 * 1024];
267
268 {
269 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
270 loop {
271 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
272 if n == 0 {
273 break;
274 }
275 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
276 total_in += n as u64;
277 encoder
278 .write_all(&in_buf[..n])
279 .await
280 .map_err(CodecError::Io)?;
281 }
282 encoder.shutdown().await.map_err(CodecError::Io)?;
283 }
284
285 let compressed_len = compressed_buf.len() as u64;
286 Ok((
287 Bytes::from(compressed_buf),
288 ChunkManifest {
289 codec: CodecKind::CpuZstd,
290 original_size: total_in,
291 compressed_size: compressed_len,
292 crc32c: crc,
293 },
294 ))
295}
296
297pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
304
305pub fn pick_chunk_size(content_length: Option<u64>) -> usize {
317 match content_length {
318 None => DEFAULT_S4F2_CHUNK_SIZE,
319 Some(len) if len <= 1024 * 1024 => 1024 * 1024,
320 Some(len) if len <= 100 * 1024 * 1024 => DEFAULT_S4F2_CHUNK_SIZE,
321 Some(_) => 16 * 1024 * 1024,
322 }
323}
324
325pub const DEFAULT_S4F2_INFLIGHT: usize = 3;
338
339pub async fn streaming_compress_to_frames(
370 body: StreamingBlob,
371 registry: Arc<CodecRegistry>,
372 codec_kind: CodecKind,
373 chunk_size: usize,
374 expected_size: Option<u64>,
375) -> Result<(Bytes, ChunkManifest), CodecError> {
376 streaming_compress_to_frames_with(
377 body,
378 registry,
379 codec_kind,
380 chunk_size,
381 DEFAULT_S4F2_INFLIGHT,
382 expected_size,
383 )
384 .await
385}
386
387pub async fn streaming_compress_to_frames_with(
392 body: StreamingBlob,
393 registry: Arc<CodecRegistry>,
394 codec_kind: CodecKind,
395 chunk_size: usize,
396 inflight: usize,
397 expected_size: Option<u64>,
398) -> Result<(Bytes, ChunkManifest), CodecError> {
399 use bytes::BytesMut;
400 use futures::StreamExt as _;
401 use futures::stream::FuturesOrdered;
402
403 let inflight = inflight.max(1);
404 let mut read = blob_to_async_read(body);
405 let mut framed = BytesMut::with_capacity(chunk_size);
406 let mut rolling_crc: u32 = 0;
407 let mut total_in: u64 = 0;
408 let mut chunk_buf = vec![0u8; chunk_size];
409
410 type InFlight = futures::future::BoxFuture<'static, Result<(FrameHeader, Bytes), CodecError>>;
414 let mut queue: FuturesOrdered<InFlight> = FuturesOrdered::new();
415 let mut eof = false;
416
417 loop {
418 while !eof && queue.len() < inflight {
420 let mut filled = 0;
421 while filled < chunk_size {
422 let n = read
423 .read(&mut chunk_buf[filled..])
424 .await
425 .map_err(CodecError::Io)?;
426 if n == 0 {
427 break;
428 }
429 filled += n;
430 }
431 if filled == 0 {
432 eof = true;
433 break;
434 }
435
436 let chunk_slice = &chunk_buf[..filled];
437 let chunk_crc = crc32c::crc32c(chunk_slice);
438 rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
439 total_in += filled as u64;
440
441 let header = FrameHeader {
442 codec: codec_kind,
443 original_size: filled as u64,
444 compressed_size: 0, crc32c: chunk_crc,
446 };
447 let original_chunk = Bytes::copy_from_slice(chunk_slice);
448 let registry = Arc::clone(®istry);
449 queue.push_back(Box::pin(async move {
450 let (compressed_chunk, _per_chunk_manifest) =
451 registry.compress(original_chunk, codec_kind).await?;
452 let mut header = header;
453 header.compressed_size = compressed_chunk.len() as u64;
454 Ok::<_, CodecError>((header, compressed_chunk))
455 }));
456 }
457
458 match queue.next().await {
460 Some(Ok((header, compressed_chunk))) => {
461 write_frame(&mut framed, header, &compressed_chunk);
462 }
463 Some(Err(e)) => return Err(e),
464 None => break,
465 }
466 }
467
468 if let Some(expected) = expected_size
477 && total_in < expected
478 {
479 return Err(CodecError::TruncatedStream {
480 expected,
481 got: total_in,
482 });
483 }
484
485 let total_framed = framed.len() as u64;
486 Ok((
487 framed.freeze(),
488 ChunkManifest {
489 codec: codec_kind,
490 original_size: total_in,
491 compressed_size: total_framed,
492 crc32c: rolling_crc,
493 },
494 ))
495}
496
497pub async fn streaming_passthrough(
499 body: StreamingBlob,
500) -> Result<(Bytes, ChunkManifest), CodecError> {
501 let mut read = blob_to_async_read(body);
502 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
503 let mut crc: u32 = 0;
504 let mut total: u64 = 0;
505 let mut chunk = vec![0u8; 64 * 1024];
506 loop {
507 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
508 if n == 0 {
509 break;
510 }
511 crc = crc32c::crc32c_append(crc, &chunk[..n]);
512 total += n as u64;
513 buf.extend_from_slice(&chunk[..n]);
514 }
515 let len = buf.len() as u64;
516 Ok((
517 Bytes::from(buf),
518 ChunkManifest {
519 codec: CodecKind::Passthrough,
520 original_size: total,
521 compressed_size: len,
522 crc32c: crc,
523 },
524 ))
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use bytes::BytesMut;
531 use futures::stream;
532 use futures::stream::StreamExt;
533
534 #[test]
536 fn pick_chunk_size_thresholds() {
537 assert_eq!(pick_chunk_size(None), DEFAULT_S4F2_CHUNK_SIZE);
539 assert_eq!(pick_chunk_size(Some(0)), 1024 * 1024);
541 assert_eq!(pick_chunk_size(Some(64 * 1024)), 1024 * 1024);
542 assert_eq!(pick_chunk_size(Some(1024 * 1024)), 1024 * 1024);
543 assert_eq!(
545 pick_chunk_size(Some(1024 * 1024 + 1)),
546 DEFAULT_S4F2_CHUNK_SIZE
547 );
548 assert_eq!(
549 pick_chunk_size(Some(50 * 1024 * 1024)),
550 DEFAULT_S4F2_CHUNK_SIZE
551 );
552 assert_eq!(
553 pick_chunk_size(Some(100 * 1024 * 1024)),
554 DEFAULT_S4F2_CHUNK_SIZE
555 );
556 assert_eq!(
558 pick_chunk_size(Some(100 * 1024 * 1024 + 1)),
559 16 * 1024 * 1024
560 );
561 assert_eq!(
562 pick_chunk_size(Some(10 * 1024 * 1024 * 1024)),
563 16 * 1024 * 1024
564 );
565 }
566
567 async fn collect(blob: StreamingBlob) -> Bytes {
568 let mut buf = BytesMut::new();
569 let mut s = blob;
570 while let Some(chunk) = s.next().await {
571 buf.extend_from_slice(&chunk.unwrap());
572 }
573 buf.freeze()
574 }
575
576 fn make_blob(b: Bytes) -> StreamingBlob {
577 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
578 StreamingBlob::wrap(stream)
579 }
580
581 #[tokio::test]
582 async fn cpu_zstd_streaming_roundtrip_small() {
583 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
584 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
585 let blob = make_blob(Bytes::from(compressed));
586 let out_blob = cpu_zstd_decompress_stream(blob);
587 let out = collect(out_blob).await;
588 assert_eq!(out, original);
589 }
590
591 #[tokio::test]
592 async fn cpu_zstd_streaming_handles_chunked_input() {
593 let original = Bytes::from(vec![b'x'; 1_000_000]);
594 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
595 let mut chunks = Vec::new();
597 for chunk in compressed.chunks(1024) {
598 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
599 }
600 let in_stream = stream::iter(chunks);
601 let blob = StreamingBlob::wrap(in_stream);
602 let out_blob = cpu_zstd_decompress_stream(blob);
603 let out = collect(out_blob).await;
604 assert_eq!(out, original);
605 }
606
607 #[tokio::test]
608 async fn streaming_passes_through_for_passthrough() {
609 let original = Bytes::from_static(b"hello");
610 let blob = make_blob(original.clone());
611 let out_blob = async_read_to_blob(blob_to_async_read(blob));
612 let out = collect(out_blob).await;
613 assert_eq!(out, original);
614 }
615
616 #[tokio::test]
617 async fn streaming_compress_then_decompress_roundtrip() {
618 let original = Bytes::from(vec![b'q'; 200_000]);
619 let blob = make_blob(original.clone());
620 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
621 assert!(
622 compressed.len() < original.len() / 100,
623 "should be highly compressible"
624 );
625 assert_eq!(manifest.codec, CodecKind::CpuZstd);
626 assert_eq!(manifest.original_size, original.len() as u64);
627 assert_eq!(manifest.compressed_size, compressed.len() as u64);
628 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
630
631 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
633 let out = collect(decompressed_blob).await;
634 assert_eq!(out, original);
635 }
636
637 #[tokio::test]
642 async fn concatenated_zstd_frames_are_a_single_valid_stream() {
643 let chunk_a = Bytes::from(vec![b'a'; 50_000]);
644 let chunk_b = Bytes::from(vec![b'b'; 50_000]);
645 let chunk_c = Bytes::from(vec![b'c'; 50_000]);
646
647 let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
648 let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
649 let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
650
651 let mut concatenated: Vec<u8> = Vec::new();
652 concatenated.extend_from_slice(&frame_a);
653 concatenated.extend_from_slice(&frame_b);
654 concatenated.extend_from_slice(&frame_c);
655
656 let expected: Vec<u8> = chunk_a
657 .iter()
658 .chain(chunk_b.iter())
659 .chain(chunk_c.iter())
660 .copied()
661 .collect();
662
663 let blob = make_blob(Bytes::from(concatenated));
664 let out_blob = cpu_zstd_decompress_stream(blob);
665 let out = collect(out_blob).await;
666 assert_eq!(out, Bytes::from(expected));
667 }
668
669 #[tokio::test]
674 async fn streaming_chunked_compress_pipeline_roundtrip() {
675 async fn streaming_compress_chunked_cpu_zstd(
679 body: StreamingBlob,
680 chunk_size: usize,
681 ) -> Result<(Bytes, ChunkManifest), CodecError> {
682 let mut read = blob_to_async_read(body);
683 let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
684 let mut crc: u32 = 0;
685 let mut total_in: u64 = 0;
686 let mut chunk_buf = vec![0u8; chunk_size];
687 loop {
688 let mut filled = 0;
689 while filled < chunk_size {
690 let n = read
691 .read(&mut chunk_buf[filled..])
692 .await
693 .map_err(CodecError::Io)?;
694 if n == 0 {
695 break;
696 }
697 filled += n;
698 }
699 if filled == 0 {
700 break;
701 }
702 crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
703 total_in += filled as u64;
704 let compressed_chunk =
705 zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
706 compressed_buf.extend_from_slice(&compressed_chunk);
707 }
708 let compressed_len = compressed_buf.len() as u64;
709 Ok((
710 Bytes::from(compressed_buf),
711 ChunkManifest {
712 codec: CodecKind::CpuZstd,
713 original_size: total_in,
714 compressed_size: compressed_len,
715 crc32c: crc,
716 },
717 ))
718 }
719
720 let original = Bytes::from(
722 (0u32..65_536)
723 .flat_map(|n| n.to_le_bytes())
724 .collect::<Vec<u8>>(),
725 );
726 assert_eq!(original.len(), 262_144);
727
728 let blob = make_blob(original.clone());
729 let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
730 .await
731 .unwrap();
732
733 assert_eq!(manifest.original_size, original.len() as u64);
734 assert_eq!(manifest.compressed_size, compressed.len() as u64);
735 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
736
737 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
739 let out = collect(decompressed_blob).await;
740 assert_eq!(out, original);
741 }
742
743 #[tokio::test]
744 async fn streaming_passthrough_yields_input_unchanged() {
745 let original = Bytes::from_static(b"hello world");
746 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
747 .await
748 .unwrap();
749 assert_eq!(out, original);
750 assert_eq!(manifest.codec, CodecKind::Passthrough);
751 assert_eq!(manifest.original_size, original.len() as u64);
752 assert_eq!(manifest.compressed_size, original.len() as u64);
753 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
754 }
755
756 #[tokio::test]
763 async fn crc32c_verifying_reader_passes_correct_crc() {
764 use tokio::io::AsyncReadExt as _;
765 let original = Bytes::from(vec![0xa3u8; 17_000]);
766 let crc = crc32c::crc32c(&original);
767 let inner = blob_to_async_read(make_blob(original.clone()));
768 let mut verifier = Crc32cVerifyingReader::new(inner, crc, original.len() as u64);
769 let mut out = Vec::new();
770 verifier
771 .read_to_end(&mut out)
772 .await
773 .expect("clean stream must read cleanly");
774 assert_eq!(out, original.as_ref());
775 assert_eq!(verifier.rolling_crc(), crc);
777 assert_eq!(verifier.bytes_read(), original.len() as u64);
778 }
779
780 #[tokio::test]
785 async fn crc32c_verifying_reader_detects_corruption() {
786 use tokio::io::AsyncReadExt as _;
787 let original = Bytes::from_static(b"clean payload bytes");
788 let real_crc = crc32c::crc32c(&original);
789 let bogus_expected_crc = real_crc.wrapping_add(1);
794 let inner = blob_to_async_read(make_blob(original.clone()));
795 let mut verifier =
796 Crc32cVerifyingReader::new(inner, bogus_expected_crc, original.len() as u64);
797 let mut out = Vec::new();
798 let err = verifier
799 .read_to_end(&mut out)
800 .await
801 .expect_err("CRC mismatch must surface as io::Error");
802 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
803 let msg = err.to_string();
804 assert!(
805 msg.contains("crc32c mismatch"),
806 "error must mention CRC mismatch, got `{msg}`"
807 );
808 assert_eq!(out, original.as_ref());
811 }
812
813 #[tokio::test]
817 async fn streaming_compress_truncated_input_returns_truncated_stream_error() {
818 use s4_codec::cpu_zstd::CpuZstd;
819 let registry =
820 Arc::new(CodecRegistry::new(CodecKind::CpuZstd).with(Arc::new(CpuZstd::default())));
821 let actual = Bytes::from(vec![b'z'; 4096]);
825 let advertised: u64 = 16 * 1024;
826 let blob = make_blob(actual.clone());
827 let err = streaming_compress_to_frames(
828 blob,
829 registry,
830 CodecKind::CpuZstd,
831 1024,
832 Some(advertised),
833 )
834 .await
835 .expect_err("truncated stream must error");
836 match err {
837 CodecError::TruncatedStream { expected, got } => {
838 assert_eq!(expected, advertised);
839 assert_eq!(got, actual.len() as u64);
840 }
841 other => panic!("expected TruncatedStream, got {other:?}"),
842 }
843 }
844}