zune_psd/
decoder.rs

1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9//! A simple PSD reader.
10//!
11//! This crate features a simple and performant PSD reader
12//! based on STB implementation.
13//!
14//! It currently does not support a lot of spec details.
15//! Only extracting the image without respecting blend layers
16//! and masks but such functionality will be added with time
17//!
18//!
19use alloc::vec;
20use alloc::vec::Vec;
21use core::cmp::Ordering;
22
23use zune_core::bit_depth::BitDepth;
24use zune_core::bytestream::{ZByteReader, ZReaderTrait};
25use zune_core::colorspace::ColorSpace;
26use zune_core::log::trace;
27use zune_core::options::DecoderOptions;
28use zune_core::result::DecodingResult;
29
30use crate::constants::{ColorModes, CompressionMethod, PSD_IDENTIFIER_BE};
31use crate::errors::PSDDecodeErrors;
32
33/// A simple Photoshop PSD reader.
34///
35/// This currently doesn't support layer flattening
36/// but it's useful enough in that we can extract images
37/// from it.
38///
39/// Further work will go onto adding a renderer that flattens
40/// image pixels. But for now this is a good basis.
41pub struct PSDDecoder<T>
42where
43    T: ZReaderTrait
44{
45    width:          usize,
46    height:         usize,
47    decoded_header: bool,
48    stream:         ZByteReader<T>,
49    options:        DecoderOptions,
50    depth:          BitDepth,
51    color_type:     Option<ColorModes>,
52    compression:    CompressionMethod,
53    channel_count:  usize
54}
55
56impl<T> PSDDecoder<T>
57where
58    T: ZReaderTrait
59{
60    /// Create a new decoder that reads a photoshop encoded file
61    /// from `T` and returns pixels
62    ///
63    /// # Arguments
64    /// - data: Data source, it has to implement the `ZReaderTrait
65    pub fn new(data: T) -> PSDDecoder<T> {
66        Self::new_with_options(data, DecoderOptions::default())
67    }
68
69    /// Creates a new decoder with options that influence decoding routines
70    ///
71    /// # Arguments
72    /// - data: Data source
73    /// - options: Custom options for the decoder
74    pub fn new_with_options(data: T, options: DecoderOptions) -> PSDDecoder<T> {
75        PSDDecoder {
76            width: 0,
77            height: 0,
78            decoded_header: false,
79            stream: ZByteReader::new(data),
80            options,
81            depth: BitDepth::Eight,
82            color_type: None,
83            compression: CompressionMethod::NoCompression,
84            channel_count: 0
85        }
86    }
87
88    /// Decode headers from the encoded image
89    ///
90    /// This confirms whether the image is a photoshop image and extracts
91    /// relevant information from the image including width,height and extra information.
92    ///
93    pub fn decode_headers(&mut self) -> Result<(), PSDDecodeErrors> {
94        if self.decoded_header {
95            return Ok(());
96        }
97        // Check identifier
98        let magic = self.stream.get_u32_be_err()?;
99
100        if magic != PSD_IDENTIFIER_BE {
101            return Err(PSDDecodeErrors::WrongMagicBytes(magic));
102        }
103
104        //  file version
105        let version = self.stream.get_u16_be_err()?;
106
107        if version != 1 {
108            return Err(PSDDecodeErrors::UnsupportedFileType(version));
109        }
110        // Skip 6 reserved bytes
111        self.stream.skip(6);
112        // Read the number of channels (R, G, B, A, etc).
113        let channel_count = self.stream.get_u16_be_err()?;
114
115        if channel_count > 4 {
116            return Err(PSDDecodeErrors::UnsupportedChannelCount(channel_count));
117        }
118
119        self.channel_count = usize::from(channel_count);
120
121        let height = self.stream.get_u32_be_err()? as usize;
122        let width = self.stream.get_u32_be_err()? as usize;
123
124        if width > self.options.get_max_width() {
125            return Err(PSDDecodeErrors::LargeDimensions(
126                self.options.get_max_width(),
127                width
128            ));
129        }
130
131        if height > self.options.get_max_height() {
132            return Err(PSDDecodeErrors::LargeDimensions(
133                self.options.get_max_height(),
134                height
135            ));
136        }
137
138        self.width = width;
139        self.height = height;
140
141        if self.width == 0 || self.height == 0 || self.channel_count == 0 {
142            return Err(PSDDecodeErrors::ZeroDimensions);
143        }
144
145        let depth = self.stream.get_u16_be_err()?;
146
147        if depth != 8 && depth != 16 {
148            return Err(PSDDecodeErrors::UnsupportedBitDepth(depth));
149        }
150        let im_depth = match depth {
151            8 => BitDepth::Eight,
152            16 => BitDepth::Sixteen,
153            _ => unreachable!()
154        };
155
156        self.depth = im_depth;
157
158        let color_mode = self.stream.get_u16_be_err()?;
159
160        let color_enum = ColorModes::from_int(color_mode);
161
162        if let Some(color) = color_enum {
163            if !matches!(
164                color,
165                ColorModes::RGB | ColorModes::Grayscale | ColorModes::CYMK
166            ) {
167                return Err(PSDDecodeErrors::UnsupportedColorFormat(color_enum));
168            }
169        } else {
170            return Err(PSDDecodeErrors::Generic("Unknown color mode"));
171        }
172        self.color_type = color_enum;
173
174        // skip mode data
175        let bytes = self.stream.get_u32_be_err()? as usize;
176        self.stream.skip(bytes);
177
178        // skip image resources
179        let bytes = self.stream.get_u32_be_err()? as usize;
180        self.stream.skip(bytes);
181
182        // skip reserved data
183        let bytes = self.stream.get_u32_be_err()? as usize;
184        self.stream.skip(bytes);
185
186        // find out if data is compressed
187        let compression = self.stream.get_u16_be_err()?;
188
189        if compression > 1 {
190            return Err(PSDDecodeErrors::UnknownCompression);
191        }
192        if self.color_type == Some(ColorModes::Grayscale) {
193            // PSD may have grayscale images with more than one
194            // channel and will specify channel_count as 3.
195            // So let's fix that here
196            self.channel_count = 1;
197        }
198
199        self.compression = CompressionMethod::from_int(compression).unwrap();
200
201        self.decoded_header = true;
202
203        trace!("Image width:{}", self.width);
204        trace!("Image height:{}", self.height);
205        trace!("Channels: {}", self.channel_count);
206        trace!("Bit depth : {:?}", self.depth);
207
208        Ok(())
209    }
210
211    /// Decode an image to bytes without regard to depth or endianness
212    ///
213    /// # Returns
214    /// Ok(bytes):  Raw bytes of the image
215    /// Err(E): An error if it occurred during decoding
216    pub fn decode_raw(&mut self) -> Result<Vec<u8>, PSDDecodeErrors> {
217        if !self.decoded_header {
218            self.decode_headers()?;
219        }
220
221        let pixel_count = self.width * self.height;
222
223        let mut result = match (self.compression, self.depth) {
224            (CompressionMethod::RLE, BitDepth::Eight) => {
225                // RLE
226                // Loop until you get the number of unpacked bytes you are expecting:
227                //     Read the next source byte into n.
228                //     If n is between 0 and 127 inclusive, copy the next n+1 bytes
229                //     literally. Else if n is between -127 and -1 inclusive, copy the next
230                //     byte -n+1 times. Else if n is 128, noop.
231                // Endloop
232
233                // The RLE-compressed data is preceded by a 2-byte data count for each row
234                // in the data, which we're going to just skip.
235                let skipped = self.height * self.channel_count * 2;
236                self.stream.skip(skipped);
237
238                let mut out_channel = vec![0; pixel_count * self.channel_count + 10];
239
240                for channel in 0..self.channel_count {
241                    let pixel_count = self.width * self.height;
242                    self.psd_decode_rle(pixel_count, &mut out_channel[channel..])?;
243                }
244
245                out_channel.truncate(pixel_count * self.channel_count);
246
247                out_channel
248            }
249            (CompressionMethod::NoCompression, BitDepth::Eight) => {
250                // We're at the raw image data.  It's each channel in order (Red, Green,
251                // Blue, Alpha, ...) where each channel consists of an 8-bit
252                // value for each pixel in the image.
253
254                // Read the data by channel.
255
256                let mut out_channel = vec![0; self.width * self.height * self.channel_count + 10];
257                let pixel_count = self.width * self.height;
258
259                // check we have enough data
260                if !self.stream.has(pixel_count * self.channel_count) {
261                    return Err(PSDDecodeErrors::Generic("Incomplete bitstream"));
262                }
263
264                for channel in 0..self.channel_count {
265                    let mut i = channel;
266
267                    while i < pixel_count {
268                        out_channel[i] = self.stream.get_u8();
269                        i += self.channel_count;
270                    }
271                }
272
273                out_channel.truncate(pixel_count * self.channel_count);
274                out_channel
275            }
276
277            (CompressionMethod::NoCompression, BitDepth::Sixteen) => {
278                // We're at the raw image data.  It's each channel in order (Red, Green,
279                // Blue, Alpha, ...) where each channel consists of an 8-bit
280                // value for each pixel in the image.
281
282                // Read the data by channel.
283
284                // size of a single channel
285                let channel_dimensions = self.width * self.height;
286
287                let mut out_channel = vec![0; 2 * (channel_dimensions * self.channel_count + 10)];
288
289                let pixel_count = channel_dimensions * 2;
290
291                // check we have enough data
292                if !self.stream.has(pixel_count * self.channel_count) {
293                    return Err(PSDDecodeErrors::Generic("Incomplete bitstream"));
294                }
295
296                // iterate per channel
297                for channel in 0..self.channel_count {
298                    let i = channel * 2;
299                    let out_chunks = out_channel[i..].chunks_exact_mut(self.channel_count * 2);
300
301                    // iterate only taking the image dimensions
302                    for out in out_chunks.take(channel_dimensions) {
303                        let value = self.stream.get_u16_be();
304
305                        out[..2].copy_from_slice(&value.to_ne_bytes());
306                    }
307                }
308
309                out_channel.truncate(pixel_count * self.channel_count);
310                out_channel
311            }
312            _ => return Err(PSDDecodeErrors::Generic("Not implemented or Unknown"))
313        };
314        // remove white matte from psd
315        if self.channel_count >= 4 {
316            match self.depth {
317                BitDepth::Sixteen => {
318                    for pixel in result.chunks_exact_mut(8) {
319                        let px3 = u16::from_be_bytes(pixel[6..8].try_into().unwrap());
320                        if px3 != 0 && px3 != 65535 {
321                            let px0 = u16::from_be_bytes(pixel[0..2].try_into().unwrap());
322                            let px1 = u16::from_be_bytes(pixel[2..4].try_into().unwrap());
323                            let px2 = u16::from_be_bytes(pixel[4..6].try_into().unwrap());
324
325                            let a = f32::from(px3) / 65535.0;
326                            let ra = 1.0 / a;
327                            let inv_a = 65535.0 * (1.0 - ra);
328
329                            let x = (f32::from(px0) * ra + inv_a) as u16;
330                            let y = (f32::from(px1) * ra + inv_a) as u16;
331                            let z = (f32::from(px2) * ra + inv_a) as u16;
332
333                            pixel[0..2].copy_from_slice(&x.to_ne_bytes());
334                            pixel[2..4].copy_from_slice(&y.to_ne_bytes());
335                            pixel[4..6].copy_from_slice(&z.to_ne_bytes());
336                        }
337                    }
338                }
339                BitDepth::Eight => {
340                    for pixel in result.chunks_exact_mut(4) {
341                        if pixel[3] != 0 && pixel[3] != 255 {
342                            let a = f32::from(pixel[3]) / 255.0;
343                            let ra = 1.0 / a;
344                            let inv_a = 255.0 * (1.0 - ra);
345                            pixel[0] = (f32::from(pixel[0]) * ra + inv_a) as u8;
346                            pixel[1] = (f32::from(pixel[1]) * ra + inv_a) as u8;
347                            pixel[2] = (f32::from(pixel[2]) * ra + inv_a) as u8;
348                        }
349                    }
350                }
351                _ => unreachable!()
352            }
353        }
354        Ok(result)
355    }
356    /// Decode a PSD file extracting the image only
357    ///
358    /// Currently this does it without respect to  layers
359    /// and such, only extracting the PSD image, hence might not be the
360    /// most useful one.
361    ///
362    pub fn decode(&mut self) -> Result<DecodingResult, PSDDecodeErrors> {
363        let raw = self.decode_raw()?;
364
365        if self.depth == BitDepth::Eight {
366            return Ok(DecodingResult::U8(raw));
367        }
368        if self.depth == BitDepth::Sixteen {
369            // https://github.com/etemesi254/zune-image/issues/36
370            let new_array: Vec<u16> = raw
371                .chunks_exact(2)
372                .map(|chunk| {
373                    let value: [u8; 2] = chunk.try_into().unwrap();
374                    u16::from_be_bytes(value)
375                })
376                .collect();
377
378            return Ok(DecodingResult::U16(new_array));
379        }
380
381        Err(PSDDecodeErrors::Generic("Not implemented"))
382    }
383
384    fn psd_decode_rle(
385        &mut self, pixel_count: usize, buffer: &mut [u8]
386    ) -> Result<(), PSDDecodeErrors> {
387        let mut count = 0;
388        let mut nleft = pixel_count - count;
389
390        let mut position = 0;
391
392        while nleft > 0 {
393            let mut len = usize::from(self.stream.get_u8());
394
395            match len.cmp(&128) {
396                Ordering::Less => {
397                    // copy next len+1 bytes literally
398                    len += 1;
399                    if len > nleft {
400                        return Err(PSDDecodeErrors::BadRLE);
401                    }
402                    count += len;
403
404                    if position + (self.channel_count * len) > buffer.len() {
405                        return Err(PSDDecodeErrors::BadRLE);
406                    }
407
408                    while len > 0 {
409                        buffer[position] = self.stream.get_u8();
410                        position += self.channel_count;
411                        len -= 1;
412                    }
413                }
414                Ordering::Equal => (),
415                Ordering::Greater => {
416                    // Next -len+1 bytes in the dest are replicated from next source byte.
417                    // (Interpret len as a negative 8-bit int.)
418                    len = 257_usize.wrapping_sub(len) & 255;
419
420                    if len > nleft {
421                        return Err(PSDDecodeErrors::BadRLE);
422                    }
423                    count += len;
424                    let val = self.stream.get_u8();
425
426                    if position + (self.channel_count * len) > buffer.len() {
427                        return Err(PSDDecodeErrors::BadRLE);
428                    }
429
430                    while len > 0 {
431                        buffer[position] = val;
432                        position += self.channel_count;
433                        len -= 1;
434                    }
435                }
436            }
437
438            nleft = pixel_count - count;
439        }
440        Ok(())
441    }
442
443    /// Get image bit depth or None if the headers haven't been decoded
444    pub const fn get_bit_depth(&self) -> Option<BitDepth> {
445        if self.decoded_header {
446            return Some(self.depth);
447        }
448        None
449    }
450
451    /// Get image width and height respectively or None if the
452    /// headers haven't been decoded
453    pub fn get_dimensions(&self) -> Option<(usize, usize)> {
454        if self.decoded_header {
455            return Some((self.width, self.height));
456        }
457        None
458    }
459    /// Get image colorspace or None if the
460    /// image header hasn't been decoded
461    pub fn get_colorspace(&self) -> Option<ColorSpace> {
462        if let Some(color) = self.color_type {
463            if color == ColorModes::RGB {
464                return if self.channel_count == 4 {
465                    Some(ColorSpace::RGBA)
466                } else {
467                    Some(ColorSpace::RGB)
468                };
469            } else if color == ColorModes::Grayscale {
470                return if self.channel_count == 1 {
471                    Some(ColorSpace::Luma)
472                } else if self.channel_count == 2 {
473                    Some(ColorSpace::LumaA)
474                } else {
475                    None
476                };
477            }
478            if color == ColorModes::CYMK {
479                return Some(ColorSpace::CMYK);
480            }
481        }
482        None
483    }
484}