Skip to main content

tiff_reader/
filters.rs

1//! Compression filter pipeline for TIFF strip/tile decompression.
2
3#[cfg(feature = "jpeg")]
4use std::io::Cursor;
5use std::io::Read;
6
7use crate::error::{Error, Result};
8use crate::header::ByteOrder;
9use tiff_core::{Compression, Predictor};
10
11/// Decompress a strip or tile according to the TIFF compression scheme.
12pub fn decompress(
13    compression: u16,
14    data: &[u8],
15    index: usize,
16    _jpeg_tables: Option<&[u8]>,
17) -> Result<Vec<u8>> {
18    match Compression::from_code(compression) {
19        Some(Compression::None) => Ok(data.to_vec()),
20        Some(Compression::Deflate | Compression::DeflateOld) => decompress_deflate(data, index),
21        Some(Compression::Lzw) => decompress_lzw(data, index),
22        Some(Compression::PackBits) => decompress_packbits(data, index),
23        #[cfg(feature = "jpeg")]
24        Some(Compression::OldJpeg) => Err(Error::UnsupportedCompression(compression)),
25        #[cfg(feature = "jpeg")]
26        Some(Compression::Jpeg) => decompress_jpeg(data, index, _jpeg_tables),
27        #[cfg(not(feature = "jpeg"))]
28        Some(Compression::OldJpeg | Compression::Jpeg) => {
29            Err(Error::UnsupportedCompression(compression))
30        }
31        #[cfg(feature = "zstd")]
32        Some(Compression::Zstd) => decompress_zstd(data, index),
33        #[cfg(not(feature = "zstd"))]
34        Some(Compression::Zstd) => Err(Error::UnsupportedCompression(compression)),
35        None => Err(Error::UnsupportedCompression(compression)),
36    }
37}
38
39/// Normalize row bytes into native-endian decoded samples and reverse any TIFF predictor.
40pub fn fix_endianness_and_predict(
41    row: &mut [u8],
42    bit_depth: u16,
43    samples: u16,
44    byte_order: ByteOrder,
45    predictor: u16,
46) -> Result<()> {
47    match Predictor::from_code(predictor) {
48        Some(Predictor::None) => {
49            fix_endianness(row, byte_order, bit_depth);
50            Ok(())
51        }
52        Some(Predictor::Horizontal) => {
53            fix_endianness(row, byte_order, bit_depth);
54            reverse_horizontal_predictor(row, bit_depth, samples);
55            Ok(())
56        }
57        Some(Predictor::FloatingPoint) => match bit_depth {
58            16 => {
59                let mut encoded = row.to_vec();
60                predict_f16(&mut encoded, row, samples);
61                Ok(())
62            }
63            32 => {
64                let mut encoded = row.to_vec();
65                predict_f32(&mut encoded, row, samples);
66                Ok(())
67            }
68            64 => {
69                let mut encoded = row.to_vec();
70                predict_f64(&mut encoded, row, samples);
71                Ok(())
72            }
73            _ => Err(Error::UnsupportedPredictor(3)),
74        },
75        None => Err(Error::UnsupportedPredictor(predictor)),
76    }
77}
78
79fn decompress_deflate(data: &[u8], index: usize) -> Result<Vec<u8>> {
80    use flate2::read::ZlibDecoder;
81
82    let mut decoder = ZlibDecoder::new(data);
83    let mut out = Vec::new();
84    decoder
85        .read_to_end(&mut out)
86        .map_err(|e| Error::DecompressionFailed {
87            index,
88            reason: format!("deflate: {e}"),
89        })?;
90    Ok(out)
91}
92
93fn decompress_lzw(data: &[u8], index: usize) -> Result<Vec<u8>> {
94    use weezl::decode::Decoder;
95    use weezl::BitOrder;
96
97    let mut decoder = Decoder::with_tiff_size_switch(BitOrder::Msb, 8);
98    decoder
99        .decode(data)
100        .map_err(|e| Error::DecompressionFailed {
101            index,
102            reason: format!("LZW: {e}"),
103        })
104}
105
106fn decompress_packbits(data: &[u8], index: usize) -> Result<Vec<u8>> {
107    let mut out = Vec::new();
108    let mut cursor = 0usize;
109
110    while cursor < data.len() {
111        let header = data[cursor] as i8;
112        cursor += 1;
113
114        if header >= 0 {
115            let count = header as usize + 1;
116            let end = cursor + count;
117            if end > data.len() {
118                return Err(Error::DecompressionFailed {
119                    index,
120                    reason: "PackBits literal run is truncated".into(),
121                });
122            }
123            out.extend_from_slice(&data[cursor..end]);
124            cursor = end;
125        } else if header != -128 {
126            if cursor >= data.len() {
127                return Err(Error::DecompressionFailed {
128                    index,
129                    reason: "PackBits repeat run is truncated".into(),
130                });
131            }
132            let count = (1i16 - header as i16) as usize;
133            let byte = data[cursor];
134            cursor += 1;
135            out.resize(out.len() + count, byte);
136        }
137    }
138
139    Ok(out)
140}
141
142#[cfg(feature = "jpeg")]
143fn decompress_jpeg(data: &[u8], index: usize, jpeg_tables: Option<&[u8]>) -> Result<Vec<u8>> {
144    use jpeg_decoder::Decoder;
145
146    let stream = merge_jpeg_stream(jpeg_tables, data);
147    let mut decoder = Decoder::new(Cursor::new(stream));
148    decoder.decode().map_err(|e| Error::DecompressionFailed {
149        index,
150        reason: format!("JPEG: {e}"),
151    })
152}
153
154#[cfg(feature = "zstd")]
155fn decompress_zstd(data: &[u8], index: usize) -> Result<Vec<u8>> {
156    zstd::stream::decode_all(Cursor::new(data)).map_err(|e| Error::DecompressionFailed {
157        index,
158        reason: format!("ZSTD: {e}"),
159    })
160}
161
162#[cfg(feature = "jpeg")]
163fn merge_jpeg_stream(jpeg_tables: Option<&[u8]>, scan_data: &[u8]) -> Vec<u8> {
164    if jpeg_tables.is_none() {
165        return scan_data.to_vec();
166    }
167
168    let tables = jpeg_tables.unwrap_or_default();
169    let table_body = match tables.strip_suffix(&[0xff, 0xd9]) {
170        Some(without_eoi) => without_eoi,
171        None => tables,
172    };
173    let scan_body = match scan_data.strip_prefix(&[0xff, 0xd8]) {
174        Some(without_soi) => without_soi,
175        None => scan_data,
176    };
177
178    let mut merged = Vec::with_capacity(table_body.len() + scan_body.len() + 2);
179    if table_body.starts_with(&[0xff, 0xd8]) {
180        merged.extend_from_slice(table_body);
181    } else {
182        merged.extend_from_slice(&[0xff, 0xd8]);
183        merged.extend_from_slice(table_body);
184    }
185    merged.extend_from_slice(scan_body);
186    if !merged.ends_with(&[0xff, 0xd9]) {
187        merged.extend_from_slice(&[0xff, 0xd9]);
188    }
189    merged
190}
191
192fn fix_endianness(buf: &mut [u8], byte_order: ByteOrder, bit_depth: u16) {
193    let host_is_little_endian = cfg!(target_endian = "little");
194    let data_is_little_endian = matches!(byte_order, ByteOrder::LittleEndian);
195    if host_is_little_endian == data_is_little_endian {
196        return;
197    }
198
199    let chunk = match bit_depth {
200        0..=8 => 1,
201        9..=16 => 2,
202        17..=32 => 4,
203        _ => 8,
204    };
205    if chunk == 1 {
206        return;
207    }
208
209    for value in buf.chunks_exact_mut(chunk) {
210        value.reverse();
211    }
212}
213
214fn reverse_horizontal_predictor(buf: &mut [u8], bit_depth: u16, samples: u16) {
215    let bytes_per_value = match bit_depth {
216        0..=8 => 1,
217        9..=16 => 2,
218        17..=32 => 4,
219        _ => 8,
220    };
221    let lookback = usize::from(samples) * bytes_per_value;
222
223    match bytes_per_value {
224        1 => {
225            for index in lookback..buf.len() {
226                buf[index] = buf[index].wrapping_add(buf[index - lookback]);
227            }
228        }
229        2 => {
230            for index in (lookback..buf.len()).step_by(2) {
231                let current = u16::from_ne_bytes(buf[index..index + 2].try_into().unwrap());
232                let previous = u16::from_ne_bytes(
233                    buf[index - lookback..index - lookback + 2]
234                        .try_into()
235                        .unwrap(),
236                );
237                buf[index..index + 2]
238                    .copy_from_slice(&current.wrapping_add(previous).to_ne_bytes());
239            }
240        }
241        4 => {
242            for index in (lookback..buf.len()).step_by(4) {
243                let current = u32::from_ne_bytes(buf[index..index + 4].try_into().unwrap());
244                let previous = u32::from_ne_bytes(
245                    buf[index - lookback..index - lookback + 4]
246                        .try_into()
247                        .unwrap(),
248                );
249                buf[index..index + 4]
250                    .copy_from_slice(&current.wrapping_add(previous).to_ne_bytes());
251            }
252        }
253        _ => {
254            for index in (lookback..buf.len()).step_by(8) {
255                let current = u64::from_ne_bytes(buf[index..index + 8].try_into().unwrap());
256                let previous = u64::from_ne_bytes(
257                    buf[index - lookback..index - lookback + 8]
258                        .try_into()
259                        .unwrap(),
260                );
261                buf[index..index + 8]
262                    .copy_from_slice(&current.wrapping_add(previous).to_ne_bytes());
263            }
264        }
265    }
266}
267
268fn predict_f16(input: &mut [u8], output: &mut [u8], samples: u16) {
269    let samples = usize::from(samples);
270    for i in samples..input.len() {
271        input[i] = input[i].wrapping_add(input[i - samples]);
272    }
273    for (i, chunk) in output.chunks_mut(2).enumerate() {
274        chunk.copy_from_slice(&u16::to_ne_bytes(u16::from_be_bytes([
275            input[i],
276            input[input.len() / 2 + i],
277        ])));
278    }
279}
280
281fn predict_f32(input: &mut [u8], output: &mut [u8], samples: u16) {
282    let samples = usize::from(samples);
283    for i in samples..input.len() {
284        input[i] = input[i].wrapping_add(input[i - samples]);
285    }
286    for (i, chunk) in output.chunks_mut(4).enumerate() {
287        chunk.copy_from_slice(&u32::to_ne_bytes(u32::from_be_bytes([
288            input[i],
289            input[input.len() / 4 + i],
290            input[input.len() / 2 + i],
291            input[input.len() / 4 * 3 + i],
292        ])));
293    }
294}
295
296fn predict_f64(input: &mut [u8], output: &mut [u8], samples: u16) {
297    let samples = usize::from(samples);
298    for i in samples..input.len() {
299        input[i] = input[i].wrapping_add(input[i - samples]);
300    }
301    for (i, chunk) in output.chunks_mut(8).enumerate() {
302        chunk.copy_from_slice(&u64::to_ne_bytes(u64::from_be_bytes([
303            input[i],
304            input[input.len() / 8 + i],
305            input[input.len() / 8 * 2 + i],
306            input[input.len() / 8 * 3 + i],
307            input[input.len() / 8 * 4 + i],
308            input[input.len() / 8 * 5 + i],
309            input[input.len() / 8 * 6 + i],
310            input[input.len() / 8 * 7 + i],
311        ])));
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use std::path::Path;
318
319    #[cfg(feature = "jpeg")]
320    use super::merge_jpeg_stream;
321    use super::{decompress_lzw, decompress_packbits, fix_endianness_and_predict};
322    use crate::header::ByteOrder;
323
324    #[test]
325    fn horizontal_predictor_restores_u16_rows() {
326        let mut row = vec![1, 0, 1, 0, 2, 0];
327        fix_endianness_and_predict(&mut row, 16, 1, ByteOrder::LittleEndian, 2).unwrap();
328        assert_eq!(row, vec![1, 0, 2, 0, 4, 0]);
329    }
330
331    #[test]
332    fn packbits_decoder_rejects_truncated_repeat_run() {
333        let err = decompress_packbits(&[0xff], 0).unwrap_err();
334        assert!(err.to_string().contains("PackBits"));
335    }
336
337    #[test]
338    fn lzw_real_cog_tile_requires_repeated_trailer_bytes() {
339        let fixture = Path::new(env!("CARGO_MANIFEST_DIR"))
340            .join("../testdata/interoperability/gdal/gcore/data/cog/byte_little_endian_golden.tif");
341        let bytes = std::fs::read(fixture).unwrap();
342
343        let without_trailer = &bytes[570..570 + 1223];
344        let with_trailer = &bytes[570..570 + 1227];
345
346        assert!(decompress_lzw(without_trailer, 0).is_ok());
347        assert!(decompress_lzw(with_trailer, 0).is_ok());
348    }
349
350    #[cfg(feature = "jpeg")]
351    #[test]
352    fn merges_jpeg_tables_with_abbreviated_scan() {
353        let merged = merge_jpeg_stream(
354            Some(&[0xff, 0xd8, 0xff, 0xdb, 0x00, 0x43, 0xff, 0xd9]),
355            &[0xff, 0xda, 0x00, 0x08, 0x00],
356        );
357        assert_eq!(&merged[..6], &[0xff, 0xd8, 0xff, 0xdb, 0x00, 0x43]);
358        assert!(merged.ends_with(&[0xff, 0xd9]));
359    }
360}