1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::multipart::{FrameHeader, write_frame};
35use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
36use std::sync::Arc;
37use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
38use tokio_util::io::{ReaderStream, StreamReader};
39
40pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync + 'static {
47 let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
48 StreamReader::new(mapped)
49}
50
51pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
53 reader: R,
54) -> StreamingBlob {
55 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
56 StreamingBlob::new(StreamWrapper { inner: stream })
57}
58
59pin_project_lite::pin_project! {
60 struct StreamWrapper<S> { #[pin] inner: S }
63}
64
65impl<S> Stream for StreamWrapper<S>
66where
67 S: Stream<Item = Result<Bytes, StdError>>,
68{
69 type Item = Result<Bytes, StdError>;
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 self.project().inner.poll_next(cx)
72 }
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 self.inner.size_hint()
75 }
76}
77
78impl<S> ByteStream for StreamWrapper<S>
79where
80 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
81{
82 fn remaining_length(&self) -> RemainingLength {
83 RemainingLength::unknown()
85 }
86}
87
88pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
98 let read = blob_to_async_read(body);
99 let mut decoder = ZstdDecoder::new(BufReader::new(read));
100 decoder.multiple_members(true);
101 async_read_to_blob(decoder)
102}
103
104pub struct Crc32cVerifyingReader<R> {
125 inner: R,
126 expected_crc: u32,
127 expected_size: u64,
128 rolling_crc: u32,
129 bytes_read: u64,
130 failed: bool,
135}
136
137impl<R> Crc32cVerifyingReader<R> {
138 pub fn new(inner: R, expected_crc: u32, expected_size: u64) -> Self {
139 Self {
140 inner,
141 expected_crc,
142 expected_size,
143 rolling_crc: 0,
144 bytes_read: 0,
145 failed: false,
146 }
147 }
148
149 #[cfg(test)]
152 pub fn rolling_crc(&self) -> u32 {
153 self.rolling_crc
154 }
155
156 #[cfg(test)]
157 pub fn bytes_read(&self) -> u64 {
158 self.bytes_read
159 }
160}
161
162impl<R> AsyncRead for Crc32cVerifyingReader<R>
163where
164 R: AsyncRead + Unpin,
165{
166 fn poll_read(
167 mut self: Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 buf: &mut ReadBuf<'_>,
170 ) -> Poll<io::Result<()>> {
171 if self.failed {
172 return Poll::Ready(Ok(()));
178 }
179 let pre_filled = buf.filled().len();
180 match Pin::new(&mut self.inner).poll_read(cx, buf) {
181 Poll::Pending => Poll::Pending,
182 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
183 Poll::Ready(Ok(())) => {
184 let new_filled = buf.filled().len();
185 if new_filled > pre_filled {
186 let chunk = &buf.filled()[pre_filled..new_filled];
187 self.rolling_crc = crc32c::crc32c_append(self.rolling_crc, chunk);
188 self.bytes_read = self.bytes_read.saturating_add(chunk.len() as u64);
189 Poll::Ready(Ok(()))
190 } else {
191 if self.bytes_read != self.expected_size {
196 self.failed = true;
197 return Poll::Ready(Err(io::Error::new(
198 io::ErrorKind::InvalidData,
199 format!(
200 "S4 streaming GET size mismatch: \
201 expected {} bytes, got {}",
202 self.expected_size, self.bytes_read
203 ),
204 )));
205 }
206 if self.rolling_crc != self.expected_crc {
207 self.failed = true;
208 return Poll::Ready(Err(io::Error::new(
209 io::ErrorKind::InvalidData,
210 format!(
211 "S4 streaming GET crc32c mismatch: \
212 expected {:#010x}, got {:#010x}",
213 self.expected_crc, self.rolling_crc
214 ),
215 )));
216 }
217 Poll::Ready(Ok(()))
218 }
219 }
220 }
221 }
222}
223
224pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
226 matches!(
230 codec,
231 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
232 )
233}
234
235pub fn supports_streaming_compress(codec: CodecKind) -> bool {
236 #[cfg(feature = "nvcomp-gpu")]
237 {
238 matches!(
239 codec,
240 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
241 )
242 }
243 #[cfg(not(feature = "nvcomp-gpu"))]
244 {
245 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
246 }
247}
248
249pub async fn streaming_compress_cpu_zstd(
259 body: StreamingBlob,
260 level: i32,
261) -> Result<(Bytes, ChunkManifest), CodecError> {
262 let mut read = blob_to_async_read(body);
263 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
264 let mut crc: u32 = 0;
265 let mut total_in: u64 = 0;
266 let mut in_buf = vec![0u8; 64 * 1024];
267
268 {
269 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
270 loop {
271 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
272 if n == 0 {
273 break;
274 }
275 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
276 total_in += n as u64;
277 encoder
278 .write_all(&in_buf[..n])
279 .await
280 .map_err(CodecError::Io)?;
281 }
282 encoder.shutdown().await.map_err(CodecError::Io)?;
283 }
284
285 let compressed_len = compressed_buf.len() as u64;
286 Ok((
287 Bytes::from(compressed_buf),
288 ChunkManifest {
289 codec: CodecKind::CpuZstd,
290 original_size: total_in,
291 compressed_size: compressed_len,
292 crc32c: crc,
293 },
294 ))
295}
296
297pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
304
305pub fn pick_chunk_size(content_length: Option<u64>) -> usize {
317 match content_length {
318 None => DEFAULT_S4F2_CHUNK_SIZE,
319 Some(len) if len <= 1024 * 1024 => 1024 * 1024,
320 Some(len) if len <= 100 * 1024 * 1024 => DEFAULT_S4F2_CHUNK_SIZE,
321 Some(_) => 16 * 1024 * 1024,
322 }
323}
324
325pub const DEFAULT_S4F2_INFLIGHT: usize = 3;
338
339pub async fn streaming_compress_to_frames(
370 body: StreamingBlob,
371 registry: Arc<CodecRegistry>,
372 codec_kind: CodecKind,
373 chunk_size: usize,
374 expected_size: Option<u64>,
375) -> Result<(Bytes, ChunkManifest), CodecError> {
376 streaming_compress_to_frames_with(
377 body,
378 registry,
379 codec_kind,
380 chunk_size,
381 DEFAULT_S4F2_INFLIGHT,
382 expected_size,
383 )
384 .await
385}
386
387pub async fn streaming_compress_to_frames_with(
392 body: StreamingBlob,
393 registry: Arc<CodecRegistry>,
394 codec_kind: CodecKind,
395 chunk_size: usize,
396 inflight: usize,
397 expected_size: Option<u64>,
398) -> Result<(Bytes, ChunkManifest), CodecError> {
399 use bytes::BytesMut;
400 use futures::StreamExt as _;
401 use futures::stream::FuturesOrdered;
402
403 let inflight = inflight.max(1);
404 let mut read = blob_to_async_read(body);
405 let mut framed = BytesMut::with_capacity(chunk_size);
406 let mut rolling_crc: u32 = 0;
407 let mut total_in: u64 = 0;
408 let mut chunk_buf = vec![0u8; chunk_size];
409
410 type InFlight = futures::future::BoxFuture<'static, Result<(FrameHeader, Bytes), CodecError>>;
414 let mut queue: FuturesOrdered<InFlight> = FuturesOrdered::new();
415 let mut eof = false;
416
417 loop {
418 while !eof && queue.len() < inflight {
420 let mut filled = 0;
421 while filled < chunk_size {
422 let n = read
423 .read(&mut chunk_buf[filled..])
424 .await
425 .map_err(CodecError::Io)?;
426 if n == 0 {
427 break;
428 }
429 filled += n;
430 }
431 if filled == 0 {
432 eof = true;
433 break;
434 }
435
436 let chunk_slice = &chunk_buf[..filled];
437 let chunk_crc = crc32c::crc32c(chunk_slice);
438 rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
439 total_in += filled as u64;
440
441 let header = FrameHeader {
442 codec: codec_kind,
443 original_size: filled as u64,
444 compressed_size: 0, crc32c: chunk_crc,
446 };
447 let original_chunk = Bytes::copy_from_slice(chunk_slice);
448 let registry = Arc::clone(®istry);
449 queue.push_back(Box::pin(async move {
450 let (compressed_chunk, _per_chunk_manifest) =
451 registry.compress(original_chunk, codec_kind).await?;
452 let mut header = header;
453 header.compressed_size = compressed_chunk.len() as u64;
454 Ok::<_, CodecError>((header, compressed_chunk))
455 }));
456 }
457
458 match queue.next().await {
460 Some(Ok((header, compressed_chunk))) => {
461 write_frame(&mut framed, header, &compressed_chunk);
462 }
463 Some(Err(e)) => return Err(e),
464 None => break,
465 }
466 }
467
468 if let Some(expected) = expected_size {
477 if total_in < expected {
478 return Err(CodecError::TruncatedStream {
479 expected,
480 got: total_in,
481 });
482 }
483 if total_in > expected {
493 return Err(CodecError::OverlengthStream {
494 expected,
495 got: total_in,
496 });
497 }
498 }
499
500 let total_framed = framed.len() as u64;
501 Ok((
502 framed.freeze(),
503 ChunkManifest {
504 codec: codec_kind,
505 original_size: total_in,
506 compressed_size: total_framed,
507 crc32c: rolling_crc,
508 },
509 ))
510}
511
512pub async fn streaming_passthrough(
514 body: StreamingBlob,
515) -> Result<(Bytes, ChunkManifest), CodecError> {
516 let mut read = blob_to_async_read(body);
517 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
518 let mut crc: u32 = 0;
519 let mut total: u64 = 0;
520 let mut chunk = vec![0u8; 64 * 1024];
521 loop {
522 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
523 if n == 0 {
524 break;
525 }
526 crc = crc32c::crc32c_append(crc, &chunk[..n]);
527 total += n as u64;
528 buf.extend_from_slice(&chunk[..n]);
529 }
530 let len = buf.len() as u64;
531 Ok((
532 Bytes::from(buf),
533 ChunkManifest {
534 codec: CodecKind::Passthrough,
535 original_size: total,
536 compressed_size: len,
537 crc32c: crc,
538 },
539 ))
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use bytes::BytesMut;
546 use futures::stream;
547 use futures::stream::StreamExt;
548
549 #[test]
551 fn pick_chunk_size_thresholds() {
552 assert_eq!(pick_chunk_size(None), DEFAULT_S4F2_CHUNK_SIZE);
554 assert_eq!(pick_chunk_size(Some(0)), 1024 * 1024);
556 assert_eq!(pick_chunk_size(Some(64 * 1024)), 1024 * 1024);
557 assert_eq!(pick_chunk_size(Some(1024 * 1024)), 1024 * 1024);
558 assert_eq!(
560 pick_chunk_size(Some(1024 * 1024 + 1)),
561 DEFAULT_S4F2_CHUNK_SIZE
562 );
563 assert_eq!(
564 pick_chunk_size(Some(50 * 1024 * 1024)),
565 DEFAULT_S4F2_CHUNK_SIZE
566 );
567 assert_eq!(
568 pick_chunk_size(Some(100 * 1024 * 1024)),
569 DEFAULT_S4F2_CHUNK_SIZE
570 );
571 assert_eq!(
573 pick_chunk_size(Some(100 * 1024 * 1024 + 1)),
574 16 * 1024 * 1024
575 );
576 assert_eq!(
577 pick_chunk_size(Some(10 * 1024 * 1024 * 1024)),
578 16 * 1024 * 1024
579 );
580 }
581
582 async fn collect(blob: StreamingBlob) -> Bytes {
583 let mut buf = BytesMut::new();
584 let mut s = blob;
585 while let Some(chunk) = s.next().await {
586 buf.extend_from_slice(&chunk.unwrap());
587 }
588 buf.freeze()
589 }
590
591 fn make_blob(b: Bytes) -> StreamingBlob {
592 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
593 StreamingBlob::wrap(stream)
594 }
595
596 #[tokio::test]
597 async fn cpu_zstd_streaming_roundtrip_small() {
598 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
599 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
600 let blob = make_blob(Bytes::from(compressed));
601 let out_blob = cpu_zstd_decompress_stream(blob);
602 let out = collect(out_blob).await;
603 assert_eq!(out, original);
604 }
605
606 #[tokio::test]
607 async fn cpu_zstd_streaming_handles_chunked_input() {
608 let original = Bytes::from(vec![b'x'; 1_000_000]);
609 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
610 let mut chunks = Vec::new();
612 for chunk in compressed.chunks(1024) {
613 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
614 }
615 let in_stream = stream::iter(chunks);
616 let blob = StreamingBlob::wrap(in_stream);
617 let out_blob = cpu_zstd_decompress_stream(blob);
618 let out = collect(out_blob).await;
619 assert_eq!(out, original);
620 }
621
622 #[tokio::test]
623 async fn streaming_passes_through_for_passthrough() {
624 let original = Bytes::from_static(b"hello");
625 let blob = make_blob(original.clone());
626 let out_blob = async_read_to_blob(blob_to_async_read(blob));
627 let out = collect(out_blob).await;
628 assert_eq!(out, original);
629 }
630
631 #[tokio::test]
632 async fn streaming_compress_then_decompress_roundtrip() {
633 let original = Bytes::from(vec![b'q'; 200_000]);
634 let blob = make_blob(original.clone());
635 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
636 assert!(
637 compressed.len() < original.len() / 100,
638 "should be highly compressible"
639 );
640 assert_eq!(manifest.codec, CodecKind::CpuZstd);
641 assert_eq!(manifest.original_size, original.len() as u64);
642 assert_eq!(manifest.compressed_size, compressed.len() as u64);
643 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
645
646 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
648 let out = collect(decompressed_blob).await;
649 assert_eq!(out, original);
650 }
651
652 #[tokio::test]
657 async fn concatenated_zstd_frames_are_a_single_valid_stream() {
658 let chunk_a = Bytes::from(vec![b'a'; 50_000]);
659 let chunk_b = Bytes::from(vec![b'b'; 50_000]);
660 let chunk_c = Bytes::from(vec![b'c'; 50_000]);
661
662 let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
663 let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
664 let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
665
666 let mut concatenated: Vec<u8> = Vec::new();
667 concatenated.extend_from_slice(&frame_a);
668 concatenated.extend_from_slice(&frame_b);
669 concatenated.extend_from_slice(&frame_c);
670
671 let expected: Vec<u8> = chunk_a
672 .iter()
673 .chain(chunk_b.iter())
674 .chain(chunk_c.iter())
675 .copied()
676 .collect();
677
678 let blob = make_blob(Bytes::from(concatenated));
679 let out_blob = cpu_zstd_decompress_stream(blob);
680 let out = collect(out_blob).await;
681 assert_eq!(out, Bytes::from(expected));
682 }
683
684 #[tokio::test]
689 async fn streaming_chunked_compress_pipeline_roundtrip() {
690 async fn streaming_compress_chunked_cpu_zstd(
694 body: StreamingBlob,
695 chunk_size: usize,
696 ) -> Result<(Bytes, ChunkManifest), CodecError> {
697 let mut read = blob_to_async_read(body);
698 let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
699 let mut crc: u32 = 0;
700 let mut total_in: u64 = 0;
701 let mut chunk_buf = vec![0u8; chunk_size];
702 loop {
703 let mut filled = 0;
704 while filled < chunk_size {
705 let n = read
706 .read(&mut chunk_buf[filled..])
707 .await
708 .map_err(CodecError::Io)?;
709 if n == 0 {
710 break;
711 }
712 filled += n;
713 }
714 if filled == 0 {
715 break;
716 }
717 crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
718 total_in += filled as u64;
719 let compressed_chunk =
720 zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
721 compressed_buf.extend_from_slice(&compressed_chunk);
722 }
723 let compressed_len = compressed_buf.len() as u64;
724 Ok((
725 Bytes::from(compressed_buf),
726 ChunkManifest {
727 codec: CodecKind::CpuZstd,
728 original_size: total_in,
729 compressed_size: compressed_len,
730 crc32c: crc,
731 },
732 ))
733 }
734
735 let original = Bytes::from(
737 (0u32..65_536)
738 .flat_map(|n| n.to_le_bytes())
739 .collect::<Vec<u8>>(),
740 );
741 assert_eq!(original.len(), 262_144);
742
743 let blob = make_blob(original.clone());
744 let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
745 .await
746 .unwrap();
747
748 assert_eq!(manifest.original_size, original.len() as u64);
749 assert_eq!(manifest.compressed_size, compressed.len() as u64);
750 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
751
752 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
754 let out = collect(decompressed_blob).await;
755 assert_eq!(out, original);
756 }
757
758 #[tokio::test]
759 async fn streaming_passthrough_yields_input_unchanged() {
760 let original = Bytes::from_static(b"hello world");
761 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
762 .await
763 .unwrap();
764 assert_eq!(out, original);
765 assert_eq!(manifest.codec, CodecKind::Passthrough);
766 assert_eq!(manifest.original_size, original.len() as u64);
767 assert_eq!(manifest.compressed_size, original.len() as u64);
768 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
769 }
770
771 #[tokio::test]
778 async fn crc32c_verifying_reader_passes_correct_crc() {
779 use tokio::io::AsyncReadExt as _;
780 let original = Bytes::from(vec![0xa3u8; 17_000]);
781 let crc = crc32c::crc32c(&original);
782 let inner = blob_to_async_read(make_blob(original.clone()));
783 let mut verifier = Crc32cVerifyingReader::new(inner, crc, original.len() as u64);
784 let mut out = Vec::new();
785 verifier
786 .read_to_end(&mut out)
787 .await
788 .expect("clean stream must read cleanly");
789 assert_eq!(out, original.as_ref());
790 assert_eq!(verifier.rolling_crc(), crc);
792 assert_eq!(verifier.bytes_read(), original.len() as u64);
793 }
794
795 #[tokio::test]
800 async fn crc32c_verifying_reader_detects_corruption() {
801 use tokio::io::AsyncReadExt as _;
802 let original = Bytes::from_static(b"clean payload bytes");
803 let real_crc = crc32c::crc32c(&original);
804 let bogus_expected_crc = real_crc.wrapping_add(1);
809 let inner = blob_to_async_read(make_blob(original.clone()));
810 let mut verifier =
811 Crc32cVerifyingReader::new(inner, bogus_expected_crc, original.len() as u64);
812 let mut out = Vec::new();
813 let err = verifier
814 .read_to_end(&mut out)
815 .await
816 .expect_err("CRC mismatch must surface as io::Error");
817 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
818 let msg = err.to_string();
819 assert!(
820 msg.contains("crc32c mismatch"),
821 "error must mention CRC mismatch, got `{msg}`"
822 );
823 assert_eq!(out, original.as_ref());
826 }
827
828 #[tokio::test]
832 async fn streaming_compress_truncated_input_returns_truncated_stream_error() {
833 use s4_codec::cpu_zstd::CpuZstd;
834 let registry =
835 Arc::new(CodecRegistry::new(CodecKind::CpuZstd).with(Arc::new(CpuZstd::default())));
836 let actual = Bytes::from(vec![b'z'; 4096]);
840 let advertised: u64 = 16 * 1024;
841 let blob = make_blob(actual.clone());
842 let err = streaming_compress_to_frames(
843 blob,
844 registry,
845 CodecKind::CpuZstd,
846 1024,
847 Some(advertised),
848 )
849 .await
850 .expect_err("truncated stream must error");
851 match err {
852 CodecError::TruncatedStream { expected, got } => {
853 assert_eq!(expected, advertised);
854 assert_eq!(got, actual.len() as u64);
855 }
856 other => panic!("expected TruncatedStream, got {other:?}"),
857 }
858 }
859}