Skip to main content

structured_zstd/encoding/
streaming_encoder.rs

1use alloc::format;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4use core::mem;
5
6use crate::common::MAX_BLOCK_SIZE;
7#[cfg(feature = "hash")]
8use core::hash::Hasher;
9#[cfg(feature = "hash")]
10use twox_hash::XxHash64;
11
12use crate::encoding::levels::compress_block_encoded;
13use crate::encoding::{
14    CompressionLevel, MatchGeneratorDriver, Matcher, block_header::BlockHeader,
15    frame_compressor::CompressState, frame_compressor::FseTables, frame_header::FrameHeader,
16};
17use crate::io::{Error, ErrorKind, Write};
18
19/// Incremental frame encoder that implements [`Write`].
20///
21/// Data can be provided with multiple `write()` calls. Full blocks are compressed
22/// automatically, `flush()` emits the currently buffered partial block as non-last,
23/// and `finish()` closes the frame and returns the wrapped writer.
24pub struct StreamingEncoder<W: Write, M: Matcher = MatchGeneratorDriver> {
25    drain: Option<W>,
26    compression_level: CompressionLevel,
27    state: CompressState<M>,
28    pending: Vec<u8>,
29    encoded_scratch: Vec<u8>,
30    errored: bool,
31    last_error_kind: Option<ErrorKind>,
32    last_error_message: Option<String>,
33    frame_started: bool,
34    pledged_content_size: Option<u64>,
35    bytes_consumed: u64,
36    /// `ZSTD_f_zstd1_magicless` — omit the 4-byte magic number prefix.
37    /// Default false. See [`Self::set_magicless`].
38    magicless: bool,
39    #[cfg(feature = "hash")]
40    hasher: XxHash64,
41}
42
43impl<W: Write> StreamingEncoder<W, MatchGeneratorDriver> {
44    /// Creates a streaming encoder backed by the default match generator.
45    ///
46    /// The encoder writes compressed bytes into `drain` and applies `compression_level`
47    /// to all subsequently written blocks.
48    pub fn new(drain: W, compression_level: CompressionLevel) -> Self {
49        Self::new_with_matcher(
50            MatchGeneratorDriver::new(MAX_BLOCK_SIZE as usize, 1),
51            drain,
52            compression_level,
53        )
54    }
55}
56
57impl<W: Write, M: Matcher> StreamingEncoder<W, M> {
58    /// Creates a streaming encoder with an explicitly provided matcher implementation.
59    ///
60    /// This constructor is primarily intended for tests and advanced callers that need
61    /// custom match-window behavior.
62    pub fn new_with_matcher(matcher: M, drain: W, compression_level: CompressionLevel) -> Self {
63        Self {
64            drain: Some(drain),
65            compression_level,
66            state: CompressState {
67                matcher,
68                last_huff_table: None,
69                fse_tables: FseTables::new(),
70                block_scratch: crate::encoding::blocks::CompressedBlockScratch::new(),
71                offset_hist: [1, 4, 8],
72                strategy_tag: crate::encoding::strategy::StrategyTag::for_compression_level(
73                    compression_level,
74                ),
75            },
76            pending: Vec::new(),
77            encoded_scratch: Vec::new(),
78            errored: false,
79            last_error_kind: None,
80            last_error_message: None,
81            frame_started: false,
82            pledged_content_size: None,
83            bytes_consumed: 0,
84            magicless: false,
85            #[cfg(feature = "hash")]
86            hasher: XxHash64::with_seed(0),
87        }
88    }
89
90    /// Enable or disable magicless frame format (`ZSTD_f_zstd1_magicless`).
91    ///
92    /// When set to `true`, the frame header serialized by this encoder
93    /// omits the 4-byte magic number prefix. Must be called BEFORE the
94    /// first [`write`](Write::write) call; calling it after the frame
95    /// header has already been emitted returns an error so the caller
96    /// can't be misled into thinking they produced a magicless stream.
97    pub fn set_magicless(&mut self, magicless: bool) -> Result<(), Error> {
98        self.ensure_open()?;
99        if self.frame_started {
100            return Err(invalid_input_error(
101                "magicless format must be set before the first write",
102            ));
103        }
104        self.magicless = magicless;
105        Ok(())
106    }
107
108    /// Pledge the total uncompressed content size for this frame.
109    ///
110    /// When set, the frame header will include a `Frame_Content_Size` field.
111    /// This enables decoders to pre-allocate output buffers.
112    /// The pledged size is also forwarded as a source-size hint to the
113    /// matcher so small inputs can use smaller matching tables.
114    ///
115    /// Must be called **before** the first [`write`](Write::write) call;
116    /// calling it after the frame header has already been emitted returns an
117    /// error.
118    pub fn set_pledged_content_size(&mut self, size: u64) -> Result<(), Error> {
119        self.ensure_open()?;
120        if self.frame_started {
121            return Err(invalid_input_error(
122                "pledged content size must be set before the first write",
123            ));
124        }
125        self.pledged_content_size = Some(size);
126        // Also use pledged size as source-size hint so the matcher
127        // can select smaller tables for small inputs.
128        self.state.matcher.set_source_size_hint(size);
129        Ok(())
130    }
131
132    /// Provide a hint about the total uncompressed size for the next frame.
133    ///
134    /// Unlike [`set_pledged_content_size`](Self::set_pledged_content_size),
135    /// this does **not** enforce that exactly `size` bytes are written; it
136    /// may reduce matcher tables, advertised frame window, and block sizing
137    /// for small inputs. Must be called before the first
138    /// [`write`](Write::write).
139    pub fn set_source_size_hint(&mut self, size: u64) -> Result<(), Error> {
140        self.ensure_open()?;
141        if self.frame_started {
142            return Err(invalid_input_error(
143                "source size hint must be set before the first write",
144            ));
145        }
146        self.state.matcher.set_source_size_hint(size);
147        Ok(())
148    }
149
150    /// Returns an immutable reference to the wrapped output drain.
151    ///
152    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
153    /// consumes the encoder and returns ownership of the drain.
154    pub fn get_ref(&self) -> &W {
155        self.drain
156            .as_ref()
157            .expect("streaming encoder drain is present until finish consumes self")
158    }
159
160    /// Returns a mutable reference to the wrapped output drain.
161    ///
162    /// It is inadvisable to directly write to the underlying writer, as doing
163    /// so would corrupt the zstd frame being assembled by the encoder.
164    ///
165    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
166    /// consumes the encoder and returns ownership of the drain.
167    pub fn get_mut(&mut self) -> &mut W {
168        self.drain
169            .as_mut()
170            .expect("streaming encoder drain is present until finish consumes self")
171    }
172
173    /// Finalizes the current zstd frame and returns the wrapped output drain.
174    ///
175    /// If no payload was written yet, this still emits a valid empty frame.
176    /// Calling this method consumes the encoder.
177    pub fn finish(mut self) -> Result<W, Error> {
178        self.ensure_open()?;
179
180        // Validate the pledge before finalizing the frame. If finish() is
181        // called before any writes, this also avoids emitting a header with
182        // an incorrect FCS into the drain on mismatch.
183        if let Some(pledged) = self.pledged_content_size
184            && self.bytes_consumed != pledged
185        {
186            return Err(invalid_input_error(
187                "pledged content size does not match bytes consumed",
188            ));
189        }
190
191        self.ensure_frame_started()?;
192
193        if self.pending.is_empty() {
194            self.write_empty_last_block()
195                .map_err(|err| self.fail(err))?;
196        } else {
197            self.emit_pending_block(true)?;
198        }
199
200        let mut drain = self
201            .drain
202            .take()
203            .expect("streaming encoder drain must be present when finishing");
204
205        #[cfg(feature = "hash")]
206        {
207            let checksum = self.hasher.finish() as u32;
208            drain
209                .write_all(&checksum.to_le_bytes())
210                .map_err(|err| self.fail(err))?;
211        }
212
213        drain.flush().map_err(|err| self.fail(err))?;
214        Ok(drain)
215    }
216
217    fn ensure_open(&self) -> Result<(), Error> {
218        if self.errored {
219            return Err(self.sticky_error());
220        }
221        Ok(())
222    }
223
224    // Cold path (only reached after poisoning). The format!() calls still allocate
225    // in no_std even though error_with_kind_message/other_error_owned drop the
226    // message; this is acceptable on an error recovery path to keep match arms simple.
227    fn sticky_error(&self) -> Error {
228        match (self.last_error_kind, self.last_error_message.as_deref()) {
229            (Some(kind), Some(message)) => error_with_kind_message(
230                kind,
231                format!(
232                    "streaming encoder is in an errored state due to previous {kind:?} failure: {message}"
233                ),
234            ),
235            (Some(kind), None) => error_from_kind(kind),
236            (None, Some(message)) => other_error_owned(format!(
237                "streaming encoder is in an errored state: {message}"
238            )),
239            (None, None) => other_error("streaming encoder is in an errored state"),
240        }
241    }
242
243    fn drain_mut(&mut self) -> Result<&mut W, Error> {
244        self.drain
245            .as_mut()
246            .ok_or_else(|| other_error("streaming encoder has no active drain"))
247    }
248
249    fn ensure_frame_started(&mut self) -> Result<(), Error> {
250        if self.frame_started {
251            return Ok(());
252        }
253
254        self.ensure_level_supported()?;
255        self.state.matcher.reset(self.compression_level);
256        self.state.offset_hist = [1, 4, 8];
257        self.state.last_huff_table = None;
258        self.state.fse_tables.ll_previous = None;
259        self.state.fse_tables.ml_previous = None;
260        self.state.fse_tables.of_previous = None;
261        // Sync `state.strategy_tag` from the active compression level so the
262        // literal-compression gates (`min_literals_to_compress`, `min_gain`
263        // in `encoding::blocks::compressed`) see the correct strategy for
264        // every frame. Mirrors `FrameCompressor::compress` and keeps both
265        // entry points byte-equivalent at the gate level.
266        self.state.strategy_tag =
267            crate::encoding::strategy::StrategyTag::for_compression_level(self.compression_level);
268        #[cfg(feature = "hash")]
269        {
270            self.hasher = XxHash64::with_seed(0);
271        }
272
273        let window_size = self.state.matcher.window_size();
274        if window_size == 0 {
275            return Err(invalid_input_error(
276                "matcher reported window_size == 0, which is invalid",
277            ));
278        }
279
280        // FrameCompressor gates single-segment on dictionary usage state; the
281        // streaming encoder currently has no dictionary API/state, so we only
282        // gate on pledged size and window reach here.
283        // TODO: if streaming dictionary support is added, mirror the
284        // !use_dictionary_state guard from FrameCompressor.
285        let single_segment = self
286            .pledged_content_size
287            .map(|size| (512..=(1 << 14)).contains(&size) && size <= window_size)
288            .unwrap_or(false);
289
290        let header = FrameHeader {
291            frame_content_size: self.pledged_content_size,
292            single_segment,
293            content_checksum: cfg!(feature = "hash"),
294            dictionary_id: None,
295            window_size: if single_segment {
296                None
297            } else {
298                Some(window_size)
299            },
300            magicless: self.magicless,
301        };
302        let mut encoded_header = Vec::new();
303        header.serialize(&mut encoded_header);
304        self.drain_mut()
305            .and_then(|drain| drain.write_all(&encoded_header))
306            .map_err(|err| self.fail(err))?;
307
308        self.frame_started = true;
309        Ok(())
310    }
311
312    fn block_capacity(&self) -> usize {
313        let matcher_window = self.state.matcher.window_size() as usize;
314        core::cmp::max(1, core::cmp::min(matcher_window, MAX_BLOCK_SIZE as usize))
315    }
316
317    fn allocate_pending_space(&mut self, block_capacity: usize) -> Vec<u8> {
318        let mut space = match self.compression_level {
319            CompressionLevel::Fastest
320            | CompressionLevel::Default
321            | CompressionLevel::Better
322            | CompressionLevel::Best
323            | CompressionLevel::Level(_) => self.state.matcher.get_next_space(),
324            CompressionLevel::Uncompressed => Vec::new(),
325        };
326        space.clear();
327        if space.capacity() > block_capacity {
328            space.shrink_to(block_capacity);
329        }
330        if space.capacity() < block_capacity {
331            space.reserve(block_capacity - space.capacity());
332        }
333        space
334    }
335
336    fn emit_full_pending_block(
337        &mut self,
338        block_capacity: usize,
339        consumed: usize,
340    ) -> Option<Result<usize, Error>> {
341        if self.pending.len() != block_capacity {
342            return None;
343        }
344
345        let new_pending = self.allocate_pending_space(block_capacity);
346        let full_block = mem::replace(&mut self.pending, new_pending);
347        if let Err((err, restored_block)) = self.encode_block(full_block, false) {
348            self.pending = restored_block;
349            let err = self.fail(err);
350            if consumed > 0 {
351                return Some(Ok(consumed));
352            }
353            return Some(Err(err));
354        }
355        None
356    }
357
358    fn emit_pending_block(&mut self, last_block: bool) -> Result<(), Error> {
359        let block = mem::take(&mut self.pending);
360        if let Err((err, restored_block)) = self.encode_block(block, last_block) {
361            self.pending = restored_block;
362            return Err(self.fail(err));
363        }
364        if !last_block {
365            let block_capacity = self.block_capacity();
366            self.pending = self.allocate_pending_space(block_capacity);
367        }
368        Ok(())
369    }
370
371    // Exhaustive match kept intentionally: adding a new CompressionLevel
372    // variant will produce a compile error here, forcing the developer to
373    // decide whether the streaming encoder supports it before shipping.
374    fn ensure_level_supported(&self) -> Result<(), Error> {
375        match self.compression_level {
376            CompressionLevel::Uncompressed
377            | CompressionLevel::Fastest
378            | CompressionLevel::Default
379            | CompressionLevel::Better
380            | CompressionLevel::Best
381            | CompressionLevel::Level(_) => Ok(()),
382        }
383    }
384
385    fn encode_block(
386        &mut self,
387        uncompressed_data: Vec<u8>,
388        last_block: bool,
389    ) -> Result<(), (Error, Vec<u8>)> {
390        let mut raw_block = Some(uncompressed_data);
391        let mut encoded = Vec::new();
392        mem::swap(&mut encoded, &mut self.encoded_scratch);
393        encoded.clear();
394        let needed_capacity = self.block_capacity() + 3;
395        if encoded.capacity() < needed_capacity {
396            encoded.reserve(needed_capacity.saturating_sub(encoded.len()));
397        }
398        let mut moved_into_matcher = false;
399        if raw_block.as_ref().is_some_and(|block| block.is_empty()) {
400            let header = BlockHeader {
401                last_block,
402                block_type: crate::blocks::block::BlockType::Raw,
403                block_size: 0,
404            };
405            header.serialize(&mut encoded);
406        } else {
407            match self.compression_level {
408                CompressionLevel::Uncompressed => {
409                    let block = raw_block.as_ref().expect("raw block missing");
410                    let header = BlockHeader {
411                        last_block,
412                        block_type: crate::blocks::block::BlockType::Raw,
413                        block_size: block.len() as u32,
414                    };
415                    header.serialize(&mut encoded);
416                    encoded.extend_from_slice(block);
417                }
418                CompressionLevel::Fastest
419                | CompressionLevel::Default
420                | CompressionLevel::Better
421                | CompressionLevel::Best
422                | CompressionLevel::Level(_) => {
423                    let block = raw_block.take().expect("raw block missing");
424                    debug_assert!(!block.is_empty(), "empty blocks handled above");
425                    compress_block_encoded(
426                        &mut self.state,
427                        self.compression_level,
428                        last_block,
429                        block,
430                        &mut encoded,
431                        #[cfg(all(feature = "lsm", feature = "hash"))]
432                        None,
433                    );
434                    moved_into_matcher = true;
435                }
436            }
437        }
438
439        if let Err(err) = self.drain_mut().and_then(|drain| drain.write_all(&encoded)) {
440            encoded.clear();
441            mem::swap(&mut encoded, &mut self.encoded_scratch);
442            let restored = if moved_into_matcher {
443                self.state.matcher.get_last_space().to_vec()
444            } else {
445                raw_block.unwrap_or_default()
446            };
447            return Err((err, restored));
448        }
449
450        if moved_into_matcher {
451            #[cfg(feature = "hash")]
452            {
453                self.hasher.write(self.state.matcher.get_last_space());
454            }
455        } else {
456            self.hash_block(raw_block.as_deref().unwrap_or(&[]));
457        }
458        encoded.clear();
459        mem::swap(&mut encoded, &mut self.encoded_scratch);
460        Ok(())
461    }
462
463    fn write_empty_last_block(&mut self) -> Result<(), Error> {
464        self.encode_block(Vec::new(), true).map_err(|(err, _)| err)
465    }
466
467    fn fail(&mut self, err: Error) -> Error {
468        self.errored = true;
469        if self.last_error_kind.is_none() {
470            self.last_error_kind = Some(err.kind());
471        }
472        if self.last_error_message.is_none() {
473            self.last_error_message = Some(err.to_string());
474        }
475        err
476    }
477
478    #[cfg(feature = "hash")]
479    fn hash_block(&mut self, uncompressed_data: &[u8]) {
480        self.hasher.write(uncompressed_data);
481    }
482
483    #[cfg(not(feature = "hash"))]
484    fn hash_block(&mut self, _uncompressed_data: &[u8]) {}
485}
486
487impl<W: Write, M: Matcher> Write for StreamingEncoder<W, M> {
488    fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
489        self.ensure_open()?;
490        if buf.is_empty() {
491            return Ok(0);
492        }
493
494        // Check pledge before emitting the frame header so that a misuse
495        // like set_pledged_content_size(0) + write(non_empty) doesn't leave
496        // a partially-written header in the drain.
497        if let Some(pledged) = self.pledged_content_size
498            && self.bytes_consumed >= pledged
499        {
500            return Err(invalid_input_error(
501                "write would exceed pledged content size",
502            ));
503        }
504
505        self.ensure_frame_started()?;
506
507        // Enforce pledged upper bound: truncate the accepted slice to the
508        // remaining allowance so that partial-write semantics are honored
509        // (return Ok(n) with n < buf.len()) instead of failing the full call.
510        let buf = if let Some(pledged) = self.pledged_content_size {
511            let remaining_allowed = pledged
512                .checked_sub(self.bytes_consumed)
513                .ok_or_else(|| invalid_input_error("bytes consumed exceed pledged content size"))?;
514            if remaining_allowed == 0 {
515                return Err(invalid_input_error(
516                    "write would exceed pledged content size",
517                ));
518            }
519            let accepted = core::cmp::min(
520                buf.len(),
521                usize::try_from(remaining_allowed).unwrap_or(usize::MAX),
522            );
523            &buf[..accepted]
524        } else {
525            buf
526        };
527
528        let block_capacity = self.block_capacity();
529        if self.pending.capacity() == 0 {
530            self.pending = self.allocate_pending_space(block_capacity);
531        }
532        let mut remaining = buf;
533        let mut consumed = 0usize;
534
535        while !remaining.is_empty() {
536            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
537                return result;
538            }
539
540            let available = block_capacity - self.pending.len();
541            let to_take = core::cmp::min(remaining.len(), available);
542            if to_take == 0 {
543                break;
544            }
545            self.pending.extend_from_slice(&remaining[..to_take]);
546            remaining = &remaining[to_take..];
547            consumed += to_take;
548
549            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
550                if let Ok(n) = &result {
551                    self.bytes_consumed += *n as u64;
552                }
553                return result;
554            }
555        }
556        self.bytes_consumed += consumed as u64;
557        Ok(consumed)
558    }
559
560    fn flush(&mut self) -> Result<(), Error> {
561        self.ensure_open()?;
562        if self.pending.is_empty() {
563            return self
564                .drain_mut()
565                .and_then(|drain| drain.flush())
566                .map_err(|err| self.fail(err));
567        }
568        self.ensure_frame_started()?;
569        self.emit_pending_block(false)?;
570        self.drain_mut()
571            .and_then(|drain| drain.flush())
572            .map_err(|err| self.fail(err))
573    }
574}
575
576fn error_from_kind(kind: ErrorKind) -> Error {
577    Error::from(kind)
578}
579
580fn error_with_kind_message(kind: ErrorKind, message: String) -> Error {
581    #[cfg(feature = "std")]
582    {
583        Error::new(kind, message)
584    }
585    #[cfg(not(feature = "std"))]
586    {
587        Error::new(kind, alloc::boxed::Box::new(message))
588    }
589}
590
591fn invalid_input_error(message: &str) -> Error {
592    #[cfg(feature = "std")]
593    {
594        Error::new(ErrorKind::InvalidInput, message)
595    }
596    #[cfg(not(feature = "std"))]
597    {
598        Error::new(
599            ErrorKind::Other,
600            alloc::boxed::Box::new(alloc::string::String::from(message)),
601        )
602    }
603}
604
605fn other_error_owned(message: String) -> Error {
606    #[cfg(feature = "std")]
607    {
608        Error::other(message)
609    }
610    #[cfg(not(feature = "std"))]
611    {
612        Error::new(ErrorKind::Other, alloc::boxed::Box::new(message))
613    }
614}
615
616fn other_error(message: &str) -> Error {
617    #[cfg(feature = "std")]
618    {
619        Error::other(message)
620    }
621    #[cfg(not(feature = "std"))]
622    {
623        Error::new(
624            ErrorKind::Other,
625            alloc::boxed::Box::new(alloc::string::String::from(message)),
626        )
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use crate::decoding::StreamingDecoder;
633    use crate::encoding::{CompressionLevel, Matcher, Sequence, StreamingEncoder};
634    use crate::io::{Error, ErrorKind, Read, Write};
635    use alloc::vec;
636    use alloc::vec::Vec;
637
638    struct TinyMatcher {
639        last_space: Vec<u8>,
640        window_size: u64,
641    }
642
643    impl TinyMatcher {
644        fn new(window_size: u64) -> Self {
645            Self {
646                last_space: Vec::new(),
647                window_size,
648            }
649        }
650    }
651
652    impl Matcher for TinyMatcher {
653        fn get_next_space(&mut self) -> Vec<u8> {
654            vec![0; self.window_size as usize]
655        }
656
657        fn get_last_space(&mut self) -> &[u8] {
658            self.last_space.as_slice()
659        }
660
661        fn commit_space(&mut self, space: Vec<u8>) {
662            self.last_space = space;
663        }
664
665        fn skip_matching(&mut self) {}
666
667        fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
668            handle_sequence(Sequence::Literals {
669                literals: self.last_space.as_slice(),
670            });
671        }
672
673        fn reset(&mut self, _level: CompressionLevel) {
674            self.last_space.clear();
675        }
676
677        fn window_size(&self) -> u64 {
678            self.window_size
679        }
680    }
681
682    struct FailingWriteOnce {
683        writes: usize,
684        fail_on_write_number: usize,
685        sink: Vec<u8>,
686    }
687
688    impl FailingWriteOnce {
689        fn new(fail_on_write_number: usize) -> Self {
690            Self {
691                writes: 0,
692                fail_on_write_number,
693                sink: Vec::new(),
694            }
695        }
696    }
697
698    impl Write for FailingWriteOnce {
699        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
700            self.writes += 1;
701            if self.writes == self.fail_on_write_number {
702                return Err(super::other_error("injected write failure"));
703            }
704            self.sink.extend_from_slice(buf);
705            Ok(buf.len())
706        }
707
708        fn flush(&mut self) -> Result<(), Error> {
709            Ok(())
710        }
711    }
712
713    struct FailingWithKind {
714        writes: usize,
715        fail_on_write_number: usize,
716        kind: ErrorKind,
717    }
718
719    impl FailingWithKind {
720        fn new(fail_on_write_number: usize, kind: ErrorKind) -> Self {
721            Self {
722                writes: 0,
723                fail_on_write_number,
724                kind,
725            }
726        }
727    }
728
729    impl Write for FailingWithKind {
730        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
731            self.writes += 1;
732            if self.writes == self.fail_on_write_number {
733                return Err(Error::from(self.kind));
734            }
735            Ok(buf.len())
736        }
737
738        fn flush(&mut self) -> Result<(), Error> {
739            Ok(())
740        }
741    }
742
743    struct PartialThenFailWriter {
744        writes: usize,
745        fail_on_write_number: usize,
746        partial_prefix_len: usize,
747        terminal_failure: bool,
748        sink: Vec<u8>,
749    }
750
751    impl PartialThenFailWriter {
752        fn new(fail_on_write_number: usize, partial_prefix_len: usize) -> Self {
753            Self {
754                writes: 0,
755                fail_on_write_number,
756                partial_prefix_len,
757                terminal_failure: false,
758                sink: Vec::new(),
759            }
760        }
761    }
762
763    impl Write for PartialThenFailWriter {
764        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
765            if self.terminal_failure {
766                return Err(super::other_error("injected terminal write failure"));
767            }
768
769            self.writes += 1;
770            if self.writes == self.fail_on_write_number {
771                let written = core::cmp::min(self.partial_prefix_len, buf.len());
772                if written > 0 {
773                    self.sink.extend_from_slice(&buf[..written]);
774                    self.terminal_failure = true;
775                    return Ok(written);
776                }
777                return Err(super::other_error("injected terminal write failure"));
778            }
779
780            self.sink.extend_from_slice(buf);
781            Ok(buf.len())
782        }
783
784        fn flush(&mut self) -> Result<(), Error> {
785            Ok(())
786        }
787    }
788
789    /// Pre-write `set_magicless(true)` → emitted frame omits the
790    /// magic prefix AND round-trips through a magicless-aware
791    /// decoder.
792    #[test]
793    fn streaming_encoder_set_magicless_before_write_omits_magic_and_roundtrips() {
794        use crate::common::MAGIC_NUM;
795        let payload = b"streaming-magicless-roundtrip-".repeat(64);
796
797        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
798        encoder
799            .set_magicless(true)
800            .expect("set_magicless pre-write");
801        encoder.write_all(&payload).unwrap();
802        let compressed = encoder.finish().unwrap();
803
804        assert!(
805            !compressed.starts_with(&MAGIC_NUM.to_le_bytes()),
806            "magicless frame must omit the 4-byte magic prefix",
807        );
808
809        let mut decoder = crate::decoding::FrameDecoder::new();
810        decoder.set_magicless(true);
811        let mut cursor: &[u8] = compressed.as_slice();
812        decoder.init(&mut cursor).expect("magicless init");
813        decoder
814            .decode_blocks(&mut cursor, crate::decoding::BlockDecodingStrategy::All)
815            .expect("decode_blocks");
816        let mut decoded: Vec<u8> = Vec::new();
817        decoder
818            .collect_to_writer(&mut decoded)
819            .expect("collect_to_writer");
820        assert_eq!(decoded, payload);
821    }
822
823    /// `set_magicless` after the first write MUST return an error
824    /// (the frame header has already been emitted, flipping the flag
825    /// can't affect the current frame). Mirrors
826    /// `set_pledged_content_size` / `set_source_size_hint` semantics.
827    #[test]
828    fn streaming_encoder_set_magicless_after_first_write_errors() {
829        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
830        encoder.write_all(b"first-block").unwrap();
831        let err = encoder
832            .set_magicless(true)
833            .expect_err("set_magicless after first write must error");
834        assert_eq!(
835            err.kind(),
836            crate::io::ErrorKind::InvalidInput,
837            "expected InvalidInput when setting magicless after frame_started, got {err:?}",
838        );
839    }
840
841    #[test]
842    fn streaming_encoder_roundtrip_multiple_writes() {
843        let payload = b"streaming-encoder-roundtrip-".repeat(1024);
844        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
845        for chunk in payload.chunks(313) {
846            encoder.write_all(chunk).unwrap();
847        }
848        let compressed = encoder.finish().unwrap();
849
850        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
851        let mut decoded = Vec::new();
852        decoder.read_to_end(&mut decoded).unwrap();
853        assert_eq!(decoded, payload);
854    }
855
856    #[test]
857    fn flush_emits_nonempty_partial_output() {
858        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
859        encoder.write_all(b"partial-block").unwrap();
860        encoder.flush().unwrap();
861        let flushed_len = encoder.get_ref().len();
862        assert!(
863            flushed_len > 0,
864            "flush should emit header+partial block bytes"
865        );
866        let compressed = encoder.finish().unwrap();
867        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
868        let mut decoded = Vec::new();
869        decoder.read_to_end(&mut decoded).unwrap();
870        assert_eq!(decoded, b"partial-block");
871    }
872
873    #[test]
874    fn flush_without_writes_does_not_emit_frame_header() {
875        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
876        encoder.flush().unwrap();
877        assert!(encoder.get_ref().is_empty());
878    }
879
880    #[test]
881    fn block_boundary_write_emits_block_in_same_call() {
882        let mut boundary = StreamingEncoder::new_with_matcher(
883            TinyMatcher::new(4),
884            Vec::new(),
885            CompressionLevel::Uncompressed,
886        );
887        let mut below = StreamingEncoder::new_with_matcher(
888            TinyMatcher::new(4),
889            Vec::new(),
890            CompressionLevel::Uncompressed,
891        );
892
893        boundary.write_all(b"ABCD").unwrap();
894        below.write_all(b"ABC").unwrap();
895
896        let boundary_len = boundary.get_ref().len();
897        let below_len = below.get_ref().len();
898        assert!(
899            boundary_len > below_len,
900            "full block should be emitted immediately at block boundary"
901        );
902    }
903
904    #[test]
905    fn finish_consumes_encoder_and_emits_frame() {
906        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
907        encoder.write_all(b"abc").unwrap();
908        let compressed = encoder.finish().unwrap();
909        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
910        let mut decoded = Vec::new();
911        decoder.read_to_end(&mut decoded).unwrap();
912        assert_eq!(decoded, b"abc");
913    }
914
915    #[test]
916    fn finish_without_writes_emits_empty_frame() {
917        let encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
918        let compressed = encoder.finish().unwrap();
919        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
920        let mut decoded = Vec::new();
921        decoder.read_to_end(&mut decoded).unwrap();
922        assert!(decoded.is_empty());
923    }
924
925    #[test]
926    fn write_empty_buffer_returns_zero() {
927        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
928        assert_eq!(encoder.write(&[]).unwrap(), 0);
929        let _ = encoder.finish().unwrap();
930    }
931
932    #[test]
933    fn uncompressed_level_roundtrip() {
934        let payload = b"uncompressed-streaming-roundtrip".repeat(64);
935        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Uncompressed);
936        for chunk in payload.chunks(41) {
937            encoder.write_all(chunk).unwrap();
938        }
939        let compressed = encoder.finish().unwrap();
940        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
941        let mut decoded = Vec::new();
942        decoder.read_to_end(&mut decoded).unwrap();
943        assert_eq!(decoded, payload);
944    }
945
946    #[test]
947    fn better_level_streaming_roundtrip() {
948        let payload = b"better-level-streaming-test".repeat(256);
949        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Better);
950        for chunk in payload.chunks(53) {
951            encoder.write_all(chunk).unwrap();
952        }
953        let compressed = encoder.finish().unwrap();
954        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
955        let mut decoded = Vec::new();
956        decoder.read_to_end(&mut decoded).unwrap();
957        assert_eq!(decoded, payload);
958    }
959
960    #[test]
961    fn zero_window_matcher_returns_invalid_input_error() {
962        let mut encoder = StreamingEncoder::new_with_matcher(
963            TinyMatcher::new(0),
964            Vec::new(),
965            CompressionLevel::Fastest,
966        );
967        let err = encoder.write_all(b"payload").unwrap_err();
968        assert_eq!(err.kind(), ErrorKind::InvalidInput);
969    }
970
971    #[test]
972    fn best_level_streaming_roundtrip() {
973        // 200 KiB payload crosses the 128 KiB block boundary, exercising
974        // multi-block emission and matcher state carry-over for Best.
975        let payload = b"best-level-streaming-test".repeat(8 * 1024);
976        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Best);
977        for chunk in payload.chunks(53) {
978            encoder.write_all(chunk).unwrap();
979        }
980        let compressed = encoder.finish().unwrap();
981        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
982        let mut decoded = Vec::new();
983        decoder.read_to_end(&mut decoded).unwrap();
984        assert_eq!(decoded, payload);
985    }
986
987    #[test]
988    fn write_failure_poisoning_is_sticky() {
989        let mut encoder = StreamingEncoder::new_with_matcher(
990            TinyMatcher::new(4),
991            FailingWriteOnce::new(1),
992            CompressionLevel::Uncompressed,
993        );
994
995        assert!(encoder.write_all(b"ABCD").is_err());
996        assert!(encoder.flush().is_err());
997        assert!(encoder.write_all(b"EFGH").is_err());
998        assert_eq!(encoder.get_ref().sink.len(), 0);
999        assert!(encoder.finish().is_err());
1000    }
1001
1002    #[test]
1003    fn poisoned_encoder_returns_original_error_kind() {
1004        let mut encoder = StreamingEncoder::new_with_matcher(
1005            TinyMatcher::new(4),
1006            FailingWithKind::new(1, ErrorKind::BrokenPipe),
1007            CompressionLevel::Uncompressed,
1008        );
1009
1010        let first_error = encoder.write_all(b"ABCD").unwrap_err();
1011        assert_eq!(first_error.kind(), ErrorKind::BrokenPipe);
1012
1013        let second_error = encoder.write_all(b"EFGH").unwrap_err();
1014        assert_eq!(second_error.kind(), ErrorKind::BrokenPipe);
1015    }
1016
1017    #[test]
1018    fn write_reports_progress_but_poisoning_is_sticky_after_later_block_failure() {
1019        let payload = b"ABCDEFGHIJKL";
1020        let mut encoder = StreamingEncoder::new_with_matcher(
1021            TinyMatcher::new(4),
1022            FailingWriteOnce::new(3),
1023            CompressionLevel::Uncompressed,
1024        );
1025
1026        let first_write = encoder.write(payload).unwrap();
1027        assert_eq!(first_write, 8);
1028        assert!(encoder.write(&payload[first_write..]).is_err());
1029        assert!(encoder.flush().is_err());
1030        assert!(encoder.write_all(b"EFGH").is_err());
1031    }
1032
1033    #[test]
1034    fn partial_write_failure_after_progress_poisons_encoder() {
1035        let payload = b"ABCDEFGHIJKL";
1036        let mut encoder = StreamingEncoder::new_with_matcher(
1037            TinyMatcher::new(4),
1038            PartialThenFailWriter::new(3, 1),
1039            CompressionLevel::Uncompressed,
1040        );
1041
1042        let first_write = encoder.write(payload).unwrap();
1043        assert_eq!(first_write, 8);
1044
1045        let second_write = encoder.write(&payload[first_write..]);
1046        assert!(second_write.is_err());
1047        assert!(encoder.flush().is_err());
1048        assert!(encoder.write_all(b"MNOP").is_err());
1049    }
1050
1051    #[test]
1052    fn new_with_matcher_and_get_mut_work() {
1053        let matcher = TinyMatcher::new(128 * 1024);
1054        let mut encoder =
1055            StreamingEncoder::new_with_matcher(matcher, Vec::new(), CompressionLevel::Fastest);
1056        encoder.get_mut().extend_from_slice(b"");
1057        encoder.write_all(b"custom-matcher").unwrap();
1058        let compressed = encoder.finish().unwrap();
1059        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
1060        let mut decoded = Vec::new();
1061        decoder.read_to_end(&mut decoded).unwrap();
1062        assert_eq!(decoded, b"custom-matcher");
1063    }
1064
1065    #[cfg(feature = "std")]
1066    #[test]
1067    fn streaming_encoder_output_decompresses_with_c_zstd() {
1068        let payload = b"tenant=demo op=put key=streaming value=abcdef\n".repeat(4096);
1069        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1070        for chunk in payload.chunks(1024) {
1071            encoder.write_all(chunk).unwrap();
1072        }
1073        let compressed = encoder.finish().unwrap();
1074
1075        let mut decoded = Vec::with_capacity(payload.len());
1076        zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
1077        assert_eq!(decoded, payload);
1078    }
1079
1080    #[test]
1081    fn pledged_content_size_written_in_header() {
1082        let payload = b"hello world, pledged size test";
1083        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1084        encoder
1085            .set_pledged_content_size(payload.len() as u64)
1086            .unwrap();
1087        encoder.write_all(payload).unwrap();
1088        let compressed = encoder.finish().unwrap();
1089
1090        // Verify FCS is present and correct
1091        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1092            .unwrap()
1093            .0;
1094        assert_eq!(header.frame_content_size(), payload.len() as u64);
1095
1096        // Verify roundtrip
1097        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
1098        let mut decoded = Vec::new();
1099        decoder.read_to_end(&mut decoded).unwrap();
1100        assert_eq!(decoded, payload);
1101    }
1102
1103    #[test]
1104    fn pledged_content_size_mismatch_returns_error() {
1105        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1106        encoder.set_pledged_content_size(100).unwrap();
1107        encoder.write_all(b"short payload").unwrap(); // 13 bytes != 100 pledged
1108        let err = encoder.finish().unwrap_err();
1109        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1110    }
1111
1112    #[test]
1113    fn write_exceeding_pledge_returns_error() {
1114        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1115        encoder.set_pledged_content_size(5).unwrap();
1116        let err = encoder.write_all(b"exceeds five bytes").unwrap_err();
1117        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1118    }
1119
1120    #[test]
1121    fn write_straddling_pledge_reports_partial_progress() {
1122        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1123        encoder.set_pledged_content_size(5).unwrap();
1124        // write() should accept exactly 5 bytes (partial progress)
1125        assert_eq!(encoder.write(b"abcdef").unwrap(), 5);
1126        // Next write should fail — pledge exhausted
1127        let err = encoder.write(b"g").unwrap_err();
1128        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1129    }
1130
1131    #[test]
1132    fn encoded_scratch_capacity_is_reused_across_blocks() {
1133        let payload = vec![0xAB; 64 * 3];
1134        let mut encoder = StreamingEncoder::new_with_matcher(
1135            TinyMatcher::new(64),
1136            Vec::new(),
1137            CompressionLevel::Uncompressed,
1138        );
1139
1140        encoder.write_all(&payload[..64]).unwrap();
1141        let first_capacity = encoder.encoded_scratch.capacity();
1142        assert!(
1143            first_capacity >= 67,
1144            "expected encoded scratch to keep block header + payload capacity",
1145        );
1146
1147        encoder.write_all(&payload[64..128]).unwrap();
1148        let second_capacity = encoder.encoded_scratch.capacity();
1149        assert!(
1150            second_capacity >= first_capacity,
1151            "encoded scratch capacity should be reused across block emits",
1152        );
1153
1154        encoder.write_all(&payload[128..]).unwrap();
1155        let compressed = encoder.finish().unwrap();
1156        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
1157        let mut decoded = Vec::new();
1158        decoder.read_to_end(&mut decoded).unwrap();
1159        assert_eq!(decoded, payload);
1160    }
1161
1162    #[test]
1163    fn pledged_content_size_after_write_returns_error() {
1164        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1165        encoder.write_all(b"already writing").unwrap();
1166        let err = encoder.set_pledged_content_size(15).unwrap_err();
1167        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1168    }
1169
1170    #[test]
1171    fn source_size_hint_directly_reduces_window_header() {
1172        let payload = b"streaming-source-size-hint".repeat(64);
1173
1174        let mut no_hint = StreamingEncoder::new(Vec::new(), CompressionLevel::from_level(11));
1175        no_hint.write_all(payload.as_slice()).unwrap();
1176        let no_hint_frame = no_hint.finish().unwrap();
1177        let no_hint_header = crate::decoding::frame::read_frame_header(no_hint_frame.as_slice())
1178            .unwrap()
1179            .0;
1180        let no_hint_window = no_hint_header.window_size().unwrap();
1181
1182        let mut with_hint = StreamingEncoder::new(Vec::new(), CompressionLevel::from_level(11));
1183        with_hint
1184            .set_source_size_hint(payload.len() as u64)
1185            .unwrap();
1186        with_hint.write_all(payload.as_slice()).unwrap();
1187        let late_hint_err = with_hint
1188            .set_source_size_hint(payload.len() as u64)
1189            .unwrap_err();
1190        assert_eq!(late_hint_err.kind(), ErrorKind::InvalidInput);
1191        let with_hint_frame = with_hint.finish().unwrap();
1192        let with_hint_header =
1193            crate::decoding::frame::read_frame_header(with_hint_frame.as_slice())
1194                .unwrap()
1195                .0;
1196        let with_hint_window = with_hint_header.window_size().unwrap();
1197
1198        assert!(
1199            with_hint_window <= no_hint_window,
1200            "source size hint should not increase advertised window"
1201        );
1202
1203        let mut decoder = StreamingDecoder::new(with_hint_frame.as_slice()).unwrap();
1204        let mut decoded = Vec::new();
1205        decoder.read_to_end(&mut decoded).unwrap();
1206        assert_eq!(decoded, payload);
1207    }
1208
1209    #[cfg(feature = "std")]
1210    #[test]
1211    fn pledged_content_size_c_zstd_compatible() {
1212        let payload = b"tenant=demo op=put key=streaming value=abcdef\n".repeat(4096);
1213        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1214        encoder
1215            .set_pledged_content_size(payload.len() as u64)
1216            .unwrap();
1217        for chunk in payload.chunks(1024) {
1218            encoder.write_all(chunk).unwrap();
1219        }
1220        let compressed = encoder.finish().unwrap();
1221
1222        // FCS should be written
1223        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1224            .unwrap()
1225            .0;
1226        assert_eq!(header.frame_content_size(), payload.len() as u64);
1227
1228        // C zstd should decompress successfully
1229        let mut decoded = Vec::new();
1230        zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
1231        assert_eq!(decoded, payload);
1232    }
1233
1234    #[test]
1235    fn single_segment_requires_pledged_to_fit_matcher_window() {
1236        let payload = b"streaming-window-gate-".repeat(60); // 1320 bytes
1237        let mut encoder = StreamingEncoder::new_with_matcher(
1238            TinyMatcher::new(1024),
1239            Vec::new(),
1240            CompressionLevel::Fastest,
1241        );
1242        encoder
1243            .set_pledged_content_size(payload.len() as u64)
1244            .unwrap();
1245        encoder.write_all(payload.as_slice()).unwrap();
1246        let compressed = encoder.finish().unwrap();
1247
1248        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1249            .unwrap()
1250            .0;
1251        assert_eq!(header.frame_content_size(), payload.len() as u64);
1252        assert!(
1253            !header.descriptor.single_segment_flag(),
1254            "single-segment must stay off when pledged content size exceeds matcher window"
1255        );
1256        assert!(
1257            header.window_size().unwrap() >= 1024,
1258            "window descriptor should be present when single-segment is disabled"
1259        );
1260    }
1261
1262    #[test]
1263    fn ensure_frame_started_refreshes_stale_strategy_tag_at_reset() {
1264        // The literal-compression gates (`min_literals_to_compress`,
1265        // `min_gain`) read `state.strategy_tag`. Regression: every
1266        // reset site MUST refresh that tag from the active compression
1267        // level — relying on construction-time initialization alone is
1268        // not enough, because later mutations or reuse patterns can
1269        // leave the tag stale.
1270        //
1271        // To exercise the RESET-time refresh (not just the
1272        // construction-time init that `StreamingEncoder::new` does for
1273        // free), this test deliberately corrupts `state.strategy_tag`
1274        // to a value that does NOT match the active level, then
1275        // triggers `ensure_frame_started` and asserts the reset path
1276        // wrote the correct tag back. If the sync line in
1277        // `ensure_frame_started` were deleted, the corrupted value
1278        // would survive the write and fail the assertion.
1279        use crate::encoding::strategy::StrategyTag;
1280        for level in [
1281            CompressionLevel::Fastest,
1282            CompressionLevel::Default,
1283            CompressionLevel::Better,
1284            CompressionLevel::Best,
1285        ] {
1286            let expected = StrategyTag::for_compression_level(level);
1287            let mut encoder = StreamingEncoder::new(Vec::new(), level);
1288            // Pick a sentinel that differs from the legitimate tag so
1289            // a missing reset-time sync is observable. BtUltra2 is the
1290            // most-aggressive variant; the four levels above resolve
1291            // to Fast/Dfast/Lazy/Lazy respectively, none equal to it.
1292            let sentinel = StrategyTag::BtUltra2;
1293            assert_ne!(
1294                expected, sentinel,
1295                "sentinel must differ from the legitimate tag at level {level:?}",
1296            );
1297            encoder.state.strategy_tag = sentinel;
1298            encoder.write_all(b"x").unwrap();
1299            assert_eq!(
1300                encoder.state.strategy_tag, expected,
1301                "reset-time strategy_tag sync missing at level {level:?}: \
1302                 sentinel survived `ensure_frame_started`",
1303            );
1304            let _ = encoder.finish().unwrap();
1305        }
1306    }
1307
1308    #[test]
1309    fn no_pledged_size_omits_fcs_from_header() {
1310        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1311        encoder.write_all(b"no pledged size").unwrap();
1312        let compressed = encoder.finish().unwrap();
1313
1314        // FCS should be omitted from the header; the decoder reports absent FCS as 0.
1315        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1316            .unwrap()
1317            .0;
1318        assert_eq!(header.frame_content_size(), 0);
1319        // Verify the descriptor confirms FCS field is truly absent (0 bytes),
1320        // not just FCS present with value 0.
1321        assert_eq!(header.descriptor.frame_content_size_bytes().unwrap(), 0);
1322    }
1323}