warcat/warc/
decode.rs

1//! WARC file reading
2use std::{
3    collections::VecDeque,
4    io::{Read, Seek, Write},
5};
6
7use crate::{
8    compress::{DecompressorConfig, PushDecompressor},
9    error::{GeneralError, ProtocolError, ProtocolErrorKind},
10    header::WarcHeader,
11    io::LogicalPosition,
12};
13
14const BUFFER_LENGTH: usize = crate::io::IO_BUFFER_LENGTH;
15const MAX_HEADER_LENGTH: usize = 32768;
16
17/// Configuration for a [`Decoder`]
18#[derive(Debug, Clone, Default)]
19#[non_exhaustive]
20pub struct DecoderConfig {
21    /// Compression configuration of the file to be read
22    pub decompressor: DecompressorConfig,
23}
24
25#[derive(Debug)]
26pub struct DecStateHeader;
27#[derive(Debug, Default)]
28pub struct DecStateBlock {
29    is_end: bool,
30}
31
32/// WARC format reader
33#[derive(Debug)]
34pub struct Decoder<S, R: Read> {
35    state: S,
36    input: R,
37    push_decoder: PushDecoder,
38    logical_position: u64,
39    buf: Vec<u8>,
40}
41
42impl<S, R: Read> Decoder<S, R> {
43    pub fn get_ref(&self) -> &R {
44        &self.input
45    }
46
47    pub fn get_mut(&mut self) -> &mut R {
48        &mut self.input
49    }
50
51    /// Returns the position of the beginning of a WARC record.
52    ///
53    /// This function is intended for indexing a WARC file.
54    pub fn record_boundary_position(&self) -> u64 {
55        self.push_decoder.record_boundary_position()
56    }
57
58    fn read_into_push_decoder(&mut self) -> std::io::Result<usize> {
59        tracing::trace!("read into push decoder");
60
61        self.buf.resize(BUFFER_LENGTH, 0);
62
63        let read_length = self.input.read(&mut self.buf)?;
64
65        self.buf.truncate(read_length);
66
67        self.logical_position += read_length as u64;
68
69        self.push_decoder.write_all(&self.buf)?;
70
71        if read_length == 0 {
72            self.push_decoder.write_eof()?;
73        }
74
75        tracing::trace!(read_length, "read into push decoder");
76
77        Ok(read_length)
78    }
79
80    fn read_nonzero_into_push_decoder(&mut self) -> std::io::Result<()> {
81        let read_length = self.read_into_push_decoder()?;
82
83        if read_length == 0 {
84            Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))
85        } else {
86            Ok(())
87        }
88    }
89
90    /// Resets the decoder state so that a new record can be decoded.
91    ///
92    /// Configuration is kept but any buffered data is discarded.
93    ///
94    /// This function may be used after file seeking or
95    /// partially reading records.
96    pub fn reset(mut self) -> std::io::Result<Decoder<DecStateHeader, R>> {
97        self.push_decoder.reset()?;
98
99        Ok(Decoder {
100            state: DecStateHeader,
101            input: self.input,
102            push_decoder: self.push_decoder,
103            logical_position: self.logical_position,
104            buf: self.buf,
105        })
106    }
107}
108
109impl<R: Read> Decoder<DecStateHeader, R> {
110    /// Creates a new decoder that reads from the given reader.
111    pub fn new(input: R, config: DecoderConfig) -> std::io::Result<Self> {
112        let push_decoder = PushDecoder::new(config)?;
113
114        Ok(Self {
115            state: DecStateHeader,
116            input,
117            push_decoder,
118            logical_position: 0,
119            buf: Vec::with_capacity(BUFFER_LENGTH),
120        })
121    }
122
123    /// Returns the underlying reader.
124    pub fn into_inner(self) -> R {
125        self.input
126    }
127
128    /// Returns whether it was detected that the file was compressed
129    /// in a manner that makes random access to each record impossible.
130    ///
131    /// A false value is not guaranteed to be false unless the entire file has
132    /// been read.
133    pub fn has_record_at_time_compression_fault(&self) -> bool {
134        self.push_decoder.has_record_at_time_compression_fault()
135    }
136
137    /// Returns whether there is another WARC record to be read.
138    pub fn has_next_record(&mut self) -> std::io::Result<bool> {
139        if self.push_decoder.is_finished() {
140            return Ok(false);
141        } else if self.push_decoder.is_ready() {
142            self.read_into_push_decoder()?;
143        }
144
145        Ok(!self.push_decoder.is_ready())
146    }
147
148    /// Reads the header portion of a WARC record.
149    ///
150    /// This function consumes the reader and returns a typestate transitioned
151    /// reader for reading the block portion of a WARC record.
152    pub fn read_header(mut self) -> Result<(WarcHeader, Decoder<DecStateBlock, R>), GeneralError> {
153        loop {
154            match self.push_decoder.get_event()? {
155                PushDecoderEvent::Ready | PushDecoderEvent::WantData => {
156                    self.read_into_push_decoder()?;
157                    continue;
158                }
159                PushDecoderEvent::Finished => {
160                    return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput).into());
161                }
162                PushDecoderEvent::Continue => continue,
163                PushDecoderEvent::Header { header } => {
164                    return Ok((
165                        header,
166                        Decoder {
167                            state: DecStateBlock::default(),
168                            input: self.input,
169                            push_decoder: self.push_decoder,
170                            buf: self.buf,
171                            logical_position: self.logical_position,
172                        },
173                    ));
174                }
175                PushDecoderEvent::BlockData { data: _ } => unreachable!(),
176                PushDecoderEvent::EndRecord => unreachable!(),
177            }
178        }
179    }
180}
181
182impl<R: Read + Seek> Decoder<DecStateHeader, R> {
183    /// Prepare the internal decompressor to be ready for the source to be seeked.
184    ///
185    /// For Zstandard, this may load an embedded dictionary.
186    /// For other compression formats, this has no effect.
187    pub fn prepare_for_seek(&mut self) -> Result<(), GeneralError> {
188        if self
189            .push_decoder
190            .config
191            .decompressor
192            .format
193            .supports_concatenation()
194        {
195            loop {
196                self.read_into_push_decoder()?;
197
198                match self.push_decoder.get_event()? {
199                    PushDecoderEvent::Ready
200                    | PushDecoderEvent::Finished
201                    | PushDecoderEvent::WantData
202                    | PushDecoderEvent::Continue => {}
203                    PushDecoderEvent::Header { .. }
204                    | PushDecoderEvent::BlockData { .. }
205                    | PushDecoderEvent::EndRecord => break,
206                }
207            }
208
209            self.input.seek(std::io::SeekFrom::Start(0))?;
210            self.push_decoder.reset()?;
211        }
212
213        Ok(())
214    }
215}
216
217impl<R: Read> Decoder<DecStateBlock, R> {
218    fn read_block_impl(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
219        if self.state.is_end {
220            return Ok(0);
221        }
222
223        if buf.is_empty() {
224            return Ok(0);
225        }
226
227        self.push_decoder.set_max_buffer_len(buf.len());
228
229        loop {
230            match self
231                .push_decoder
232                .get_event()
233                .map_err(std::io::Error::other)?
234            {
235                PushDecoderEvent::Ready => unreachable!(),
236                PushDecoderEvent::Finished => unreachable!(),
237                PushDecoderEvent::WantData => {
238                    self.read_into_push_decoder()?;
239                    continue;
240                }
241                PushDecoderEvent::Continue => continue,
242                PushDecoderEvent::Header { header: _ } => unreachable!(),
243                PushDecoderEvent::BlockData { data } => {
244                    debug_assert!(data.len() <= buf.len());
245
246                    let buf_upper = buf.len().min(data.len());
247                    tracing::trace!(read_length = buf_upper, "read block");
248
249                    buf[0..buf_upper].copy_from_slice(&data[0..buf_upper]);
250
251                    return Ok(buf_upper);
252                }
253                PushDecoderEvent::EndRecord => {
254                    self.state.is_end = true;
255                    return Ok(0);
256                }
257            }
258        }
259    }
260
261    /// Indicate that reading the block portion of WARC record has completed.
262    ///
263    /// It's not necessary for the user to read the entire block or at all;
264    /// this function will continue to the end of the record automatically.
265    ///
266    /// Consumes the writer and returns a typestate transitioned writer that
267    /// can read the next WARC record.
268    pub fn finish_block(mut self) -> Result<Decoder<DecStateHeader, R>, GeneralError> {
269        tracing::trace!("finish block");
270        self.read_remaining_block()?;
271
272        Ok(Decoder {
273            state: DecStateHeader,
274            input: self.input,
275            push_decoder: self.push_decoder,
276            logical_position: self.logical_position,
277            buf: self.buf,
278        })
279    }
280
281    fn read_remaining_block(&mut self) -> Result<(), GeneralError> {
282        tracing::trace!("read remaining block");
283
284        self.push_decoder.set_max_buffer_len(BUFFER_LENGTH);
285
286        while !self.state.is_end {
287            match self.push_decoder.get_event()? {
288                PushDecoderEvent::Ready => unreachable!(),
289                PushDecoderEvent::Finished => unreachable!(),
290                PushDecoderEvent::WantData => {
291                    self.read_nonzero_into_push_decoder()?;
292                    continue;
293                }
294                PushDecoderEvent::Continue => continue,
295                PushDecoderEvent::Header { header: _ } => unreachable!(),
296                PushDecoderEvent::BlockData { data: _ } => continue,
297                PushDecoderEvent::EndRecord => self.state.is_end = true,
298            }
299        }
300
301        Ok(())
302    }
303}
304
305impl<R: Read> Read for Decoder<DecStateBlock, R> {
306    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
307        self.read_block_impl(buf)
308    }
309}
310
311impl<R: Read, S> LogicalPosition for Decoder<S, R> {
312    fn logical_position(&self) -> u64 {
313        self.logical_position
314    }
315}
316
317/// Events for [`PushDecoder`].
318#[derive(Debug)]
319pub enum PushDecoderEvent<'a> {
320    /// No input data has been received yet.
321    Ready,
322    /// End-of-file has been reached and no more data can be decoded.
323    Finished,
324    /// Either more data or end-of-file (EOF) is needed.
325    WantData,
326    /// Internal processing was successful and the user should call again.
327    Continue,
328    /// Decoding a header was successful.
329    Header { header: WarcHeader },
330    /// A chunk of the decoded block data.
331    BlockData { data: &'a [u8] },
332    /// Finished processing a single record.
333    EndRecord,
334}
335
336impl<'a> PushDecoderEvent<'a> {
337    pub fn is_ready(&self) -> bool {
338        matches!(self, Self::Ready)
339    }
340
341    pub fn is_finished(&self) -> bool {
342        matches!(self, Self::Finished)
343    }
344
345    pub fn is_want_data(&self) -> bool {
346        matches!(self, Self::WantData)
347    }
348
349    pub fn is_continue(&self) -> bool {
350        matches!(self, Self::Continue)
351    }
352
353    pub fn is_header(&self) -> bool {
354        matches!(self, Self::Header { .. })
355    }
356
357    pub fn is_block_data(&self) -> bool {
358        matches!(self, Self::BlockData { .. })
359    }
360
361    pub fn as_header(&self) -> Option<&WarcHeader> {
362        if let Self::Header { header } = self {
363            Some(header)
364        } else {
365            None
366        }
367    }
368
369    pub fn as_block_data(&self) -> Option<&'a [u8]> {
370        if let Self::BlockData { data } = self {
371            Some(data)
372        } else {
373            None
374        }
375    }
376
377    pub fn is_end_record(&self) -> bool {
378        matches!(self, Self::EndRecord)
379    }
380}
381
382#[derive(Debug, Clone, Copy, PartialEq, Eq)]
383enum PushDecoderState {
384    PendingHeader,
385    Header,
386    Block,
387    RecordBoundary,
388    EndOfSegment,
389    Finished,
390}
391
392/// WARC format decoder push-style.
393///
394/// This is similar to [`Decoder`] but input data is written to the struct
395/// and events are gathered by the caller. This push-style method can be
396/// use for sans-IO implementations.
397#[derive(Debug)]
398pub struct PushDecoder {
399    config: DecoderConfig,
400    state: PushDecoderState,
401    decompressor: PushDecompressor<VecDeque<u8>>,
402    decompressor_eof: bool,
403    input_eof: bool,
404    // Data that has not been decompresssed yet because it's for the next record.
405    deferred_input_buf: VecDeque<u8>,
406    // Total number of bytes written into the decoder.
407    bytes_written_decoder: u64,
408    // Total number of bytes read from the decoder (not including bytes for the next record).
409    decoded_bytes_consumed: u64,
410    // Index of record boundary in the raw file.
411    record_boundary_position: u64,
412    // Total number of bytes to be read from the record block.
413    block_length: u64,
414    // Number of bytes read so far from the record block.
415    block_current_position: u64,
416    // Maximum number of bytes that can be used for PushDecoderEvent::BlockData.
417    buf_output_max_len: usize,
418    // Number of bytes borrowed for PushDecoderEvent::BlockData.
419    buf_output_reference_len: usize,
420    // Detected a compressed file that can't be randomly accessed
421    has_rat_comp_fault: bool,
422}
423
424impl PushDecoder {
425    /// Creates a new decoder.
426    pub fn new(config: DecoderConfig) -> std::io::Result<Self> {
427        let decompressor =
428            PushDecompressor::with_config(VecDeque::new(), config.decompressor.clone())?;
429
430        Ok(Self {
431            config,
432            state: PushDecoderState::PendingHeader,
433            decompressor,
434            decompressor_eof: false,
435            input_eof: false,
436            deferred_input_buf: VecDeque::with_capacity(BUFFER_LENGTH),
437            bytes_written_decoder: 0,
438            decoded_bytes_consumed: 0,
439            record_boundary_position: 0,
440            block_length: 0,
441            block_current_position: 0,
442            buf_output_max_len: BUFFER_LENGTH,
443            buf_output_reference_len: 0,
444            has_rat_comp_fault: false,
445        })
446    }
447
448    /// Returns the position of the beginning of a WARC record.
449    ///
450    /// This function is intended for indexing a WARC file.
451    pub fn record_boundary_position(&self) -> u64 {
452        self.record_boundary_position
453    }
454
455    /// Returns whether internal buffer contains unused bytes that can be
456    /// used to decode the next record.
457    pub fn has_next_record(&self) -> bool {
458        !self.deferred_input_buf.is_empty()
459    }
460
461    /// Returns the maximum buffer length that can be used in [`PushDecoderEvent::BlockData`].
462    pub fn max_buffer_len(&self) -> usize {
463        self.buf_output_max_len
464    }
465
466    /// Sets the maximum buffer length that can be used in [`PushDecoderEvent::BlockData`].
467    ///
468    /// If the given value is 0, the value is set to a non-zero default.
469    pub fn set_max_buffer_len(&mut self, value: usize) {
470        if value != 0 {
471            self.buf_output_max_len = value;
472        } else {
473            self.buf_output_max_len = BUFFER_LENGTH;
474        }
475    }
476
477    /// Returns whether it was detected that the file was compressed
478    /// in a manner that makes random access to each record impossible.
479    ///
480    /// A false value is not guaranteed to be false unless the entire file has
481    /// been read.
482    pub fn has_record_at_time_compression_fault(&self) -> bool {
483        self.has_rat_comp_fault
484    }
485
486    /// Returns whether the next call to [`get_event()`](Self::get_event())
487    /// will return [`PushDecoderEvent::Ready`].
488    pub fn is_ready(&self) -> bool {
489        matches!(self.state, PushDecoderState::PendingHeader)
490    }
491
492    /// Returns whether the next call to [`get_event()`](Self::get_event())
493    /// will return [`PushDecoderEvent::Finished`].
494    pub fn is_finished(&self) -> bool {
495        matches!(self.state, PushDecoderState::Finished)
496    }
497
498    /// Returns a processed event.
499    ///
500    /// In order for this decoder to produce events, the caller must
501    /// put input data using the [`Write`] trait.
502    pub fn get_event(&mut self) -> Result<PushDecoderEvent, GeneralError> {
503        self.decompressor
504            .get_mut()
505            .drain(0..self.buf_output_reference_len);
506        self.buf_output_reference_len = 0;
507
508        match self.state {
509            PushDecoderState::PendingHeader => Ok(PushDecoderEvent::Ready),
510            PushDecoderState::Header => self.process_header(),
511            PushDecoderState::Block => self.process_block(),
512            PushDecoderState::RecordBoundary => self.process_record_boundary(),
513            PushDecoderState::EndOfSegment => self.process_end_of_segment(),
514            PushDecoderState::Finished => Ok(PushDecoderEvent::Finished),
515        }
516    }
517
518    /// Resets the decoder state so that a new record can be decoded.
519    ///
520    /// Configuration is kept but any buffered data is discarded.
521    ///
522    /// This function may be used after file seeking or
523    /// partially reading records.
524    pub fn reset(&mut self) -> std::io::Result<()> {
525        self.state = PushDecoderState::PendingHeader;
526        self.decompressor.get_mut().clear();
527        self.deferred_input_buf.clear();
528        self.decompressor.start_next_segment()?;
529        Ok(())
530    }
531
532    fn process_header(&mut self) -> Result<PushDecoderEvent, GeneralError> {
533        let buf = self.decompressor.get_mut().make_contiguous();
534
535        if let Some(index) = crate::parse::scan_header_deliminator(buf) {
536            let header = self.process_decodable_header(index)?;
537
538            return Ok(PushDecoderEvent::Header { header });
539        }
540
541        self.precheck_header()?;
542        self.check_max_header_length()?;
543
544        Ok(PushDecoderEvent::WantData)
545    }
546
547    fn process_decodable_header(&mut self, index: usize) -> Result<WarcHeader, GeneralError> {
548        // Okay to discard slice1 because we called make_contiguous() earlier.
549        let (buf, _slice1) = self.decompressor.get_ref().as_slices();
550
551        let header_bytes = &buf[0..index];
552        let header = WarcHeader::parse(header_bytes)?;
553        let length = header.content_length()?;
554        let record_id = header.fields.get("WARC-Record-ID");
555        let warc_type = header.fields.get("WARC-Type");
556
557        self.decompressor.get_mut().drain(0..index);
558        self.decoded_bytes_consumed += index as u64;
559
560        tracing::trace!(
561            record_id,
562            warc_type,
563            content_length = length,
564            "process decodable header"
565        );
566
567        self.block_current_position = 0;
568        self.block_length = length;
569
570        tracing::trace!("Header -> Block");
571        self.state = PushDecoderState::Block;
572
573        Ok(header)
574    }
575
576    fn precheck_header(&self) -> Result<(), ProtocolError> {
577        // Okay to discard slice1 because we called make_contiguous() earlier.
578        let (buf, _slice1) = self.decompressor.get_ref().as_slices();
579
580        match detect_header(buf) {
581            HeaderDetectResult::Warc => Ok(()),
582            HeaderDetectResult::Compression => {
583                Err(ProtocolError::new(ProtocolErrorKind::UnexpectedCompression)
584                    .with_position(self.bytes_written_decoder)
585                    .with_snippet(buf[0..buf.len().min(16)].escape_ascii().to_string()))
586            }
587            HeaderDetectResult::NotWarc => {
588                Err(ProtocolError::new(ProtocolErrorKind::UnknownHeader)
589                    .with_position(self.bytes_written_decoder)
590                    .with_snippet(buf[0..buf.len().min(16)].escape_ascii().to_string()))
591            }
592            HeaderDetectResult::NotSure => Ok(()),
593        }
594    }
595
596    fn check_max_header_length(&self) -> Result<(), ProtocolError> {
597        tracing::trace!("check max header length");
598
599        if self.decompressor.get_ref().len() > MAX_HEADER_LENGTH {
600            Err(ProtocolError::new(ProtocolErrorKind::HeaderTooBig))
601        } else {
602            Ok(())
603        }
604    }
605
606    fn process_block(&mut self) -> Result<PushDecoderEvent, GeneralError> {
607        tracing::trace!(
608            self.block_length,
609            self.block_current_position,
610            "process block"
611        );
612
613        debug_assert!(self.block_length >= self.block_current_position);
614        let remaining_bytes = self.block_length - self.block_current_position;
615
616        if remaining_bytes == 0 {
617            tracing::trace!("Block -> RecordBoundary");
618            self.state = PushDecoderState::RecordBoundary;
619            Ok(PushDecoderEvent::Continue)
620        } else if self.decompressor.get_ref().is_empty() {
621            Ok(PushDecoderEvent::WantData)
622        } else {
623            // Okay to discard slice1 because the caller will continually poll
624            // until the buffer is empty.
625            let (slice0, _slice1) = self.decompressor.get_ref().as_slices();
626
627            let consume_len = self.buf_output_max_len.min(slice0.len());
628            let consume_len = consume_len.min(remaining_bytes.try_into().unwrap_or(usize::MAX));
629
630            self.block_current_position += consume_len as u64;
631            self.buf_output_reference_len = consume_len;
632            self.decoded_bytes_consumed += consume_len as u64;
633
634            tracing::trace!(consume_len, "process block");
635
636            Ok(PushDecoderEvent::BlockData {
637                data: &slice0[0..consume_len],
638            })
639        }
640    }
641
642    fn process_record_boundary(&mut self) -> Result<PushDecoderEvent, GeneralError> {
643        tracing::trace!(
644            len = self.decompressor.get_ref().len(),
645            "process record boundary"
646        );
647
648        if self.decompressor.get_ref().len() >= 4 {
649            let mut buf = [0u8; 4];
650            let mut iter = self.decompressor.get_ref().range(0..4).copied();
651            buf[0] = iter.next().unwrap();
652            buf[1] = iter.next().unwrap();
653            buf[2] = iter.next().unwrap();
654            buf[3] = iter.next().unwrap();
655
656            if !buf.starts_with(b"\r\n\r\n") {
657                Err(ProtocolError::new(ProtocolErrorKind::InvalidRecordBoundary).into())
658            } else {
659                self.decompressor.get_mut().drain(0..4);
660                self.decoded_bytes_consumed += 4;
661
662                self.state = PushDecoderState::EndOfSegment;
663                Ok(PushDecoderEvent::Continue)
664            }
665        } else {
666            Ok(PushDecoderEvent::WantData)
667        }
668    }
669
670    fn process_end_of_segment(&mut self) -> Result<PushDecoderEvent, GeneralError> {
671        tracing::trace!(self.decompressor_eof, "process end of segment");
672
673        if self.config.decompressor.format.supports_concatenation()
674            && self.decompressor.get_ref().is_empty()
675            && !self.decompressor_eof
676            && !self.input_eof
677        {
678            // Finish reading any end of compression member/frame checksums.
679            Ok(PushDecoderEvent::WantData)
680        } else {
681            self.reset_for_next_record()?;
682
683            Ok(PushDecoderEvent::EndRecord)
684        }
685    }
686
687    fn reset_for_next_record(&mut self) -> Result<(), GeneralError> {
688        tracing::trace!(
689            remain_decomp_len = self.decompressor.get_ref().len(),
690            "reset for next record"
691        );
692        // dbg!(String::from_utf8_lossy(self.decompressor.get_ref().as_slices().0));
693        // dbg!(String::from_utf8_lossy(self.decompressor.get_ref().as_slices().1));
694
695        if self.config.decompressor.format.is_identity() {
696            self.record_boundary_position = self.decoded_bytes_consumed;
697        } else {
698            self.record_boundary_position = self.bytes_written_decoder;
699        }
700
701        if self.config.decompressor.format.supports_concatenation()
702            && self.decompressor.get_ref().is_empty()
703        {
704            tracing::trace!("decompressor start next segment");
705            self.decompressor.start_next_segment()?;
706        } else if self.config.decompressor.format.supports_concatenation()
707            && !self.has_rat_comp_fault
708        {
709            tracing::warn!("file is not using Record-at-time compression");
710            self.has_rat_comp_fault = true;
711        }
712
713        self.consume_deferred_input_buf()?;
714
715        self.decompressor_eof = false;
716
717        if self.decompressor.get_ref().is_empty() {
718            if self.input_eof {
719                tracing::trace!("RecordBoundary -> Finished");
720                self.state = PushDecoderState::Finished;
721            } else {
722                tracing::trace!("RecordBoundary -> PendingHeader");
723                self.state = PushDecoderState::PendingHeader;
724            }
725        } else {
726            tracing::trace!("RecordBoundary -> Header");
727            self.state = PushDecoderState::Header;
728        }
729
730        Ok(())
731    }
732
733    fn consume_deferred_input_buf(&mut self) -> Result<(), GeneralError> {
734        tracing::trace!(
735            len = self.deferred_input_buf.len(),
736            "consume deferred input buf"
737        );
738
739        while !self.deferred_input_buf.is_empty() {
740            let (slice0, _slice1) = self.deferred_input_buf.as_slices();
741            let write_len = self.decompressor.write(slice0)?;
742            tracing::trace!(write_len, "consume deferred input buf");
743
744            self.bytes_written_decoder += write_len as u64;
745
746            if write_len == 0 {
747                break;
748            }
749
750            self.deferred_input_buf.drain(..write_len);
751        }
752        Ok(())
753    }
754
755    /// Notify that there is no more input to be decoded.
756    pub fn write_eof(&mut self) -> std::io::Result<()> {
757        tracing::trace!("push decoder got write eof");
758        self.input_eof = true;
759        self.decompressor.write_eof()?;
760        tracing::trace!(decoded_buf_len = self.decompressor.get_ref().len());
761
762        Ok(())
763    }
764}
765
766impl Write for PushDecoder {
767    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
768        if buf.is_empty() {
769            return Ok(0);
770        }
771
772        if self.state == PushDecoderState::PendingHeader {
773            tracing::trace!("PendingHeader -> Header");
774            self.state = PushDecoderState::Header;
775        }
776
777        let write_len = self.decompressor.write(buf)?;
778        self.bytes_written_decoder += write_len as u64;
779
780        tracing::trace!(
781            buf_len = buf.len(),
782            write_len,
783            decoded_buf_len = self.decompressor.get_ref().len(),
784            "push decoder write"
785        );
786
787        if write_len != 0 {
788            // FIXME: handle the case where a single record is compressed as
789            // several zstd frames
790            Ok(write_len)
791        } else {
792            self.decompressor_eof = true;
793            self.deferred_input_buf.write_all(buf)?;
794            Ok(buf.len())
795        }
796    }
797
798    fn flush(&mut self) -> std::io::Result<()> {
799        self.decompressor.flush()
800    }
801}
802
803#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
804enum HeaderDetectResult {
805    Warc,
806    NotWarc,
807    Compression,
808    NotSure,
809}
810
811fn detect_header(buf: &[u8]) -> HeaderDetectResult {
812    if buf.starts_with(b"WARC/") {
813        HeaderDetectResult::Warc
814    } else if buf.starts_with(b"\x1f\x8b") || buf.starts_with(b"\x28\xb5\x2f\xfd") {
815        HeaderDetectResult::Compression
816    } else if buf.len() >= 5 {
817        HeaderDetectResult::NotWarc
818    } else {
819        HeaderDetectResult::NotSure
820    }
821}
822
823#[cfg(test)]
824mod tests {
825    use std::io::Cursor;
826
827    use super::*;
828
829    #[test]
830    fn test_detect_header() {
831        assert_eq!(detect_header(b"WA"), HeaderDetectResult::NotSure);
832        assert_eq!(detect_header(b"WARC"), HeaderDetectResult::NotSure);
833        assert_eq!(detect_header(b"WARC/"), HeaderDetectResult::Warc);
834        assert_eq!(detect_header(b"WARC/1"), HeaderDetectResult::Warc);
835        assert_eq!(detect_header(b"AAAAA"), HeaderDetectResult::NotWarc);
836        assert_eq!(detect_header(b"AAAAAA"), HeaderDetectResult::NotWarc);
837        assert_eq!(detect_header(b"\x1f\x8b"), HeaderDetectResult::Compression);
838        assert_eq!(detect_header(b"\x1f\x8b "), HeaderDetectResult::Compression);
839        assert_eq!(
840            detect_header(b"\x28\xb5\x2f\xfd"),
841            HeaderDetectResult::Compression
842        );
843        assert_eq!(detect_header(b"\x28\xb5"), HeaderDetectResult::NotSure);
844    }
845
846    #[tracing_test::traced_test]
847    #[test]
848    fn test_reader() {
849        let data = b"WARC/1.1\r\n\
850            Content-Length: 12\r\n\
851            \r\n\
852            Hello world!\
853            \r\n\r\n\
854            WARC/1.1\r\n\
855            Content-Length: 0\r\n\
856            \r\n\
857            \r\n\r\n";
858
859        let reader = Decoder::new(Cursor::new(data), DecoderConfig::default()).unwrap();
860
861        let (_header, mut reader) = reader.read_header().unwrap();
862        let mut block = Vec::new();
863        reader.read_to_end(&mut block).unwrap();
864        let mut reader = reader.finish_block().unwrap();
865
866        assert!(reader.has_next_record().unwrap());
867
868        let (_header, mut reader) = reader.read_header().unwrap();
869        let mut block = Vec::new();
870        reader.read_to_end(&mut block).unwrap();
871        let mut reader = reader.finish_block().unwrap();
872
873        assert!(!reader.has_next_record().unwrap());
874
875        reader.into_inner();
876    }
877
878    #[tracing_test::traced_test]
879    #[test]
880    fn test_push_reader() {
881        let _data = b"WARC/1.1\r\n\
882            Content-Length: 12\r\n\
883            \r\n\
884            Hello world!\
885            \r\n\r\n\
886            WARC/1.1\r\n\
887            Content-Length: 0\r\n\
888            \r\n\
889            \r\n\r\n";
890
891        let mut decoder = PushDecoder::new(DecoderConfig::default()).unwrap();
892
893        let event = decoder.get_event().unwrap();
894        assert!(event.is_ready());
895
896        decoder.write_all(b"WARC/1.1\r\n").unwrap(); // len = 10
897
898        let event = decoder.get_event().unwrap();
899        assert!(event.is_want_data());
900
901        decoder.write_all(b"Content-Length: 12\r\n").unwrap(); // len = 20
902        decoder.write_all(b"\r\n").unwrap(); // len = 2
903        decoder.write_all(b"Hello ").unwrap(); // len = 6
904
905        let event = decoder.get_event().unwrap();
906        assert!(event.is_header());
907        assert_eq!(decoder.record_boundary_position(), 0);
908
909        let event = decoder.get_event().unwrap();
910        assert!(event.is_block_data());
911        assert_eq!(event.as_block_data().unwrap(), b"Hello ");
912
913        let event = decoder.get_event().unwrap();
914        assert!(event.is_want_data());
915
916        decoder.write_all(b"world!\r\n").unwrap(); // len = 8
917
918        let event = decoder.get_event().unwrap();
919        assert!(event.is_block_data());
920        assert_eq!(event.as_block_data().unwrap(), b"world!");
921
922        let event = decoder.get_event().unwrap();
923        assert!(event.is_continue());
924
925        let event = decoder.get_event().unwrap();
926        assert!(event.is_want_data());
927
928        decoder.write_all(b"\r\n").unwrap(); // len = 2
929        decoder.write_all(b"WARC/1.1\r\n").unwrap();
930
931        let event = decoder.get_event().unwrap();
932        assert!(event.is_continue());
933
934        let event = decoder.get_event().unwrap();
935        assert!(event.is_end_record());
936
937        let event = decoder.get_event().unwrap();
938        assert!(event.is_want_data());
939
940        decoder
941            .write_all(
942                b"Content-Length: 0\r\n\
943                \r\n\
944                \r\n\r\n",
945            )
946            .unwrap();
947
948        decoder.write_eof().unwrap();
949
950        let event = decoder.get_event().unwrap();
951        assert!(event.is_header());
952        assert_eq!(decoder.record_boundary_position(), 48);
953
954        let event = decoder.get_event().unwrap();
955        assert!(event.is_continue());
956
957        let event = decoder.get_event().unwrap();
958        assert!(event.is_continue());
959
960        let event = decoder.get_event().unwrap();
961        assert!(event.is_end_record());
962
963        let event = decoder.get_event().unwrap();
964        assert!(event.is_finished());
965    }
966
967    #[tracing_test::traced_test]
968    #[test]
969    fn test_wrong_format() {
970        let data = b"CDX\r\n";
971        let reader = Decoder::new(Cursor::new(data), DecoderConfig::default()).unwrap();
972
973        let error = reader.read_header().unwrap_err();
974        dbg!(&error);
975        assert_eq!(
976            error.as_protocol().unwrap().kind(),
977            ProtocolErrorKind::UnknownHeader
978        );
979    }
980
981    #[tracing_test::traced_test]
982    #[test]
983    fn test_unexpected_compression() {
984        let data = b"\x1f\x8babc";
985        let reader = Decoder::new(Cursor::new(data), DecoderConfig::default()).unwrap();
986
987        let error = reader.read_header().unwrap_err();
988        dbg!(&error);
989        assert_eq!(
990            error.as_protocol().unwrap().kind(),
991            ProtocolErrorKind::UnexpectedCompression
992        );
993    }
994}