1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::multipart::{FrameHeader, write_frame};
35use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
36use std::sync::Arc;
37use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
38use tokio_util::io::{ReaderStream, StreamReader};
39
40pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync {
47 let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
48 StreamReader::new(mapped)
49}
50
51pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
53 reader: R,
54) -> StreamingBlob {
55 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
56 StreamingBlob::new(StreamWrapper { inner: stream })
57}
58
59pin_project_lite::pin_project! {
60 struct StreamWrapper<S> { #[pin] inner: S }
63}
64
65impl<S> Stream for StreamWrapper<S>
66where
67 S: Stream<Item = Result<Bytes, StdError>>,
68{
69 type Item = Result<Bytes, StdError>;
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 self.project().inner.poll_next(cx)
72 }
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 self.inner.size_hint()
75 }
76}
77
78impl<S> ByteStream for StreamWrapper<S>
79where
80 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
81{
82 fn remaining_length(&self) -> RemainingLength {
83 RemainingLength::unknown()
85 }
86}
87
88pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
98 let read = blob_to_async_read(body);
99 let mut decoder = ZstdDecoder::new(BufReader::new(read));
100 decoder.multiple_members(true);
101 async_read_to_blob(decoder)
102}
103
104pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
106 matches!(
110 codec,
111 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
112 )
113}
114
115pub fn supports_streaming_compress(codec: CodecKind) -> bool {
116 #[cfg(feature = "nvcomp-gpu")]
117 {
118 matches!(
119 codec,
120 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
121 )
122 }
123 #[cfg(not(feature = "nvcomp-gpu"))]
124 {
125 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
126 }
127}
128
129pub async fn streaming_compress_cpu_zstd(
139 body: StreamingBlob,
140 level: i32,
141) -> Result<(Bytes, ChunkManifest), CodecError> {
142 let mut read = blob_to_async_read(body);
143 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
144 let mut crc: u32 = 0;
145 let mut total_in: u64 = 0;
146 let mut in_buf = vec![0u8; 64 * 1024];
147
148 {
149 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
150 loop {
151 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
152 if n == 0 {
153 break;
154 }
155 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
156 total_in += n as u64;
157 encoder
158 .write_all(&in_buf[..n])
159 .await
160 .map_err(CodecError::Io)?;
161 }
162 encoder.shutdown().await.map_err(CodecError::Io)?;
163 }
164
165 let compressed_len = compressed_buf.len() as u64;
166 Ok((
167 Bytes::from(compressed_buf),
168 ChunkManifest {
169 codec: CodecKind::CpuZstd,
170 original_size: total_in,
171 compressed_size: compressed_len,
172 crc32c: crc,
173 },
174 ))
175}
176
177pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
184
185pub fn pick_chunk_size(content_length: Option<u64>) -> usize {
197 match content_length {
198 None => DEFAULT_S4F2_CHUNK_SIZE,
199 Some(len) if len <= 1024 * 1024 => 1024 * 1024,
200 Some(len) if len <= 100 * 1024 * 1024 => DEFAULT_S4F2_CHUNK_SIZE,
201 Some(_) => 16 * 1024 * 1024,
202 }
203}
204
205pub const DEFAULT_S4F2_INFLIGHT: usize = 3;
218
219pub async fn streaming_compress_to_frames(
240 body: StreamingBlob,
241 registry: Arc<CodecRegistry>,
242 codec_kind: CodecKind,
243 chunk_size: usize,
244) -> Result<(Bytes, ChunkManifest), CodecError> {
245 streaming_compress_to_frames_with(
246 body,
247 registry,
248 codec_kind,
249 chunk_size,
250 DEFAULT_S4F2_INFLIGHT,
251 )
252 .await
253}
254
255pub async fn streaming_compress_to_frames_with(
260 body: StreamingBlob,
261 registry: Arc<CodecRegistry>,
262 codec_kind: CodecKind,
263 chunk_size: usize,
264 inflight: usize,
265) -> Result<(Bytes, ChunkManifest), CodecError> {
266 use bytes::BytesMut;
267 use futures::StreamExt as _;
268 use futures::stream::FuturesOrdered;
269
270 let inflight = inflight.max(1);
271 let mut read = blob_to_async_read(body);
272 let mut framed = BytesMut::with_capacity(chunk_size);
273 let mut rolling_crc: u32 = 0;
274 let mut total_in: u64 = 0;
275 let mut chunk_buf = vec![0u8; chunk_size];
276
277 type InFlight = futures::future::BoxFuture<'static, Result<(FrameHeader, Bytes), CodecError>>;
281 let mut queue: FuturesOrdered<InFlight> = FuturesOrdered::new();
282 let mut eof = false;
283
284 loop {
285 while !eof && queue.len() < inflight {
287 let mut filled = 0;
288 while filled < chunk_size {
289 let n = read
290 .read(&mut chunk_buf[filled..])
291 .await
292 .map_err(CodecError::Io)?;
293 if n == 0 {
294 break;
295 }
296 filled += n;
297 }
298 if filled == 0 {
299 eof = true;
300 break;
301 }
302
303 let chunk_slice = &chunk_buf[..filled];
304 let chunk_crc = crc32c::crc32c(chunk_slice);
305 rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
306 total_in += filled as u64;
307
308 let header = FrameHeader {
309 codec: codec_kind,
310 original_size: filled as u64,
311 compressed_size: 0, crc32c: chunk_crc,
313 };
314 let original_chunk = Bytes::copy_from_slice(chunk_slice);
315 let registry = Arc::clone(®istry);
316 queue.push_back(Box::pin(async move {
317 let (compressed_chunk, _per_chunk_manifest) =
318 registry.compress(original_chunk, codec_kind).await?;
319 let mut header = header;
320 header.compressed_size = compressed_chunk.len() as u64;
321 Ok::<_, CodecError>((header, compressed_chunk))
322 }));
323 }
324
325 match queue.next().await {
327 Some(Ok((header, compressed_chunk))) => {
328 write_frame(&mut framed, header, &compressed_chunk);
329 }
330 Some(Err(e)) => return Err(e),
331 None => break,
332 }
333 }
334
335 let total_framed = framed.len() as u64;
336 Ok((
337 framed.freeze(),
338 ChunkManifest {
339 codec: codec_kind,
340 original_size: total_in,
341 compressed_size: total_framed,
342 crc32c: rolling_crc,
343 },
344 ))
345}
346
347pub async fn streaming_passthrough(
349 body: StreamingBlob,
350) -> Result<(Bytes, ChunkManifest), CodecError> {
351 let mut read = blob_to_async_read(body);
352 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
353 let mut crc: u32 = 0;
354 let mut total: u64 = 0;
355 let mut chunk = vec![0u8; 64 * 1024];
356 loop {
357 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
358 if n == 0 {
359 break;
360 }
361 crc = crc32c::crc32c_append(crc, &chunk[..n]);
362 total += n as u64;
363 buf.extend_from_slice(&chunk[..n]);
364 }
365 let len = buf.len() as u64;
366 Ok((
367 Bytes::from(buf),
368 ChunkManifest {
369 codec: CodecKind::Passthrough,
370 original_size: total,
371 compressed_size: len,
372 crc32c: crc,
373 },
374 ))
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use bytes::BytesMut;
381 use futures::stream;
382 use futures::stream::StreamExt;
383
384 #[test]
386 fn pick_chunk_size_thresholds() {
387 assert_eq!(pick_chunk_size(None), DEFAULT_S4F2_CHUNK_SIZE);
389 assert_eq!(pick_chunk_size(Some(0)), 1024 * 1024);
391 assert_eq!(pick_chunk_size(Some(64 * 1024)), 1024 * 1024);
392 assert_eq!(pick_chunk_size(Some(1024 * 1024)), 1024 * 1024);
393 assert_eq!(
395 pick_chunk_size(Some(1024 * 1024 + 1)),
396 DEFAULT_S4F2_CHUNK_SIZE
397 );
398 assert_eq!(
399 pick_chunk_size(Some(50 * 1024 * 1024)),
400 DEFAULT_S4F2_CHUNK_SIZE
401 );
402 assert_eq!(
403 pick_chunk_size(Some(100 * 1024 * 1024)),
404 DEFAULT_S4F2_CHUNK_SIZE
405 );
406 assert_eq!(
408 pick_chunk_size(Some(100 * 1024 * 1024 + 1)),
409 16 * 1024 * 1024
410 );
411 assert_eq!(
412 pick_chunk_size(Some(10 * 1024 * 1024 * 1024)),
413 16 * 1024 * 1024
414 );
415 }
416
417 async fn collect(blob: StreamingBlob) -> Bytes {
418 let mut buf = BytesMut::new();
419 let mut s = blob;
420 while let Some(chunk) = s.next().await {
421 buf.extend_from_slice(&chunk.unwrap());
422 }
423 buf.freeze()
424 }
425
426 fn make_blob(b: Bytes) -> StreamingBlob {
427 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
428 StreamingBlob::wrap(stream)
429 }
430
431 #[tokio::test]
432 async fn cpu_zstd_streaming_roundtrip_small() {
433 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
434 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
435 let blob = make_blob(Bytes::from(compressed));
436 let out_blob = cpu_zstd_decompress_stream(blob);
437 let out = collect(out_blob).await;
438 assert_eq!(out, original);
439 }
440
441 #[tokio::test]
442 async fn cpu_zstd_streaming_handles_chunked_input() {
443 let original = Bytes::from(vec![b'x'; 1_000_000]);
444 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
445 let mut chunks = Vec::new();
447 for chunk in compressed.chunks(1024) {
448 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
449 }
450 let in_stream = stream::iter(chunks);
451 let blob = StreamingBlob::wrap(in_stream);
452 let out_blob = cpu_zstd_decompress_stream(blob);
453 let out = collect(out_blob).await;
454 assert_eq!(out, original);
455 }
456
457 #[tokio::test]
458 async fn streaming_passes_through_for_passthrough() {
459 let original = Bytes::from_static(b"hello");
460 let blob = make_blob(original.clone());
461 let out_blob = async_read_to_blob(blob_to_async_read(blob));
462 let out = collect(out_blob).await;
463 assert_eq!(out, original);
464 }
465
466 #[tokio::test]
467 async fn streaming_compress_then_decompress_roundtrip() {
468 let original = Bytes::from(vec![b'q'; 200_000]);
469 let blob = make_blob(original.clone());
470 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
471 assert!(
472 compressed.len() < original.len() / 100,
473 "should be highly compressible"
474 );
475 assert_eq!(manifest.codec, CodecKind::CpuZstd);
476 assert_eq!(manifest.original_size, original.len() as u64);
477 assert_eq!(manifest.compressed_size, compressed.len() as u64);
478 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
480
481 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
483 let out = collect(decompressed_blob).await;
484 assert_eq!(out, original);
485 }
486
487 #[tokio::test]
492 async fn concatenated_zstd_frames_are_a_single_valid_stream() {
493 let chunk_a = Bytes::from(vec![b'a'; 50_000]);
494 let chunk_b = Bytes::from(vec![b'b'; 50_000]);
495 let chunk_c = Bytes::from(vec![b'c'; 50_000]);
496
497 let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
498 let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
499 let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
500
501 let mut concatenated: Vec<u8> = Vec::new();
502 concatenated.extend_from_slice(&frame_a);
503 concatenated.extend_from_slice(&frame_b);
504 concatenated.extend_from_slice(&frame_c);
505
506 let expected: Vec<u8> = chunk_a
507 .iter()
508 .chain(chunk_b.iter())
509 .chain(chunk_c.iter())
510 .copied()
511 .collect();
512
513 let blob = make_blob(Bytes::from(concatenated));
514 let out_blob = cpu_zstd_decompress_stream(blob);
515 let out = collect(out_blob).await;
516 assert_eq!(out, Bytes::from(expected));
517 }
518
519 #[tokio::test]
524 async fn streaming_chunked_compress_pipeline_roundtrip() {
525 async fn streaming_compress_chunked_cpu_zstd(
529 body: StreamingBlob,
530 chunk_size: usize,
531 ) -> Result<(Bytes, ChunkManifest), CodecError> {
532 let mut read = blob_to_async_read(body);
533 let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
534 let mut crc: u32 = 0;
535 let mut total_in: u64 = 0;
536 let mut chunk_buf = vec![0u8; chunk_size];
537 loop {
538 let mut filled = 0;
539 while filled < chunk_size {
540 let n = read
541 .read(&mut chunk_buf[filled..])
542 .await
543 .map_err(CodecError::Io)?;
544 if n == 0 {
545 break;
546 }
547 filled += n;
548 }
549 if filled == 0 {
550 break;
551 }
552 crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
553 total_in += filled as u64;
554 let compressed_chunk =
555 zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
556 compressed_buf.extend_from_slice(&compressed_chunk);
557 }
558 let compressed_len = compressed_buf.len() as u64;
559 Ok((
560 Bytes::from(compressed_buf),
561 ChunkManifest {
562 codec: CodecKind::CpuZstd,
563 original_size: total_in,
564 compressed_size: compressed_len,
565 crc32c: crc,
566 },
567 ))
568 }
569
570 let original = Bytes::from(
572 (0u32..65_536)
573 .flat_map(|n| n.to_le_bytes())
574 .collect::<Vec<u8>>(),
575 );
576 assert_eq!(original.len(), 262_144);
577
578 let blob = make_blob(original.clone());
579 let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
580 .await
581 .unwrap();
582
583 assert_eq!(manifest.original_size, original.len() as u64);
584 assert_eq!(manifest.compressed_size, compressed.len() as u64);
585 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
586
587 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
589 let out = collect(decompressed_blob).await;
590 assert_eq!(out, original);
591 }
592
593 #[tokio::test]
594 async fn streaming_passthrough_yields_input_unchanged() {
595 let original = Bytes::from_static(b"hello world");
596 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
597 .await
598 .unwrap();
599 assert_eq!(out, original);
600 assert_eq!(manifest.codec, CodecKind::Passthrough);
601 assert_eq!(manifest.original_size, original.len() as u64);
602 assert_eq!(manifest.compressed_size, original.len() as u64);
603 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
604 }
605}