1use alloc::vec;
20use alloc::vec::Vec;
21use core::cmp::Ordering;
22
23use zune_core::bit_depth::BitDepth;
24use zune_core::bytestream::{ZByteReader, ZReaderTrait};
25use zune_core::colorspace::ColorSpace;
26use zune_core::log::trace;
27use zune_core::options::DecoderOptions;
28use zune_core::result::DecodingResult;
29
30use crate::constants::{ColorModes, CompressionMethod, PSD_IDENTIFIER_BE};
31use crate::errors::PSDDecodeErrors;
32
33pub struct PSDDecoder<T>
42where
43 T: ZReaderTrait
44{
45 width: usize,
46 height: usize,
47 decoded_header: bool,
48 stream: ZByteReader<T>,
49 options: DecoderOptions,
50 depth: BitDepth,
51 color_type: Option<ColorModes>,
52 compression: CompressionMethod,
53 channel_count: usize
54}
55
56impl<T> PSDDecoder<T>
57where
58 T: ZReaderTrait
59{
60 pub fn new(data: T) -> PSDDecoder<T> {
66 Self::new_with_options(data, DecoderOptions::default())
67 }
68
69 pub fn new_with_options(data: T, options: DecoderOptions) -> PSDDecoder<T> {
75 PSDDecoder {
76 width: 0,
77 height: 0,
78 decoded_header: false,
79 stream: ZByteReader::new(data),
80 options,
81 depth: BitDepth::Eight,
82 color_type: None,
83 compression: CompressionMethod::NoCompression,
84 channel_count: 0
85 }
86 }
87
88 pub fn decode_headers(&mut self) -> Result<(), PSDDecodeErrors> {
94 if self.decoded_header {
95 return Ok(());
96 }
97 let magic = self.stream.get_u32_be_err()?;
99
100 if magic != PSD_IDENTIFIER_BE {
101 return Err(PSDDecodeErrors::WrongMagicBytes(magic));
102 }
103
104 let version = self.stream.get_u16_be_err()?;
106
107 if version != 1 {
108 return Err(PSDDecodeErrors::UnsupportedFileType(version));
109 }
110 self.stream.skip(6);
112 let channel_count = self.stream.get_u16_be_err()?;
114
115 if channel_count > 4 {
116 return Err(PSDDecodeErrors::UnsupportedChannelCount(channel_count));
117 }
118
119 self.channel_count = usize::from(channel_count);
120
121 let height = self.stream.get_u32_be_err()? as usize;
122 let width = self.stream.get_u32_be_err()? as usize;
123
124 if width > self.options.get_max_width() {
125 return Err(PSDDecodeErrors::LargeDimensions(
126 self.options.get_max_width(),
127 width
128 ));
129 }
130
131 if height > self.options.get_max_height() {
132 return Err(PSDDecodeErrors::LargeDimensions(
133 self.options.get_max_height(),
134 height
135 ));
136 }
137
138 self.width = width;
139 self.height = height;
140
141 if self.width == 0 || self.height == 0 || self.channel_count == 0 {
142 return Err(PSDDecodeErrors::ZeroDimensions);
143 }
144
145 let depth = self.stream.get_u16_be_err()?;
146
147 if depth != 8 && depth != 16 {
148 return Err(PSDDecodeErrors::UnsupportedBitDepth(depth));
149 }
150 let im_depth = match depth {
151 8 => BitDepth::Eight,
152 16 => BitDepth::Sixteen,
153 _ => unreachable!()
154 };
155
156 self.depth = im_depth;
157
158 let color_mode = self.stream.get_u16_be_err()?;
159
160 let color_enum = ColorModes::from_int(color_mode);
161
162 if let Some(color) = color_enum {
163 if !matches!(
164 color,
165 ColorModes::RGB | ColorModes::Grayscale | ColorModes::CYMK
166 ) {
167 return Err(PSDDecodeErrors::UnsupportedColorFormat(color_enum));
168 }
169 } else {
170 return Err(PSDDecodeErrors::Generic("Unknown color mode"));
171 }
172 self.color_type = color_enum;
173
174 let bytes = self.stream.get_u32_be_err()? as usize;
176 self.stream.skip(bytes);
177
178 let bytes = self.stream.get_u32_be_err()? as usize;
180 self.stream.skip(bytes);
181
182 let bytes = self.stream.get_u32_be_err()? as usize;
184 self.stream.skip(bytes);
185
186 let compression = self.stream.get_u16_be_err()?;
188
189 if compression > 1 {
190 return Err(PSDDecodeErrors::UnknownCompression);
191 }
192 if self.color_type == Some(ColorModes::Grayscale) {
193 self.channel_count = 1;
197 }
198
199 self.compression = CompressionMethod::from_int(compression).unwrap();
200
201 self.decoded_header = true;
202
203 trace!("Image width:{}", self.width);
204 trace!("Image height:{}", self.height);
205 trace!("Channels: {}", self.channel_count);
206 trace!("Bit depth : {:?}", self.depth);
207
208 Ok(())
209 }
210
211 pub fn decode_raw(&mut self) -> Result<Vec<u8>, PSDDecodeErrors> {
217 if !self.decoded_header {
218 self.decode_headers()?;
219 }
220
221 let pixel_count = self.width * self.height;
222
223 let mut result = match (self.compression, self.depth) {
224 (CompressionMethod::RLE, BitDepth::Eight) => {
225 let skipped = self.height * self.channel_count * 2;
236 self.stream.skip(skipped);
237
238 let mut out_channel = vec![0; pixel_count * self.channel_count + 10];
239
240 for channel in 0..self.channel_count {
241 let pixel_count = self.width * self.height;
242 self.psd_decode_rle(pixel_count, &mut out_channel[channel..])?;
243 }
244
245 out_channel.truncate(pixel_count * self.channel_count);
246
247 out_channel
248 }
249 (CompressionMethod::NoCompression, BitDepth::Eight) => {
250 let mut out_channel = vec![0; self.width * self.height * self.channel_count + 10];
257 let pixel_count = self.width * self.height;
258
259 if !self.stream.has(pixel_count * self.channel_count) {
261 return Err(PSDDecodeErrors::Generic("Incomplete bitstream"));
262 }
263
264 for channel in 0..self.channel_count {
265 let mut i = channel;
266
267 while i < pixel_count {
268 out_channel[i] = self.stream.get_u8();
269 i += self.channel_count;
270 }
271 }
272
273 out_channel.truncate(pixel_count * self.channel_count);
274 out_channel
275 }
276
277 (CompressionMethod::NoCompression, BitDepth::Sixteen) => {
278 let channel_dimensions = self.width * self.height;
286
287 let mut out_channel = vec![0; 2 * (channel_dimensions * self.channel_count + 10)];
288
289 let pixel_count = channel_dimensions * 2;
290
291 if !self.stream.has(pixel_count * self.channel_count) {
293 return Err(PSDDecodeErrors::Generic("Incomplete bitstream"));
294 }
295
296 for channel in 0..self.channel_count {
298 let i = channel * 2;
299 let out_chunks = out_channel[i..].chunks_exact_mut(self.channel_count * 2);
300
301 for out in out_chunks.take(channel_dimensions) {
303 let value = self.stream.get_u16_be();
304
305 out[..2].copy_from_slice(&value.to_ne_bytes());
306 }
307 }
308
309 out_channel.truncate(pixel_count * self.channel_count);
310 out_channel
311 }
312 _ => return Err(PSDDecodeErrors::Generic("Not implemented or Unknown"))
313 };
314 if self.channel_count >= 4 {
316 match self.depth {
317 BitDepth::Sixteen => {
318 for pixel in result.chunks_exact_mut(8) {
319 let px3 = u16::from_be_bytes(pixel[6..8].try_into().unwrap());
320 if px3 != 0 && px3 != 65535 {
321 let px0 = u16::from_be_bytes(pixel[0..2].try_into().unwrap());
322 let px1 = u16::from_be_bytes(pixel[2..4].try_into().unwrap());
323 let px2 = u16::from_be_bytes(pixel[4..6].try_into().unwrap());
324
325 let a = f32::from(px3) / 65535.0;
326 let ra = 1.0 / a;
327 let inv_a = 65535.0 * (1.0 - ra);
328
329 let x = (f32::from(px0) * ra + inv_a) as u16;
330 let y = (f32::from(px1) * ra + inv_a) as u16;
331 let z = (f32::from(px2) * ra + inv_a) as u16;
332
333 pixel[0..2].copy_from_slice(&x.to_ne_bytes());
334 pixel[2..4].copy_from_slice(&y.to_ne_bytes());
335 pixel[4..6].copy_from_slice(&z.to_ne_bytes());
336 }
337 }
338 }
339 BitDepth::Eight => {
340 for pixel in result.chunks_exact_mut(4) {
341 if pixel[3] != 0 && pixel[3] != 255 {
342 let a = f32::from(pixel[3]) / 255.0;
343 let ra = 1.0 / a;
344 let inv_a = 255.0 * (1.0 - ra);
345 pixel[0] = (f32::from(pixel[0]) * ra + inv_a) as u8;
346 pixel[1] = (f32::from(pixel[1]) * ra + inv_a) as u8;
347 pixel[2] = (f32::from(pixel[2]) * ra + inv_a) as u8;
348 }
349 }
350 }
351 _ => unreachable!()
352 }
353 }
354 Ok(result)
355 }
356 pub fn decode(&mut self) -> Result<DecodingResult, PSDDecodeErrors> {
363 let raw = self.decode_raw()?;
364
365 if self.depth == BitDepth::Eight {
366 return Ok(DecodingResult::U8(raw));
367 }
368 if self.depth == BitDepth::Sixteen {
369 let new_array: Vec<u16> = raw
371 .chunks_exact(2)
372 .map(|chunk| {
373 let value: [u8; 2] = chunk.try_into().unwrap();
374 u16::from_be_bytes(value)
375 })
376 .collect();
377
378 return Ok(DecodingResult::U16(new_array));
379 }
380
381 Err(PSDDecodeErrors::Generic("Not implemented"))
382 }
383
384 fn psd_decode_rle(
385 &mut self, pixel_count: usize, buffer: &mut [u8]
386 ) -> Result<(), PSDDecodeErrors> {
387 let mut count = 0;
388 let mut nleft = pixel_count - count;
389
390 let mut position = 0;
391
392 while nleft > 0 {
393 let mut len = usize::from(self.stream.get_u8());
394
395 match len.cmp(&128) {
396 Ordering::Less => {
397 len += 1;
399 if len > nleft {
400 return Err(PSDDecodeErrors::BadRLE);
401 }
402 count += len;
403
404 if position + (self.channel_count * len) > buffer.len() {
405 return Err(PSDDecodeErrors::BadRLE);
406 }
407
408 while len > 0 {
409 buffer[position] = self.stream.get_u8();
410 position += self.channel_count;
411 len -= 1;
412 }
413 }
414 Ordering::Equal => (),
415 Ordering::Greater => {
416 len = 257_usize.wrapping_sub(len) & 255;
419
420 if len > nleft {
421 return Err(PSDDecodeErrors::BadRLE);
422 }
423 count += len;
424 let val = self.stream.get_u8();
425
426 if position + (self.channel_count * len) > buffer.len() {
427 return Err(PSDDecodeErrors::BadRLE);
428 }
429
430 while len > 0 {
431 buffer[position] = val;
432 position += self.channel_count;
433 len -= 1;
434 }
435 }
436 }
437
438 nleft = pixel_count - count;
439 }
440 Ok(())
441 }
442
443 pub const fn get_bit_depth(&self) -> Option<BitDepth> {
445 if self.decoded_header {
446 return Some(self.depth);
447 }
448 None
449 }
450
451 pub fn get_dimensions(&self) -> Option<(usize, usize)> {
454 if self.decoded_header {
455 return Some((self.width, self.height));
456 }
457 None
458 }
459 pub fn get_colorspace(&self) -> Option<ColorSpace> {
462 if let Some(color) = self.color_type {
463 if color == ColorModes::RGB {
464 return if self.channel_count == 4 {
465 Some(ColorSpace::RGBA)
466 } else {
467 Some(ColorSpace::RGB)
468 };
469 } else if color == ColorModes::Grayscale {
470 return if self.channel_count == 1 {
471 Some(ColorSpace::Luma)
472 } else if self.channel_count == 2 {
473 Some(ColorSpace::LumaA)
474 } else {
475 None
476 };
477 }
478 if color == ColorModes::CYMK {
479 return Some(ColorSpace::CMYK);
480 }
481 }
482 None
483 }
484}