Skip to main content

tiff_reader/
filters.rs

1//! Compression filter pipeline for TIFF strip/tile decompression.
2
3#[cfg(any(feature = "jpeg", feature = "zstd"))]
4use std::io::Cursor;
5use std::io::Read;
6
7use crate::error::{Error, Result};
8use crate::header::ByteOrder;
9use tiff_core::{Compression, Predictor};
10
11/// Decompress a strip or tile according to the TIFF compression scheme.
12pub fn decompress(
13    compression: u16,
14    data: &[u8],
15    index: usize,
16    _jpeg_tables: Option<&[u8]>,
17) -> Result<Vec<u8>> {
18    match Compression::from_code(compression) {
19        Some(Compression::None) => Ok(data.to_vec()),
20        Some(Compression::Deflate | Compression::DeflateOld) => decompress_deflate(data, index),
21        Some(Compression::Lzw) => decompress_lzw(data, index),
22        Some(Compression::PackBits) => decompress_packbits(data, index),
23        #[cfg(feature = "jpeg")]
24        Some(Compression::OldJpeg) => Err(Error::UnsupportedCompression(compression)),
25        #[cfg(feature = "jpeg")]
26        Some(Compression::Jpeg) => decompress_jpeg(data, index, _jpeg_tables),
27        #[cfg(not(feature = "jpeg"))]
28        Some(Compression::OldJpeg | Compression::Jpeg) => {
29            Err(Error::UnsupportedCompression(compression))
30        }
31        #[cfg(feature = "zstd")]
32        Some(Compression::Zstd) => decompress_zstd(data, index),
33        #[cfg(not(feature = "zstd"))]
34        Some(Compression::Zstd) => Err(Error::UnsupportedCompression(compression)),
35        None => Err(Error::UnsupportedCompression(compression)),
36    }
37}
38
39/// Normalize row bytes into native-endian decoded samples and reverse any TIFF predictor.
40pub fn fix_endianness_and_predict(
41    row: &mut [u8],
42    bit_depth: u16,
43    samples: u16,
44    byte_order: ByteOrder,
45    predictor: u16,
46) -> Result<()> {
47    match Predictor::from_code(predictor) {
48        Some(Predictor::None) => {
49            fix_endianness(row, byte_order, bit_depth);
50            Ok(())
51        }
52        Some(Predictor::Horizontal) => {
53            fix_endianness(row, byte_order, bit_depth);
54            reverse_horizontal_predictor(row, bit_depth, samples);
55            Ok(())
56        }
57        Some(Predictor::FloatingPoint) => match bit_depth {
58            16 => {
59                let mut encoded = row.to_vec();
60                predict_f16(&mut encoded, row, samples);
61                Ok(())
62            }
63            32 => {
64                let mut encoded = row.to_vec();
65                predict_f32(&mut encoded, row, samples);
66                Ok(())
67            }
68            64 => {
69                let mut encoded = row.to_vec();
70                predict_f64(&mut encoded, row, samples);
71                Ok(())
72            }
73            _ => Err(Error::UnsupportedPredictor(3)),
74        },
75        None => Err(Error::UnsupportedPredictor(predictor)),
76    }
77}
78
79fn decompress_deflate(data: &[u8], index: usize) -> Result<Vec<u8>> {
80    use flate2::read::ZlibDecoder;
81
82    let mut decoder = ZlibDecoder::new(data);
83    let mut out = Vec::new();
84    decoder
85        .read_to_end(&mut out)
86        .map_err(|e| Error::DecompressionFailed {
87            index,
88            reason: format!("deflate: {e}"),
89        })?;
90    Ok(out)
91}
92
93fn decompress_lzw(data: &[u8], index: usize) -> Result<Vec<u8>> {
94    use weezl::decode::Decoder;
95    use weezl::BitOrder;
96
97    let mut decoder = Decoder::with_tiff_size_switch(BitOrder::Msb, 8);
98    decoder
99        .decode(data)
100        .map_err(|e| Error::DecompressionFailed {
101            index,
102            reason: format!("LZW: {e}"),
103        })
104}
105
106fn decompress_packbits(data: &[u8], index: usize) -> Result<Vec<u8>> {
107    let mut out = Vec::new();
108    let mut cursor = 0usize;
109
110    while cursor < data.len() {
111        let header = data[cursor] as i8;
112        cursor += 1;
113
114        if header >= 0 {
115            let count = header as usize + 1;
116            let end = cursor + count;
117            if end > data.len() {
118                return Err(Error::DecompressionFailed {
119                    index,
120                    reason: "PackBits literal run is truncated".into(),
121                });
122            }
123            out.extend_from_slice(&data[cursor..end]);
124            cursor = end;
125        } else if header != -128 {
126            if cursor >= data.len() {
127                return Err(Error::DecompressionFailed {
128                    index,
129                    reason: "PackBits repeat run is truncated".into(),
130                });
131            }
132            let count = (1i16 - header as i16) as usize;
133            let byte = data[cursor];
134            cursor += 1;
135            out.resize(out.len() + count, byte);
136        }
137    }
138
139    Ok(out)
140}
141
142#[cfg(feature = "jpeg")]
143fn decompress_jpeg(data: &[u8], index: usize, jpeg_tables: Option<&[u8]>) -> Result<Vec<u8>> {
144    let stream = merge_jpeg_stream(jpeg_tables, data);
145    let mut decoder = jpeg_decoder::Decoder::new(Cursor::new(stream));
146    decoder.decode().map_err(|e| Error::DecompressionFailed {
147        index,
148        reason: format!("JPEG: {e}"),
149    })
150}
151
152#[cfg(feature = "zstd")]
153fn decompress_zstd(data: &[u8], index: usize) -> Result<Vec<u8>> {
154    zstd::stream::decode_all(Cursor::new(data)).map_err(|e| Error::DecompressionFailed {
155        index,
156        reason: format!("ZSTD: {e}"),
157    })
158}
159
160#[cfg(feature = "jpeg")]
161fn merge_jpeg_stream(jpeg_tables: Option<&[u8]>, scan_data: &[u8]) -> Vec<u8> {
162    if jpeg_tables.is_none() {
163        return scan_data.to_vec();
164    }
165
166    let tables = jpeg_tables.unwrap_or_default();
167    let table_body = match tables.strip_suffix(&[0xff, 0xd9]) {
168        Some(without_eoi) => without_eoi,
169        None => tables,
170    };
171    let scan_body = match scan_data.strip_prefix(&[0xff, 0xd8]) {
172        Some(without_soi) => without_soi,
173        None => scan_data,
174    };
175
176    let mut merged = Vec::with_capacity(table_body.len() + scan_body.len() + 2);
177    if table_body.starts_with(&[0xff, 0xd8]) {
178        merged.extend_from_slice(table_body);
179    } else {
180        merged.extend_from_slice(&[0xff, 0xd8]);
181        merged.extend_from_slice(table_body);
182    }
183    merged.extend_from_slice(scan_body);
184    if !merged.ends_with(&[0xff, 0xd9]) {
185        merged.extend_from_slice(&[0xff, 0xd9]);
186    }
187    merged
188}
189
190fn fix_endianness(buf: &mut [u8], byte_order: ByteOrder, bit_depth: u16) {
191    let host_is_little_endian = cfg!(target_endian = "little");
192    let data_is_little_endian = matches!(byte_order, ByteOrder::LittleEndian);
193    if host_is_little_endian == data_is_little_endian {
194        return;
195    }
196
197    let chunk = match bit_depth {
198        0..=8 => 1,
199        9..=16 => 2,
200        17..=32 => 4,
201        _ => 8,
202    };
203    if chunk == 1 {
204        return;
205    }
206
207    for value in buf.chunks_exact_mut(chunk) {
208        value.reverse();
209    }
210}
211
212fn reverse_horizontal_predictor(buf: &mut [u8], bit_depth: u16, samples: u16) {
213    let bytes_per_value = match bit_depth {
214        0..=8 => 1,
215        9..=16 => 2,
216        17..=32 => 4,
217        _ => 8,
218    };
219    let lookback = usize::from(samples) * bytes_per_value;
220
221    match bytes_per_value {
222        1 => {
223            for index in lookback..buf.len() {
224                buf[index] = buf[index].wrapping_add(buf[index - lookback]);
225            }
226        }
227        2 => {
228            for index in (lookback..buf.len()).step_by(2) {
229                let current = u16::from_ne_bytes(buf[index..index + 2].try_into().unwrap());
230                let previous = u16::from_ne_bytes(
231                    buf[index - lookback..index - lookback + 2]
232                        .try_into()
233                        .unwrap(),
234                );
235                buf[index..index + 2]
236                    .copy_from_slice(&current.wrapping_add(previous).to_ne_bytes());
237            }
238        }
239        4 => {
240            for index in (lookback..buf.len()).step_by(4) {
241                let current = u32::from_ne_bytes(buf[index..index + 4].try_into().unwrap());
242                let previous = u32::from_ne_bytes(
243                    buf[index - lookback..index - lookback + 4]
244                        .try_into()
245                        .unwrap(),
246                );
247                buf[index..index + 4]
248                    .copy_from_slice(&current.wrapping_add(previous).to_ne_bytes());
249            }
250        }
251        _ => {
252            for index in (lookback..buf.len()).step_by(8) {
253                let current = u64::from_ne_bytes(buf[index..index + 8].try_into().unwrap());
254                let previous = u64::from_ne_bytes(
255                    buf[index - lookback..index - lookback + 8]
256                        .try_into()
257                        .unwrap(),
258                );
259                buf[index..index + 8]
260                    .copy_from_slice(&current.wrapping_add(previous).to_ne_bytes());
261            }
262        }
263    }
264}
265
266fn predict_f16(input: &mut [u8], output: &mut [u8], samples: u16) {
267    let samples = usize::from(samples);
268    for i in samples..input.len() {
269        input[i] = input[i].wrapping_add(input[i - samples]);
270    }
271    for (i, chunk) in output.chunks_mut(2).enumerate() {
272        chunk.copy_from_slice(&u16::to_ne_bytes(u16::from_be_bytes([
273            input[i],
274            input[input.len() / 2 + i],
275        ])));
276    }
277}
278
279fn predict_f32(input: &mut [u8], output: &mut [u8], samples: u16) {
280    let samples = usize::from(samples);
281    for i in samples..input.len() {
282        input[i] = input[i].wrapping_add(input[i - samples]);
283    }
284    for (i, chunk) in output.chunks_mut(4).enumerate() {
285        chunk.copy_from_slice(&u32::to_ne_bytes(u32::from_be_bytes([
286            input[i],
287            input[input.len() / 4 + i],
288            input[input.len() / 2 + i],
289            input[input.len() / 4 * 3 + i],
290        ])));
291    }
292}
293
294fn predict_f64(input: &mut [u8], output: &mut [u8], samples: u16) {
295    let samples = usize::from(samples);
296    for i in samples..input.len() {
297        input[i] = input[i].wrapping_add(input[i - samples]);
298    }
299    for (i, chunk) in output.chunks_mut(8).enumerate() {
300        chunk.copy_from_slice(&u64::to_ne_bytes(u64::from_be_bytes([
301            input[i],
302            input[input.len() / 8 + i],
303            input[input.len() / 8 * 2 + i],
304            input[input.len() / 8 * 3 + i],
305            input[input.len() / 8 * 4 + i],
306            input[input.len() / 8 * 5 + i],
307            input[input.len() / 8 * 6 + i],
308            input[input.len() / 8 * 7 + i],
309        ])));
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use std::path::Path;
316
317    #[cfg(feature = "jpeg")]
318    use super::merge_jpeg_stream;
319    use super::{decompress_lzw, decompress_packbits, fix_endianness_and_predict};
320    use crate::header::ByteOrder;
321
322    #[test]
323    fn horizontal_predictor_restores_u16_rows() {
324        let mut row = vec![1, 0, 1, 0, 2, 0];
325        fix_endianness_and_predict(&mut row, 16, 1, ByteOrder::LittleEndian, 2).unwrap();
326        assert_eq!(row, vec![1, 0, 2, 0, 4, 0]);
327    }
328
329    #[test]
330    fn packbits_decoder_rejects_truncated_repeat_run() {
331        let err = decompress_packbits(&[0xff], 0).unwrap_err();
332        assert!(err.to_string().contains("PackBits"));
333    }
334
335    #[test]
336    fn lzw_real_cog_tile_requires_repeated_trailer_bytes() {
337        let fixture = Path::new(env!("CARGO_MANIFEST_DIR"))
338            .join("../testdata/interoperability/gdal/gcore/data/cog/byte_little_endian_golden.tif");
339        let bytes = std::fs::read(fixture).unwrap();
340
341        let without_trailer = &bytes[570..570 + 1223];
342        let with_trailer = &bytes[570..570 + 1227];
343
344        assert!(decompress_lzw(without_trailer, 0).is_ok());
345        assert!(decompress_lzw(with_trailer, 0).is_ok());
346    }
347
348    #[cfg(feature = "jpeg")]
349    #[test]
350    fn merges_jpeg_tables_with_abbreviated_scan() {
351        let merged = merge_jpeg_stream(
352            Some(&[0xff, 0xd8, 0xff, 0xdb, 0x00, 0x43, 0xff, 0xd9]),
353            &[0xff, 0xda, 0x00, 0x08, 0x00],
354        );
355        assert_eq!(&merged[..6], &[0xff, 0xd8, 0xff, 0xdb, 0x00, 0x43]);
356        assert!(merged.ends_with(&[0xff, 0xd9]));
357    }
358}