streaming_libdeflate_rs/streams/
deflate_chunked_buffer_input.rs

1use crate::{decompress_utils::copy_dword_unaligned, DeflateInput, DeflateOutput};
2use nightly_quirks::utils::NightlyUtils;
3use std::cmp::min;
4
5pub struct DeflateChunkedBufferInput<'a> {
6    buffer: Box<[u8]>,
7    buf_size: usize,
8    global_position_offset: usize,
9    position: usize,
10    end_position: usize,
11    overread_position_limit: usize,
12    func: Box<dyn FnMut(&mut [u8]) -> usize + 'a>,
13}
14
15impl<'a> DeflateChunkedBufferInput<'a> {
16    pub fn new<F: FnMut(&mut [u8]) -> usize + 'a>(read_func: F, buf_size: usize) -> Self {
17        Self {
18            buffer: unsafe {
19                NightlyUtils::box_new_uninit_slice_assume_init(buf_size + Self::MAX_OVERREAD)
20            },
21            buf_size,
22            global_position_offset: 0,
23            position: 0,
24            end_position: 0,
25            overread_position_limit: 0,
26            func: Box::new(read_func),
27        }
28    }
29
30    #[cold]
31    #[inline(never)]
32    fn refill_buffer(&mut self) -> bool {
33        let keep_buf_len = min(self.position, Self::MAX_LOOK_BACK);
34        let move_offset = self.position - keep_buf_len;
35        let move_amount = self.end_position - move_offset;
36
37        self.global_position_offset += move_offset;
38        unsafe {
39            std::ptr::copy(
40                self.buffer.as_ptr().add(move_offset),
41                self.buffer.as_mut_ptr(),
42                move_amount,
43            );
44        }
45        self.position -= move_offset;
46        self.end_position -= move_offset;
47
48        let count = (self.func)(&mut self.buffer[self.end_position..self.buf_size]);
49
50        self.end_position += count;
51
52        // Keep at least MAX_OVERREAD bytes available
53        self.overread_position_limit = (self.end_position - self.position).max(Self::MAX_OVERREAD)
54            + self.position
55            - Self::MAX_OVERREAD;
56
57        self.position < self.end_position
58    }
59}
60
61impl<'a> DeflateInput for DeflateChunkedBufferInput<'a> {
62    #[inline(always)]
63    unsafe fn get_le_word_no_advance(&mut self) -> usize {
64        usize::from_le_bytes(
65            *(self.buffer.as_ptr().add(self.position) as *const [u8; std::mem::size_of::<usize>()]),
66        )
67        .to_le()
68    }
69
70    #[inline(always)]
71    fn move_stream_pos<const REFILL: bool>(&mut self, amount: isize) {
72        const REFILL: bool = true;
73        if REFILL && amount > 0 {
74            if self.position + amount as usize > self.end_position {
75                self.refill_buffer();
76            }
77        }
78
79        self.position = self.position.wrapping_add_signed(amount);
80    }
81
82    #[inline(always)]
83    fn get_stream_pos_mut(&mut self) -> &mut usize {
84        &mut self.position
85    }
86
87    fn tell_stream_pos(&self) -> usize {
88        self.global_position_offset + self.position
89    }
90
91    #[inline(always)]
92    fn read<const REFILL: bool>(&mut self, out_data: &mut [u8]) -> usize {
93        const REFILL: bool = true;
94        if REFILL && self.end_position < self.position + out_data.len() {
95            self.refill_buffer();
96        }
97
98        let avail_bytes = if REFILL {
99            min(out_data.len(), self.end_position - self.position)
100        } else {
101            out_data.len()
102        };
103
104        unsafe {
105            std::ptr::copy_nonoverlapping(
106                self.buffer.as_ptr().add(self.position),
107                out_data.as_mut_ptr(),
108                avail_bytes,
109            );
110            self.position += avail_bytes;
111        }
112        avail_bytes
113    }
114
115    #[inline(always)]
116    fn ensure_overread_length(&mut self) {
117        if self.position > self.overread_position_limit {
118            self.refill_buffer();
119        }
120    }
121
122    fn has_readable_overread(&self) -> bool {
123        self.position <= self.overread_position_limit
124    }
125
126    fn has_valid_bytes_slow(&mut self) -> bool {
127        if self.position >= self.end_position {
128            self.refill_buffer();
129        }
130        self.position < self.end_position
131    }
132
133    #[inline(always)]
134    fn read_exact_into<O: DeflateOutput>(&mut self, out_stream: &mut O, mut length: usize) -> bool {
135        const CHUNK_SIZE: usize = 256;
136        unsafe {
137            while length > 0 {
138                out_stream.flush_ensure_length(CHUNK_SIZE);
139                self.ensure_overread_length();
140
141                let mut src = self.buffer.as_ptr().add(self.position) as *const u64;
142                let mut dst = out_stream.get_output_ptr() as *mut u64;
143
144                let max_copyable = length
145                    .min(CHUNK_SIZE)
146                    .min(self.end_position - self.position);
147                let mut copyable = max_copyable as isize;
148                while copyable > 0 {
149                    copy_dword_unaligned(src, dst);
150                    src = src.add(2);
151                    dst = dst.add(2);
152                    copyable -= 16;
153                }
154                length -= max_copyable - copyable.max(0) as usize;
155
156                let mut src = src as *mut u8;
157                let mut dst = dst as *mut u8;
158                if copyable < 0 {
159                    // Remove extra copied bytes
160                    src = src.sub(-copyable as usize);
161                    dst = dst.sub(-copyable as usize);
162                }
163
164                out_stream.set_output_ptr(dst);
165
166                self.position = src.offset_from(self.buffer.as_ptr()) as usize;
167            }
168        }
169
170        true
171    }
172}