Skip to main content

structured_zstd/encoding/
streaming_encoder.rs

1use alloc::format;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4use core::mem;
5
6use crate::common::MAX_BLOCK_SIZE;
7#[cfg(feature = "hash")]
8use core::hash::Hasher;
9#[cfg(feature = "hash")]
10use twox_hash::XxHash64;
11
12use crate::encoding::levels::compress_block_encoded;
13use crate::encoding::{
14    CompressionLevel, MatchGeneratorDriver, Matcher, block_header::BlockHeader,
15    frame_compressor::CompressState, frame_compressor::FseTables, frame_header::FrameHeader,
16};
17use crate::io::{Error, ErrorKind, Write};
18
19/// Incremental frame encoder that implements [`Write`].
20///
21/// Data can be provided with multiple `write()` calls. Full blocks are compressed
22/// automatically, `flush()` emits the currently buffered partial block as non-last,
23/// and `finish()` closes the frame and returns the wrapped writer.
24pub struct StreamingEncoder<W: Write, M: Matcher = MatchGeneratorDriver> {
25    drain: Option<W>,
26    compression_level: CompressionLevel,
27    state: CompressState<M>,
28    pending: Vec<u8>,
29    encoded_scratch: Vec<u8>,
30    errored: bool,
31    last_error_kind: Option<ErrorKind>,
32    last_error_message: Option<String>,
33    frame_started: bool,
34    pledged_content_size: Option<u64>,
35    bytes_consumed: u64,
36    #[cfg(feature = "hash")]
37    hasher: XxHash64,
38}
39
40impl<W: Write> StreamingEncoder<W, MatchGeneratorDriver> {
41    /// Creates a streaming encoder backed by the default match generator.
42    ///
43    /// The encoder writes compressed bytes into `drain` and applies `compression_level`
44    /// to all subsequently written blocks.
45    pub fn new(drain: W, compression_level: CompressionLevel) -> Self {
46        Self::new_with_matcher(
47            MatchGeneratorDriver::new(MAX_BLOCK_SIZE as usize, 1),
48            drain,
49            compression_level,
50        )
51    }
52}
53
54impl<W: Write, M: Matcher> StreamingEncoder<W, M> {
55    /// Creates a streaming encoder with an explicitly provided matcher implementation.
56    ///
57    /// This constructor is primarily intended for tests and advanced callers that need
58    /// custom match-window behavior.
59    pub fn new_with_matcher(matcher: M, drain: W, compression_level: CompressionLevel) -> Self {
60        Self {
61            drain: Some(drain),
62            compression_level,
63            state: CompressState {
64                matcher,
65                last_huff_table: None,
66                fse_tables: FseTables::new(),
67                block_scratch: crate::encoding::blocks::CompressedBlockScratch::new(),
68                offset_hist: [1, 4, 8],
69                strategy_tag: crate::encoding::strategy::StrategyTag::for_compression_level(
70                    compression_level,
71                ),
72            },
73            pending: Vec::new(),
74            encoded_scratch: Vec::new(),
75            errored: false,
76            last_error_kind: None,
77            last_error_message: None,
78            frame_started: false,
79            pledged_content_size: None,
80            bytes_consumed: 0,
81            #[cfg(feature = "hash")]
82            hasher: XxHash64::with_seed(0),
83        }
84    }
85
86    /// Pledge the total uncompressed content size for this frame.
87    ///
88    /// When set, the frame header will include a `Frame_Content_Size` field.
89    /// This enables decoders to pre-allocate output buffers.
90    /// The pledged size is also forwarded as a source-size hint to the
91    /// matcher so small inputs can use smaller matching tables.
92    ///
93    /// Must be called **before** the first [`write`](Write::write) call;
94    /// calling it after the frame header has already been emitted returns an
95    /// error.
96    pub fn set_pledged_content_size(&mut self, size: u64) -> Result<(), Error> {
97        self.ensure_open()?;
98        if self.frame_started {
99            return Err(invalid_input_error(
100                "pledged content size must be set before the first write",
101            ));
102        }
103        self.pledged_content_size = Some(size);
104        // Also use pledged size as source-size hint so the matcher
105        // can select smaller tables for small inputs.
106        self.state.matcher.set_source_size_hint(size);
107        Ok(())
108    }
109
110    /// Provide a hint about the total uncompressed size for the next frame.
111    ///
112    /// Unlike [`set_pledged_content_size`](Self::set_pledged_content_size),
113    /// this does **not** enforce that exactly `size` bytes are written; it
114    /// may reduce matcher tables, advertised frame window, and block sizing
115    /// for small inputs. Must be called before the first
116    /// [`write`](Write::write).
117    pub fn set_source_size_hint(&mut self, size: u64) -> Result<(), Error> {
118        self.ensure_open()?;
119        if self.frame_started {
120            return Err(invalid_input_error(
121                "source size hint must be set before the first write",
122            ));
123        }
124        self.state.matcher.set_source_size_hint(size);
125        Ok(())
126    }
127
128    /// Returns an immutable reference to the wrapped output drain.
129    ///
130    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
131    /// consumes the encoder and returns ownership of the drain.
132    pub fn get_ref(&self) -> &W {
133        self.drain
134            .as_ref()
135            .expect("streaming encoder drain is present until finish consumes self")
136    }
137
138    /// Returns a mutable reference to the wrapped output drain.
139    ///
140    /// It is inadvisable to directly write to the underlying writer, as doing
141    /// so would corrupt the zstd frame being assembled by the encoder.
142    ///
143    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
144    /// consumes the encoder and returns ownership of the drain.
145    pub fn get_mut(&mut self) -> &mut W {
146        self.drain
147            .as_mut()
148            .expect("streaming encoder drain is present until finish consumes self")
149    }
150
151    /// Finalizes the current zstd frame and returns the wrapped output drain.
152    ///
153    /// If no payload was written yet, this still emits a valid empty frame.
154    /// Calling this method consumes the encoder.
155    pub fn finish(mut self) -> Result<W, Error> {
156        self.ensure_open()?;
157
158        // Validate the pledge before finalizing the frame. If finish() is
159        // called before any writes, this also avoids emitting a header with
160        // an incorrect FCS into the drain on mismatch.
161        if let Some(pledged) = self.pledged_content_size
162            && self.bytes_consumed != pledged
163        {
164            return Err(invalid_input_error(
165                "pledged content size does not match bytes consumed",
166            ));
167        }
168
169        self.ensure_frame_started()?;
170
171        if self.pending.is_empty() {
172            self.write_empty_last_block()
173                .map_err(|err| self.fail(err))?;
174        } else {
175            self.emit_pending_block(true)?;
176        }
177
178        let mut drain = self
179            .drain
180            .take()
181            .expect("streaming encoder drain must be present when finishing");
182
183        #[cfg(feature = "hash")]
184        {
185            let checksum = self.hasher.finish() as u32;
186            drain
187                .write_all(&checksum.to_le_bytes())
188                .map_err(|err| self.fail(err))?;
189        }
190
191        drain.flush().map_err(|err| self.fail(err))?;
192        Ok(drain)
193    }
194
195    fn ensure_open(&self) -> Result<(), Error> {
196        if self.errored {
197            return Err(self.sticky_error());
198        }
199        Ok(())
200    }
201
202    // Cold path (only reached after poisoning). The format!() calls still allocate
203    // in no_std even though error_with_kind_message/other_error_owned drop the
204    // message; this is acceptable on an error recovery path to keep match arms simple.
205    fn sticky_error(&self) -> Error {
206        match (self.last_error_kind, self.last_error_message.as_deref()) {
207            (Some(kind), Some(message)) => error_with_kind_message(
208                kind,
209                format!(
210                    "streaming encoder is in an errored state due to previous {kind:?} failure: {message}"
211                ),
212            ),
213            (Some(kind), None) => error_from_kind(kind),
214            (None, Some(message)) => other_error_owned(format!(
215                "streaming encoder is in an errored state: {message}"
216            )),
217            (None, None) => other_error("streaming encoder is in an errored state"),
218        }
219    }
220
221    fn drain_mut(&mut self) -> Result<&mut W, Error> {
222        self.drain
223            .as_mut()
224            .ok_or_else(|| other_error("streaming encoder has no active drain"))
225    }
226
227    fn ensure_frame_started(&mut self) -> Result<(), Error> {
228        if self.frame_started {
229            return Ok(());
230        }
231
232        self.ensure_level_supported()?;
233        self.state.matcher.reset(self.compression_level);
234        self.state.offset_hist = [1, 4, 8];
235        self.state.last_huff_table = None;
236        self.state.fse_tables.ll_previous = None;
237        self.state.fse_tables.ml_previous = None;
238        self.state.fse_tables.of_previous = None;
239        // Sync `state.strategy_tag` from the active compression level so the
240        // literal-compression gates (`min_literals_to_compress`, `min_gain`
241        // in `encoding::blocks::compressed`) see the correct strategy for
242        // every frame. Mirrors `FrameCompressor::compress` and keeps both
243        // entry points byte-equivalent at the gate level.
244        self.state.strategy_tag =
245            crate::encoding::strategy::StrategyTag::for_compression_level(self.compression_level);
246        #[cfg(feature = "hash")]
247        {
248            self.hasher = XxHash64::with_seed(0);
249        }
250
251        let window_size = self.state.matcher.window_size();
252        if window_size == 0 {
253            return Err(invalid_input_error(
254                "matcher reported window_size == 0, which is invalid",
255            ));
256        }
257
258        // FrameCompressor gates single-segment on dictionary usage state; the
259        // streaming encoder currently has no dictionary API/state, so we only
260        // gate on pledged size and window reach here.
261        // TODO: if streaming dictionary support is added, mirror the
262        // !use_dictionary_state guard from FrameCompressor.
263        let single_segment = self
264            .pledged_content_size
265            .map(|size| (512..=(1 << 14)).contains(&size) && size <= window_size)
266            .unwrap_or(false);
267
268        let header = FrameHeader {
269            frame_content_size: self.pledged_content_size,
270            single_segment,
271            content_checksum: cfg!(feature = "hash"),
272            dictionary_id: None,
273            window_size: if single_segment {
274                None
275            } else {
276                Some(window_size)
277            },
278        };
279        let mut encoded_header = Vec::new();
280        header.serialize(&mut encoded_header);
281        self.drain_mut()
282            .and_then(|drain| drain.write_all(&encoded_header))
283            .map_err(|err| self.fail(err))?;
284
285        self.frame_started = true;
286        Ok(())
287    }
288
289    fn block_capacity(&self) -> usize {
290        let matcher_window = self.state.matcher.window_size() as usize;
291        core::cmp::max(1, core::cmp::min(matcher_window, MAX_BLOCK_SIZE as usize))
292    }
293
294    fn allocate_pending_space(&mut self, block_capacity: usize) -> Vec<u8> {
295        let mut space = match self.compression_level {
296            CompressionLevel::Fastest
297            | CompressionLevel::Default
298            | CompressionLevel::Better
299            | CompressionLevel::Best
300            | CompressionLevel::Level(_) => self.state.matcher.get_next_space(),
301            CompressionLevel::Uncompressed => Vec::new(),
302        };
303        space.clear();
304        if space.capacity() > block_capacity {
305            space.shrink_to(block_capacity);
306        }
307        if space.capacity() < block_capacity {
308            space.reserve(block_capacity - space.capacity());
309        }
310        space
311    }
312
313    fn emit_full_pending_block(
314        &mut self,
315        block_capacity: usize,
316        consumed: usize,
317    ) -> Option<Result<usize, Error>> {
318        if self.pending.len() != block_capacity {
319            return None;
320        }
321
322        let new_pending = self.allocate_pending_space(block_capacity);
323        let full_block = mem::replace(&mut self.pending, new_pending);
324        if let Err((err, restored_block)) = self.encode_block(full_block, false) {
325            self.pending = restored_block;
326            let err = self.fail(err);
327            if consumed > 0 {
328                return Some(Ok(consumed));
329            }
330            return Some(Err(err));
331        }
332        None
333    }
334
335    fn emit_pending_block(&mut self, last_block: bool) -> Result<(), Error> {
336        let block = mem::take(&mut self.pending);
337        if let Err((err, restored_block)) = self.encode_block(block, last_block) {
338            self.pending = restored_block;
339            return Err(self.fail(err));
340        }
341        if !last_block {
342            let block_capacity = self.block_capacity();
343            self.pending = self.allocate_pending_space(block_capacity);
344        }
345        Ok(())
346    }
347
348    // Exhaustive match kept intentionally: adding a new CompressionLevel
349    // variant will produce a compile error here, forcing the developer to
350    // decide whether the streaming encoder supports it before shipping.
351    fn ensure_level_supported(&self) -> Result<(), Error> {
352        match self.compression_level {
353            CompressionLevel::Uncompressed
354            | CompressionLevel::Fastest
355            | CompressionLevel::Default
356            | CompressionLevel::Better
357            | CompressionLevel::Best
358            | CompressionLevel::Level(_) => Ok(()),
359        }
360    }
361
362    fn encode_block(
363        &mut self,
364        uncompressed_data: Vec<u8>,
365        last_block: bool,
366    ) -> Result<(), (Error, Vec<u8>)> {
367        let mut raw_block = Some(uncompressed_data);
368        let mut encoded = Vec::new();
369        mem::swap(&mut encoded, &mut self.encoded_scratch);
370        encoded.clear();
371        let needed_capacity = self.block_capacity() + 3;
372        if encoded.capacity() < needed_capacity {
373            encoded.reserve(needed_capacity.saturating_sub(encoded.len()));
374        }
375        let mut moved_into_matcher = false;
376        if raw_block.as_ref().is_some_and(|block| block.is_empty()) {
377            let header = BlockHeader {
378                last_block,
379                block_type: crate::blocks::block::BlockType::Raw,
380                block_size: 0,
381            };
382            header.serialize(&mut encoded);
383        } else {
384            match self.compression_level {
385                CompressionLevel::Uncompressed => {
386                    let block = raw_block.as_ref().expect("raw block missing");
387                    let header = BlockHeader {
388                        last_block,
389                        block_type: crate::blocks::block::BlockType::Raw,
390                        block_size: block.len() as u32,
391                    };
392                    header.serialize(&mut encoded);
393                    encoded.extend_from_slice(block);
394                }
395                CompressionLevel::Fastest
396                | CompressionLevel::Default
397                | CompressionLevel::Better
398                | CompressionLevel::Best
399                | CompressionLevel::Level(_) => {
400                    let block = raw_block.take().expect("raw block missing");
401                    debug_assert!(!block.is_empty(), "empty blocks handled above");
402                    compress_block_encoded(
403                        &mut self.state,
404                        self.compression_level,
405                        last_block,
406                        block,
407                        &mut encoded,
408                    );
409                    moved_into_matcher = true;
410                }
411            }
412        }
413
414        if let Err(err) = self.drain_mut().and_then(|drain| drain.write_all(&encoded)) {
415            encoded.clear();
416            mem::swap(&mut encoded, &mut self.encoded_scratch);
417            let restored = if moved_into_matcher {
418                self.state.matcher.get_last_space().to_vec()
419            } else {
420                raw_block.unwrap_or_default()
421            };
422            return Err((err, restored));
423        }
424
425        if moved_into_matcher {
426            #[cfg(feature = "hash")]
427            {
428                self.hasher.write(self.state.matcher.get_last_space());
429            }
430        } else {
431            self.hash_block(raw_block.as_deref().unwrap_or(&[]));
432        }
433        encoded.clear();
434        mem::swap(&mut encoded, &mut self.encoded_scratch);
435        Ok(())
436    }
437
438    fn write_empty_last_block(&mut self) -> Result<(), Error> {
439        self.encode_block(Vec::new(), true).map_err(|(err, _)| err)
440    }
441
442    fn fail(&mut self, err: Error) -> Error {
443        self.errored = true;
444        if self.last_error_kind.is_none() {
445            self.last_error_kind = Some(err.kind());
446        }
447        if self.last_error_message.is_none() {
448            self.last_error_message = Some(err.to_string());
449        }
450        err
451    }
452
453    #[cfg(feature = "hash")]
454    fn hash_block(&mut self, uncompressed_data: &[u8]) {
455        self.hasher.write(uncompressed_data);
456    }
457
458    #[cfg(not(feature = "hash"))]
459    fn hash_block(&mut self, _uncompressed_data: &[u8]) {}
460}
461
462impl<W: Write, M: Matcher> Write for StreamingEncoder<W, M> {
463    fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
464        self.ensure_open()?;
465        if buf.is_empty() {
466            return Ok(0);
467        }
468
469        // Check pledge before emitting the frame header so that a misuse
470        // like set_pledged_content_size(0) + write(non_empty) doesn't leave
471        // a partially-written header in the drain.
472        if let Some(pledged) = self.pledged_content_size
473            && self.bytes_consumed >= pledged
474        {
475            return Err(invalid_input_error(
476                "write would exceed pledged content size",
477            ));
478        }
479
480        self.ensure_frame_started()?;
481
482        // Enforce pledged upper bound: truncate the accepted slice to the
483        // remaining allowance so that partial-write semantics are honored
484        // (return Ok(n) with n < buf.len()) instead of failing the full call.
485        let buf = if let Some(pledged) = self.pledged_content_size {
486            let remaining_allowed = pledged
487                .checked_sub(self.bytes_consumed)
488                .ok_or_else(|| invalid_input_error("bytes consumed exceed pledged content size"))?;
489            if remaining_allowed == 0 {
490                return Err(invalid_input_error(
491                    "write would exceed pledged content size",
492                ));
493            }
494            let accepted = core::cmp::min(
495                buf.len(),
496                usize::try_from(remaining_allowed).unwrap_or(usize::MAX),
497            );
498            &buf[..accepted]
499        } else {
500            buf
501        };
502
503        let block_capacity = self.block_capacity();
504        if self.pending.capacity() == 0 {
505            self.pending = self.allocate_pending_space(block_capacity);
506        }
507        let mut remaining = buf;
508        let mut consumed = 0usize;
509
510        while !remaining.is_empty() {
511            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
512                return result;
513            }
514
515            let available = block_capacity - self.pending.len();
516            let to_take = core::cmp::min(remaining.len(), available);
517            if to_take == 0 {
518                break;
519            }
520            self.pending.extend_from_slice(&remaining[..to_take]);
521            remaining = &remaining[to_take..];
522            consumed += to_take;
523
524            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
525                if let Ok(n) = &result {
526                    self.bytes_consumed += *n as u64;
527                }
528                return result;
529            }
530        }
531        self.bytes_consumed += consumed as u64;
532        Ok(consumed)
533    }
534
535    fn flush(&mut self) -> Result<(), Error> {
536        self.ensure_open()?;
537        if self.pending.is_empty() {
538            return self
539                .drain_mut()
540                .and_then(|drain| drain.flush())
541                .map_err(|err| self.fail(err));
542        }
543        self.ensure_frame_started()?;
544        self.emit_pending_block(false)?;
545        self.drain_mut()
546            .and_then(|drain| drain.flush())
547            .map_err(|err| self.fail(err))
548    }
549}
550
551fn error_from_kind(kind: ErrorKind) -> Error {
552    Error::from(kind)
553}
554
555fn error_with_kind_message(kind: ErrorKind, message: String) -> Error {
556    #[cfg(feature = "std")]
557    {
558        Error::new(kind, message)
559    }
560    #[cfg(not(feature = "std"))]
561    {
562        Error::new(kind, alloc::boxed::Box::new(message))
563    }
564}
565
566fn invalid_input_error(message: &str) -> Error {
567    #[cfg(feature = "std")]
568    {
569        Error::new(ErrorKind::InvalidInput, message)
570    }
571    #[cfg(not(feature = "std"))]
572    {
573        Error::new(
574            ErrorKind::Other,
575            alloc::boxed::Box::new(alloc::string::String::from(message)),
576        )
577    }
578}
579
580fn other_error_owned(message: String) -> Error {
581    #[cfg(feature = "std")]
582    {
583        Error::other(message)
584    }
585    #[cfg(not(feature = "std"))]
586    {
587        Error::new(ErrorKind::Other, alloc::boxed::Box::new(message))
588    }
589}
590
591fn other_error(message: &str) -> Error {
592    #[cfg(feature = "std")]
593    {
594        Error::other(message)
595    }
596    #[cfg(not(feature = "std"))]
597    {
598        Error::new(
599            ErrorKind::Other,
600            alloc::boxed::Box::new(alloc::string::String::from(message)),
601        )
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use crate::decoding::StreamingDecoder;
608    use crate::encoding::{CompressionLevel, Matcher, Sequence, StreamingEncoder};
609    use crate::io::{Error, ErrorKind, Read, Write};
610    use alloc::vec;
611    use alloc::vec::Vec;
612
613    struct TinyMatcher {
614        last_space: Vec<u8>,
615        window_size: u64,
616    }
617
618    impl TinyMatcher {
619        fn new(window_size: u64) -> Self {
620            Self {
621                last_space: Vec::new(),
622                window_size,
623            }
624        }
625    }
626
627    impl Matcher for TinyMatcher {
628        fn get_next_space(&mut self) -> Vec<u8> {
629            vec![0; self.window_size as usize]
630        }
631
632        fn get_last_space(&mut self) -> &[u8] {
633            self.last_space.as_slice()
634        }
635
636        fn commit_space(&mut self, space: Vec<u8>) {
637            self.last_space = space;
638        }
639
640        fn skip_matching(&mut self) {}
641
642        fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
643            handle_sequence(Sequence::Literals {
644                literals: self.last_space.as_slice(),
645            });
646        }
647
648        fn reset(&mut self, _level: CompressionLevel) {
649            self.last_space.clear();
650        }
651
652        fn window_size(&self) -> u64 {
653            self.window_size
654        }
655    }
656
657    struct FailingWriteOnce {
658        writes: usize,
659        fail_on_write_number: usize,
660        sink: Vec<u8>,
661    }
662
663    impl FailingWriteOnce {
664        fn new(fail_on_write_number: usize) -> Self {
665            Self {
666                writes: 0,
667                fail_on_write_number,
668                sink: Vec::new(),
669            }
670        }
671    }
672
673    impl Write for FailingWriteOnce {
674        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
675            self.writes += 1;
676            if self.writes == self.fail_on_write_number {
677                return Err(super::other_error("injected write failure"));
678            }
679            self.sink.extend_from_slice(buf);
680            Ok(buf.len())
681        }
682
683        fn flush(&mut self) -> Result<(), Error> {
684            Ok(())
685        }
686    }
687
688    struct FailingWithKind {
689        writes: usize,
690        fail_on_write_number: usize,
691        kind: ErrorKind,
692    }
693
694    impl FailingWithKind {
695        fn new(fail_on_write_number: usize, kind: ErrorKind) -> Self {
696            Self {
697                writes: 0,
698                fail_on_write_number,
699                kind,
700            }
701        }
702    }
703
704    impl Write for FailingWithKind {
705        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
706            self.writes += 1;
707            if self.writes == self.fail_on_write_number {
708                return Err(Error::from(self.kind));
709            }
710            Ok(buf.len())
711        }
712
713        fn flush(&mut self) -> Result<(), Error> {
714            Ok(())
715        }
716    }
717
718    struct PartialThenFailWriter {
719        writes: usize,
720        fail_on_write_number: usize,
721        partial_prefix_len: usize,
722        terminal_failure: bool,
723        sink: Vec<u8>,
724    }
725
726    impl PartialThenFailWriter {
727        fn new(fail_on_write_number: usize, partial_prefix_len: usize) -> Self {
728            Self {
729                writes: 0,
730                fail_on_write_number,
731                partial_prefix_len,
732                terminal_failure: false,
733                sink: Vec::new(),
734            }
735        }
736    }
737
738    impl Write for PartialThenFailWriter {
739        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
740            if self.terminal_failure {
741                return Err(super::other_error("injected terminal write failure"));
742            }
743
744            self.writes += 1;
745            if self.writes == self.fail_on_write_number {
746                let written = core::cmp::min(self.partial_prefix_len, buf.len());
747                if written > 0 {
748                    self.sink.extend_from_slice(&buf[..written]);
749                    self.terminal_failure = true;
750                    return Ok(written);
751                }
752                return Err(super::other_error("injected terminal write failure"));
753            }
754
755            self.sink.extend_from_slice(buf);
756            Ok(buf.len())
757        }
758
759        fn flush(&mut self) -> Result<(), Error> {
760            Ok(())
761        }
762    }
763
764    #[test]
765    fn streaming_encoder_roundtrip_multiple_writes() {
766        let payload = b"streaming-encoder-roundtrip-".repeat(1024);
767        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
768        for chunk in payload.chunks(313) {
769            encoder.write_all(chunk).unwrap();
770        }
771        let compressed = encoder.finish().unwrap();
772
773        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
774        let mut decoded = Vec::new();
775        decoder.read_to_end(&mut decoded).unwrap();
776        assert_eq!(decoded, payload);
777    }
778
779    #[test]
780    fn flush_emits_nonempty_partial_output() {
781        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
782        encoder.write_all(b"partial-block").unwrap();
783        encoder.flush().unwrap();
784        let flushed_len = encoder.get_ref().len();
785        assert!(
786            flushed_len > 0,
787            "flush should emit header+partial block bytes"
788        );
789        let compressed = encoder.finish().unwrap();
790        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
791        let mut decoded = Vec::new();
792        decoder.read_to_end(&mut decoded).unwrap();
793        assert_eq!(decoded, b"partial-block");
794    }
795
796    #[test]
797    fn flush_without_writes_does_not_emit_frame_header() {
798        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
799        encoder.flush().unwrap();
800        assert!(encoder.get_ref().is_empty());
801    }
802
803    #[test]
804    fn block_boundary_write_emits_block_in_same_call() {
805        let mut boundary = StreamingEncoder::new_with_matcher(
806            TinyMatcher::new(4),
807            Vec::new(),
808            CompressionLevel::Uncompressed,
809        );
810        let mut below = StreamingEncoder::new_with_matcher(
811            TinyMatcher::new(4),
812            Vec::new(),
813            CompressionLevel::Uncompressed,
814        );
815
816        boundary.write_all(b"ABCD").unwrap();
817        below.write_all(b"ABC").unwrap();
818
819        let boundary_len = boundary.get_ref().len();
820        let below_len = below.get_ref().len();
821        assert!(
822            boundary_len > below_len,
823            "full block should be emitted immediately at block boundary"
824        );
825    }
826
827    #[test]
828    fn finish_consumes_encoder_and_emits_frame() {
829        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
830        encoder.write_all(b"abc").unwrap();
831        let compressed = encoder.finish().unwrap();
832        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
833        let mut decoded = Vec::new();
834        decoder.read_to_end(&mut decoded).unwrap();
835        assert_eq!(decoded, b"abc");
836    }
837
838    #[test]
839    fn finish_without_writes_emits_empty_frame() {
840        let encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
841        let compressed = encoder.finish().unwrap();
842        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
843        let mut decoded = Vec::new();
844        decoder.read_to_end(&mut decoded).unwrap();
845        assert!(decoded.is_empty());
846    }
847
848    #[test]
849    fn write_empty_buffer_returns_zero() {
850        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
851        assert_eq!(encoder.write(&[]).unwrap(), 0);
852        let _ = encoder.finish().unwrap();
853    }
854
855    #[test]
856    fn uncompressed_level_roundtrip() {
857        let payload = b"uncompressed-streaming-roundtrip".repeat(64);
858        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Uncompressed);
859        for chunk in payload.chunks(41) {
860            encoder.write_all(chunk).unwrap();
861        }
862        let compressed = encoder.finish().unwrap();
863        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
864        let mut decoded = Vec::new();
865        decoder.read_to_end(&mut decoded).unwrap();
866        assert_eq!(decoded, payload);
867    }
868
869    #[test]
870    fn better_level_streaming_roundtrip() {
871        let payload = b"better-level-streaming-test".repeat(256);
872        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Better);
873        for chunk in payload.chunks(53) {
874            encoder.write_all(chunk).unwrap();
875        }
876        let compressed = encoder.finish().unwrap();
877        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
878        let mut decoded = Vec::new();
879        decoder.read_to_end(&mut decoded).unwrap();
880        assert_eq!(decoded, payload);
881    }
882
883    #[test]
884    fn zero_window_matcher_returns_invalid_input_error() {
885        let mut encoder = StreamingEncoder::new_with_matcher(
886            TinyMatcher::new(0),
887            Vec::new(),
888            CompressionLevel::Fastest,
889        );
890        let err = encoder.write_all(b"payload").unwrap_err();
891        assert_eq!(err.kind(), ErrorKind::InvalidInput);
892    }
893
894    #[test]
895    fn best_level_streaming_roundtrip() {
896        // 200 KiB payload crosses the 128 KiB block boundary, exercising
897        // multi-block emission and matcher state carry-over for Best.
898        let payload = b"best-level-streaming-test".repeat(8 * 1024);
899        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Best);
900        for chunk in payload.chunks(53) {
901            encoder.write_all(chunk).unwrap();
902        }
903        let compressed = encoder.finish().unwrap();
904        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
905        let mut decoded = Vec::new();
906        decoder.read_to_end(&mut decoded).unwrap();
907        assert_eq!(decoded, payload);
908    }
909
910    #[test]
911    fn write_failure_poisoning_is_sticky() {
912        let mut encoder = StreamingEncoder::new_with_matcher(
913            TinyMatcher::new(4),
914            FailingWriteOnce::new(1),
915            CompressionLevel::Uncompressed,
916        );
917
918        assert!(encoder.write_all(b"ABCD").is_err());
919        assert!(encoder.flush().is_err());
920        assert!(encoder.write_all(b"EFGH").is_err());
921        assert_eq!(encoder.get_ref().sink.len(), 0);
922        assert!(encoder.finish().is_err());
923    }
924
925    #[test]
926    fn poisoned_encoder_returns_original_error_kind() {
927        let mut encoder = StreamingEncoder::new_with_matcher(
928            TinyMatcher::new(4),
929            FailingWithKind::new(1, ErrorKind::BrokenPipe),
930            CompressionLevel::Uncompressed,
931        );
932
933        let first_error = encoder.write_all(b"ABCD").unwrap_err();
934        assert_eq!(first_error.kind(), ErrorKind::BrokenPipe);
935
936        let second_error = encoder.write_all(b"EFGH").unwrap_err();
937        assert_eq!(second_error.kind(), ErrorKind::BrokenPipe);
938    }
939
940    #[test]
941    fn write_reports_progress_but_poisoning_is_sticky_after_later_block_failure() {
942        let payload = b"ABCDEFGHIJKL";
943        let mut encoder = StreamingEncoder::new_with_matcher(
944            TinyMatcher::new(4),
945            FailingWriteOnce::new(3),
946            CompressionLevel::Uncompressed,
947        );
948
949        let first_write = encoder.write(payload).unwrap();
950        assert_eq!(first_write, 8);
951        assert!(encoder.write(&payload[first_write..]).is_err());
952        assert!(encoder.flush().is_err());
953        assert!(encoder.write_all(b"EFGH").is_err());
954    }
955
956    #[test]
957    fn partial_write_failure_after_progress_poisons_encoder() {
958        let payload = b"ABCDEFGHIJKL";
959        let mut encoder = StreamingEncoder::new_with_matcher(
960            TinyMatcher::new(4),
961            PartialThenFailWriter::new(3, 1),
962            CompressionLevel::Uncompressed,
963        );
964
965        let first_write = encoder.write(payload).unwrap();
966        assert_eq!(first_write, 8);
967
968        let second_write = encoder.write(&payload[first_write..]);
969        assert!(second_write.is_err());
970        assert!(encoder.flush().is_err());
971        assert!(encoder.write_all(b"MNOP").is_err());
972    }
973
974    #[test]
975    fn new_with_matcher_and_get_mut_work() {
976        let matcher = TinyMatcher::new(128 * 1024);
977        let mut encoder =
978            StreamingEncoder::new_with_matcher(matcher, Vec::new(), CompressionLevel::Fastest);
979        encoder.get_mut().extend_from_slice(b"");
980        encoder.write_all(b"custom-matcher").unwrap();
981        let compressed = encoder.finish().unwrap();
982        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
983        let mut decoded = Vec::new();
984        decoder.read_to_end(&mut decoded).unwrap();
985        assert_eq!(decoded, b"custom-matcher");
986    }
987
988    #[cfg(feature = "std")]
989    #[test]
990    fn streaming_encoder_output_decompresses_with_c_zstd() {
991        let payload = b"tenant=demo op=put key=streaming value=abcdef\n".repeat(4096);
992        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
993        for chunk in payload.chunks(1024) {
994            encoder.write_all(chunk).unwrap();
995        }
996        let compressed = encoder.finish().unwrap();
997
998        let mut decoded = Vec::with_capacity(payload.len());
999        zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
1000        assert_eq!(decoded, payload);
1001    }
1002
1003    #[test]
1004    fn pledged_content_size_written_in_header() {
1005        let payload = b"hello world, pledged size test";
1006        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1007        encoder
1008            .set_pledged_content_size(payload.len() as u64)
1009            .unwrap();
1010        encoder.write_all(payload).unwrap();
1011        let compressed = encoder.finish().unwrap();
1012
1013        // Verify FCS is present and correct
1014        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1015            .unwrap()
1016            .0;
1017        assert_eq!(header.frame_content_size(), payload.len() as u64);
1018
1019        // Verify roundtrip
1020        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
1021        let mut decoded = Vec::new();
1022        decoder.read_to_end(&mut decoded).unwrap();
1023        assert_eq!(decoded, payload);
1024    }
1025
1026    #[test]
1027    fn pledged_content_size_mismatch_returns_error() {
1028        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1029        encoder.set_pledged_content_size(100).unwrap();
1030        encoder.write_all(b"short payload").unwrap(); // 13 bytes != 100 pledged
1031        let err = encoder.finish().unwrap_err();
1032        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1033    }
1034
1035    #[test]
1036    fn write_exceeding_pledge_returns_error() {
1037        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1038        encoder.set_pledged_content_size(5).unwrap();
1039        let err = encoder.write_all(b"exceeds five bytes").unwrap_err();
1040        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1041    }
1042
1043    #[test]
1044    fn write_straddling_pledge_reports_partial_progress() {
1045        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1046        encoder.set_pledged_content_size(5).unwrap();
1047        // write() should accept exactly 5 bytes (partial progress)
1048        assert_eq!(encoder.write(b"abcdef").unwrap(), 5);
1049        // Next write should fail — pledge exhausted
1050        let err = encoder.write(b"g").unwrap_err();
1051        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1052    }
1053
1054    #[test]
1055    fn encoded_scratch_capacity_is_reused_across_blocks() {
1056        let payload = vec![0xAB; 64 * 3];
1057        let mut encoder = StreamingEncoder::new_with_matcher(
1058            TinyMatcher::new(64),
1059            Vec::new(),
1060            CompressionLevel::Uncompressed,
1061        );
1062
1063        encoder.write_all(&payload[..64]).unwrap();
1064        let first_capacity = encoder.encoded_scratch.capacity();
1065        assert!(
1066            first_capacity >= 67,
1067            "expected encoded scratch to keep block header + payload capacity",
1068        );
1069
1070        encoder.write_all(&payload[64..128]).unwrap();
1071        let second_capacity = encoder.encoded_scratch.capacity();
1072        assert!(
1073            second_capacity >= first_capacity,
1074            "encoded scratch capacity should be reused across block emits",
1075        );
1076
1077        encoder.write_all(&payload[128..]).unwrap();
1078        let compressed = encoder.finish().unwrap();
1079        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
1080        let mut decoded = Vec::new();
1081        decoder.read_to_end(&mut decoded).unwrap();
1082        assert_eq!(decoded, payload);
1083    }
1084
1085    #[test]
1086    fn pledged_content_size_after_write_returns_error() {
1087        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1088        encoder.write_all(b"already writing").unwrap();
1089        let err = encoder.set_pledged_content_size(15).unwrap_err();
1090        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1091    }
1092
1093    #[test]
1094    fn source_size_hint_directly_reduces_window_header() {
1095        let payload = b"streaming-source-size-hint".repeat(64);
1096
1097        let mut no_hint = StreamingEncoder::new(Vec::new(), CompressionLevel::from_level(11));
1098        no_hint.write_all(payload.as_slice()).unwrap();
1099        let no_hint_frame = no_hint.finish().unwrap();
1100        let no_hint_header = crate::decoding::frame::read_frame_header(no_hint_frame.as_slice())
1101            .unwrap()
1102            .0;
1103        let no_hint_window = no_hint_header.window_size().unwrap();
1104
1105        let mut with_hint = StreamingEncoder::new(Vec::new(), CompressionLevel::from_level(11));
1106        with_hint
1107            .set_source_size_hint(payload.len() as u64)
1108            .unwrap();
1109        with_hint.write_all(payload.as_slice()).unwrap();
1110        let late_hint_err = with_hint
1111            .set_source_size_hint(payload.len() as u64)
1112            .unwrap_err();
1113        assert_eq!(late_hint_err.kind(), ErrorKind::InvalidInput);
1114        let with_hint_frame = with_hint.finish().unwrap();
1115        let with_hint_header =
1116            crate::decoding::frame::read_frame_header(with_hint_frame.as_slice())
1117                .unwrap()
1118                .0;
1119        let with_hint_window = with_hint_header.window_size().unwrap();
1120
1121        assert!(
1122            with_hint_window <= no_hint_window,
1123            "source size hint should not increase advertised window"
1124        );
1125
1126        let mut decoder = StreamingDecoder::new(with_hint_frame.as_slice()).unwrap();
1127        let mut decoded = Vec::new();
1128        decoder.read_to_end(&mut decoded).unwrap();
1129        assert_eq!(decoded, payload);
1130    }
1131
1132    #[cfg(feature = "std")]
1133    #[test]
1134    fn pledged_content_size_c_zstd_compatible() {
1135        let payload = b"tenant=demo op=put key=streaming value=abcdef\n".repeat(4096);
1136        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1137        encoder
1138            .set_pledged_content_size(payload.len() as u64)
1139            .unwrap();
1140        for chunk in payload.chunks(1024) {
1141            encoder.write_all(chunk).unwrap();
1142        }
1143        let compressed = encoder.finish().unwrap();
1144
1145        // FCS should be written
1146        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1147            .unwrap()
1148            .0;
1149        assert_eq!(header.frame_content_size(), payload.len() as u64);
1150
1151        // C zstd should decompress successfully
1152        let mut decoded = Vec::new();
1153        zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
1154        assert_eq!(decoded, payload);
1155    }
1156
1157    #[test]
1158    fn single_segment_requires_pledged_to_fit_matcher_window() {
1159        let payload = b"streaming-window-gate-".repeat(60); // 1320 bytes
1160        let mut encoder = StreamingEncoder::new_with_matcher(
1161            TinyMatcher::new(1024),
1162            Vec::new(),
1163            CompressionLevel::Fastest,
1164        );
1165        encoder
1166            .set_pledged_content_size(payload.len() as u64)
1167            .unwrap();
1168        encoder.write_all(payload.as_slice()).unwrap();
1169        let compressed = encoder.finish().unwrap();
1170
1171        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1172            .unwrap()
1173            .0;
1174        assert_eq!(header.frame_content_size(), payload.len() as u64);
1175        assert!(
1176            !header.descriptor.single_segment_flag(),
1177            "single-segment must stay off when pledged content size exceeds matcher window"
1178        );
1179        assert!(
1180            header.window_size().unwrap() >= 1024,
1181            "window descriptor should be present when single-segment is disabled"
1182        );
1183    }
1184
1185    #[test]
1186    fn ensure_frame_started_refreshes_stale_strategy_tag_at_reset() {
1187        // The literal-compression gates (`min_literals_to_compress`,
1188        // `min_gain`) read `state.strategy_tag`. Regression: every
1189        // reset site MUST refresh that tag from the active compression
1190        // level — relying on construction-time initialization alone is
1191        // not enough, because later mutations or reuse patterns can
1192        // leave the tag stale.
1193        //
1194        // To exercise the RESET-time refresh (not just the
1195        // construction-time init that `StreamingEncoder::new` does for
1196        // free), this test deliberately corrupts `state.strategy_tag`
1197        // to a value that does NOT match the active level, then
1198        // triggers `ensure_frame_started` and asserts the reset path
1199        // wrote the correct tag back. If the sync line in
1200        // `ensure_frame_started` were deleted, the corrupted value
1201        // would survive the write and fail the assertion.
1202        use crate::encoding::strategy::StrategyTag;
1203        for level in [
1204            CompressionLevel::Fastest,
1205            CompressionLevel::Default,
1206            CompressionLevel::Better,
1207            CompressionLevel::Best,
1208        ] {
1209            let expected = StrategyTag::for_compression_level(level);
1210            let mut encoder = StreamingEncoder::new(Vec::new(), level);
1211            // Pick a sentinel that differs from the legitimate tag so
1212            // a missing reset-time sync is observable. BtUltra2 is the
1213            // most-aggressive variant; the four levels above resolve
1214            // to Fast/Dfast/Lazy/Lazy respectively, none equal to it.
1215            let sentinel = StrategyTag::BtUltra2;
1216            assert_ne!(
1217                expected, sentinel,
1218                "sentinel must differ from the legitimate tag at level {level:?}",
1219            );
1220            encoder.state.strategy_tag = sentinel;
1221            encoder.write_all(b"x").unwrap();
1222            assert_eq!(
1223                encoder.state.strategy_tag, expected,
1224                "reset-time strategy_tag sync missing at level {level:?}: \
1225                 sentinel survived `ensure_frame_started`",
1226            );
1227            let _ = encoder.finish().unwrap();
1228        }
1229    }
1230
1231    #[test]
1232    fn no_pledged_size_omits_fcs_from_header() {
1233        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
1234        encoder.write_all(b"no pledged size").unwrap();
1235        let compressed = encoder.finish().unwrap();
1236
1237        // FCS should be omitted from the header; the decoder reports absent FCS as 0.
1238        let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
1239            .unwrap()
1240            .0;
1241        assert_eq!(header.frame_content_size(), 0);
1242        // Verify the descriptor confirms FCS field is truly absent (0 bytes),
1243        // not just FCS present with value 0.
1244        assert_eq!(header.descriptor.frame_content_size_bytes().unwrap(), 0);
1245    }
1246}