Skip to main content

s4_server/
streaming_checksum.rs

1//! v0.9 #106 — true streaming PUT checksum verify (tee-into-hasher).
2//!
3//! Wraps an inbound [`StreamingBlob`] so that every chunk pulled by a
4//! downstream consumer (typically [`streaming_compress_to_frames`]) is
5//! **also** fed into one or more client-declared digesters
6//! (`Content-MD5`, `x-amz-checksum-{crc32, crc32c, sha1, sha256, crc64nvme}`).
7//! When the upstream stream reaches EOF, the wrapper finalises every active
8//! hasher and compares it against the client-supplied (base64) value; a
9//! mismatch is surfaced as a [`std::io::Error`] tagged with
10//! [`StreamingChecksumError`] so the s4-server PUT handler can translate it
11//! into a typed `400 BadDigest` response instead of letting the bytes
12//! reach the backend.
13//!
14//! ## Why this exists
15//!
16//! Before v0.9 #106, S4 only verified client-supplied whole-body
17//! checksums on the **buffered** PUT path
18//! (`crates/s4-server/src/service.rs::verify_client_body_checksums`),
19//! because that path already held the entire body in memory. The
20//! streaming-framed PUT path (CPU-zstd / passthrough single-PUT) accepted
21//! `x-amz-checksum-*` headers and silently passed them through without
22//! verifying them — the v0.8.13 #127 attempt to "force buffered when a
23//! checksum is supplied" regressed sidecar correctness for AWS-SDK PUTs
24//! (which auto-add `x-amz-checksum-crc32` by default), see v0.8.14 #129.
25//!
26//! True streaming verify computes the digests **as the bytes flow through
27//! the codec pipeline**, so the buffered fallback is no longer required to
28//! get integrity coverage. The wrapper preserves the streaming property:
29//! it only holds one chunk's worth of bytes in flight (whatever the
30//! upstream blob yields) plus the constant-size hasher state.
31//!
32//! ## Scope
33//!
34//! Wired in for **single-PUT, cpu-zstd / passthrough, non-multipart**.
35//! Multipart `UploadPart` keeps the buffered per-part verify it already
36//! had (the per-part body is already in memory there for the framing /
37//! padding step, so there's nothing to win from streaming verify on that
38//! branch). GPU codecs are bytes-buffered today; their verify happens on
39//! the buffered fallback like before.
40//!
41//! ## Failure model
42//!
43//! - Header malformed (bad base64 / wrong byte length) → caller-side
44//!   [`ClientChecksums::from_request_fields`] returns
45//!   `S3Result<S3Error(InvalidDigest)>` **before** the wrapper is built,
46//!   matching the buffered path's pre-stream validation.
47//! - Stream errors mid-flight → propagated unchanged; we don't compare
48//!   digests (the bytes the client intended never landed). The PUT
49//!   eventually surfaces as a `TruncatedStream` / I/O error from the
50//!   compressor, not a `BadDigest`.
51//! - Stream completes but digest mismatch → wrapper emits one synthetic
52//!   `io::ErrorKind::InvalidData` carrying [`StreamingChecksumError`] on
53//!   the **next** `poll_next` call after EOF. The compressor sees this
54//!   as an I/O error mid-read and returns `CodecError::Io(...)`. The
55//!   PUT handler then downcasts the inner error chain to recover the
56//!   `StreamingChecksumError` and maps to `BadDigest`.
57
58use std::pin::Pin;
59use std::sync::Arc;
60use std::task::{Context, Poll};
61
62use base64::Engine as _;
63use bytes::Bytes;
64use crc32fast::Hasher as Crc32Hasher;
65use futures::{Stream, StreamExt};
66use md5::{Digest as Md5Digest, Md5};
67use s3s::dto::StreamingBlob;
68use s3s::stream::{ByteStream, RemainingLength};
69use s3s::{S3Error, S3ErrorCode, S3Result};
70use sha1::Sha1;
71use sha2::Sha256;
72use std::sync::Mutex;
73
74/// Sentinel error carried inside [`std::io::Error::other`] when the
75/// streaming wrapper detects a client-vs-actual digest mismatch at EOF.
76/// The PUT handler downcasts the error chain to recover this and emits
77/// a typed `BadDigest` S3 response.
78#[derive(Debug, Clone)]
79pub struct StreamingChecksumError {
80    /// Human-readable name of the checksum algorithm that failed
81    /// (`Content-MD5`, `x-amz-checksum-crc32c`, ...). Used verbatim in
82    /// the `BadDigest` message so operators see which header was wrong.
83    pub algorithm: &'static str,
84}
85
86impl std::fmt::Display for StreamingChecksumError {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        write!(
89            f,
90            "client-supplied {} did not match the streamed body",
91            self.algorithm
92        )
93    }
94}
95
96impl std::error::Error for StreamingChecksumError {}
97
98/// Parsed, byte-length-validated client checksum claims for a single PUT
99/// request. Built once at request entry from the raw header strings;
100/// then handed to [`tee_into_hashers`] to drive the streaming verifier.
101///
102/// `from_request_fields` performs the same pre-stream validation the
103/// buffered path's `verify_client_body_checksums` did inline (base64
104/// decodes, byte-length checks) so a malformed header fails with
105/// `InvalidDigest` **before** the body ever flows. After this returns
106/// `Ok`, every present claim has the correct decoded byte length.
107#[derive(Debug, Default, Clone)]
108pub struct ClientChecksums {
109    content_md5: Option<[u8; 16]>,
110    crc32: Option<[u8; 4]>,
111    crc32c: Option<[u8; 4]>,
112    sha1: Option<[u8; 20]>,
113    sha256: Option<[u8; 32]>,
114    crc64nvme: Option<[u8; 8]>,
115}
116
117impl ClientChecksums {
118    /// Returns true when at least one checksum claim is set — i.e. the
119    /// streaming wrapper has work to do. When this returns false the
120    /// caller should skip the wrapper entirely (zero-cost: no hashers
121    /// allocated, no per-chunk update path) so non-checksummed PUTs
122    /// keep their pre-#106 throughput.
123    pub fn any(&self) -> bool {
124        self.content_md5.is_some()
125            || self.crc32.is_some()
126            || self.crc32c.is_some()
127            || self.sha1.is_some()
128            || self.sha256.is_some()
129            || self.crc64nvme.is_some()
130    }
131
132    /// Parse the six AWS-spec checksum header values supplied on a single
133    /// PUT request. Each argument is the base64-encoded header value (or
134    /// `None` when the header was absent). Returns `Err(InvalidDigest)`
135    /// when any present value is malformed (bad base64 or wrong decoded
136    /// length); identical pre-stream behaviour to the buffered path's
137    /// inline validation.
138    pub fn from_request_fields(
139        content_md5: Option<&str>,
140        crc32: Option<&str>,
141        crc32c: Option<&str>,
142        sha1: Option<&str>,
143        sha256: Option<&str>,
144        crc64nvme: Option<&str>,
145    ) -> S3Result<Self> {
146        let b64 = base64::engine::general_purpose::STANDARD;
147        let decode_fixed = |val: &str, expected_len: usize, label: &str| -> S3Result<Vec<u8>> {
148            let v = b64.decode(val).map_err(|_| {
149                S3Error::with_message(S3ErrorCode::InvalidDigest, format!("malformed {label}"))
150            })?;
151            if v.len() != expected_len {
152                return Err(S3Error::with_message(
153                    S3ErrorCode::InvalidDigest,
154                    format!("{label} must decode to {expected_len} bytes"),
155                ));
156            }
157            Ok(v)
158        };
159        let mut out = ClientChecksums::default();
160        if let Some(v) = content_md5 {
161            let bytes = decode_fixed(v, 16, "Content-MD5")?;
162            let mut arr = [0u8; 16];
163            arr.copy_from_slice(&bytes);
164            out.content_md5 = Some(arr);
165        }
166        if let Some(v) = crc32 {
167            let bytes = decode_fixed(v, 4, "x-amz-checksum-crc32")?;
168            let mut arr = [0u8; 4];
169            arr.copy_from_slice(&bytes);
170            out.crc32 = Some(arr);
171        }
172        if let Some(v) = crc32c {
173            let bytes = decode_fixed(v, 4, "x-amz-checksum-crc32c")?;
174            let mut arr = [0u8; 4];
175            arr.copy_from_slice(&bytes);
176            out.crc32c = Some(arr);
177        }
178        if let Some(v) = sha1 {
179            let bytes = decode_fixed(v, 20, "x-amz-checksum-sha1")?;
180            let mut arr = [0u8; 20];
181            arr.copy_from_slice(&bytes);
182            out.sha1 = Some(arr);
183        }
184        if let Some(v) = sha256 {
185            let bytes = decode_fixed(v, 32, "x-amz-checksum-sha256")?;
186            let mut arr = [0u8; 32];
187            arr.copy_from_slice(&bytes);
188            out.sha256 = Some(arr);
189        }
190        if let Some(v) = crc64nvme {
191            let bytes = decode_fixed(v, 8, "x-amz-checksum-crc64nvme")?;
192            let mut arr = [0u8; 8];
193            arr.copy_from_slice(&bytes);
194            out.crc64nvme = Some(arr);
195        }
196        Ok(out)
197    }
198}
199
200/// Which hashers the tee should drive on every chunk. Derived from
201/// both the parsed request headers ([`ClientChecksums`]) AND any
202/// algorithms the client announced via `x-amz-trailer` (the chunked /
203/// SigV4-streaming SDK case, where the actual digest value arrives in
204/// the request trailers after the body is fully consumed).
205///
206/// Each flag is on iff the corresponding algorithm must be computed —
207/// avoiding the per-chunk cost of (in particular) the SHA family when
208/// the client only wants a CRC.
209#[derive(Debug, Default, Clone, Copy)]
210pub struct WhichHashers {
211    pub content_md5: bool,
212    pub crc32: bool,
213    pub crc32c: bool,
214    pub sha1: bool,
215    pub sha256: bool,
216    pub crc64nvme: bool,
217}
218
219impl WhichHashers {
220    pub fn any(&self) -> bool {
221        self.content_md5 || self.crc32 || self.crc32c || self.sha1 || self.sha256 || self.crc64nvme
222    }
223
224    /// Union: enable a hasher in `self` if it was on in either side.
225    pub fn or(self, other: Self) -> Self {
226        Self {
227            content_md5: self.content_md5 || other.content_md5,
228            crc32: self.crc32 || other.crc32,
229            crc32c: self.crc32c || other.crc32c,
230            sha1: self.sha1 || other.sha1,
231            sha256: self.sha256 || other.sha256,
232            crc64nvme: self.crc64nvme || other.crc64nvme,
233        }
234    }
235
236    /// Drop any hashers whose algorithm name appears in the
237    /// comma-separated `x-amz-trailer` header value. Each name is
238    /// trimmed; case-insensitive against the AWS spec names.
239    pub fn from_trailer_header(value: &str) -> Self {
240        let mut out = Self::default();
241        for raw in value.split(',') {
242            let name = raw.trim();
243            if name.eq_ignore_ascii_case("x-amz-checksum-crc32") {
244                out.crc32 = true;
245            } else if name.eq_ignore_ascii_case("x-amz-checksum-crc32c") {
246                out.crc32c = true;
247            } else if name.eq_ignore_ascii_case("x-amz-checksum-sha1") {
248                out.sha1 = true;
249            } else if name.eq_ignore_ascii_case("x-amz-checksum-sha256") {
250                out.sha256 = true;
251            } else if name.eq_ignore_ascii_case("x-amz-checksum-crc64nvme") {
252                out.crc64nvme = true;
253            }
254            // Other trailers (`x-amz-trailer-signature`, custom)
255            // do not request hashing; they are ignored here.
256        }
257        out
258    }
259}
260
261impl ClientChecksums {
262    /// Project the parsed claim set onto the boolean hasher-selector
263    /// used by [`WhichHashers`]. A header-supplied claim implies the
264    /// hasher must run (for EOF eager-compare).
265    pub fn which_hashers(&self) -> WhichHashers {
266        WhichHashers {
267            content_md5: self.content_md5.is_some(),
268            crc32: self.crc32.is_some(),
269            crc32c: self.crc32c.is_some(),
270            sha1: self.sha1.is_some(),
271            sha256: self.sha256.is_some(),
272            crc64nvme: self.crc64nvme.is_some(),
273        }
274    }
275}
276
277/// Finalised digest values for every algorithm whose hasher was
278/// active. Populated by the tee on the EOF poll and exposed via the
279/// [`DigestHandle`] returned by [`tee_into_hashers_with_handle`]; the
280/// PUT handler reads it after body consumption to compare against
281/// request **trailer** values (the chunked / SigV4-streaming SDK case
282/// where the checksum is delivered post-body rather than as a header).
283#[derive(Debug, Default, Clone)]
284pub struct ComputedDigests {
285    pub content_md5: Option<[u8; 16]>,
286    pub crc32_be: Option<[u8; 4]>,
287    pub crc32c_be: Option<[u8; 4]>,
288    pub sha1: Option<[u8; 20]>,
289    pub sha256: Option<[u8; 32]>,
290    pub crc64nvme_be: Option<[u8; 8]>,
291}
292
293impl ComputedDigests {
294    /// Compare one finalised digest against a base64-encoded
295    /// trailer-supplied claim. Returns `Err(BadDigest)` on mismatch,
296    /// `Err(InvalidDigest)` on malformed input, `Ok(())` on match.
297    /// `algorithm` is the wire header name used in the error message;
298    /// the match is **case-insensitive** because HTTP header field
299    /// names are case-insensitive per RFC 9110 §5.1 and AWS SDKs may
300    /// announce trailers as `X-Amz-Checksum-Crc32c` (or any other
301    /// casing) — we keep the original casing in error messages for
302    /// fidelity but normalise for the dispatch.
303    pub fn compare_b64(&self, algorithm: &str, claim_b64: &str) -> S3Result<()> {
304        let b64 = base64::engine::general_purpose::STANDARD;
305        let want = b64.decode(claim_b64).map_err(|_| {
306            S3Error::with_message(S3ErrorCode::InvalidDigest, format!("malformed {algorithm}"))
307        })?;
308        let bad = || {
309            let code =
310                S3ErrorCode::from_bytes(b"BadDigest").unwrap_or(S3ErrorCode::InvalidArgument);
311            S3Error::with_message(
312                code,
313                format!("client-supplied {algorithm} did not match the received body"),
314            )
315        };
316        let len_err = |expected: usize| {
317            S3Error::with_message(
318                S3ErrorCode::InvalidDigest,
319                format!("{algorithm} must decode to {expected} bytes"),
320            )
321        };
322        // Lowercase only for dispatch — header field names are
323        // case-insensitive (RFC 9110 §5.1) but we keep the
324        // client-supplied form for the surface text so operators see
325        // what the client actually sent.
326        let lc = algorithm.to_ascii_lowercase();
327        match lc.as_str() {
328            "content-md5" => {
329                if want.len() != 16 {
330                    return Err(len_err(16));
331                }
332                if let Some(got) = self.content_md5
333                    && got[..] == want[..]
334                {
335                    return Ok(());
336                }
337                Err(bad())
338            }
339            "x-amz-checksum-crc32" => {
340                if want.len() != 4 {
341                    return Err(len_err(4));
342                }
343                if let Some(got) = self.crc32_be
344                    && got[..] == want[..]
345                {
346                    return Ok(());
347                }
348                Err(bad())
349            }
350            "x-amz-checksum-crc32c" => {
351                if want.len() != 4 {
352                    return Err(len_err(4));
353                }
354                if let Some(got) = self.crc32c_be
355                    && got[..] == want[..]
356                {
357                    return Ok(());
358                }
359                Err(bad())
360            }
361            "x-amz-checksum-sha1" => {
362                if want.len() != 20 {
363                    return Err(len_err(20));
364                }
365                if let Some(got) = self.sha1
366                    && got[..] == want[..]
367                {
368                    return Ok(());
369                }
370                Err(bad())
371            }
372            "x-amz-checksum-sha256" => {
373                if want.len() != 32 {
374                    return Err(len_err(32));
375                }
376                if let Some(got) = self.sha256
377                    && got[..] == want[..]
378                {
379                    return Ok(());
380                }
381                Err(bad())
382            }
383            "x-amz-checksum-crc64nvme" => {
384                if want.len() != 8 {
385                    return Err(len_err(8));
386                }
387                if let Some(got) = self.crc64nvme_be
388                    && got[..] == want[..]
389                {
390                    return Ok(());
391                }
392                Err(bad())
393            }
394            _ => Err(S3Error::with_message(
395                S3ErrorCode::InvalidArgument,
396                format!("unknown checksum trailer: {algorithm}"),
397            )),
398        }
399    }
400}
401
402/// Shared, post-EOF-readable digest container. The tee's `poll_next`
403/// EOF branch deposits the finalised [`ComputedDigests`] here; the PUT
404/// handler reads it after the body has been fully consumed by the
405/// codec to compare against any trailer-supplied claims.
406pub type DigestHandle = Arc<Mutex<Option<ComputedDigests>>>;
407
408/// Internal hasher state. Each variant maintains a rolling digest fed
409/// chunk-by-chunk from the wrapper's `poll_next`. Wrapped in a `Mutex`
410/// on the wrapper side because pin-projection-friendly interior
411/// mutability is the cleanest way to keep the wrapper `Send + Sync`
412/// (the codec dispatcher holds the blob inside an `Arc`-cloned closure
413/// in places).
414struct HasherSet {
415    expected: ClientChecksums,
416    which: WhichHashers,
417    // CRC32 (IEEE) and CRC32C use accumulators, not the `Digest` trait.
418    crc32: Crc32Hasher,
419    crc32c_acc: u32,
420    crc64nvme_acc: u64,
421    md5: Md5,
422    sha1: Sha1,
423    sha256: Sha256,
424}
425
426impl HasherSet {
427    fn new(expected: ClientChecksums, which: WhichHashers) -> Self {
428        Self {
429            expected,
430            which,
431            crc32: Crc32Hasher::new(),
432            crc32c_acc: 0,
433            crc64nvme_acc: !0u64,
434            md5: Md5::new(),
435            sha1: Sha1::new(),
436            sha256: Sha256::new(),
437        }
438    }
439
440    fn update(&mut self, chunk: &[u8]) {
441        // Only feed hashers whose flag is on — saves the per-byte cost
442        // on PUTs that supply (say) only crc32c and not sha256.
443        // CRC32 / CRC32C are cheap (SIMD on modern CPUs); the SHA
444        // family is the expensive one to skip.
445        if self.which.crc32 {
446            self.crc32.update(chunk);
447        }
448        if self.which.crc32c {
449            self.crc32c_acc = crc32c::crc32c_append(self.crc32c_acc, chunk);
450        }
451        if self.which.crc64nvme {
452            self.crc64nvme_acc = crc64_nvme_append(self.crc64nvme_acc, chunk);
453        }
454        if self.which.content_md5 {
455            self.md5.update(chunk);
456        }
457        if self.which.sha1 {
458            self.sha1.update(chunk);
459        }
460        if self.which.sha256 {
461            self.sha256.update(chunk);
462        }
463    }
464
465    /// Finalise every active hasher and produce a [`ComputedDigests`]
466    /// snapshot. Consumes self — the hasher state is destructively
467    /// finalised by the underlying crates' `.finalize()` calls.
468    fn finalize(self) -> ComputedDigests {
469        let mut out = ComputedDigests::default();
470        if self.which.content_md5 {
471            let d = self.md5.finalize();
472            let mut arr = [0u8; 16];
473            arr.copy_from_slice(&d);
474            out.content_md5 = Some(arr);
475        }
476        if self.which.crc32 {
477            out.crc32_be = Some(self.crc32.finalize().to_be_bytes());
478        }
479        if self.which.crc32c {
480            out.crc32c_be = Some(self.crc32c_acc.to_be_bytes());
481        }
482        if self.which.sha1 {
483            let d = self.sha1.finalize();
484            let mut arr = [0u8; 20];
485            arr.copy_from_slice(&d);
486            out.sha1 = Some(arr);
487        }
488        if self.which.sha256 {
489            let d = self.sha256.finalize();
490            let mut arr = [0u8; 32];
491            arr.copy_from_slice(&d);
492            out.sha256 = Some(arr);
493        }
494        if self.which.crc64nvme {
495            out.crc64nvme_be = Some((!self.crc64nvme_acc).to_be_bytes());
496        }
497        out
498    }
499
500    /// Eager EOF-time comparison against the **header-supplied** claim
501    /// set captured at request entry. Returns `Err(StreamingChecksumError)`
502    /// on the first mismatch (deterministic order: Content-MD5, CRC32,
503    /// CRC32C, SHA-1, SHA-256, CRC64-NVME — mirrors the buffered path's
504    /// `verify_client_body_checksums` order so error messages are
505    /// reproducible across the two paths). Header claims are checked
506    /// at EOF so a streaming body fails fast at the codec layer;
507    /// trailer claims are checked after EOF by the PUT handler via
508    /// [`ComputedDigests::compare_b64`].
509    fn compare_header_claims(
510        digests: &ComputedDigests,
511        expected: &ClientChecksums,
512    ) -> Result<(), StreamingChecksumError> {
513        if let (Some(want), Some(got)) = (expected.content_md5, digests.content_md5)
514            && got != want
515        {
516            return Err(StreamingChecksumError {
517                algorithm: "Content-MD5",
518            });
519        }
520        if let (Some(want), Some(got)) = (expected.crc32, digests.crc32_be)
521            && got != want
522        {
523            return Err(StreamingChecksumError {
524                algorithm: "x-amz-checksum-crc32",
525            });
526        }
527        if let (Some(want), Some(got)) = (expected.crc32c, digests.crc32c_be)
528            && got != want
529        {
530            return Err(StreamingChecksumError {
531                algorithm: "x-amz-checksum-crc32c",
532            });
533        }
534        if let (Some(want), Some(got)) = (expected.sha1, digests.sha1)
535            && got != want
536        {
537            return Err(StreamingChecksumError {
538                algorithm: "x-amz-checksum-sha1",
539            });
540        }
541        if let (Some(want), Some(got)) = (expected.sha256, digests.sha256)
542            && got != want
543        {
544            return Err(StreamingChecksumError {
545                algorithm: "x-amz-checksum-sha256",
546            });
547        }
548        if let (Some(want), Some(got)) = (expected.crc64nvme, digests.crc64nvme_be)
549            && got != want
550        {
551            return Err(StreamingChecksumError {
552                algorithm: "x-amz-checksum-crc64nvme",
553            });
554        }
555        Ok(())
556    }
557}
558
559/// Rolling CRC-64/NVMe accumulator (matches the buffered-path table in
560/// `service.rs::crc64_nvme`). The full byte-for-byte digest is
561/// `!crc64_nvme_append(!0u64, bytes)` — the `!0u64` init is bit-flipped
562/// on output to match the NVMe spec (`xorout = 0xffff_ffff_ffff_ffff`).
563fn crc64_nvme_append(init: u64, bytes: &[u8]) -> u64 {
564    use std::sync::OnceLock;
565    static TABLE: OnceLock<[u64; 256]> = OnceLock::new();
566    let tbl = TABLE.get_or_init(|| {
567        // Reflected polynomial (bit-reverse of 0xad93d23594c93659) —
568        // identical constant to `service.rs::crc64_nvme`; intentionally
569        // duplicated rather than re-exported so the streaming wrapper
570        // has no cross-module dependency on the buffered helper (the
571        // table is 2 KiB total, cost is negligible).
572        const POLY_REFLECTED: u64 = 0x9a6c_9329_ac4b_c9b5;
573        let mut t = [0u64; 256];
574        let mut i = 0usize;
575        while i < 256 {
576            let mut c = i as u64;
577            let mut j = 0;
578            while j < 8 {
579                c = if c & 1 != 0 {
580                    (c >> 1) ^ POLY_REFLECTED
581                } else {
582                    c >> 1
583                };
584                j += 1;
585            }
586            t[i] = c;
587            i += 1;
588        }
589        t
590    });
591    let mut crc = init;
592    for &b in bytes {
593        let idx = ((crc as u8) ^ b) as usize;
594        crc = (crc >> 8) ^ tbl[idx];
595    }
596    crc
597}
598
599/// Wraps `inner` so every chunk yielded by `poll_next` is **also** fed
600/// into a [`HasherSet`] before being passed downstream. On the EOF poll
601/// the hashers are finalised and compared against `expected`; a
602/// mismatch is emitted as a fresh [`io::Error`] (`InvalidData`) carrying
603/// [`StreamingChecksumError`] in its source chain.
604///
605/// The wrapper holds the [`HasherSet`] in a `Mutex` so it stays
606/// `Send + Sync` (the s3s `StreamingBlob` is wrapped in a `Sync`-
607/// erased trait object and any non-`Sync` field would force a deeper
608/// rework of the blob constructor). Lock contention is zero in
609/// practice — only one task polls a given stream at a time.
610struct TeeStream {
611    inner: StreamingBlob,
612    state: Arc<Mutex<TeeState>>,
613    /// Cloned to the PUT handler so it can read the post-EOF
614    /// `ComputedDigests` and run any trailer-supplied comparisons that
615    /// only arrive after the body is consumed.
616    digests_out: DigestHandle,
617}
618
619/// `Some(HasherSet)` while the stream is live; `None` once finalised
620/// (EOF reached AND comparison performed) so we don't double-finalise
621/// if the downstream consumer keeps polling past EOF.
622struct TeeState {
623    hashers: Option<HasherSet>,
624}
625
626impl Stream for TeeStream {
627    type Item = Result<Bytes, s3s::StdError>;
628
629    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
630        let this = self.get_mut();
631        match this.inner.poll_next_unpin(cx) {
632            Poll::Ready(Some(Ok(chunk))) => {
633                // Feed the hashers before yielding the chunk
634                // downstream. Holding the lock across update() is fine
635                // — single-polling-task contract on a Stream.
636                let mut guard = this.state.lock().expect("tee hasher lock poisoned");
637                if let Some(h) = guard.hashers.as_mut() {
638                    h.update(&chunk);
639                }
640                Poll::Ready(Some(Ok(chunk)))
641            }
642            Poll::Ready(Some(Err(e))) => {
643                // Drop the hashers on upstream error so any subsequent
644                // polls don't fire the EOF comparison (the bytes that
645                // were supposed to land never did — comparing a partial
646                // digest would yield a meaningless mismatch).
647                let mut guard = this.state.lock().expect("tee hasher lock poisoned");
648                guard.hashers = None;
649                Poll::Ready(Some(Err(e)))
650            }
651            Poll::Ready(None) => {
652                let mut guard = this.state.lock().expect("tee hasher lock poisoned");
653                if let Some(hashers) = guard.hashers.take() {
654                    let expected_header_claims = hashers.expected.clone();
655                    let digests = hashers.finalize();
656                    // Stash the finalised digests so the PUT handler
657                    // can run trailer comparisons after the body has
658                    // been consumed.
659                    *this
660                        .digests_out
661                        .lock()
662                        .expect("digest handle lock poisoned") = Some(digests.clone());
663                    // Eager EOF-time check against header claims:
664                    // surface a mismatch as a synthetic
665                    // `InvalidData` I/O error carrying
666                    // `StreamingChecksumError`. The PUT handler
667                    // downcasts back to map to `BadDigest`. Trailer
668                    // claims aren't reachable here (they arrive
669                    // post-stream); the PUT handler runs them
670                    // separately via `ComputedDigests::compare_b64`.
671                    if let Err(mismatch) =
672                        HasherSet::compare_header_claims(&digests, &expected_header_claims)
673                    {
674                        let io_err = std::io::Error::new(std::io::ErrorKind::InvalidData, mismatch);
675                        let boxed: s3s::StdError = Box::new(io_err);
676                        return Poll::Ready(Some(Err(boxed)));
677                    }
678                }
679                Poll::Ready(None)
680            }
681            Poll::Pending => Poll::Pending,
682        }
683    }
684
685    fn size_hint(&self) -> (usize, Option<usize>) {
686        self.inner.size_hint()
687    }
688}
689
690impl ByteStream for TeeStream {
691    fn remaining_length(&self) -> RemainingLength {
692        // The wrapper is a 1:1 byte pass-through; defer to inner.
693        self.inner.remaining_length()
694    }
695}
696
697/// Build a [`StreamingBlob`] that streams the underlying body through
698/// `inner`, simultaneously feeding every chunk into the hashers
699/// selected by `which`. On EOF the wrapper finalises every hasher,
700/// (a) deposits the result into the returned [`DigestHandle`] for
701/// post-body trailer comparisons, and (b) eagerly compares against
702/// every claim already present in `expected` (the
703/// header-supplied set); a mismatch surfaces as a synthetic
704/// `io::Error` carrying [`StreamingChecksumError`].
705///
706/// **Caller MUST consume the returned blob to completion** for the
707/// verification step to fire. Dropping the blob mid-stream is treated
708/// as a stream abort (same as any other stream consumer) and skips
709/// the final comparison — the PUT itself will already have failed by
710/// then, so there's nothing useful to compare against.
711pub fn tee_into_hashers_with_handle(
712    inner: StreamingBlob,
713    expected: ClientChecksums,
714    which: WhichHashers,
715) -> (StreamingBlob, DigestHandle) {
716    let digests_out: DigestHandle = Arc::new(Mutex::new(None));
717    let state = Arc::new(Mutex::new(TeeState {
718        hashers: Some(HasherSet::new(expected, which)),
719    }));
720    let tee = TeeStream {
721        inner,
722        state,
723        digests_out: Arc::clone(&digests_out),
724    };
725    (StreamingBlob::new(tee), digests_out)
726}
727
728/// Convenience wrapper for the common case: only header-supplied
729/// claims are present, no trailer expectation. Discards the digest
730/// handle. Kept for the existing call sites and unit tests; new code
731/// that needs trailer support should use
732/// [`tee_into_hashers_with_handle`].
733pub fn tee_into_hashers(inner: StreamingBlob, expected: ClientChecksums) -> StreamingBlob {
734    let which = expected.which_hashers();
735    let (blob, _handle) = tee_into_hashers_with_handle(inner, expected, which);
736    blob
737}
738
739/// v0.9 #106-audit-R2 P2-INT-2: compute the algorithms named by `which`
740/// over `body` in one shot, producing a [`ComputedDigests`] suitable
741/// for trailer comparison via [`ComputedDigests::compare_b64`]. This is
742/// the **buffered-path** counterpart of the streaming tee: the body is
743/// already in memory (e.g. GPU codec branch or non-streaming-framed
744/// PUT), so we don't need the chunk-by-chunk wrapper — a single
745/// in-place hash run suffices.
746///
747/// Lives in this module so the trailer-verify logic on both paths
748/// (streaming-framed and buffered) calls the same finaliser
749/// (`compare_b64`) and the test surface that covers it
750/// (`computed_digests_compare_b64_*`) keeps both paths honest.
751pub fn compute_digests(body: &[u8], which: WhichHashers) -> ComputedDigests {
752    let mut hashers = HasherSet::new(ClientChecksums::default(), which);
753    hashers.update(body);
754    hashers.finalize()
755}
756
757/// Walk an `io::Error` source chain looking for a
758/// [`StreamingChecksumError`]. Returns the algorithm name when found.
759/// Used by the PUT handler to decide whether a `CodecError::Io` was
760/// the streaming verifier's mismatch (→ `BadDigest`) or a genuine
761/// transport-layer I/O failure (→ `InternalError`).
762///
763/// The chain we have to walk:
764///
765/// ```text
766///   CodecError::Io(outer)                            ← service.rs sees this
767///     outer = io::Error::other(boxed_std_err)        ← blob_to_async_read wraps
768///       boxed_std_err = Box<io::Error(InvalidData, StreamingChecksumError)>
769///                                                    ← tee_into_hashers emits
770/// ```
771///
772/// `outer.get_ref()` returns `&dyn Error` pointing at `boxed_std_err`'s
773/// inner. We try the direct downcast first (covers tests / callers
774/// that pass the inner io::Error directly), then peel one Error::other
775/// wrapper to recover the nested io::Error built by the tee.
776pub fn extract_streaming_checksum_error(err: &std::io::Error) -> Option<&'static str> {
777    // 1. Direct: the io::Error we were given carries the
778    //    StreamingChecksumError as its inner. This happens in unit
779    //    tests that construct the error themselves and in any future
780    //    caller that doesn't add the StreamReader wrapper.
781    if let Some(inner) = err.get_ref()
782        && let Some(s) = inner.downcast_ref::<StreamingChecksumError>()
783    {
784        return Some(s.algorithm);
785    }
786    // 2. One-deep: the StreamReader → io::Error::other(StdError) wrap
787    //    added by `blob_to_async_read`. The inner is a
788    //    `Box<dyn Error + Send + Sync>` which we built ourselves
789    //    around an `io::Error`; recover that nested io::Error and
790    //    repeat the lookup.
791    if let Some(inner) = err.get_ref()
792        && let Some(nested_io) = inner.downcast_ref::<std::io::Error>()
793        && let Some(deeper) = nested_io.get_ref()
794        && let Some(s) = deeper.downcast_ref::<StreamingChecksumError>()
795    {
796        return Some(s.algorithm);
797    }
798    // 3. Fallback: best-effort walk of the conventional source chain
799    //    for any future re-wrap that uses `Error::source` properly.
800    let mut src: Option<&dyn std::error::Error> = std::error::Error::source(err);
801    while let Some(e) = src {
802        if let Some(s) = e.downcast_ref::<StreamingChecksumError>() {
803            return Some(s.algorithm);
804        }
805        src = e.source();
806    }
807    None
808}
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813    use bytes::Bytes;
814    use futures::stream;
815
816    fn b64encode(b: &[u8]) -> String {
817        base64::engine::general_purpose::STANDARD.encode(b)
818    }
819
820    fn make_chunked_blob(chunks: Vec<Bytes>) -> StreamingBlob {
821        let stream = stream::iter(chunks.into_iter().map(Ok::<_, std::io::Error>));
822        StreamingBlob::wrap(stream)
823    }
824
825    async fn drain(blob: StreamingBlob) -> Result<Vec<u8>, String> {
826        let mut s = blob;
827        let mut out = Vec::new();
828        while let Some(chunk) = s.next().await {
829            let chunk = chunk.map_err(|e| format!("{e}"))?;
830            out.extend_from_slice(&chunk);
831        }
832        Ok(out)
833    }
834
835    /// Plain pass-through (no claims set) yields the original bytes
836    /// unchanged and never errors at EOF.
837    #[tokio::test]
838    async fn tee_with_no_claims_is_passthrough() {
839        let body = Bytes::from_static(b"hello streaming s4");
840        let blob = make_chunked_blob(vec![body.clone()]);
841        let wrapped = tee_into_hashers(blob, ClientChecksums::default());
842        let got = drain(wrapped).await.unwrap();
843        assert_eq!(got, body.to_vec());
844    }
845
846    /// crc32c claim matches → drain succeeds, no synthetic error.
847    #[tokio::test]
848    async fn crc32c_match_yields_full_body() {
849        let body: Vec<u8> = (0..50_000u32).map(|i| i as u8).collect();
850        let crc = crc32c::crc32c(&body).to_be_bytes();
851        let claims = ClientChecksums::from_request_fields(
852            None,
853            None,
854            Some(&b64encode(&crc)),
855            None,
856            None,
857            None,
858        )
859        .unwrap();
860        let blob = make_chunked_blob(vec![
861            Bytes::copy_from_slice(&body[..20_000]),
862            Bytes::copy_from_slice(&body[20_000..]),
863        ]);
864        let wrapped = tee_into_hashers(blob, claims);
865        let got = drain(wrapped).await.unwrap();
866        assert_eq!(got, body);
867    }
868
869    /// crc32c claim mismatched (off-by-one byte) → EOF poll emits a
870    /// synthetic InvalidData carrying StreamingChecksumError.
871    #[tokio::test]
872    async fn crc32c_mismatch_fires_at_eof() {
873        let body: Vec<u8> = vec![b'a'; 4096];
874        // Use a deliberately wrong CRC.
875        let wrong_crc = (crc32c::crc32c(&body) ^ 0xFFFF_FFFF).to_be_bytes();
876        let claims = ClientChecksums::from_request_fields(
877            None,
878            None,
879            Some(&b64encode(&wrong_crc)),
880            None,
881            None,
882            None,
883        )
884        .unwrap();
885        let blob = make_chunked_blob(vec![Bytes::copy_from_slice(&body)]);
886        let wrapped = tee_into_hashers(blob, claims);
887        let err = drain(wrapped).await.unwrap_err();
888        assert!(
889            err.contains("x-amz-checksum-crc32c"),
890            "error must name the failing algorithm, got: {err}"
891        );
892    }
893
894    /// sha256 claim matches → drain succeeds.
895    #[tokio::test]
896    async fn sha256_match_succeeds_across_many_small_chunks() {
897        let body: Vec<u8> = (0..123_456u32).map(|i| (i ^ 0x5a) as u8).collect();
898        let digest = {
899            let mut h = Sha256::new();
900            h.update(&body);
901            h.finalize()
902        };
903        let claims = ClientChecksums::from_request_fields(
904            None,
905            None,
906            None,
907            None,
908            Some(&b64encode(&digest)),
909            None,
910        )
911        .unwrap();
912        // Split body into many small chunks to exercise the per-chunk
913        // update path.
914        let chunks: Vec<Bytes> = body.chunks(1024).map(Bytes::copy_from_slice).collect();
915        let blob = make_chunked_blob(chunks);
916        let wrapped = tee_into_hashers(blob, claims);
917        let got = drain(wrapped).await.unwrap();
918        assert_eq!(got, body);
919    }
920
921    /// Multiple algorithms set together — all must verify; mismatch in
922    /// any one fires.
923    #[tokio::test]
924    async fn multi_algorithm_one_wrong_fires() {
925        let body = vec![0u8; 8192];
926        let crc32c_be = crc32c::crc32c(&body).to_be_bytes();
927        let mut sha = Sha256::new();
928        sha.update(&body);
929        let sha_correct = sha.finalize();
930        // Flip a byte in the SHA-256 claim.
931        let mut sha_wrong = sha_correct.to_vec();
932        sha_wrong[0] ^= 0xFF;
933        let claims = ClientChecksums::from_request_fields(
934            None,
935            None,
936            Some(&b64encode(&crc32c_be)),
937            None,
938            Some(&b64encode(&sha_wrong)),
939            None,
940        )
941        .unwrap();
942        let blob = make_chunked_blob(vec![Bytes::copy_from_slice(&body)]);
943        let wrapped = tee_into_hashers(blob, claims);
944        let err = drain(wrapped).await.unwrap_err();
945        assert!(
946            err.contains("x-amz-checksum-sha256"),
947            "expected sha256 mismatch, got: {err}"
948        );
949    }
950
951    /// Malformed base64 in the header → InvalidDigest BEFORE any body
952    /// flows.
953    #[test]
954    fn from_request_fields_rejects_malformed_base64() {
955        let err = ClientChecksums::from_request_fields(
956            None,
957            None,
958            Some("not-base-64!!!"),
959            None,
960            None,
961            None,
962        )
963        .unwrap_err();
964        assert_eq!(err.code(), &S3ErrorCode::InvalidDigest);
965    }
966
967    /// Correct base64 but wrong decoded length → InvalidDigest.
968    #[test]
969    fn from_request_fields_rejects_wrong_length() {
970        // base64 of 3 bytes — crc32c demands 4.
971        let too_short = base64::engine::general_purpose::STANDARD.encode([1u8, 2, 3]);
972        let err =
973            ClientChecksums::from_request_fields(None, None, Some(&too_short), None, None, None)
974                .unwrap_err();
975        assert_eq!(err.code(), &S3ErrorCode::InvalidDigest);
976    }
977
978    /// CRC-64/NVME table cross-check against the buffered helper:
979    /// `crc64_nvme_append(!0u64, b"")` xor-out should be 0 (empty
980    /// input has crc 0 by NVMe spec — init xor xorout).
981    #[test]
982    fn crc64_nvme_empty_input_is_zero() {
983        let crc = crc64_nvme_append(!0u64, b"");
984        assert_eq!(!crc, 0u64, "NVMe empty-input CRC must be 0");
985    }
986
987    /// extract_streaming_checksum_error round-trips for an error we
988    /// constructed the same way the wrapper does.
989    #[test]
990    fn extract_recovers_algorithm() {
991        let mismatch = StreamingChecksumError {
992            algorithm: "x-amz-checksum-crc32c",
993        };
994        let io = std::io::Error::new(std::io::ErrorKind::InvalidData, mismatch);
995        assert_eq!(
996            extract_streaming_checksum_error(&io),
997            Some("x-amz-checksum-crc32c")
998        );
999    }
1000
1001    /// Returns None for an unrelated io error.
1002    #[test]
1003    fn extract_returns_none_for_unrelated_io_error() {
1004        let io = std::io::Error::other("unrelated");
1005        assert_eq!(extract_streaming_checksum_error(&io), None);
1006    }
1007
1008    /// `WhichHashers::from_trailer_header` recognises every AWS
1009    /// checksum trailer name (case-insensitive) and ignores
1010    /// unrelated entries.
1011    #[test]
1012    fn which_hashers_from_trailer_header_parses_all_known_names() {
1013        let w = WhichHashers::from_trailer_header(
1014            "x-amz-checksum-crc32, X-Amz-Checksum-Crc32c, x-amz-trailer-signature",
1015        );
1016        assert!(w.crc32);
1017        assert!(w.crc32c);
1018        assert!(!w.sha1);
1019        assert!(!w.sha256);
1020        assert!(!w.crc64nvme);
1021        assert!(!w.content_md5);
1022
1023        let w2 = WhichHashers::from_trailer_header("x-amz-checksum-sha256");
1024        assert!(w2.sha256);
1025        assert!(!w2.crc32c);
1026        let w3 = WhichHashers::from_trailer_header("x-amz-checksum-crc64nvme");
1027        assert!(w3.crc64nvme);
1028    }
1029
1030    /// Trailer-deferred path: tee runs the hasher because
1031    /// `x-amz-trailer` announced it; the header claim itself is
1032    /// empty; the digest handle exposes the finalised value for the
1033    /// PUT handler to compare against the actual trailer value.
1034    #[tokio::test]
1035    async fn tee_with_handle_stashes_digests_for_trailer_compare() {
1036        let body: Vec<u8> = vec![7u8; 9000];
1037        // No header claim — only the trailer hasher selector.
1038        let which = WhichHashers {
1039            crc32c: true,
1040            sha256: true,
1041            ..Default::default()
1042        };
1043        let blob = make_chunked_blob(vec![Bytes::copy_from_slice(&body)]);
1044        let (wrapped, handle) =
1045            tee_into_hashers_with_handle(blob, ClientChecksums::default(), which);
1046        let got = drain(wrapped).await.unwrap();
1047        assert_eq!(got, body);
1048
1049        let computed = handle.lock().unwrap().clone().expect("digests stashed");
1050        let expected_crc32c = crc32c::crc32c(&body).to_be_bytes();
1051        assert_eq!(computed.crc32c_be, Some(expected_crc32c));
1052        let expected_sha256 = {
1053            let mut h = Sha256::new();
1054            h.update(&body);
1055            h.finalize()
1056        };
1057        assert_eq!(computed.sha256.unwrap(), expected_sha256[..]);
1058    }
1059
1060    /// `ComputedDigests::compare_b64` matches a correct claim and
1061    /// rejects a mismatched / malformed claim with the right S3 code.
1062    #[test]
1063    fn computed_digests_compare_b64_match_and_mismatch() {
1064        let body = b"sample-bytes";
1065        let d = ComputedDigests {
1066            crc32c_be: Some(crc32c::crc32c(body).to_be_bytes()),
1067            ..Default::default()
1068        };
1069        // Correct
1070        d.compare_b64(
1071            "x-amz-checksum-crc32c",
1072            &b64encode(&crc32c::crc32c(body).to_be_bytes()),
1073        )
1074        .expect("match must succeed");
1075        // Mismatch
1076        let err = d
1077            .compare_b64("x-amz-checksum-crc32c", &b64encode(&[0u8; 4]))
1078            .unwrap_err();
1079        assert_eq!(err.code().as_str(), "BadDigest");
1080        // Malformed base64
1081        let err = d
1082            .compare_b64("x-amz-checksum-crc32c", "@@@not-b64@@@")
1083            .unwrap_err();
1084        assert_eq!(err.code(), &S3ErrorCode::InvalidDigest);
1085        // Wrong length
1086        let err = d
1087            .compare_b64("x-amz-checksum-crc32c", &b64encode(&[0u8; 8]))
1088            .unwrap_err();
1089        assert_eq!(err.code(), &S3ErrorCode::InvalidDigest);
1090    }
1091
1092    /// `compare_b64` against an algorithm the tee never hashed (slot
1093    /// is `None`) returns `BadDigest` — we cannot verify what the
1094    /// client promised, so we must refuse the PUT rather than
1095    /// silently accept it.
1096    #[test]
1097    fn computed_digests_compare_b64_against_unhashed_algorithm_rejects() {
1098        let d = ComputedDigests::default(); // nothing hashed
1099        let err = d
1100            .compare_b64("x-amz-checksum-sha256", &b64encode(&[0u8; 32]))
1101            .unwrap_err();
1102        assert_eq!(err.code().as_str(), "BadDigest");
1103    }
1104
1105    /// `compare_b64` accepts any-casing trailer names (HTTP header
1106    /// names are case-insensitive per RFC 9110 §5.1; AWS SDKs may
1107    /// announce `X-Amz-Checksum-Crc32c` or `x-amz-checksum-crc32c`
1108    /// interchangeably).
1109    #[test]
1110    fn computed_digests_compare_b64_case_insensitive_algorithm() {
1111        let body = b"sample";
1112        let d = ComputedDigests {
1113            crc32c_be: Some(crc32c::crc32c(body).to_be_bytes()),
1114            ..Default::default()
1115        };
1116        let want = b64encode(&crc32c::crc32c(body).to_be_bytes());
1117        for variant in [
1118            "x-amz-checksum-crc32c",
1119            "X-Amz-Checksum-Crc32c",
1120            "X-AMZ-CHECKSUM-CRC32C",
1121            "x-AMZ-checksum-CRC32C",
1122        ] {
1123            d.compare_b64(variant, &want)
1124                .unwrap_or_else(|e| panic!("variant {variant} must match, got {e:?}"));
1125        }
1126    }
1127}