1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::multipart::{FrameHeader, write_frame};
35use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
36use std::sync::Arc;
37use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
38use tokio_util::io::{ReaderStream, StreamReader};
39
40pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync + 'static {
54 let mapped = blob.map(|chunk| chunk.map_err(io::Error::other));
55 StreamReader::new(mapped)
56}
57
58pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
60 reader: R,
61) -> StreamingBlob {
62 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
63 StreamingBlob::new(StreamWrapper { inner: stream })
64}
65
66pin_project_lite::pin_project! {
67 struct StreamWrapper<S> { #[pin] inner: S }
70}
71
72impl<S> Stream for StreamWrapper<S>
73where
74 S: Stream<Item = Result<Bytes, StdError>>,
75{
76 type Item = Result<Bytes, StdError>;
77 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 self.project().inner.poll_next(cx)
79 }
80 fn size_hint(&self) -> (usize, Option<usize>) {
81 self.inner.size_hint()
82 }
83}
84
85impl<S> ByteStream for StreamWrapper<S>
86where
87 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
88{
89 fn remaining_length(&self) -> RemainingLength {
90 RemainingLength::unknown()
92 }
93}
94
95pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
105 let read = blob_to_async_read(body);
106 let mut decoder = ZstdDecoder::new(BufReader::new(read));
107 decoder.multiple_members(true);
108 async_read_to_blob(decoder)
109}
110
111pub struct Crc32cVerifyingReader<R> {
132 inner: R,
133 expected_crc: u32,
134 expected_size: u64,
135 rolling_crc: u32,
136 bytes_read: u64,
137 failed: bool,
142}
143
144impl<R> Crc32cVerifyingReader<R> {
145 pub fn new(inner: R, expected_crc: u32, expected_size: u64) -> Self {
146 Self {
147 inner,
148 expected_crc,
149 expected_size,
150 rolling_crc: 0,
151 bytes_read: 0,
152 failed: false,
153 }
154 }
155
156 #[cfg(test)]
159 pub fn rolling_crc(&self) -> u32 {
160 self.rolling_crc
161 }
162
163 #[cfg(test)]
164 pub fn bytes_read(&self) -> u64 {
165 self.bytes_read
166 }
167}
168
169impl<R> AsyncRead for Crc32cVerifyingReader<R>
170where
171 R: AsyncRead + Unpin,
172{
173 fn poll_read(
174 mut self: Pin<&mut Self>,
175 cx: &mut Context<'_>,
176 buf: &mut ReadBuf<'_>,
177 ) -> Poll<io::Result<()>> {
178 if self.failed {
179 return Poll::Ready(Ok(()));
185 }
186 let pre_filled = buf.filled().len();
187 match Pin::new(&mut self.inner).poll_read(cx, buf) {
188 Poll::Pending => Poll::Pending,
189 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
190 Poll::Ready(Ok(())) => {
191 let new_filled = buf.filled().len();
192 if new_filled > pre_filled {
193 let chunk = &buf.filled()[pre_filled..new_filled];
194 self.rolling_crc = crc32c::crc32c_append(self.rolling_crc, chunk);
195 self.bytes_read = self.bytes_read.saturating_add(chunk.len() as u64);
196 Poll::Ready(Ok(()))
197 } else {
198 if self.bytes_read != self.expected_size {
203 self.failed = true;
204 return Poll::Ready(Err(io::Error::new(
205 io::ErrorKind::InvalidData,
206 format!(
207 "S4 streaming GET size mismatch: \
208 expected {} bytes, got {}",
209 self.expected_size, self.bytes_read
210 ),
211 )));
212 }
213 if self.rolling_crc != self.expected_crc {
214 self.failed = true;
215 return Poll::Ready(Err(io::Error::new(
216 io::ErrorKind::InvalidData,
217 format!(
218 "S4 streaming GET crc32c mismatch: \
219 expected {:#010x}, got {:#010x}",
220 self.expected_crc, self.rolling_crc
221 ),
222 )));
223 }
224 Poll::Ready(Ok(()))
225 }
226 }
227 }
228 }
229}
230
231pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
233 matches!(
237 codec,
238 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
239 )
240}
241
242pub fn supports_streaming_compress(codec: CodecKind) -> bool {
243 #[cfg(feature = "nvcomp-gpu")]
244 {
245 matches!(
246 codec,
247 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
248 )
249 }
250 #[cfg(not(feature = "nvcomp-gpu"))]
251 {
252 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
253 }
254}
255
256pub async fn streaming_compress_cpu_zstd(
266 body: StreamingBlob,
267 level: i32,
268) -> Result<(Bytes, ChunkManifest), CodecError> {
269 let mut read = blob_to_async_read(body);
270 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
271 let mut crc: u32 = 0;
272 let mut total_in: u64 = 0;
273 let mut in_buf = vec![0u8; 64 * 1024];
274
275 {
276 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
277 loop {
278 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
279 if n == 0 {
280 break;
281 }
282 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
283 total_in += n as u64;
284 encoder
285 .write_all(&in_buf[..n])
286 .await
287 .map_err(CodecError::Io)?;
288 }
289 encoder.shutdown().await.map_err(CodecError::Io)?;
290 }
291
292 let compressed_len = compressed_buf.len() as u64;
293 Ok((
294 Bytes::from(compressed_buf),
295 ChunkManifest {
296 codec: CodecKind::CpuZstd,
297 original_size: total_in,
298 compressed_size: compressed_len,
299 crc32c: crc,
300 },
301 ))
302}
303
304pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
311
312pub fn pick_chunk_size(content_length: Option<u64>) -> usize {
324 match content_length {
325 None => DEFAULT_S4F2_CHUNK_SIZE,
326 Some(len) if len <= 1024 * 1024 => 1024 * 1024,
327 Some(len) if len <= 100 * 1024 * 1024 => DEFAULT_S4F2_CHUNK_SIZE,
328 Some(_) => 16 * 1024 * 1024,
329 }
330}
331
332pub const DEFAULT_S4F2_INFLIGHT: usize = 3;
345
346pub async fn streaming_compress_to_frames(
377 body: StreamingBlob,
378 registry: Arc<CodecRegistry>,
379 codec_kind: CodecKind,
380 chunk_size: usize,
381 expected_size: Option<u64>,
382) -> Result<(Bytes, ChunkManifest), CodecError> {
383 streaming_compress_to_frames_with(
384 body,
385 registry,
386 codec_kind,
387 chunk_size,
388 DEFAULT_S4F2_INFLIGHT,
389 expected_size,
390 )
391 .await
392}
393
394pub async fn streaming_compress_to_frames_with(
399 body: StreamingBlob,
400 registry: Arc<CodecRegistry>,
401 codec_kind: CodecKind,
402 chunk_size: usize,
403 inflight: usize,
404 expected_size: Option<u64>,
405) -> Result<(Bytes, ChunkManifest), CodecError> {
406 use bytes::BytesMut;
407 use futures::StreamExt as _;
408 use futures::stream::FuturesOrdered;
409
410 let inflight = inflight.max(1);
411 let mut read = blob_to_async_read(body);
412 let mut framed = BytesMut::with_capacity(chunk_size);
413 let mut rolling_crc: u32 = 0;
414 let mut total_in: u64 = 0;
415 let mut chunk_buf = vec![0u8; chunk_size];
416
417 type InFlight = futures::future::BoxFuture<'static, Result<(FrameHeader, Bytes), CodecError>>;
421 let mut queue: FuturesOrdered<InFlight> = FuturesOrdered::new();
422 let mut eof = false;
423
424 loop {
425 while !eof && queue.len() < inflight {
427 let mut filled = 0;
428 while filled < chunk_size {
429 let n = read
430 .read(&mut chunk_buf[filled..])
431 .await
432 .map_err(CodecError::Io)?;
433 if n == 0 {
434 break;
435 }
436 filled += n;
437 }
438 if filled == 0 {
439 eof = true;
440 break;
441 }
442
443 let chunk_slice = &chunk_buf[..filled];
444 let chunk_crc = crc32c::crc32c(chunk_slice);
445 rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
446 total_in += filled as u64;
447 if let Some(expected) = expected_size
457 && total_in > expected
458 {
459 return Err(CodecError::OverlengthStream {
460 expected,
461 got: total_in,
462 });
463 }
464
465 let header = FrameHeader {
466 codec: codec_kind,
467 original_size: filled as u64,
468 compressed_size: 0, crc32c: chunk_crc,
470 };
471 let original_chunk = Bytes::copy_from_slice(chunk_slice);
472 let registry = Arc::clone(®istry);
473 queue.push_back(Box::pin(async move {
474 let (compressed_chunk, _per_chunk_manifest) =
475 registry.compress(original_chunk, codec_kind).await?;
476 let mut header = header;
477 header.compressed_size = compressed_chunk.len() as u64;
478 Ok::<_, CodecError>((header, compressed_chunk))
479 }));
480 }
481
482 match queue.next().await {
484 Some(Ok((header, compressed_chunk))) => {
485 write_frame(&mut framed, header, &compressed_chunk);
486 }
487 Some(Err(e)) => return Err(e),
488 None => break,
489 }
490 }
491
492 if let Some(expected) = expected_size {
501 if total_in < expected {
502 return Err(CodecError::TruncatedStream {
503 expected,
504 got: total_in,
505 });
506 }
507 if total_in > expected {
517 return Err(CodecError::OverlengthStream {
518 expected,
519 got: total_in,
520 });
521 }
522 }
523
524 let total_framed = framed.len() as u64;
525 Ok((
526 framed.freeze(),
527 ChunkManifest {
528 codec: codec_kind,
529 original_size: total_in,
530 compressed_size: total_framed,
531 crc32c: rolling_crc,
532 },
533 ))
534}
535
536pub async fn streaming_passthrough(
538 body: StreamingBlob,
539) -> Result<(Bytes, ChunkManifest), CodecError> {
540 let mut read = blob_to_async_read(body);
541 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
542 let mut crc: u32 = 0;
543 let mut total: u64 = 0;
544 let mut chunk = vec![0u8; 64 * 1024];
545 loop {
546 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
547 if n == 0 {
548 break;
549 }
550 crc = crc32c::crc32c_append(crc, &chunk[..n]);
551 total += n as u64;
552 buf.extend_from_slice(&chunk[..n]);
553 }
554 let len = buf.len() as u64;
555 Ok((
556 Bytes::from(buf),
557 ChunkManifest {
558 codec: CodecKind::Passthrough,
559 original_size: total,
560 compressed_size: len,
561 crc32c: crc,
562 },
563 ))
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use bytes::BytesMut;
570 use futures::stream;
571 use futures::stream::StreamExt;
572
573 #[test]
575 fn pick_chunk_size_thresholds() {
576 assert_eq!(pick_chunk_size(None), DEFAULT_S4F2_CHUNK_SIZE);
578 assert_eq!(pick_chunk_size(Some(0)), 1024 * 1024);
580 assert_eq!(pick_chunk_size(Some(64 * 1024)), 1024 * 1024);
581 assert_eq!(pick_chunk_size(Some(1024 * 1024)), 1024 * 1024);
582 assert_eq!(
584 pick_chunk_size(Some(1024 * 1024 + 1)),
585 DEFAULT_S4F2_CHUNK_SIZE
586 );
587 assert_eq!(
588 pick_chunk_size(Some(50 * 1024 * 1024)),
589 DEFAULT_S4F2_CHUNK_SIZE
590 );
591 assert_eq!(
592 pick_chunk_size(Some(100 * 1024 * 1024)),
593 DEFAULT_S4F2_CHUNK_SIZE
594 );
595 assert_eq!(
597 pick_chunk_size(Some(100 * 1024 * 1024 + 1)),
598 16 * 1024 * 1024
599 );
600 assert_eq!(
601 pick_chunk_size(Some(10 * 1024 * 1024 * 1024)),
602 16 * 1024 * 1024
603 );
604 }
605
606 async fn collect(blob: StreamingBlob) -> Bytes {
607 let mut buf = BytesMut::new();
608 let mut s = blob;
609 while let Some(chunk) = s.next().await {
610 buf.extend_from_slice(&chunk.unwrap());
611 }
612 buf.freeze()
613 }
614
615 fn make_blob(b: Bytes) -> StreamingBlob {
616 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
617 StreamingBlob::wrap(stream)
618 }
619
620 #[tokio::test]
621 async fn cpu_zstd_streaming_roundtrip_small() {
622 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
623 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
624 let blob = make_blob(Bytes::from(compressed));
625 let out_blob = cpu_zstd_decompress_stream(blob);
626 let out = collect(out_blob).await;
627 assert_eq!(out, original);
628 }
629
630 #[tokio::test]
631 async fn cpu_zstd_streaming_handles_chunked_input() {
632 let original = Bytes::from(vec![b'x'; 1_000_000]);
633 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
634 let mut chunks = Vec::new();
636 for chunk in compressed.chunks(1024) {
637 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
638 }
639 let in_stream = stream::iter(chunks);
640 let blob = StreamingBlob::wrap(in_stream);
641 let out_blob = cpu_zstd_decompress_stream(blob);
642 let out = collect(out_blob).await;
643 assert_eq!(out, original);
644 }
645
646 #[tokio::test]
647 async fn streaming_passes_through_for_passthrough() {
648 let original = Bytes::from_static(b"hello");
649 let blob = make_blob(original.clone());
650 let out_blob = async_read_to_blob(blob_to_async_read(blob));
651 let out = collect(out_blob).await;
652 assert_eq!(out, original);
653 }
654
655 #[tokio::test]
656 async fn streaming_compress_then_decompress_roundtrip() {
657 let original = Bytes::from(vec![b'q'; 200_000]);
658 let blob = make_blob(original.clone());
659 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
660 assert!(
661 compressed.len() < original.len() / 100,
662 "should be highly compressible"
663 );
664 assert_eq!(manifest.codec, CodecKind::CpuZstd);
665 assert_eq!(manifest.original_size, original.len() as u64);
666 assert_eq!(manifest.compressed_size, compressed.len() as u64);
667 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
669
670 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
672 let out = collect(decompressed_blob).await;
673 assert_eq!(out, original);
674 }
675
676 #[tokio::test]
681 async fn concatenated_zstd_frames_are_a_single_valid_stream() {
682 let chunk_a = Bytes::from(vec![b'a'; 50_000]);
683 let chunk_b = Bytes::from(vec![b'b'; 50_000]);
684 let chunk_c = Bytes::from(vec![b'c'; 50_000]);
685
686 let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
687 let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
688 let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
689
690 let mut concatenated: Vec<u8> = Vec::new();
691 concatenated.extend_from_slice(&frame_a);
692 concatenated.extend_from_slice(&frame_b);
693 concatenated.extend_from_slice(&frame_c);
694
695 let expected: Vec<u8> = chunk_a
696 .iter()
697 .chain(chunk_b.iter())
698 .chain(chunk_c.iter())
699 .copied()
700 .collect();
701
702 let blob = make_blob(Bytes::from(concatenated));
703 let out_blob = cpu_zstd_decompress_stream(blob);
704 let out = collect(out_blob).await;
705 assert_eq!(out, Bytes::from(expected));
706 }
707
708 #[tokio::test]
713 async fn streaming_chunked_compress_pipeline_roundtrip() {
714 async fn streaming_compress_chunked_cpu_zstd(
718 body: StreamingBlob,
719 chunk_size: usize,
720 ) -> Result<(Bytes, ChunkManifest), CodecError> {
721 let mut read = blob_to_async_read(body);
722 let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
723 let mut crc: u32 = 0;
724 let mut total_in: u64 = 0;
725 let mut chunk_buf = vec![0u8; chunk_size];
726 loop {
727 let mut filled = 0;
728 while filled < chunk_size {
729 let n = read
730 .read(&mut chunk_buf[filled..])
731 .await
732 .map_err(CodecError::Io)?;
733 if n == 0 {
734 break;
735 }
736 filled += n;
737 }
738 if filled == 0 {
739 break;
740 }
741 crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
742 total_in += filled as u64;
743 let compressed_chunk =
744 zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
745 compressed_buf.extend_from_slice(&compressed_chunk);
746 }
747 let compressed_len = compressed_buf.len() as u64;
748 Ok((
749 Bytes::from(compressed_buf),
750 ChunkManifest {
751 codec: CodecKind::CpuZstd,
752 original_size: total_in,
753 compressed_size: compressed_len,
754 crc32c: crc,
755 },
756 ))
757 }
758
759 let original = Bytes::from(
761 (0u32..65_536)
762 .flat_map(|n| n.to_le_bytes())
763 .collect::<Vec<u8>>(),
764 );
765 assert_eq!(original.len(), 262_144);
766
767 let blob = make_blob(original.clone());
768 let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
769 .await
770 .unwrap();
771
772 assert_eq!(manifest.original_size, original.len() as u64);
773 assert_eq!(manifest.compressed_size, compressed.len() as u64);
774 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
775
776 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
778 let out = collect(decompressed_blob).await;
779 assert_eq!(out, original);
780 }
781
782 #[tokio::test]
783 async fn streaming_passthrough_yields_input_unchanged() {
784 let original = Bytes::from_static(b"hello world");
785 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
786 .await
787 .unwrap();
788 assert_eq!(out, original);
789 assert_eq!(manifest.codec, CodecKind::Passthrough);
790 assert_eq!(manifest.original_size, original.len() as u64);
791 assert_eq!(manifest.compressed_size, original.len() as u64);
792 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
793 }
794
795 #[tokio::test]
802 async fn crc32c_verifying_reader_passes_correct_crc() {
803 use tokio::io::AsyncReadExt as _;
804 let original = Bytes::from(vec![0xa3u8; 17_000]);
805 let crc = crc32c::crc32c(&original);
806 let inner = blob_to_async_read(make_blob(original.clone()));
807 let mut verifier = Crc32cVerifyingReader::new(inner, crc, original.len() as u64);
808 let mut out = Vec::new();
809 verifier
810 .read_to_end(&mut out)
811 .await
812 .expect("clean stream must read cleanly");
813 assert_eq!(out, original.as_ref());
814 assert_eq!(verifier.rolling_crc(), crc);
816 assert_eq!(verifier.bytes_read(), original.len() as u64);
817 }
818
819 #[tokio::test]
824 async fn crc32c_verifying_reader_detects_corruption() {
825 use tokio::io::AsyncReadExt as _;
826 let original = Bytes::from_static(b"clean payload bytes");
827 let real_crc = crc32c::crc32c(&original);
828 let bogus_expected_crc = real_crc.wrapping_add(1);
833 let inner = blob_to_async_read(make_blob(original.clone()));
834 let mut verifier =
835 Crc32cVerifyingReader::new(inner, bogus_expected_crc, original.len() as u64);
836 let mut out = Vec::new();
837 let err = verifier
838 .read_to_end(&mut out)
839 .await
840 .expect_err("CRC mismatch must surface as io::Error");
841 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
842 let msg = err.to_string();
843 assert!(
844 msg.contains("crc32c mismatch"),
845 "error must mention CRC mismatch, got `{msg}`"
846 );
847 assert_eq!(out, original.as_ref());
850 }
851
852 #[tokio::test]
856 async fn streaming_compress_truncated_input_returns_truncated_stream_error() {
857 use s4_codec::cpu_zstd::CpuZstd;
858 let registry =
859 Arc::new(CodecRegistry::new(CodecKind::CpuZstd).with(Arc::new(CpuZstd::default())));
860 let actual = Bytes::from(vec![b'z'; 4096]);
864 let advertised: u64 = 16 * 1024;
865 let blob = make_blob(actual.clone());
866 let err = streaming_compress_to_frames(
867 blob,
868 registry,
869 CodecKind::CpuZstd,
870 1024,
871 Some(advertised),
872 )
873 .await
874 .expect_err("truncated stream must error");
875 match err {
876 CodecError::TruncatedStream { expected, got } => {
877 assert_eq!(expected, advertised);
878 assert_eq!(got, actual.len() as u64);
879 }
880 other => panic!("expected TruncatedStream, got {other:?}"),
881 }
882 }
883}