Skip to main content

s4_server/
blob.rs

1//! `s3s::dto::StreamingBlob` と `bytes::Bytes` の相互変換ヘルパ。
2//!
3//! Phase 1 の方針: PUT body 全体を一旦 memory に集めてから圧縮する。streaming-aware な
4//! chunk 圧縮 (Phase 2 で取り組む) に比べると memory cost が高いが、roundtrip 検証と
5//! manifest 生成の単純さを優先。max_bytes で受け取れる最大サイズを上限化する。
6
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use bytes::{Bytes, BytesMut};
11use futures::{Stream, StreamExt};
12use s3s::StdError;
13use s3s::dto::StreamingBlob;
14use s3s::stream::{ByteStream, RemainingLength};
15
16/// `StreamingBlob` を Bytes に collect。`max_bytes` を超えたら早期に Err。
17pub async fn collect_blob(blob: StreamingBlob, max_bytes: usize) -> Result<Bytes, BlobError> {
18    let hint = blob.remaining_length().exact().unwrap_or(0).min(max_bytes);
19    let mut buf = BytesMut::with_capacity(hint);
20    let mut stream = blob;
21    while let Some(chunk) = stream.next().await {
22        let chunk = chunk.map_err(|e| BlobError::Read(format!("{e}")))?;
23        if buf.len().saturating_add(chunk.len()) > max_bytes {
24            return Err(BlobError::Oversized {
25                limit: max_bytes,
26                seen_at_least: buf.len() + chunk.len(),
27            });
28        }
29        buf.extend_from_slice(&chunk);
30    }
31    Ok(buf.freeze())
32}
33
34/// `Bytes` を 1 chunk の `StreamingBlob` に包む。
35///
36/// content-length を **既知** として返す ByteStream impl にすることが重要。
37/// `StreamingBlob::wrap` (futures::Stream 越し) だと remaining_length が unknown
38/// になり、aws-sdk-s3 が `AwsChunkedContentEncodingInterceptor` で
39/// `UnsizedRequestBody` エラーを返す。
40pub fn bytes_to_blob(bytes: Bytes) -> StreamingBlob {
41    StreamingBlob::new(SingleChunkBlob(Some(bytes)))
42}
43
44/// 単一の `Bytes` を 1 度だけ yield して終わる `ByteStream`。
45/// `remaining_length` を正確な byte 数として返すので、aws-sdk-s3 の chunked
46/// signing path がそのまま動く。
47struct SingleChunkBlob(Option<Bytes>);
48
49impl Stream for SingleChunkBlob {
50    type Item = Result<Bytes, StdError>;
51    fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52        Poll::Ready(self.get_mut().0.take().map(Ok))
53    }
54    fn size_hint(&self) -> (usize, Option<usize>) {
55        match &self.0 {
56            Some(_) => (1, Some(1)),
57            None => (0, Some(0)),
58        }
59    }
60}
61
62impl ByteStream for SingleChunkBlob {
63    fn remaining_length(&self) -> RemainingLength {
64        match &self.0 {
65            Some(b) => RemainingLength::new_exact(b.len()),
66            None => RemainingLength::new_exact(0),
67        }
68    }
69}
70
71#[derive(Debug, thiserror::Error)]
72pub enum BlobError {
73    #[error("body exceeded configured limit ({limit} bytes); saw at least {seen_at_least}")]
74    Oversized { limit: usize, seen_at_least: usize },
75    #[error("error reading streaming body: {0}")]
76    Read(String),
77}
78
79/// `blob` の先頭から最大 `sample_bytes` を読み出して `(sample, rest_stream)` に分ける。
80/// `sample` は collected Bytes、`rest_stream` は残りの未消費 chunk ストリーム。
81/// stream 全体が `sample_bytes` 未満ならば `rest_stream` は空。
82pub async fn peek_sample(
83    mut blob: StreamingBlob,
84    sample_bytes: usize,
85) -> Result<(Bytes, StreamingBlob), BlobError> {
86    let mut sample = BytesMut::with_capacity(sample_bytes);
87    let mut leftover: Option<Bytes> = None;
88    while sample.len() < sample_bytes {
89        match blob.next().await {
90            Some(Ok(chunk)) => {
91                let remaining = sample_bytes.saturating_sub(sample.len());
92                if chunk.len() <= remaining {
93                    sample.extend_from_slice(&chunk);
94                } else {
95                    sample.extend_from_slice(&chunk[..remaining]);
96                    leftover = Some(chunk.slice(remaining..));
97                    break;
98                }
99            }
100            Some(Err(e)) => return Err(BlobError::Read(format!("{e}"))),
101            None => break,
102        }
103    }
104    let sample_bytes = sample.freeze();
105    let rest = chain_leftover_with_blob(leftover, blob);
106    Ok((sample_bytes, rest))
107}
108
109/// `peek_sample` で取り出した sample を rest stream の先頭に再 prepend して
110/// 1 本のストリームに戻す。
111pub fn chain_sample_with_rest(sample: Bytes, rest: StreamingBlob) -> StreamingBlob {
112    let head = futures::stream::once(async move { Ok::<_, std::io::Error>(sample) });
113    let tail = rest.map(|r| r.map_err(|e| std::io::Error::other(e.to_string())));
114    StreamingBlob::wrap(head.chain(tail))
115}
116
117fn chain_leftover_with_blob(leftover: Option<Bytes>, rest: StreamingBlob) -> StreamingBlob {
118    match leftover {
119        Some(b) => chain_sample_with_rest(b, rest),
120        None => rest,
121    }
122}
123
124/// `peek_sample` の結果を再度結合した上で全体を Bytes に collect。GPU codec 経路用。
125pub async fn collect_with_sample(
126    sample: Bytes,
127    rest: StreamingBlob,
128    max_bytes: usize,
129) -> Result<Bytes, BlobError> {
130    if sample.len() > max_bytes {
131        return Err(BlobError::Oversized {
132            limit: max_bytes,
133            seen_at_least: sample.len(),
134        });
135    }
136    let mut buf = BytesMut::with_capacity(sample.len() + 4096);
137    buf.extend_from_slice(&sample);
138    let mut stream = rest;
139    while let Some(chunk) = stream.next().await {
140        let chunk = chunk.map_err(|e| BlobError::Read(format!("{e}")))?;
141        if buf.len().saturating_add(chunk.len()) > max_bytes {
142            return Err(BlobError::Oversized {
143                limit: max_bytes,
144                seen_at_least: buf.len() + chunk.len(),
145            });
146        }
147        buf.extend_from_slice(&chunk);
148    }
149    Ok(buf.freeze())
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[tokio::test]
157    async fn collect_roundtrip() {
158        let original = Bytes::from_static(b"hello squished s3");
159        let blob = bytes_to_blob(original.clone());
160        let collected = collect_blob(blob, 1024).await.unwrap();
161        assert_eq!(collected, original);
162    }
163
164    #[tokio::test]
165    async fn collect_rejects_oversized() {
166        let big = Bytes::from(vec![0u8; 2048]);
167        let blob = bytes_to_blob(big);
168        let err = collect_blob(blob, 1024).await.unwrap_err();
169        assert!(matches!(err, BlobError::Oversized { .. }));
170    }
171}