Skip to main content

structured_zstd/encoding/
streaming_encoder.rs

1use alloc::format;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4use core::mem;
5
6use crate::common::MAX_BLOCK_SIZE;
7#[cfg(feature = "hash")]
8use core::hash::Hasher;
9#[cfg(feature = "hash")]
10use twox_hash::XxHash64;
11
12use crate::encoding::levels::compress_block_encoded;
13use crate::encoding::{
14    CompressionLevel, MatchGeneratorDriver, Matcher, block_header::BlockHeader,
15    frame_compressor::CompressState, frame_compressor::FseTables, frame_header::FrameHeader,
16};
17use crate::io::{Error, ErrorKind, Write};
18
19/// Incremental frame encoder that implements [`Write`].
20///
21/// Data can be provided with multiple `write()` calls. Full blocks are compressed
22/// automatically, `flush()` emits the currently buffered partial block as non-last,
23/// and `finish()` closes the frame and returns the wrapped writer.
24pub struct StreamingEncoder<W: Write, M: Matcher = MatchGeneratorDriver> {
25    drain: Option<W>,
26    compression_level: CompressionLevel,
27    state: CompressState<M>,
28    pending: Vec<u8>,
29    encoded_scratch: Vec<u8>,
30    errored: bool,
31    last_error_kind: Option<ErrorKind>,
32    last_error_message: Option<String>,
33    frame_started: bool,
34    pledged_content_size: Option<u64>,
35    bytes_consumed: u64,
36    #[cfg(feature = "hash")]
37    hasher: XxHash64,
38}
39
40impl<W: Write> StreamingEncoder<W, MatchGeneratorDriver> {
41    /// Creates a streaming encoder backed by the default match generator.
42    ///
43    /// The encoder writes compressed bytes into `drain` and applies `compression_level`
44    /// to all subsequently written blocks.
45    pub fn new(drain: W, compression_level: CompressionLevel) -> Self {
46        Self::new_with_matcher(
47            MatchGeneratorDriver::new(MAX_BLOCK_SIZE as usize, 1),
48            drain,
49            compression_level,
50        )
51    }
52}
53
54impl<W: Write, M: Matcher> StreamingEncoder<W, M> {
55    /// Creates a streaming encoder with an explicitly provided matcher implementation.
56    ///
57    /// This constructor is primarily intended for tests and advanced callers that need
58    /// custom match-window behavior.
59    pub fn new_with_matcher(matcher: M, drain: W, compression_level: CompressionLevel) -> Self {
60        Self {
61            drain: Some(drain),
62            compression_level,
63            state: CompressState {
64                matcher,
65                last_huff_table: None,
66                fse_tables: FseTables::new(),
67                offset_hist: [1, 4, 8],
68            },
69            pending: Vec::new(),
70            encoded_scratch: Vec::new(),
71            errored: false,
72            last_error_kind: None,
73            last_error_message: None,
74            frame_started: false,
75            pledged_content_size: None,
76            bytes_consumed: 0,
77            #[cfg(feature = "hash")]
78            hasher: XxHash64::with_seed(0),
79        }
80    }
81
82    /// Pledge the total uncompressed content size for this frame.
83    ///
84    /// When set, the frame header will include a `Frame_Content_Size` field.
85    /// This enables decoders to pre-allocate output buffers.
86    /// The pledged size is also forwarded as a source-size hint to the
87    /// matcher so small inputs can use smaller matching tables.
88    ///
89    /// Must be called **before** the first [`write`](Write::write) call;
90    /// calling it after the frame header has already been emitted returns an
91    /// error.
92    pub fn set_pledged_content_size(&mut self, size: u64) -> Result<(), Error> {
93        self.ensure_open()?;
94        if self.frame_started {
95            return Err(invalid_input_error(
96                "pledged content size must be set before the first write",
97            ));
98        }
99        self.pledged_content_size = Some(size);
100        // Also use pledged size as source-size hint so the matcher
101        // can select smaller tables for small inputs.
102        self.state.matcher.set_source_size_hint(size);
103        Ok(())
104    }
105
106    /// Provide a hint about the total uncompressed size for the next frame.
107    ///
108    /// Unlike [`set_pledged_content_size`](Self::set_pledged_content_size),
109    /// this does **not** enforce that exactly `size` bytes are written; it
110    /// may reduce matcher tables, advertised frame window, and block sizing
111    /// for small inputs. Must be called before the first
112    /// [`write`](Write::write).
113    pub fn set_source_size_hint(&mut self, size: u64) -> Result<(), Error> {
114        self.ensure_open()?;
115        if self.frame_started {
116            return Err(invalid_input_error(
117                "source size hint must be set before the first write",
118            ));
119        }
120        self.state.matcher.set_source_size_hint(size);
121        Ok(())
122    }
123
124    /// Returns an immutable reference to the wrapped output drain.
125    ///
126    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
127    /// consumes the encoder and returns ownership of the drain.
128    pub fn get_ref(&self) -> &W {
129        self.drain
130            .as_ref()
131            .expect("streaming encoder drain is present until finish consumes self")
132    }
133
134    /// Returns a mutable reference to the wrapped output drain.
135    ///
136    /// It is inadvisable to directly write to the underlying writer, as doing
137    /// so would corrupt the zstd frame being assembled by the encoder.
138    ///
139    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
140    /// consumes the encoder and returns ownership of the drain.
141    pub fn get_mut(&mut self) -> &mut W {
142        self.drain
143            .as_mut()
144            .expect("streaming encoder drain is present until finish consumes self")
145    }
146
147    /// Finalizes the current zstd frame and returns the wrapped output drain.
148    ///
149    /// If no payload was written yet, this still emits a valid empty frame.
150    /// Calling this method consumes the encoder.
151    pub fn finish(mut self) -> Result<W, Error> {
152        self.ensure_open()?;
153
154        // Validate the pledge before finalizing the frame. If finish() is
155        // called before any writes, this also avoids emitting a header with
156        // an incorrect FCS into the drain on mismatch.
157        if let Some(pledged) = self.pledged_content_size
158            && self.bytes_consumed != pledged
159        {
160            return Err(invalid_input_error(
161                "pledged content size does not match bytes consumed",
162            ));
163        }
164
165        self.ensure_frame_started()?;
166
167        if self.pending.is_empty() {
168            self.write_empty_last_block()
169                .map_err(|err| self.fail(err))?;
170        } else {
171            self.emit_pending_block(true)?;
172        }
173
174        let mut drain = self
175            .drain
176            .take()
177            .expect("streaming encoder drain must be present when finishing");
178
179        #[cfg(feature = "hash")]
180        {
181            let checksum = self.hasher.finish() as u32;
182            drain
183                .write_all(&checksum.to_le_bytes())
184                .map_err(|err| self.fail(err))?;
185        }
186
187        drain.flush().map_err(|err| self.fail(err))?;
188        Ok(drain)
189    }
190
191    fn ensure_open(&self) -> Result<(), Error> {
192        if self.errored {
193            return Err(self.sticky_error());
194        }
195        Ok(())
196    }
197
198    // Cold path (only reached after poisoning). The format!() calls still allocate
199    // in no_std even though error_with_kind_message/other_error_owned drop the
200    // message; this is acceptable on an error recovery path to keep match arms simple.
201    fn sticky_error(&self) -> Error {
202        match (self.last_error_kind, self.last_error_message.as_deref()) {
203            (Some(kind), Some(message)) => error_with_kind_message(
204                kind,
205                format!(
206                    "streaming encoder is in an errored state due to previous {kind:?} failure: {message}"
207                ),
208            ),
209            (Some(kind), None) => error_from_kind(kind),
210            (None, Some(message)) => other_error_owned(format!(
211                "streaming encoder is in an errored state: {message}"
212            )),
213            (None, None) => other_error("streaming encoder is in an errored state"),
214        }
215    }
216
217    fn drain_mut(&mut self) -> Result<&mut W, Error> {
218        self.drain
219            .as_mut()
220            .ok_or_else(|| other_error("streaming encoder has no active drain"))
221    }
222
223    fn ensure_frame_started(&mut self) -> Result<(), Error> {
224        if self.frame_started {
225            return Ok(());
226        }
227
228        self.ensure_level_supported()?;
229        self.state.matcher.reset(self.compression_level);
230        self.state.offset_hist = [1, 4, 8];
231        self.state.last_huff_table = None;
232        self.state.fse_tables.ll_previous = None;
233        self.state.fse_tables.ml_previous = None;
234        self.state.fse_tables.of_previous = None;
235        #[cfg(feature = "hash")]
236        {
237            self.hasher = XxHash64::with_seed(0);
238        }
239
240        let window_size = self.state.matcher.window_size();
241        if window_size == 0 {
242            return Err(invalid_input_error(
243                "matcher reported window_size == 0, which is invalid",
244            ));
245        }
246
247        let header = FrameHeader {
248            frame_content_size: self.pledged_content_size,
249            single_segment: false,
250            content_checksum: cfg!(feature = "hash"),
251            dictionary_id: None,
252            window_size: Some(window_size),
253        };
254        let mut encoded_header = Vec::new();
255        header.serialize(&mut encoded_header);
256        self.drain_mut()
257            .and_then(|drain| drain.write_all(&encoded_header))
258            .map_err(|err| self.fail(err))?;
259
260        self.frame_started = true;
261        Ok(())
262    }
263
264    fn block_capacity(&self) -> usize {
265        let matcher_window = self.state.matcher.window_size() as usize;
266        core::cmp::max(1, core::cmp::min(matcher_window, MAX_BLOCK_SIZE as usize))
267    }
268
269    fn allocate_pending_space(&mut self, block_capacity: usize) -> Vec<u8> {
270        let mut space = match self.compression_level {
271            CompressionLevel::Fastest
272            | CompressionLevel::Default
273            | CompressionLevel::Better
274            | CompressionLevel::Best
275            | CompressionLevel::Level(_) => self.state.matcher.get_next_space(),
276            CompressionLevel::Uncompressed => Vec::new(),
277        };
278        space.clear();
279        if space.capacity() > block_capacity {
280            space.shrink_to(block_capacity);
281        }
282        if space.capacity() < block_capacity {
283            space.reserve(block_capacity - space.capacity());
284        }
285        space
286    }
287
288    fn emit_full_pending_block(
289        &mut self,
290        block_capacity: usize,
291        consumed: usize,
292    ) -> Option<Result<usize, Error>> {
293        if self.pending.len() != block_capacity {
294            return None;
295        }
296
297        let new_pending = self.allocate_pending_space(block_capacity);
298        let full_block = mem::replace(&mut self.pending, new_pending);
299        if let Err((err, restored_block)) = self.encode_block(full_block, false) {
300            self.pending = restored_block;
301            let err = self.fail(err);
302            if consumed > 0 {
303                return Some(Ok(consumed));
304            }
305            return Some(Err(err));
306        }
307        None
308    }
309
310    fn emit_pending_block(&mut self, last_block: bool) -> Result<(), Error> {
311        let block = mem::take(&mut self.pending);
312        if let Err((err, restored_block)) = self.encode_block(block, last_block) {
313            self.pending = restored_block;
314            return Err(self.fail(err));
315        }
316        if !last_block {
317            let block_capacity = self.block_capacity();
318            self.pending = self.allocate_pending_space(block_capacity);
319        }
320        Ok(())
321    }
322
323    // Exhaustive match kept intentionally: adding a new CompressionLevel
324    // variant will produce a compile error here, forcing the developer to
325    // decide whether the streaming encoder supports it before shipping.
326    fn ensure_level_supported(&self) -> Result<(), Error> {
327        match self.compression_level {
328            CompressionLevel::Uncompressed
329            | CompressionLevel::Fastest
330            | CompressionLevel::Default
331            | CompressionLevel::Better
332            | CompressionLevel::Best
333            | CompressionLevel::Level(_) => Ok(()),
334        }
335    }
336
337    fn encode_block(
338        &mut self,
339        uncompressed_data: Vec<u8>,
340        last_block: bool,
341    ) -> Result<(), (Error, Vec<u8>)> {
342        let mut raw_block = Some(uncompressed_data);
343        let mut encoded = Vec::new();
344        mem::swap(&mut encoded, &mut self.encoded_scratch);
345        encoded.clear();
346        let needed_capacity = self.block_capacity() + 3;
347        if encoded.capacity() < needed_capacity {
348            encoded.reserve(needed_capacity.saturating_sub(encoded.len()));
349        }
350        let mut moved_into_matcher = false;
351        if raw_block.as_ref().is_some_and(|block| block.is_empty()) {
352            let header = BlockHeader {
353                last_block,
354                block_type: crate::blocks::block::BlockType::Raw,
355                block_size: 0,
356            };
357            header.serialize(&mut encoded);
358        } else {
359            match self.compression_level {
360                CompressionLevel::Uncompressed => {
361                    let block = raw_block.as_ref().expect("raw block missing");
362                    let header = BlockHeader {
363                        last_block,
364                        block_type: crate::blocks::block::BlockType::Raw,
365                        block_size: block.len() as u32,
366                    };
367                    header.serialize(&mut encoded);
368                    encoded.extend_from_slice(block);
369                }
370                CompressionLevel::Fastest
371                | CompressionLevel::Default
372                | CompressionLevel::Better
373                | CompressionLevel::Best
374                | CompressionLevel::Level(_) => {
375                    let block = raw_block.take().expect("raw block missing");
376                    debug_assert!(!block.is_empty(), "empty blocks handled above");
377                    compress_block_encoded(&mut self.state, last_block, block, &mut encoded);
378                    moved_into_matcher = true;
379                }
380            }
381        }
382
383        if let Err(err) = self.drain_mut().and_then(|drain| drain.write_all(&encoded)) {
384            encoded.clear();
385            mem::swap(&mut encoded, &mut self.encoded_scratch);
386            let restored = if moved_into_matcher {
387                self.state.matcher.get_last_space().to_vec()
388            } else {
389                raw_block.unwrap_or_default()
390            };
391            return Err((err, restored));
392        }
393
394        if moved_into_matcher {
395            #[cfg(feature = "hash")]
396            {
397                self.hasher.write(self.state.matcher.get_last_space());
398            }
399        } else {
400            self.hash_block(raw_block.as_deref().unwrap_or(&[]));
401        }
402        encoded.clear();
403        mem::swap(&mut encoded, &mut self.encoded_scratch);
404        Ok(())
405    }
406
407    fn write_empty_last_block(&mut self) -> Result<(), Error> {
408        self.encode_block(Vec::new(), true).map_err(|(err, _)| err)
409    }
410
411    fn fail(&mut self, err: Error) -> Error {
412        self.errored = true;
413        if self.last_error_kind.is_none() {
414            self.last_error_kind = Some(err.kind());
415        }
416        if self.last_error_message.is_none() {
417            self.last_error_message = Some(err.to_string());
418        }
419        err
420    }
421
422    #[cfg(feature = "hash")]
423    fn hash_block(&mut self, uncompressed_data: &[u8]) {
424        self.hasher.write(uncompressed_data);
425    }
426
427    #[cfg(not(feature = "hash"))]
428    fn hash_block(&mut self, _uncompressed_data: &[u8]) {}
429}
430
431impl<W: Write, M: Matcher> Write for StreamingEncoder<W, M> {
432    fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
433        self.ensure_open()?;
434        if buf.is_empty() {
435            return Ok(0);
436        }
437
438        // Check pledge before emitting the frame header so that a misuse
439        // like set_pledged_content_size(0) + write(non_empty) doesn't leave
440        // a partially-written header in the drain.
441        if let Some(pledged) = self.pledged_content_size
442            && self.bytes_consumed >= pledged
443        {
444            return Err(invalid_input_error(
445                "write would exceed pledged content size",
446            ));
447        }
448
449        self.ensure_frame_started()?;
450
451        // Enforce pledged upper bound: truncate the accepted slice to the
452        // remaining allowance so that partial-write semantics are honored
453        // (return Ok(n) with n < buf.len()) instead of failing the full call.
454        let buf = if let Some(pledged) = self.pledged_content_size {
455            let remaining_allowed = pledged
456                .checked_sub(self.bytes_consumed)
457                .ok_or_else(|| invalid_input_error("bytes consumed exceed pledged content size"))?;
458            if remaining_allowed == 0 {
459                return Err(invalid_input_error(
460                    "write would exceed pledged content size",
461                ));
462            }
463            let accepted = core::cmp::min(
464                buf.len(),
465                usize::try_from(remaining_allowed).unwrap_or(usize::MAX),
466            );
467            &buf[..accepted]
468        } else {
469            buf
470        };
471
472        let block_capacity = self.block_capacity();
473        if self.pending.capacity() == 0 {
474            self.pending = self.allocate_pending_space(block_capacity);
475        }
476        let mut remaining = buf;
477        let mut consumed = 0usize;
478
479        while !remaining.is_empty() {
480            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
481                return result;
482            }
483
484            let available = block_capacity - self.pending.len();
485            let to_take = core::cmp::min(remaining.len(), available);
486            if to_take == 0 {
487                break;
488            }
489            self.pending.extend_from_slice(&remaining[..to_take]);
490            remaining = &remaining[to_take..];
491            consumed += to_take;
492
493            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
494                if let Ok(n) = &result {
495                    self.bytes_consumed += *n as u64;
496                }
497                return result;
498            }
499        }
500        self.bytes_consumed += consumed as u64;
501        Ok(consumed)
502    }
503
504    fn flush(&mut self) -> Result<(), Error> {
505        self.ensure_open()?;
506        if self.pending.is_empty() {
507            return self
508                .drain_mut()
509                .and_then(|drain| drain.flush())
510                .map_err(|err| self.fail(err));
511        }
512        self.ensure_frame_started()?;
513        self.emit_pending_block(false)?;
514        self.drain_mut()
515            .and_then(|drain| drain.flush())
516            .map_err(|err| self.fail(err))
517    }
518}
519
520fn error_from_kind(kind: ErrorKind) -> Error {
521    Error::from(kind)
522}
523
524fn error_with_kind_message(kind: ErrorKind, message: String) -> Error {
525    #[cfg(feature = "std")]
526    {
527        Error::new(kind, message)
528    }
529    #[cfg(not(feature = "std"))]
530    {
531        Error::new(kind, alloc::boxed::Box::new(message))
532    }
533}
534
535fn invalid_input_error(message: &str) -> Error {
536    #[cfg(feature = "std")]
537    {
538        Error::new(ErrorKind::InvalidInput, message)
539    }
540    #[cfg(not(feature = "std"))]
541    {
542        Error::new(
543            ErrorKind::Other,
544            alloc::boxed::Box::new(alloc::string::String::from(message)),
545        )
546    }
547}
548
549fn other_error_owned(message: String) -> Error {
550    #[cfg(feature = "std")]
551    {
552        Error::other(message)
553    }
554    #[cfg(not(feature = "std"))]
555    {
556        Error::new(ErrorKind::Other, alloc::boxed::Box::new(message))
557    }
558}
559
560fn other_error(message: &str) -> Error {
561    #[cfg(feature = "std")]
562    {
563        Error::other(message)
564    }
565    #[cfg(not(feature = "std"))]
566    {
567        Error::new(
568            ErrorKind::Other,
569            alloc::boxed::Box::new(alloc::string::String::from(message)),
570        )
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use crate::decoding::StreamingDecoder;
577    use crate::encoding::{CompressionLevel, Matcher, Sequence, StreamingEncoder};
578    use crate::io::{Error, ErrorKind, Read, Write};
579    use alloc::vec;
580    use alloc::vec::Vec;
581
582    struct TinyMatcher {
583        last_space: Vec<u8>,
584        window_size: u64,
585    }
586
587    impl TinyMatcher {
588        fn new(window_size: u64) -> Self {
589            Self {
590                last_space: Vec::new(),
591                window_size,
592            }
593        }
594    }
595
596    impl Matcher for TinyMatcher {
597        fn get_next_space(&mut self) -> Vec<u8> {
598            vec![0; self.window_size as usize]
599        }
600
601        fn get_last_space(&mut self) -> &[u8] {
602            self.last_space.as_slice()
603        }
604
605        fn commit_space(&mut self, space: Vec<u8>) {
606            self.last_space = space;
607        }
608
609        fn skip_matching(&mut self) {}
610
611        fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
612            handle_sequence(Sequence::Literals {
613                literals: self.last_space.as_slice(),
614            });
615        }
616
617        fn reset(&mut self, _level: CompressionLevel) {
618            self.last_space.clear();
619        }
620
621        fn window_size(&self) -> u64 {
622            self.window_size
623        }
624    }
625
626    struct FailingWriteOnce {
627        writes: usize,
628        fail_on_write_number: usize,
629        sink: Vec<u8>,
630    }
631
632    impl FailingWriteOnce {
633        fn new(fail_on_write_number: usize) -> Self {
634            Self {
635                writes: 0,
636                fail_on_write_number,
637                sink: Vec::new(),
638            }
639        }
640    }
641
642    impl Write for FailingWriteOnce {
643        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
644            self.writes += 1;
645            if self.writes == self.fail_on_write_number {
646                return Err(super::other_error("injected write failure"));
647            }
648            self.sink.extend_from_slice(buf);
649            Ok(buf.len())
650        }
651
652        fn flush(&mut self) -> Result<(), Error> {
653            Ok(())
654        }
655    }
656
657    struct FailingWithKind {
658        writes: usize,
659        fail_on_write_number: usize,
660        kind: ErrorKind,
661    }
662
663    impl FailingWithKind {
664        fn new(fail_on_write_number: usize, kind: ErrorKind) -> Self {
665            Self {
666                writes: 0,
667                fail_on_write_number,
668                kind,
669            }
670        }
671    }
672
673    impl Write for FailingWithKind {
674        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
675            self.writes += 1;
676            if self.writes == self.fail_on_write_number {
677                return Err(Error::from(self.kind));
678            }
679            Ok(buf.len())
680        }
681
682        fn flush(&mut self) -> Result<(), Error> {
683            Ok(())
684        }
685    }
686
687    struct PartialThenFailWriter {
688        writes: usize,
689        fail_on_write_number: usize,
690        partial_prefix_len: usize,
691        terminal_failure: bool,
692        sink: Vec<u8>,
693    }
694
695    impl PartialThenFailWriter {
696        fn new(fail_on_write_number: usize, partial_prefix_len: usize) -> Self {
697            Self {
698                writes: 0,
699                fail_on_write_number,
700                partial_prefix_len,
701                terminal_failure: false,
702                sink: Vec::new(),
703            }
704        }
705    }
706
707    impl Write for PartialThenFailWriter {
708        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
709            if self.terminal_failure {
710                return Err(super::other_error("injected terminal write failure"));
711            }
712
713            self.writes += 1;
714            if self.writes == self.fail_on_write_number {
715                let written = core::cmp::min(self.partial_prefix_len, buf.len());
716                if written > 0 {
717                    self.sink.extend_from_slice(&buf[..written]);
718                    self.terminal_failure = true;
719                    return Ok(written);
720                }
721                return Err(super::other_error("injected terminal write failure"));
722            }
723
724            self.sink.extend_from_slice(buf);
725            Ok(buf.len())
726        }
727
728        fn flush(&mut self) -> Result<(), Error> {
729            Ok(())
730        }
731    }
732
733    #[test]
734    fn streaming_encoder_roundtrip_multiple_writes() {
735        let payload = b"streaming-encoder-roundtrip-".repeat(1024);
736        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
737        for chunk in payload.chunks(313) {
738            encoder.write_all(chunk).unwrap();
739        }
740        let compressed = encoder.finish().unwrap();
741
742        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
743        let mut decoded = Vec::new();
744        decoder.read_to_end(&mut decoded).unwrap();
745        assert_eq!(decoded, payload);
746    }
747
748    #[test]
749    fn flush_emits_nonempty_partial_output() {
750        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
751        encoder.write_all(b"partial-block").unwrap();
752        encoder.flush().unwrap();
753        let flushed_len = encoder.get_ref().len();
754        assert!(
755            flushed_len > 0,
756            "flush should emit header+partial block bytes"
757        );
758        let compressed = encoder.finish().unwrap();
759        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
760        let mut decoded = Vec::new();
761        decoder.read_to_end(&mut decoded).unwrap();
762        assert_eq!(decoded, b"partial-block");
763    }
764
765    #[test]
766    fn flush_without_writes_does_not_emit_frame_header() {
767        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
768        encoder.flush().unwrap();
769        assert!(encoder.get_ref().is_empty());
770    }
771
772    #[test]
773    fn block_boundary_write_emits_block_in_same_call() {
774        let mut boundary = StreamingEncoder::new_with_matcher(
775            TinyMatcher::new(4),
776            Vec::new(),
777            CompressionLevel::Uncompressed,
778        );
779        let mut below = StreamingEncoder::new_with_matcher(
780            TinyMatcher::new(4),
781            Vec::new(),
782            CompressionLevel::Uncompressed,
783        );
784
785        boundary.write_all(b"ABCD").unwrap();
786        below.write_all(b"ABC").unwrap();
787
788        let boundary_len = boundary.get_ref().len();
789        let below_len = below.get_ref().len();
790        assert!(
791            boundary_len > below_len,
792            "full block should be emitted immediately at block boundary"
793        );
794    }
795
796    #[test]
797    fn finish_consumes_encoder_and_emits_frame() {
798        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
799        encoder.write_all(b"abc").unwrap();
800        let compressed = encoder.finish().unwrap();
801        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
802        let mut decoded = Vec::new();
803        decoder.read_to_end(&mut decoded).unwrap();
804        assert_eq!(decoded, b"abc");
805    }
806
807    #[test]
808    fn finish_without_writes_emits_empty_frame() {
809        let encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
810        let compressed = encoder.finish().unwrap();
811        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
812        let mut decoded = Vec::new();
813        decoder.read_to_end(&mut decoded).unwrap();
814        assert!(decoded.is_empty());
815    }
816
817    #[test]
818    fn write_empty_buffer_returns_zero() {
819        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
820        assert_eq!(encoder.write(&[]).unwrap(), 0);
821        let _ = encoder.finish().unwrap();
822    }
823
824    #[test]
825    fn uncompressed_level_roundtrip() {
826        let payload = b"uncompressed-streaming-roundtrip".repeat(64);
827        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Uncompressed);
828        for chunk in payload.chunks(41) {
829            encoder.write_all(chunk).unwrap();
830        }
831        let compressed = encoder.finish().unwrap();
832        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
833        let mut decoded = Vec::new();
834        decoder.read_to_end(&mut decoded).unwrap();
835        assert_eq!(decoded, payload);
836    }
837
838    #[test]
839    fn better_level_streaming_roundtrip() {
840        let payload = b"better-level-streaming-test".repeat(256);
841        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Better);
842        for chunk in payload.chunks(53) {
843            encoder.write_all(chunk).unwrap();
844        }
845        let compressed = encoder.finish().unwrap();
846        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
847        let mut decoded = Vec::new();
848        decoder.read_to_end(&mut decoded).unwrap();
849        assert_eq!(decoded, payload);
850    }
851
852    #[test]
853    fn zero_window_matcher_returns_invalid_input_error() {
854        let mut encoder = StreamingEncoder::new_with_matcher(
855            TinyMatcher::new(0),
856            Vec::new(),
857            CompressionLevel::Fastest,
858        );
859        let err = encoder.write_all(b"payload").unwrap_err();
860        assert_eq!(err.kind(), ErrorKind::InvalidInput);
861    }
862
863    #[test]
864    fn best_level_streaming_roundtrip() {
865        // 200 KiB payload crosses the 128 KiB block boundary, exercising
866        // multi-block emission and matcher state carry-over for Best.
867        let payload = b"best-level-streaming-test".repeat(8 * 1024);
868        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Best);
869        for chunk in payload.chunks(53) {
870            encoder.write_all(chunk).unwrap();
871        }
872        let compressed = encoder.finish().unwrap();
873        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
874        let mut decoded = Vec::new();
875        decoder.read_to_end(&mut decoded).unwrap();
876        assert_eq!(decoded, payload);
877    }
878
879    #[test]
880    fn write_failure_poisoning_is_sticky() {
881        let mut encoder = StreamingEncoder::new_with_matcher(
882            TinyMatcher::new(4),
883            FailingWriteOnce::new(1),
884            CompressionLevel::Uncompressed,
885        );
886
887        assert!(encoder.write_all(b"ABCD").is_err());
888        assert!(encoder.flush().is_err());
889        assert!(encoder.write_all(b"EFGH").is_err());
890        assert_eq!(encoder.get_ref().sink.len(), 0);
891        assert!(encoder.finish().is_err());
892    }
893
894    #[test]
895    fn poisoned_encoder_returns_original_error_kind() {
896        let mut encoder = StreamingEncoder::new_with_matcher(
897            TinyMatcher::new(4),
898            FailingWithKind::new(1, ErrorKind::BrokenPipe),
899            CompressionLevel::Uncompressed,
900        );
901
902        let first_error = encoder.write_all(b"ABCD").unwrap_err();
903        assert_eq!(first_error.kind(), ErrorKind::BrokenPipe);
904
905        let second_error = encoder.write_all(b"EFGH").unwrap_err();
906        assert_eq!(second_error.kind(), ErrorKind::BrokenPipe);
907    }
908
909    #[test]
910    fn write_reports_progress_but_poisoning_is_sticky_after_later_block_failure() {
911        let payload = b"ABCDEFGHIJKL";
912        let mut encoder = StreamingEncoder::new_with_matcher(
913            TinyMatcher::new(4),
914            FailingWriteOnce::new(3),
915            CompressionLevel::Uncompressed,
916        );
917
918        let first_write = encoder.write(payload).unwrap();
919        assert_eq!(first_write, 8);
920        assert!(encoder.write(&payload[first_write..]).is_err());
921        assert!(encoder.flush().is_err());
922        assert!(encoder.write_all(b"EFGH").is_err());
923    }
924
925    #[test]
926    fn partial_write_failure_after_progress_poisons_encoder() {
927        let payload = b"ABCDEFGHIJKL";
928        let mut encoder = StreamingEncoder::new_with_matcher(
929            TinyMatcher::new(4),
930            PartialThenFailWriter::new(3, 1),
931            CompressionLevel::Uncompressed,
932        );
933
934        let first_write = encoder.write(payload).unwrap();
935        assert_eq!(first_write, 8);
936
937        let second_write = encoder.write(&payload[first_write..]);
938        assert!(second_write.is_err());
939        assert!(encoder.flush().is_err());
940        assert!(encoder.write_all(b"MNOP").is_err());
941    }
942
943    #[test]
944    fn new_with_matcher_and_get_mut_work() {
945        let matcher = TinyMatcher::new(128 * 1024);
946        let mut encoder =
947            StreamingEncoder::new_with_matcher(matcher, Vec::new(), CompressionLevel::Fastest);
948        encoder.get_mut().extend_from_slice(b"");
949        encoder.write_all(b"custom-matcher").unwrap();
950        let compressed = encoder.finish().unwrap();
951        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
952        let mut decoded = Vec::new();
953        decoder.read_to_end(&mut decoded).unwrap();
954        assert_eq!(decoded, b"custom-matcher");
955    }
956
957    #[cfg(feature = "std")]
958    #[test]
959    fn streaming_encoder_output_decompresses_with_c_zstd() {
960        let payload = b"tenant=demo op=put key=streaming value=abcdef\n".repeat(4096);
961        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
962        for chunk in payload.chunks(1024) {
963            encoder.write_all(chunk).unwrap();
964        }
965        let compressed = encoder.finish().unwrap();
966
967        let mut decoded = Vec::with_capacity(payload.len());
968        zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
969        assert_eq!(decoded, payload);
970    }
971
972    #[test]
973    fn pledged_content_size_written_in_header() {
974        let payload = b"hello world, pledged size test";
975        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
976        encoder
977            .set_pledged_content_size(payload.len() as u64)
978            .unwrap();
979        encoder.write_all(payload).unwrap();
980        let compressed = encoder.finish().unwrap();
981
982        // Verify FCS is present and correct
983        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
984            .unwrap()
985            .0;
986        assert_eq!(header.frame_content_size(), payload.len() as u64);
987
988        // Verify roundtrip
989        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
990        let mut decoded = Vec::new();
991        decoder.read_to_end(&mut decoded).unwrap();
992        assert_eq!(decoded, payload);
993    }
994
995    #[test]
996    fn pledged_content_size_mismatch_returns_error() {
997        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
998        encoder.set_pledged_content_size(100).unwrap();
999        encoder.write_all(b"short payload").unwrap(); // 13 bytes != 100 pledged
1000        let err = encoder.finish().unwrap_err();
1001        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1002    }
1003
1004    #[test]
1005    fn write_exceeding_pledge_returns_error() {
1006        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1007        encoder.set_pledged_content_size(5).unwrap();
1008        let err = encoder.write_all(b"exceeds five bytes").unwrap_err();
1009        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1010    }
1011
1012    #[test]
1013    fn write_straddling_pledge_reports_partial_progress() {
1014        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1015        encoder.set_pledged_content_size(5).unwrap();
1016        // write() should accept exactly 5 bytes (partial progress)
1017        assert_eq!(encoder.write(b"abcdef").unwrap(), 5);
1018        // Next write should fail — pledge exhausted
1019        let err = encoder.write(b"g").unwrap_err();
1020        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1021    }
1022
1023    #[test]
1024    fn encoded_scratch_capacity_is_reused_across_blocks() {
1025        let payload = vec![0xAB; 64 * 3];
1026        let mut encoder = StreamingEncoder::new_with_matcher(
1027            TinyMatcher::new(64),
1028            Vec::new(),
1029            CompressionLevel::Uncompressed,
1030        );
1031
1032        encoder.write_all(&payload[..64]).unwrap();
1033        let first_capacity = encoder.encoded_scratch.capacity();
1034        assert!(
1035            first_capacity >= 67,
1036            "expected encoded scratch to keep block header + payload capacity",
1037        );
1038
1039        encoder.write_all(&payload[64..128]).unwrap();
1040        let second_capacity = encoder.encoded_scratch.capacity();
1041        assert!(
1042            second_capacity >= first_capacity,
1043            "encoded scratch capacity should be reused across block emits",
1044        );
1045
1046        encoder.write_all(&payload[128..]).unwrap();
1047        let compressed = encoder.finish().unwrap();
1048        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
1049        let mut decoded = Vec::new();
1050        decoder.read_to_end(&mut decoded).unwrap();
1051        assert_eq!(decoded, payload);
1052    }
1053
1054    #[test]
1055    fn pledged_content_size_after_write_returns_error() {
1056        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1057        encoder.write_all(b"already writing").unwrap();
1058        let err = encoder.set_pledged_content_size(15).unwrap_err();
1059        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1060    }
1061
1062    #[test]
1063    fn source_size_hint_directly_reduces_window_header() {
1064        let payload = b"streaming-source-size-hint".repeat(64);
1065
1066        let mut no_hint = StreamingEncoder::new(Vec::new(), CompressionLevel::from_level(11));
1067        no_hint.write_all(payload.as_slice()).unwrap();
1068        let no_hint_frame = no_hint.finish().unwrap();
1069        let no_hint_header = crate::decoding::frame::read_frame_header(no_hint_frame.as_slice())
1070            .unwrap()
1071            .0;
1072        let no_hint_window = no_hint_header.window_size().unwrap();
1073
1074        let mut with_hint = StreamingEncoder::new(Vec::new(), CompressionLevel::from_level(11));
1075        with_hint
1076            .set_source_size_hint(payload.len() as u64)
1077            .unwrap();
1078        with_hint.write_all(payload.as_slice()).unwrap();
1079        let late_hint_err = with_hint
1080            .set_source_size_hint(payload.len() as u64)
1081            .unwrap_err();
1082        assert_eq!(late_hint_err.kind(), ErrorKind::InvalidInput);
1083        let with_hint_frame = with_hint.finish().unwrap();
1084        let with_hint_header =
1085            crate::decoding::frame::read_frame_header(with_hint_frame.as_slice())
1086                .unwrap()
1087                .0;
1088        let with_hint_window = with_hint_header.window_size().unwrap();
1089
1090        assert!(
1091            with_hint_window <= no_hint_window,
1092            "source size hint should not increase advertised window"
1093        );
1094
1095        let mut decoder = StreamingDecoder::new(with_hint_frame.as_slice()).unwrap();
1096        let mut decoded = Vec::new();
1097        decoder.read_to_end(&mut decoded).unwrap();
1098        assert_eq!(decoded, payload);
1099    }
1100
1101    #[cfg(feature = "std")]
1102    #[test]
1103    fn pledged_content_size_c_zstd_compatible() {
1104        let payload = b"tenant=demo op=put key=streaming value=abcdef\n".repeat(4096);
1105        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1106        encoder
1107            .set_pledged_content_size(payload.len() as u64)
1108            .unwrap();
1109        for chunk in payload.chunks(1024) {
1110            encoder.write_all(chunk).unwrap();
1111        }
1112        let compressed = encoder.finish().unwrap();
1113
1114        // FCS should be written
1115        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1116            .unwrap()
1117            .0;
1118        assert_eq!(header.frame_content_size(), payload.len() as u64);
1119
1120        // C zstd should decompress successfully
1121        let mut decoded = Vec::new();
1122        zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
1123        assert_eq!(decoded, payload);
1124    }
1125
1126    #[test]
1127    fn no_pledged_size_omits_fcs_from_header() {
1128        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1129        encoder.write_all(b"no pledged size").unwrap();
1130        let compressed = encoder.finish().unwrap();
1131
1132        // FCS should be omitted from the header; the decoder reports absent FCS as 0.
1133        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1134            .unwrap()
1135            .0;
1136        assert_eq!(header.frame_content_size(), 0);
1137        // Verify the descriptor confirms FCS field is truly absent (0 bytes),
1138        // not just FCS present with value 0.
1139        assert_eq!(header.descriptor.frame_content_size_bytes().unwrap(), 0);
1140    }
1141}