Skip to main content

structured_zstd/encoding/
streaming_encoder.rs

1use alloc::format;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4use core::mem;
5
6use crate::common::MAX_BLOCK_SIZE;
7#[cfg(feature = "hash")]
8use core::hash::Hasher;
9#[cfg(feature = "hash")]
10use twox_hash::XxHash64;
11
12use crate::encoding::levels::compress_fastest;
13use crate::encoding::{
14    CompressionLevel, MatchGeneratorDriver, Matcher, block_header::BlockHeader,
15    frame_compressor::CompressState, frame_compressor::FseTables, frame_header::FrameHeader,
16};
17use crate::io::{Error, ErrorKind, Write};
18
19/// Incremental frame encoder that implements [`Write`].
20///
21/// Data can be provided with multiple `write()` calls. Full blocks are compressed
22/// automatically, `flush()` emits the currently buffered partial block as non-last,
23/// and `finish()` closes the frame and returns the wrapped writer.
24pub struct StreamingEncoder<W: Write, M: Matcher = MatchGeneratorDriver> {
25    drain: Option<W>,
26    compression_level: CompressionLevel,
27    state: CompressState<M>,
28    pending: Vec<u8>,
29    errored: bool,
30    last_error_kind: Option<ErrorKind>,
31    last_error_message: Option<String>,
32    frame_started: bool,
33    #[cfg(feature = "hash")]
34    hasher: XxHash64,
35}
36
37impl<W: Write> StreamingEncoder<W, MatchGeneratorDriver> {
38    /// Creates a streaming encoder backed by the default match generator.
39    ///
40    /// The encoder writes compressed bytes into `drain` and applies `compression_level`
41    /// to all subsequently written blocks.
42    pub fn new(drain: W, compression_level: CompressionLevel) -> Self {
43        Self::new_with_matcher(
44            MatchGeneratorDriver::new(MAX_BLOCK_SIZE as usize, 1),
45            drain,
46            compression_level,
47        )
48    }
49}
50
51impl<W: Write, M: Matcher> StreamingEncoder<W, M> {
52    /// Creates a streaming encoder with an explicitly provided matcher implementation.
53    ///
54    /// This constructor is primarily intended for tests and advanced callers that need
55    /// custom match-window behavior.
56    pub fn new_with_matcher(matcher: M, drain: W, compression_level: CompressionLevel) -> Self {
57        Self {
58            drain: Some(drain),
59            compression_level,
60            state: CompressState {
61                matcher,
62                last_huff_table: None,
63                fse_tables: FseTables::new(),
64                offset_hist: [1, 4, 8],
65            },
66            pending: Vec::new(),
67            errored: false,
68            last_error_kind: None,
69            last_error_message: None,
70            frame_started: false,
71            #[cfg(feature = "hash")]
72            hasher: XxHash64::with_seed(0),
73        }
74    }
75
76    /// Returns an immutable reference to the wrapped output drain.
77    ///
78    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
79    /// consumes the encoder and returns ownership of the drain.
80    pub fn get_ref(&self) -> &W {
81        self.drain
82            .as_ref()
83            .expect("streaming encoder drain is present until finish consumes self")
84    }
85
86    /// Returns a mutable reference to the wrapped output drain.
87    ///
88    /// It is inadvisable to directly write to the underlying writer, as doing
89    /// so would corrupt the zstd frame being assembled by the encoder.
90    ///
91    /// The drain remains available for the encoder lifetime; [`finish`](Self::finish)
92    /// consumes the encoder and returns ownership of the drain.
93    pub fn get_mut(&mut self) -> &mut W {
94        self.drain
95            .as_mut()
96            .expect("streaming encoder drain is present until finish consumes self")
97    }
98
99    /// Finalizes the current zstd frame and returns the wrapped output drain.
100    ///
101    /// If no payload was written yet, this still emits a valid empty frame.
102    /// Calling this method consumes the encoder.
103    pub fn finish(mut self) -> Result<W, Error> {
104        self.ensure_open()?;
105        self.ensure_frame_started()?;
106
107        if self.pending.is_empty() {
108            self.write_empty_last_block()
109                .map_err(|err| self.fail(err))?;
110        } else {
111            self.emit_pending_block(true)?;
112        }
113
114        let mut drain = self
115            .drain
116            .take()
117            .expect("streaming encoder drain must be present when finishing");
118
119        #[cfg(feature = "hash")]
120        {
121            let checksum = self.hasher.finish() as u32;
122            drain
123                .write_all(&checksum.to_le_bytes())
124                .map_err(|err| self.fail(err))?;
125        }
126
127        drain.flush().map_err(|err| self.fail(err))?;
128        Ok(drain)
129    }
130
131    fn ensure_open(&self) -> Result<(), Error> {
132        if self.errored {
133            return Err(self.sticky_error());
134        }
135        Ok(())
136    }
137
138    // Cold path (only reached after poisoning). The format!() calls still allocate
139    // in no_std even though error_with_kind_message/other_error_owned drop the
140    // message; this is acceptable on an error recovery path to keep match arms simple.
141    fn sticky_error(&self) -> Error {
142        match (self.last_error_kind, self.last_error_message.as_deref()) {
143            (Some(kind), Some(message)) => error_with_kind_message(
144                kind,
145                format!(
146                    "streaming encoder is in an errored state due to previous {kind:?} failure: {message}"
147                ),
148            ),
149            (Some(kind), None) => error_from_kind(kind),
150            (None, Some(message)) => other_error_owned(format!(
151                "streaming encoder is in an errored state: {message}"
152            )),
153            (None, None) => other_error("streaming encoder is in an errored state"),
154        }
155    }
156
157    fn drain_mut(&mut self) -> Result<&mut W, Error> {
158        self.drain
159            .as_mut()
160            .ok_or_else(|| other_error("streaming encoder has no active drain"))
161    }
162
163    fn ensure_frame_started(&mut self) -> Result<(), Error> {
164        if self.frame_started {
165            return Ok(());
166        }
167
168        self.ensure_level_supported()?;
169        self.state.matcher.reset(self.compression_level);
170        self.state.offset_hist = [1, 4, 8];
171        self.state.last_huff_table = None;
172        self.state.fse_tables.ll_previous = None;
173        self.state.fse_tables.ml_previous = None;
174        self.state.fse_tables.of_previous = None;
175        #[cfg(feature = "hash")]
176        {
177            self.hasher = XxHash64::with_seed(0);
178        }
179
180        let window_size = self.state.matcher.window_size();
181        if window_size == 0 {
182            return Err(invalid_input_error(
183                "matcher reported window_size == 0, which is invalid",
184            ));
185        }
186
187        let header = FrameHeader {
188            frame_content_size: None,
189            single_segment: false,
190            content_checksum: cfg!(feature = "hash"),
191            dictionary_id: None,
192            window_size: Some(window_size),
193        };
194        let mut encoded_header = Vec::new();
195        header.serialize(&mut encoded_header);
196        self.drain_mut()
197            .and_then(|drain| drain.write_all(&encoded_header))
198            .map_err(|err| self.fail(err))?;
199
200        self.frame_started = true;
201        Ok(())
202    }
203
204    fn block_capacity(&self) -> usize {
205        let matcher_window = self.state.matcher.window_size() as usize;
206        core::cmp::max(1, core::cmp::min(matcher_window, MAX_BLOCK_SIZE as usize))
207    }
208
209    fn allocate_pending_space(&mut self, block_capacity: usize) -> Vec<u8> {
210        let mut space = match self.compression_level {
211            CompressionLevel::Fastest | CompressionLevel::Default => {
212                self.state.matcher.get_next_space()
213            }
214            _ => Vec::new(),
215        };
216        space.clear();
217        if space.capacity() > block_capacity {
218            space.shrink_to(block_capacity);
219        }
220        if space.capacity() < block_capacity {
221            space.reserve(block_capacity - space.capacity());
222        }
223        space
224    }
225
226    fn emit_full_pending_block(
227        &mut self,
228        block_capacity: usize,
229        consumed: usize,
230    ) -> Option<Result<usize, Error>> {
231        if self.pending.len() != block_capacity {
232            return None;
233        }
234
235        let new_pending = self.allocate_pending_space(block_capacity);
236        let full_block = mem::replace(&mut self.pending, new_pending);
237        if let Err((err, restored_block)) = self.encode_block(full_block, false) {
238            self.pending = restored_block;
239            let err = self.fail(err);
240            if consumed > 0 {
241                return Some(Ok(consumed));
242            }
243            return Some(Err(err));
244        }
245        None
246    }
247
248    fn emit_pending_block(&mut self, last_block: bool) -> Result<(), Error> {
249        let block = mem::take(&mut self.pending);
250        if let Err((err, restored_block)) = self.encode_block(block, last_block) {
251            self.pending = restored_block;
252            return Err(self.fail(err));
253        }
254        if !last_block {
255            let block_capacity = self.block_capacity();
256            self.pending = self.allocate_pending_space(block_capacity);
257        }
258        Ok(())
259    }
260
261    fn ensure_level_supported(&self) -> Result<(), Error> {
262        match self.compression_level {
263            CompressionLevel::Uncompressed
264            | CompressionLevel::Fastest
265            | CompressionLevel::Default => Ok(()),
266            _ => Err(invalid_input_error(
267                "streaming encoder currently supports Uncompressed/Fastest/Default only",
268            )),
269        }
270    }
271
272    fn encode_block(
273        &mut self,
274        uncompressed_data: Vec<u8>,
275        last_block: bool,
276    ) -> Result<(), (Error, Vec<u8>)> {
277        let mut raw_block = Some(uncompressed_data);
278        // TODO: reuse scratch buffer across blocks to reduce allocation churn (#47)
279        let mut encoded = Vec::with_capacity(self.block_capacity() + 3);
280        let mut moved_into_matcher = false;
281        if raw_block.as_ref().is_some_and(|block| block.is_empty()) {
282            let header = BlockHeader {
283                last_block,
284                block_type: crate::blocks::block::BlockType::Raw,
285                block_size: 0,
286            };
287            header.serialize(&mut encoded);
288        } else {
289            match self.compression_level {
290                CompressionLevel::Uncompressed => {
291                    let block = raw_block.as_ref().expect("raw block missing");
292                    let header = BlockHeader {
293                        last_block,
294                        block_type: crate::blocks::block::BlockType::Raw,
295                        block_size: block.len() as u32,
296                    };
297                    header.serialize(&mut encoded);
298                    encoded.extend_from_slice(block);
299                }
300                CompressionLevel::Fastest | CompressionLevel::Default => {
301                    let block = raw_block.take().expect("raw block missing");
302                    debug_assert!(!block.is_empty(), "empty blocks handled above");
303                    compress_fastest(&mut self.state, last_block, block, &mut encoded);
304                    moved_into_matcher = true;
305                }
306                _ => {
307                    return Err((
308                        invalid_input_error(
309                            "streaming encoder currently supports Uncompressed/Fastest/Default only",
310                        ),
311                        raw_block.unwrap_or_default(),
312                    ));
313                }
314            }
315        }
316
317        if let Err(err) = self.drain_mut().and_then(|drain| drain.write_all(&encoded)) {
318            let restored = if moved_into_matcher {
319                self.state.matcher.get_last_space().to_vec()
320            } else {
321                raw_block.unwrap_or_default()
322            };
323            return Err((err, restored));
324        }
325
326        if moved_into_matcher {
327            #[cfg(feature = "hash")]
328            {
329                self.hasher.write(self.state.matcher.get_last_space());
330            }
331        } else {
332            self.hash_block(raw_block.as_deref().unwrap_or(&[]));
333        }
334        Ok(())
335    }
336
337    fn write_empty_last_block(&mut self) -> Result<(), Error> {
338        self.encode_block(Vec::new(), true).map_err(|(err, _)| err)
339    }
340
341    fn fail(&mut self, err: Error) -> Error {
342        self.errored = true;
343        if self.last_error_kind.is_none() {
344            self.last_error_kind = Some(err.kind());
345        }
346        if self.last_error_message.is_none() {
347            self.last_error_message = Some(err.to_string());
348        }
349        err
350    }
351
352    #[cfg(feature = "hash")]
353    fn hash_block(&mut self, uncompressed_data: &[u8]) {
354        self.hasher.write(uncompressed_data);
355    }
356
357    #[cfg(not(feature = "hash"))]
358    fn hash_block(&mut self, _uncompressed_data: &[u8]) {}
359}
360
361impl<W: Write, M: Matcher> Write for StreamingEncoder<W, M> {
362    fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
363        self.ensure_open()?;
364        if buf.is_empty() {
365            return Ok(0);
366        }
367
368        self.ensure_frame_started()?;
369        let block_capacity = self.block_capacity();
370        if self.pending.capacity() == 0 {
371            self.pending = self.allocate_pending_space(block_capacity);
372        }
373        let mut remaining = buf;
374        let mut consumed = 0usize;
375
376        while !remaining.is_empty() {
377            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
378                return result;
379            }
380
381            let available = block_capacity - self.pending.len();
382            let to_take = core::cmp::min(remaining.len(), available);
383            if to_take == 0 {
384                break;
385            }
386            self.pending.extend_from_slice(&remaining[..to_take]);
387            remaining = &remaining[to_take..];
388            consumed += to_take;
389
390            if let Some(result) = self.emit_full_pending_block(block_capacity, consumed) {
391                return result;
392            }
393        }
394        Ok(consumed)
395    }
396
397    fn flush(&mut self) -> Result<(), Error> {
398        self.ensure_open()?;
399        if self.pending.is_empty() {
400            return self
401                .drain_mut()
402                .and_then(|drain| drain.flush())
403                .map_err(|err| self.fail(err));
404        }
405        self.ensure_frame_started()?;
406        self.emit_pending_block(false)?;
407        self.drain_mut()
408            .and_then(|drain| drain.flush())
409            .map_err(|err| self.fail(err))
410    }
411}
412
413fn error_from_kind(kind: ErrorKind) -> Error {
414    Error::from(kind)
415}
416
417fn error_with_kind_message(kind: ErrorKind, message: String) -> Error {
418    #[cfg(feature = "std")]
419    {
420        Error::new(kind, message)
421    }
422    #[cfg(not(feature = "std"))]
423    {
424        Error::new(kind, alloc::boxed::Box::new(message))
425    }
426}
427
428fn invalid_input_error(message: &str) -> Error {
429    #[cfg(feature = "std")]
430    {
431        Error::new(ErrorKind::InvalidInput, message)
432    }
433    #[cfg(not(feature = "std"))]
434    {
435        Error::new(
436            ErrorKind::Other,
437            alloc::boxed::Box::new(alloc::string::String::from(message)),
438        )
439    }
440}
441
442fn other_error_owned(message: String) -> Error {
443    #[cfg(feature = "std")]
444    {
445        Error::other(message)
446    }
447    #[cfg(not(feature = "std"))]
448    {
449        Error::new(ErrorKind::Other, alloc::boxed::Box::new(message))
450    }
451}
452
453fn other_error(message: &str) -> Error {
454    #[cfg(feature = "std")]
455    {
456        Error::other(message)
457    }
458    #[cfg(not(feature = "std"))]
459    {
460        Error::new(
461            ErrorKind::Other,
462            alloc::boxed::Box::new(alloc::string::String::from(message)),
463        )
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use crate::decoding::StreamingDecoder;
470    use crate::encoding::{CompressionLevel, Matcher, Sequence, StreamingEncoder};
471    use crate::io::{Error, ErrorKind, Read, Write};
472    use alloc::vec;
473    use alloc::vec::Vec;
474
475    struct TinyMatcher {
476        last_space: Vec<u8>,
477        window_size: u64,
478    }
479
480    impl TinyMatcher {
481        fn new(window_size: u64) -> Self {
482            Self {
483                last_space: Vec::new(),
484                window_size,
485            }
486        }
487    }
488
489    impl Matcher for TinyMatcher {
490        fn get_next_space(&mut self) -> Vec<u8> {
491            vec![0; self.window_size as usize]
492        }
493
494        fn get_last_space(&mut self) -> &[u8] {
495            self.last_space.as_slice()
496        }
497
498        fn commit_space(&mut self, space: Vec<u8>) {
499            self.last_space = space;
500        }
501
502        fn skip_matching(&mut self) {}
503
504        fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
505            handle_sequence(Sequence::Literals {
506                literals: self.last_space.as_slice(),
507            });
508        }
509
510        fn reset(&mut self, _level: CompressionLevel) {
511            self.last_space.clear();
512        }
513
514        fn window_size(&self) -> u64 {
515            self.window_size
516        }
517    }
518
519    struct FailingWriteOnce {
520        writes: usize,
521        fail_on_write_number: usize,
522        sink: Vec<u8>,
523    }
524
525    impl FailingWriteOnce {
526        fn new(fail_on_write_number: usize) -> Self {
527            Self {
528                writes: 0,
529                fail_on_write_number,
530                sink: Vec::new(),
531            }
532        }
533    }
534
535    impl Write for FailingWriteOnce {
536        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
537            self.writes += 1;
538            if self.writes == self.fail_on_write_number {
539                return Err(super::other_error("injected write failure"));
540            }
541            self.sink.extend_from_slice(buf);
542            Ok(buf.len())
543        }
544
545        fn flush(&mut self) -> Result<(), Error> {
546            Ok(())
547        }
548    }
549
550    struct FailingWithKind {
551        writes: usize,
552        fail_on_write_number: usize,
553        kind: ErrorKind,
554    }
555
556    impl FailingWithKind {
557        fn new(fail_on_write_number: usize, kind: ErrorKind) -> Self {
558            Self {
559                writes: 0,
560                fail_on_write_number,
561                kind,
562            }
563        }
564    }
565
566    impl Write for FailingWithKind {
567        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
568            self.writes += 1;
569            if self.writes == self.fail_on_write_number {
570                return Err(Error::from(self.kind));
571            }
572            Ok(buf.len())
573        }
574
575        fn flush(&mut self) -> Result<(), Error> {
576            Ok(())
577        }
578    }
579
580    struct PartialThenFailWriter {
581        writes: usize,
582        fail_on_write_number: usize,
583        partial_prefix_len: usize,
584        terminal_failure: bool,
585        sink: Vec<u8>,
586    }
587
588    impl PartialThenFailWriter {
589        fn new(fail_on_write_number: usize, partial_prefix_len: usize) -> Self {
590            Self {
591                writes: 0,
592                fail_on_write_number,
593                partial_prefix_len,
594                terminal_failure: false,
595                sink: Vec::new(),
596            }
597        }
598    }
599
600    impl Write for PartialThenFailWriter {
601        fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
602            if self.terminal_failure {
603                return Err(super::other_error("injected terminal write failure"));
604            }
605
606            self.writes += 1;
607            if self.writes == self.fail_on_write_number {
608                let written = core::cmp::min(self.partial_prefix_len, buf.len());
609                if written > 0 {
610                    self.sink.extend_from_slice(&buf[..written]);
611                    self.terminal_failure = true;
612                    return Ok(written);
613                }
614                return Err(super::other_error("injected terminal write failure"));
615            }
616
617            self.sink.extend_from_slice(buf);
618            Ok(buf.len())
619        }
620
621        fn flush(&mut self) -> Result<(), Error> {
622            Ok(())
623        }
624    }
625
626    #[test]
627    fn streaming_encoder_roundtrip_multiple_writes() {
628        let payload = b"streaming-encoder-roundtrip-".repeat(1024);
629        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
630        for chunk in payload.chunks(313) {
631            encoder.write_all(chunk).unwrap();
632        }
633        let compressed = encoder.finish().unwrap();
634
635        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
636        let mut decoded = Vec::new();
637        decoder.read_to_end(&mut decoded).unwrap();
638        assert_eq!(decoded, payload);
639    }
640
641    #[test]
642    fn flush_emits_nonempty_partial_output() {
643        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
644        encoder.write_all(b"partial-block").unwrap();
645        encoder.flush().unwrap();
646        let flushed_len = encoder.get_ref().len();
647        assert!(
648            flushed_len > 0,
649            "flush should emit header+partial block bytes"
650        );
651        let compressed = encoder.finish().unwrap();
652        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
653        let mut decoded = Vec::new();
654        decoder.read_to_end(&mut decoded).unwrap();
655        assert_eq!(decoded, b"partial-block");
656    }
657
658    #[test]
659    fn flush_without_writes_does_not_emit_frame_header() {
660        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
661        encoder.flush().unwrap();
662        assert!(encoder.get_ref().is_empty());
663    }
664
665    #[test]
666    fn block_boundary_write_emits_block_in_same_call() {
667        let mut boundary = StreamingEncoder::new_with_matcher(
668            TinyMatcher::new(4),
669            Vec::new(),
670            CompressionLevel::Uncompressed,
671        );
672        let mut below = StreamingEncoder::new_with_matcher(
673            TinyMatcher::new(4),
674            Vec::new(),
675            CompressionLevel::Uncompressed,
676        );
677
678        boundary.write_all(b"ABCD").unwrap();
679        below.write_all(b"ABC").unwrap();
680
681        let boundary_len = boundary.get_ref().len();
682        let below_len = below.get_ref().len();
683        assert!(
684            boundary_len > below_len,
685            "full block should be emitted immediately at block boundary"
686        );
687    }
688
689    #[test]
690    fn finish_consumes_encoder_and_emits_frame() {
691        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
692        encoder.write_all(b"abc").unwrap();
693        let compressed = encoder.finish().unwrap();
694        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
695        let mut decoded = Vec::new();
696        decoder.read_to_end(&mut decoded).unwrap();
697        assert_eq!(decoded, b"abc");
698    }
699
700    #[test]
701    fn finish_without_writes_emits_empty_frame() {
702        let encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
703        let compressed = encoder.finish().unwrap();
704        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
705        let mut decoded = Vec::new();
706        decoder.read_to_end(&mut decoded).unwrap();
707        assert!(decoded.is_empty());
708    }
709
710    #[test]
711    fn write_empty_buffer_returns_zero() {
712        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
713        assert_eq!(encoder.write(&[]).unwrap(), 0);
714        let _ = encoder.finish().unwrap();
715    }
716
717    #[test]
718    fn uncompressed_level_roundtrip() {
719        let payload = b"uncompressed-streaming-roundtrip".repeat(64);
720        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Uncompressed);
721        for chunk in payload.chunks(41) {
722            encoder.write_all(chunk).unwrap();
723        }
724        let compressed = encoder.finish().unwrap();
725        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
726        let mut decoded = Vec::new();
727        decoder.read_to_end(&mut decoded).unwrap();
728        assert_eq!(decoded, payload);
729    }
730
731    #[test]
732    fn better_level_returns_unsupported_error() {
733        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Better);
734        let err = encoder.write_all(b"payload").unwrap_err();
735        assert_eq!(err.kind(), ErrorKind::InvalidInput);
736        assert!(encoder.finish().is_err());
737    }
738
739    #[test]
740    fn zero_window_matcher_returns_invalid_input_error() {
741        let mut encoder = StreamingEncoder::new_with_matcher(
742            TinyMatcher::new(0),
743            Vec::new(),
744            CompressionLevel::Fastest,
745        );
746        let err = encoder.write_all(b"payload").unwrap_err();
747        assert_eq!(err.kind(), ErrorKind::InvalidInput);
748    }
749
750    #[test]
751    fn unsupported_level_write_fails_before_emitting_frame_header() {
752        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Better);
753        assert!(encoder.write_all(b"payload").is_err());
754        assert_eq!(encoder.get_ref().len(), 0);
755    }
756
757    #[test]
758    fn write_failure_poisoning_is_sticky() {
759        let mut encoder = StreamingEncoder::new_with_matcher(
760            TinyMatcher::new(4),
761            FailingWriteOnce::new(1),
762            CompressionLevel::Uncompressed,
763        );
764
765        assert!(encoder.write_all(b"ABCD").is_err());
766        assert!(encoder.flush().is_err());
767        assert!(encoder.write_all(b"EFGH").is_err());
768        assert_eq!(encoder.get_ref().sink.len(), 0);
769        assert!(encoder.finish().is_err());
770    }
771
772    #[test]
773    fn poisoned_encoder_returns_original_error_kind() {
774        let mut encoder = StreamingEncoder::new_with_matcher(
775            TinyMatcher::new(4),
776            FailingWithKind::new(1, ErrorKind::BrokenPipe),
777            CompressionLevel::Uncompressed,
778        );
779
780        let first_error = encoder.write_all(b"ABCD").unwrap_err();
781        assert_eq!(first_error.kind(), ErrorKind::BrokenPipe);
782
783        let second_error = encoder.write_all(b"EFGH").unwrap_err();
784        assert_eq!(second_error.kind(), ErrorKind::BrokenPipe);
785    }
786
787    #[test]
788    fn write_reports_progress_but_poisoning_is_sticky_after_later_block_failure() {
789        let payload = b"ABCDEFGHIJKL";
790        let mut encoder = StreamingEncoder::new_with_matcher(
791            TinyMatcher::new(4),
792            FailingWriteOnce::new(3),
793            CompressionLevel::Uncompressed,
794        );
795
796        let first_write = encoder.write(payload).unwrap();
797        assert_eq!(first_write, 8);
798        assert!(encoder.write(&payload[first_write..]).is_err());
799        assert!(encoder.flush().is_err());
800        assert!(encoder.write_all(b"EFGH").is_err());
801    }
802
803    #[test]
804    fn partial_write_failure_after_progress_poisons_encoder() {
805        let payload = b"ABCDEFGHIJKL";
806        let mut encoder = StreamingEncoder::new_with_matcher(
807            TinyMatcher::new(4),
808            PartialThenFailWriter::new(3, 1),
809            CompressionLevel::Uncompressed,
810        );
811
812        let first_write = encoder.write(payload).unwrap();
813        assert_eq!(first_write, 8);
814
815        let second_write = encoder.write(&payload[first_write..]);
816        assert!(second_write.is_err());
817        assert!(encoder.flush().is_err());
818        assert!(encoder.write_all(b"MNOP").is_err());
819    }
820
821    #[test]
822    fn new_with_matcher_and_get_mut_work() {
823        let matcher = TinyMatcher::new(128 * 1024);
824        let mut encoder =
825            StreamingEncoder::new_with_matcher(matcher, Vec::new(), CompressionLevel::Fastest);
826        encoder.get_mut().extend_from_slice(b"");
827        encoder.write_all(b"custom-matcher").unwrap();
828        let compressed = encoder.finish().unwrap();
829        let mut decoder = StreamingDecoder::new(compressed.as_slice()).unwrap();
830        let mut decoded = Vec::new();
831        decoder.read_to_end(&mut decoded).unwrap();
832        assert_eq!(decoded, b"custom-matcher");
833    }
834
835    #[cfg(feature = "std")]
836    #[test]
837    fn streaming_encoder_output_decompresses_with_c_zstd() {
838        let payload = b"tenant=demo op=put key=streaming value=abcdef\n".repeat(4096);
839        let mut encoder = StreamingEncoder::new(Vec::new(), CompressionLevel::Fastest);
840        for chunk in payload.chunks(1024) {
841            encoder.write_all(chunk).unwrap();
842        }
843        let compressed = encoder.finish().unwrap();
844
845        let mut decoded = Vec::with_capacity(payload.len());
846        zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
847        assert_eq!(decoded, payload);
848    }
849}