1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::multipart::{FrameHeader, write_frame};
35use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
36use std::sync::Arc;
37use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
38use tokio_util::io::{ReaderStream, StreamReader};
39
40pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync + 'static {
47 let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
48 StreamReader::new(mapped)
49}
50
51pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
53 reader: R,
54) -> StreamingBlob {
55 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
56 StreamingBlob::new(StreamWrapper { inner: stream })
57}
58
59pin_project_lite::pin_project! {
60 struct StreamWrapper<S> { #[pin] inner: S }
63}
64
65impl<S> Stream for StreamWrapper<S>
66where
67 S: Stream<Item = Result<Bytes, StdError>>,
68{
69 type Item = Result<Bytes, StdError>;
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 self.project().inner.poll_next(cx)
72 }
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 self.inner.size_hint()
75 }
76}
77
78impl<S> ByteStream for StreamWrapper<S>
79where
80 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
81{
82 fn remaining_length(&self) -> RemainingLength {
83 RemainingLength::unknown()
85 }
86}
87
88pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
98 let read = blob_to_async_read(body);
99 let mut decoder = ZstdDecoder::new(BufReader::new(read));
100 decoder.multiple_members(true);
101 async_read_to_blob(decoder)
102}
103
104pub struct Crc32cVerifyingReader<R> {
125 inner: R,
126 expected_crc: u32,
127 expected_size: u64,
128 rolling_crc: u32,
129 bytes_read: u64,
130 failed: bool,
135}
136
137impl<R> Crc32cVerifyingReader<R> {
138 pub fn new(inner: R, expected_crc: u32, expected_size: u64) -> Self {
139 Self {
140 inner,
141 expected_crc,
142 expected_size,
143 rolling_crc: 0,
144 bytes_read: 0,
145 failed: false,
146 }
147 }
148
149 #[cfg(test)]
152 pub fn rolling_crc(&self) -> u32 {
153 self.rolling_crc
154 }
155
156 #[cfg(test)]
157 pub fn bytes_read(&self) -> u64 {
158 self.bytes_read
159 }
160}
161
162impl<R> AsyncRead for Crc32cVerifyingReader<R>
163where
164 R: AsyncRead + Unpin,
165{
166 fn poll_read(
167 mut self: Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 buf: &mut ReadBuf<'_>,
170 ) -> Poll<io::Result<()>> {
171 if self.failed {
172 return Poll::Ready(Ok(()));
178 }
179 let pre_filled = buf.filled().len();
180 match Pin::new(&mut self.inner).poll_read(cx, buf) {
181 Poll::Pending => Poll::Pending,
182 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
183 Poll::Ready(Ok(())) => {
184 let new_filled = buf.filled().len();
185 if new_filled > pre_filled {
186 let chunk = &buf.filled()[pre_filled..new_filled];
187 self.rolling_crc = crc32c::crc32c_append(self.rolling_crc, chunk);
188 self.bytes_read = self.bytes_read.saturating_add(chunk.len() as u64);
189 Poll::Ready(Ok(()))
190 } else {
191 if self.bytes_read != self.expected_size {
196 self.failed = true;
197 return Poll::Ready(Err(io::Error::new(
198 io::ErrorKind::InvalidData,
199 format!(
200 "S4 streaming GET size mismatch: \
201 expected {} bytes, got {}",
202 self.expected_size, self.bytes_read
203 ),
204 )));
205 }
206 if self.rolling_crc != self.expected_crc {
207 self.failed = true;
208 return Poll::Ready(Err(io::Error::new(
209 io::ErrorKind::InvalidData,
210 format!(
211 "S4 streaming GET crc32c mismatch: \
212 expected {:#010x}, got {:#010x}",
213 self.expected_crc, self.rolling_crc
214 ),
215 )));
216 }
217 Poll::Ready(Ok(()))
218 }
219 }
220 }
221 }
222}
223
224pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
226 matches!(
230 codec,
231 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
232 )
233}
234
235pub fn supports_streaming_compress(codec: CodecKind) -> bool {
236 #[cfg(feature = "nvcomp-gpu")]
237 {
238 matches!(
239 codec,
240 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
241 )
242 }
243 #[cfg(not(feature = "nvcomp-gpu"))]
244 {
245 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
246 }
247}
248
249pub async fn streaming_compress_cpu_zstd(
259 body: StreamingBlob,
260 level: i32,
261) -> Result<(Bytes, ChunkManifest), CodecError> {
262 let mut read = blob_to_async_read(body);
263 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
264 let mut crc: u32 = 0;
265 let mut total_in: u64 = 0;
266 let mut in_buf = vec![0u8; 64 * 1024];
267
268 {
269 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
270 loop {
271 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
272 if n == 0 {
273 break;
274 }
275 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
276 total_in += n as u64;
277 encoder
278 .write_all(&in_buf[..n])
279 .await
280 .map_err(CodecError::Io)?;
281 }
282 encoder.shutdown().await.map_err(CodecError::Io)?;
283 }
284
285 let compressed_len = compressed_buf.len() as u64;
286 Ok((
287 Bytes::from(compressed_buf),
288 ChunkManifest {
289 codec: CodecKind::CpuZstd,
290 original_size: total_in,
291 compressed_size: compressed_len,
292 crc32c: crc,
293 },
294 ))
295}
296
297pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
304
305pub fn pick_chunk_size(content_length: Option<u64>) -> usize {
317 match content_length {
318 None => DEFAULT_S4F2_CHUNK_SIZE,
319 Some(len) if len <= 1024 * 1024 => 1024 * 1024,
320 Some(len) if len <= 100 * 1024 * 1024 => DEFAULT_S4F2_CHUNK_SIZE,
321 Some(_) => 16 * 1024 * 1024,
322 }
323}
324
325pub const DEFAULT_S4F2_INFLIGHT: usize = 3;
338
339pub async fn streaming_compress_to_frames(
370 body: StreamingBlob,
371 registry: Arc<CodecRegistry>,
372 codec_kind: CodecKind,
373 chunk_size: usize,
374 expected_size: Option<u64>,
375) -> Result<(Bytes, ChunkManifest), CodecError> {
376 streaming_compress_to_frames_with(
377 body,
378 registry,
379 codec_kind,
380 chunk_size,
381 DEFAULT_S4F2_INFLIGHT,
382 expected_size,
383 )
384 .await
385}
386
387pub async fn streaming_compress_to_frames_with(
392 body: StreamingBlob,
393 registry: Arc<CodecRegistry>,
394 codec_kind: CodecKind,
395 chunk_size: usize,
396 inflight: usize,
397 expected_size: Option<u64>,
398) -> Result<(Bytes, ChunkManifest), CodecError> {
399 use bytes::BytesMut;
400 use futures::StreamExt as _;
401 use futures::stream::FuturesOrdered;
402
403 let inflight = inflight.max(1);
404 let mut read = blob_to_async_read(body);
405 let mut framed = BytesMut::with_capacity(chunk_size);
406 let mut rolling_crc: u32 = 0;
407 let mut total_in: u64 = 0;
408 let mut chunk_buf = vec![0u8; chunk_size];
409
410 type InFlight = futures::future::BoxFuture<'static, Result<(FrameHeader, Bytes), CodecError>>;
414 let mut queue: FuturesOrdered<InFlight> = FuturesOrdered::new();
415 let mut eof = false;
416
417 loop {
418 while !eof && queue.len() < inflight {
420 let mut filled = 0;
421 while filled < chunk_size {
422 let n = read
423 .read(&mut chunk_buf[filled..])
424 .await
425 .map_err(CodecError::Io)?;
426 if n == 0 {
427 break;
428 }
429 filled += n;
430 }
431 if filled == 0 {
432 eof = true;
433 break;
434 }
435
436 let chunk_slice = &chunk_buf[..filled];
437 let chunk_crc = crc32c::crc32c(chunk_slice);
438 rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
439 total_in += filled as u64;
440 if let Some(expected) = expected_size
450 && total_in > expected
451 {
452 return Err(CodecError::OverlengthStream {
453 expected,
454 got: total_in,
455 });
456 }
457
458 let header = FrameHeader {
459 codec: codec_kind,
460 original_size: filled as u64,
461 compressed_size: 0, crc32c: chunk_crc,
463 };
464 let original_chunk = Bytes::copy_from_slice(chunk_slice);
465 let registry = Arc::clone(®istry);
466 queue.push_back(Box::pin(async move {
467 let (compressed_chunk, _per_chunk_manifest) =
468 registry.compress(original_chunk, codec_kind).await?;
469 let mut header = header;
470 header.compressed_size = compressed_chunk.len() as u64;
471 Ok::<_, CodecError>((header, compressed_chunk))
472 }));
473 }
474
475 match queue.next().await {
477 Some(Ok((header, compressed_chunk))) => {
478 write_frame(&mut framed, header, &compressed_chunk);
479 }
480 Some(Err(e)) => return Err(e),
481 None => break,
482 }
483 }
484
485 if let Some(expected) = expected_size {
494 if total_in < expected {
495 return Err(CodecError::TruncatedStream {
496 expected,
497 got: total_in,
498 });
499 }
500 if total_in > expected {
510 return Err(CodecError::OverlengthStream {
511 expected,
512 got: total_in,
513 });
514 }
515 }
516
517 let total_framed = framed.len() as u64;
518 Ok((
519 framed.freeze(),
520 ChunkManifest {
521 codec: codec_kind,
522 original_size: total_in,
523 compressed_size: total_framed,
524 crc32c: rolling_crc,
525 },
526 ))
527}
528
529pub async fn streaming_passthrough(
531 body: StreamingBlob,
532) -> Result<(Bytes, ChunkManifest), CodecError> {
533 let mut read = blob_to_async_read(body);
534 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
535 let mut crc: u32 = 0;
536 let mut total: u64 = 0;
537 let mut chunk = vec![0u8; 64 * 1024];
538 loop {
539 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
540 if n == 0 {
541 break;
542 }
543 crc = crc32c::crc32c_append(crc, &chunk[..n]);
544 total += n as u64;
545 buf.extend_from_slice(&chunk[..n]);
546 }
547 let len = buf.len() as u64;
548 Ok((
549 Bytes::from(buf),
550 ChunkManifest {
551 codec: CodecKind::Passthrough,
552 original_size: total,
553 compressed_size: len,
554 crc32c: crc,
555 },
556 ))
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use bytes::BytesMut;
563 use futures::stream;
564 use futures::stream::StreamExt;
565
566 #[test]
568 fn pick_chunk_size_thresholds() {
569 assert_eq!(pick_chunk_size(None), DEFAULT_S4F2_CHUNK_SIZE);
571 assert_eq!(pick_chunk_size(Some(0)), 1024 * 1024);
573 assert_eq!(pick_chunk_size(Some(64 * 1024)), 1024 * 1024);
574 assert_eq!(pick_chunk_size(Some(1024 * 1024)), 1024 * 1024);
575 assert_eq!(
577 pick_chunk_size(Some(1024 * 1024 + 1)),
578 DEFAULT_S4F2_CHUNK_SIZE
579 );
580 assert_eq!(
581 pick_chunk_size(Some(50 * 1024 * 1024)),
582 DEFAULT_S4F2_CHUNK_SIZE
583 );
584 assert_eq!(
585 pick_chunk_size(Some(100 * 1024 * 1024)),
586 DEFAULT_S4F2_CHUNK_SIZE
587 );
588 assert_eq!(
590 pick_chunk_size(Some(100 * 1024 * 1024 + 1)),
591 16 * 1024 * 1024
592 );
593 assert_eq!(
594 pick_chunk_size(Some(10 * 1024 * 1024 * 1024)),
595 16 * 1024 * 1024
596 );
597 }
598
599 async fn collect(blob: StreamingBlob) -> Bytes {
600 let mut buf = BytesMut::new();
601 let mut s = blob;
602 while let Some(chunk) = s.next().await {
603 buf.extend_from_slice(&chunk.unwrap());
604 }
605 buf.freeze()
606 }
607
608 fn make_blob(b: Bytes) -> StreamingBlob {
609 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
610 StreamingBlob::wrap(stream)
611 }
612
613 #[tokio::test]
614 async fn cpu_zstd_streaming_roundtrip_small() {
615 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
616 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
617 let blob = make_blob(Bytes::from(compressed));
618 let out_blob = cpu_zstd_decompress_stream(blob);
619 let out = collect(out_blob).await;
620 assert_eq!(out, original);
621 }
622
623 #[tokio::test]
624 async fn cpu_zstd_streaming_handles_chunked_input() {
625 let original = Bytes::from(vec![b'x'; 1_000_000]);
626 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
627 let mut chunks = Vec::new();
629 for chunk in compressed.chunks(1024) {
630 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
631 }
632 let in_stream = stream::iter(chunks);
633 let blob = StreamingBlob::wrap(in_stream);
634 let out_blob = cpu_zstd_decompress_stream(blob);
635 let out = collect(out_blob).await;
636 assert_eq!(out, original);
637 }
638
639 #[tokio::test]
640 async fn streaming_passes_through_for_passthrough() {
641 let original = Bytes::from_static(b"hello");
642 let blob = make_blob(original.clone());
643 let out_blob = async_read_to_blob(blob_to_async_read(blob));
644 let out = collect(out_blob).await;
645 assert_eq!(out, original);
646 }
647
648 #[tokio::test]
649 async fn streaming_compress_then_decompress_roundtrip() {
650 let original = Bytes::from(vec![b'q'; 200_000]);
651 let blob = make_blob(original.clone());
652 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
653 assert!(
654 compressed.len() < original.len() / 100,
655 "should be highly compressible"
656 );
657 assert_eq!(manifest.codec, CodecKind::CpuZstd);
658 assert_eq!(manifest.original_size, original.len() as u64);
659 assert_eq!(manifest.compressed_size, compressed.len() as u64);
660 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
662
663 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
665 let out = collect(decompressed_blob).await;
666 assert_eq!(out, original);
667 }
668
669 #[tokio::test]
674 async fn concatenated_zstd_frames_are_a_single_valid_stream() {
675 let chunk_a = Bytes::from(vec![b'a'; 50_000]);
676 let chunk_b = Bytes::from(vec![b'b'; 50_000]);
677 let chunk_c = Bytes::from(vec![b'c'; 50_000]);
678
679 let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
680 let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
681 let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
682
683 let mut concatenated: Vec<u8> = Vec::new();
684 concatenated.extend_from_slice(&frame_a);
685 concatenated.extend_from_slice(&frame_b);
686 concatenated.extend_from_slice(&frame_c);
687
688 let expected: Vec<u8> = chunk_a
689 .iter()
690 .chain(chunk_b.iter())
691 .chain(chunk_c.iter())
692 .copied()
693 .collect();
694
695 let blob = make_blob(Bytes::from(concatenated));
696 let out_blob = cpu_zstd_decompress_stream(blob);
697 let out = collect(out_blob).await;
698 assert_eq!(out, Bytes::from(expected));
699 }
700
701 #[tokio::test]
706 async fn streaming_chunked_compress_pipeline_roundtrip() {
707 async fn streaming_compress_chunked_cpu_zstd(
711 body: StreamingBlob,
712 chunk_size: usize,
713 ) -> Result<(Bytes, ChunkManifest), CodecError> {
714 let mut read = blob_to_async_read(body);
715 let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
716 let mut crc: u32 = 0;
717 let mut total_in: u64 = 0;
718 let mut chunk_buf = vec![0u8; chunk_size];
719 loop {
720 let mut filled = 0;
721 while filled < chunk_size {
722 let n = read
723 .read(&mut chunk_buf[filled..])
724 .await
725 .map_err(CodecError::Io)?;
726 if n == 0 {
727 break;
728 }
729 filled += n;
730 }
731 if filled == 0 {
732 break;
733 }
734 crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
735 total_in += filled as u64;
736 let compressed_chunk =
737 zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
738 compressed_buf.extend_from_slice(&compressed_chunk);
739 }
740 let compressed_len = compressed_buf.len() as u64;
741 Ok((
742 Bytes::from(compressed_buf),
743 ChunkManifest {
744 codec: CodecKind::CpuZstd,
745 original_size: total_in,
746 compressed_size: compressed_len,
747 crc32c: crc,
748 },
749 ))
750 }
751
752 let original = Bytes::from(
754 (0u32..65_536)
755 .flat_map(|n| n.to_le_bytes())
756 .collect::<Vec<u8>>(),
757 );
758 assert_eq!(original.len(), 262_144);
759
760 let blob = make_blob(original.clone());
761 let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
762 .await
763 .unwrap();
764
765 assert_eq!(manifest.original_size, original.len() as u64);
766 assert_eq!(manifest.compressed_size, compressed.len() as u64);
767 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
768
769 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
771 let out = collect(decompressed_blob).await;
772 assert_eq!(out, original);
773 }
774
775 #[tokio::test]
776 async fn streaming_passthrough_yields_input_unchanged() {
777 let original = Bytes::from_static(b"hello world");
778 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
779 .await
780 .unwrap();
781 assert_eq!(out, original);
782 assert_eq!(manifest.codec, CodecKind::Passthrough);
783 assert_eq!(manifest.original_size, original.len() as u64);
784 assert_eq!(manifest.compressed_size, original.len() as u64);
785 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
786 }
787
788 #[tokio::test]
795 async fn crc32c_verifying_reader_passes_correct_crc() {
796 use tokio::io::AsyncReadExt as _;
797 let original = Bytes::from(vec![0xa3u8; 17_000]);
798 let crc = crc32c::crc32c(&original);
799 let inner = blob_to_async_read(make_blob(original.clone()));
800 let mut verifier = Crc32cVerifyingReader::new(inner, crc, original.len() as u64);
801 let mut out = Vec::new();
802 verifier
803 .read_to_end(&mut out)
804 .await
805 .expect("clean stream must read cleanly");
806 assert_eq!(out, original.as_ref());
807 assert_eq!(verifier.rolling_crc(), crc);
809 assert_eq!(verifier.bytes_read(), original.len() as u64);
810 }
811
812 #[tokio::test]
817 async fn crc32c_verifying_reader_detects_corruption() {
818 use tokio::io::AsyncReadExt as _;
819 let original = Bytes::from_static(b"clean payload bytes");
820 let real_crc = crc32c::crc32c(&original);
821 let bogus_expected_crc = real_crc.wrapping_add(1);
826 let inner = blob_to_async_read(make_blob(original.clone()));
827 let mut verifier =
828 Crc32cVerifyingReader::new(inner, bogus_expected_crc, original.len() as u64);
829 let mut out = Vec::new();
830 let err = verifier
831 .read_to_end(&mut out)
832 .await
833 .expect_err("CRC mismatch must surface as io::Error");
834 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
835 let msg = err.to_string();
836 assert!(
837 msg.contains("crc32c mismatch"),
838 "error must mention CRC mismatch, got `{msg}`"
839 );
840 assert_eq!(out, original.as_ref());
843 }
844
845 #[tokio::test]
849 async fn streaming_compress_truncated_input_returns_truncated_stream_error() {
850 use s4_codec::cpu_zstd::CpuZstd;
851 let registry =
852 Arc::new(CodecRegistry::new(CodecKind::CpuZstd).with(Arc::new(CpuZstd::default())));
853 let actual = Bytes::from(vec![b'z'; 4096]);
857 let advertised: u64 = 16 * 1024;
858 let blob = make_blob(actual.clone());
859 let err = streaming_compress_to_frames(
860 blob,
861 registry,
862 CodecKind::CpuZstd,
863 1024,
864 Some(advertised),
865 )
866 .await
867 .expect_err("truncated stream must error");
868 match err {
869 CodecError::TruncatedStream { expected, got } => {
870 assert_eq!(expected, advertised);
871 assert_eq!(got, actual.len() as u64);
872 }
873 other => panic!("expected TruncatedStream, got {other:?}"),
874 }
875 }
876}