Skip to main content

zenjxl_decoder/api/
decoder.rs

1// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6use super::{
7    JxlBasicInfo, JxlBitstreamInput, JxlColorProfile, JxlDecoderInner, JxlDecoderOptions,
8    JxlOutputBuffer, JxlPixelFormat, ProcessingResult,
9};
10#[cfg(test)]
11use crate::frame::Frame;
12use crate::{
13    api::JxlFrameHeader,
14    container::{frame_index::FrameIndexBox, gain_map::GainMapBundle},
15    error::Result,
16};
17use states::*;
18use std::marker::PhantomData;
19
20pub mod states {
21    pub trait JxlState {}
22    pub struct Initialized;
23    pub struct WithImageInfo;
24    pub struct WithFrameInfo;
25    impl JxlState for Initialized {}
26    impl JxlState for WithImageInfo {}
27    impl JxlState for WithFrameInfo {}
28}
29
30// Q: do we plan to add support for box decoding?
31// If we do, one way is to take a callback &[u8; 4] -> Box<dyn Write>.
32
33/// High level API using the typestate pattern to forbid invalid usage.
34pub struct JxlDecoder<State: JxlState> {
35    inner: Box<JxlDecoderInner>,
36    _state: PhantomData<State>,
37}
38
39#[cfg(test)]
40pub type FrameCallback = dyn FnMut(&Frame, usize) -> Result<()>;
41
42impl<S: JxlState> JxlDecoder<S> {
43    fn wrap_inner(inner: Box<JxlDecoderInner>) -> Self {
44        Self {
45            inner,
46            _state: PhantomData,
47        }
48    }
49
50    /// Sets a callback that processes all frames by calling `callback(frame, frame_index)`.
51    #[cfg(test)]
52    pub fn set_frame_callback(&mut self, callback: Box<FrameCallback>) {
53        self.inner.set_frame_callback(callback);
54    }
55
56    #[cfg(test)]
57    pub fn decoded_frames(&self) -> usize {
58        self.inner.decoded_frames()
59    }
60
61    /// Returns the reconstructed JPEG bytes if the file contained a JBRD box.
62    /// Call after decoding a frame. Returns `None` if no JBRD box was present
63    /// or the `jpeg` feature is not enabled.
64    #[cfg(feature = "jpeg")]
65    pub fn take_jpeg_reconstruction(&mut self) -> Option<Vec<u8>> {
66        self.inner.take_jpeg_reconstruction()
67    }
68
69    /// Returns the parsed frame index box, if the file contained one.
70    ///
71    /// The frame index box (`jxli`) is an optional part of the JXL container
72    /// format that provides a seek table for animated files, listing keyframe
73    /// byte offsets, timestamps, and frame counts.
74    pub fn frame_index(&self) -> Option<&FrameIndexBox> {
75        self.inner.frame_index()
76    }
77
78    /// Returns a reference to the parsed gain map bundle, if the file contained
79    /// a `jhgm` box (ISO 21496-1 HDR gain map).
80    ///
81    /// The gain map codestream is a bare JXL codestream that can be decoded
82    /// with the same decoder. The ISO 21496-1 metadata blob is stored as raw
83    /// bytes for the caller to parse.
84    pub fn gain_map(&self) -> Option<&GainMapBundle> {
85        self.inner.gain_map()
86    }
87
88    /// Takes the parsed gain map bundle, if the file contained a `jhgm` box.
89    /// After calling this, `gain_map()` will return `None`.
90    pub fn take_gain_map(&mut self) -> Option<GainMapBundle> {
91        self.inner.take_gain_map()
92    }
93
94    /// Returns the raw EXIF data from the `Exif` container box, if present.
95    ///
96    /// The 4-byte TIFF header offset prefix is stripped; this returns the raw
97    /// EXIF/TIFF bytes starting with the byte-order marker (`II` or `MM`).
98    /// Returns `None` for bare codestreams or files without an `Exif` box.
99    ///
100    /// Note: the `Exif` box may appear after the codestream in the container.
101    /// Call this after decoding at least one frame for the most complete results.
102    pub fn exif(&self) -> Option<&[u8]> {
103        self.inner.exif()
104    }
105
106    /// Takes the EXIF data, leaving `None` in its place.
107    pub fn take_exif(&mut self) -> Option<Vec<u8>> {
108        self.inner.take_exif()
109    }
110
111    /// Returns the raw XMP data from the `xml ` container box, if present.
112    ///
113    /// Returns `None` for bare codestreams or files without an `xml ` box.
114    ///
115    /// Note: the `xml ` box may appear after the codestream in the container.
116    /// Call this after decoding at least one frame for the most complete results.
117    pub fn xmp(&self) -> Option<&[u8]> {
118        self.inner.xmp()
119    }
120
121    /// Takes the XMP data, leaving `None` in its place.
122    pub fn take_xmp(&mut self) -> Option<Vec<u8>> {
123        self.inner.take_xmp()
124    }
125
126    /// Rewinds a decoder to the start of the file, allowing past frames to be displayed again.
127    pub fn rewind(mut self) -> JxlDecoder<Initialized> {
128        self.inner.rewind();
129        JxlDecoder::wrap_inner(self.inner)
130    }
131
132    fn map_inner_processing_result<SuccessState: JxlState>(
133        self,
134        inner_result: ProcessingResult<(), ()>,
135    ) -> ProcessingResult<JxlDecoder<SuccessState>, Self> {
136        match inner_result {
137            ProcessingResult::Complete { .. } => ProcessingResult::Complete {
138                result: JxlDecoder::wrap_inner(self.inner),
139            },
140            ProcessingResult::NeedsMoreInput { size_hint, .. } => {
141                ProcessingResult::NeedsMoreInput {
142                    size_hint,
143                    fallback: self,
144                }
145            }
146        }
147    }
148}
149
150impl JxlDecoder<Initialized> {
151    pub fn new(options: JxlDecoderOptions) -> Self {
152        Self::wrap_inner(Box::new(JxlDecoderInner::new(options)))
153    }
154
155    pub fn process(
156        mut self,
157        input: &mut impl JxlBitstreamInput,
158    ) -> Result<ProcessingResult<JxlDecoder<WithImageInfo>, Self>> {
159        let inner_result = self.inner.process(input, None)?;
160        Ok(self.map_inner_processing_result(inner_result))
161    }
162}
163
164impl JxlDecoder<WithImageInfo> {
165    // TODO(veluca): once frame skipping is implemented properly, expose that in the API.
166
167    /// Obtains the image's basic information.
168    pub fn basic_info(&self) -> &JxlBasicInfo {
169        self.inner.basic_info().unwrap()
170    }
171
172    /// Retrieves the file's color profile.
173    pub fn embedded_color_profile(&self) -> &JxlColorProfile {
174        self.inner.embedded_color_profile().unwrap()
175    }
176
177    /// Retrieves the current output color profile.
178    pub fn output_color_profile(&self) -> &JxlColorProfile {
179        self.inner.output_color_profile().unwrap()
180    }
181
182    /// Specifies the preferred color profile to be used for outputting data.
183    /// Same semantics as JxlDecoderSetOutputColorProfile.
184    pub fn set_output_color_profile(&mut self, profile: JxlColorProfile) -> Result<()> {
185        self.inner.set_output_color_profile(profile)
186    }
187
188    /// Retrieves the current pixel format for output buffers.
189    pub fn current_pixel_format(&self) -> &JxlPixelFormat {
190        self.inner.current_pixel_format().unwrap()
191    }
192
193    /// Specifies pixel format for output buffers.
194    ///
195    /// Setting this may also change output color profile in some cases, if the profile was not set
196    /// manually before.
197    pub fn set_pixel_format(&mut self, pixel_format: JxlPixelFormat) {
198        self.inner.set_pixel_format(pixel_format);
199    }
200
201    pub fn process(
202        mut self,
203        input: &mut impl JxlBitstreamInput,
204    ) -> Result<ProcessingResult<JxlDecoder<WithFrameInfo>, Self>> {
205        let inner_result = self.inner.process(input, None)?;
206        Ok(self.map_inner_processing_result(inner_result))
207    }
208
209    /// Draws all the pixels we have data for. This is useful for i.e. previewing LF frames.
210    ///
211    /// Note: see `process` for alignment requirements for the buffer data.
212    pub fn flush_pixels(&mut self, buffers: &mut [JxlOutputBuffer<'_>]) -> Result<()> {
213        self.inner.flush_pixels(buffers)
214    }
215
216    pub fn has_more_frames(&self) -> bool {
217        self.inner.has_more_frames()
218    }
219
220    #[cfg(test)]
221    pub(crate) fn set_use_simple_pipeline(&mut self, u: bool) {
222        self.inner.set_use_simple_pipeline(u);
223    }
224}
225
226impl JxlDecoder<WithFrameInfo> {
227    /// Skip the current frame.
228    pub fn skip_frame(
229        mut self,
230        input: &mut impl JxlBitstreamInput,
231    ) -> Result<ProcessingResult<JxlDecoder<WithImageInfo>, Self>> {
232        let inner_result = self.inner.process(input, None)?;
233        Ok(self.map_inner_processing_result(inner_result))
234    }
235
236    pub fn frame_header(&self) -> JxlFrameHeader {
237        self.inner.frame_header().unwrap()
238    }
239
240    /// Number of passes we have full data for.
241    pub fn num_completed_passes(&self) -> usize {
242        self.inner.num_completed_passes().unwrap()
243    }
244
245    /// Draws all the pixels we have data for.
246    ///
247    /// Note: see `process` for alignment requirements for the buffer data.
248    pub fn flush_pixels(&mut self, buffers: &mut [JxlOutputBuffer<'_>]) -> Result<()> {
249        self.inner.flush_pixels(buffers)
250    }
251
252    /// Guarantees to populate exactly the appropriate part of the buffers.
253    /// Wants one buffer for each non-ignored pixel type, i.e. color channels and each extra channel.
254    ///
255    /// Note: the data in `buffers` should have alignment requirements that are compatible with the
256    /// requested pixel format. This means that, if we are asking for 2-byte or 4-byte output (i.e.
257    /// u16/f16 and f32 respectively), each row in the provided buffers must be aligned to 2 or 4
258    /// bytes respectively. If that is not the case, the library may panic.
259    pub fn process<In: JxlBitstreamInput>(
260        mut self,
261        input: &mut In,
262        buffers: &mut [JxlOutputBuffer<'_>],
263    ) -> Result<ProcessingResult<JxlDecoder<WithImageInfo>, Self>> {
264        let inner_result = self.inner.process(input, Some(buffers))?;
265        Ok(self.map_inner_processing_result(inner_result))
266    }
267}
268
269#[cfg(test)]
270pub(crate) mod tests {
271    use super::*;
272    use crate::api::{JxlDataFormat, JxlDecoderOptions};
273    use crate::error::Error;
274    use crate::image::{Image, Rect};
275    use jxl_macros::for_each_test_file;
276    use std::path::Path;
277
278    #[test]
279    fn decode_small_chunks() {
280        arbtest::arbtest(|u| {
281            decode(
282                &std::fs::read("resources/test/green_queen_vardct_e3.jxl").unwrap(),
283                u.arbitrary::<u8>().unwrap() as usize + 1,
284                false,
285                false,
286                None,
287            )
288            .unwrap();
289            Ok(())
290        });
291    }
292
293    #[allow(clippy::type_complexity)]
294    pub fn decode(
295        mut input: &[u8],
296        chunk_size: usize,
297        use_simple_pipeline: bool,
298        do_flush: bool,
299        callback: Option<Box<dyn FnMut(&Frame, usize) -> Result<(), Error>>>,
300    ) -> Result<(usize, Vec<Vec<Image<f32>>>), Error> {
301        let mut options = JxlDecoderOptions::default();
302        // Correctness tests should not be constrained by memory limits.
303        // OOM/limit tests verify those separately.
304        options.limits.max_memory_bytes = None;
305        let mut initialized_decoder = JxlDecoder::<states::Initialized>::new(options);
306
307        if let Some(callback) = callback {
308            initialized_decoder.set_frame_callback(callback);
309        }
310
311        let mut chunk_input = &input[0..0];
312
313        macro_rules! advance_decoder {
314            ($decoder: ident $(, $extra_arg: expr)? $(; $flush_arg: expr)?) => {
315                loop {
316                    chunk_input =
317                        &input[..(chunk_input.len().saturating_add(chunk_size)).min(input.len())];
318                    let available_before = chunk_input.len();
319                    let process_result = $decoder.process(&mut chunk_input $(, $extra_arg)?);
320                    input = &input[(available_before - chunk_input.len())..];
321                    match process_result.unwrap() {
322                        ProcessingResult::Complete { result } => break result,
323                        ProcessingResult::NeedsMoreInput { fallback, .. } => {
324                            $(
325                                let mut fallback = fallback;
326                                if do_flush && !input.is_empty() {
327                                    fallback.flush_pixels($flush_arg)?;
328                                }
329                            )?
330                            if input.is_empty() {
331                                panic!("Unexpected end of input");
332                            }
333                            $decoder = fallback;
334                        }
335                    }
336                }
337            };
338        }
339
340        // Process until we have image info
341        let mut decoder_with_image_info = advance_decoder!(initialized_decoder);
342        decoder_with_image_info.set_use_simple_pipeline(use_simple_pipeline);
343
344        // Get basic info
345        let basic_info = decoder_with_image_info.basic_info().clone();
346        assert!(basic_info.bit_depth.bits_per_sample() > 0);
347
348        // Get image dimensions (after upsampling, which is the actual output size)
349        let (buffer_width, buffer_height) = basic_info.size;
350        assert!(buffer_width > 0);
351        assert!(buffer_height > 0);
352
353        // Explicitly request F32 pixel format (test helper returns Image<f32>)
354        let default_format = decoder_with_image_info.current_pixel_format();
355        let requested_format = JxlPixelFormat {
356            color_type: default_format.color_type,
357            color_data_format: Some(JxlDataFormat::f32()),
358            extra_channel_format: default_format
359                .extra_channel_format
360                .iter()
361                .map(|_| Some(JxlDataFormat::f32()))
362                .collect(),
363        };
364        decoder_with_image_info.set_pixel_format(requested_format);
365
366        // Get the configured pixel format
367        let pixel_format = decoder_with_image_info.current_pixel_format().clone();
368
369        let num_channels = pixel_format.color_type.samples_per_pixel();
370        assert!(num_channels > 0);
371
372        let mut frames = vec![];
373
374        loop {
375            // First channel is interleaved.
376            let mut buffers = vec![Image::new_with_value(
377                (buffer_width * num_channels, buffer_height),
378                f32::NAN,
379            )?];
380
381            for ecf in pixel_format.extra_channel_format.iter() {
382                if ecf.is_none() {
383                    continue;
384                }
385                buffers.push(Image::new_with_value(
386                    (buffer_width, buffer_height),
387                    f32::NAN,
388                )?);
389            }
390
391            let mut api_buffers: Vec<_> = buffers
392                .iter_mut()
393                .map(|b| {
394                    JxlOutputBuffer::from_image_rect_mut(
395                        b.get_rect_mut(Rect {
396                            origin: (0, 0),
397                            size: b.size(),
398                        })
399                        .into_raw(),
400                    )
401                })
402                .collect();
403
404            // Process until we have frame info
405            let mut decoder_with_frame_info =
406                advance_decoder!(decoder_with_image_info; &mut api_buffers);
407            decoder_with_image_info =
408                advance_decoder!(decoder_with_frame_info, &mut api_buffers; &mut api_buffers);
409
410            // All pixels should have been overwritten, so they should no longer be NaNs.
411            for buf in buffers.iter() {
412                let (xs, ys) = buf.size();
413                for y in 0..ys {
414                    let row = buf.row(y);
415                    for (x, v) in row.iter().enumerate() {
416                        assert!(!v.is_nan(), "NaN at {x} {y} (image size {xs}x{ys})");
417                    }
418                }
419            }
420
421            frames.push(buffers);
422
423            // Check if there are more frames
424            if !decoder_with_image_info.has_more_frames() {
425                let decoded_frames = decoder_with_image_info.decoded_frames();
426
427                // Ensure we decoded at least one frame
428                assert!(decoded_frames > 0, "No frames were decoded");
429
430                return Ok((decoded_frames, frames));
431            }
432        }
433    }
434
435    fn decode_test_file(path: &Path) -> Result<(), Error> {
436        decode(&std::fs::read(path)?, usize::MAX, false, false, None)?;
437        Ok(())
438    }
439
440    for_each_test_file!(decode_test_file);
441
442    fn decode_test_file_chunks(path: &Path) -> Result<(), Error> {
443        decode(&std::fs::read(path)?, 1, false, false, None)?;
444        Ok(())
445    }
446
447    for_each_test_file!(decode_test_file_chunks);
448
449    fn compare_frames(
450        _path: &Path,
451        fc: usize,
452        f: &[Image<f32>],
453        sf: &[Image<f32>],
454    ) -> Result<(), Error> {
455        assert_eq!(
456            f.len(),
457            sf.len(),
458            "Frame {fc} has different channels counts",
459        );
460        for (c, (b, sb)) in f.iter().zip(sf.iter()).enumerate() {
461            assert_eq!(
462                b.size(),
463                sb.size(),
464                "Channel {c} in frame {fc} has different sizes",
465            );
466            let sz = b.size();
467            for y in 0..sz.1 {
468                for x in 0..sz.0 {
469                    assert_eq!(
470                        b.row(y)[x],
471                        sb.row(y)[x],
472                        "Pixels differ at position ({x}, {y}), channel {c}"
473                    );
474                }
475            }
476        }
477        Ok(())
478    }
479
480    /// Hash all pixel rows for memory-efficient comparison.
481    fn hash_frames(frames: &[Vec<Image<f32>>]) -> Vec<Vec<Vec<u64>>> {
482        use std::hash::{Hash, Hasher};
483        frames
484            .iter()
485            .map(|channels| {
486                channels
487                    .iter()
488                    .map(|img| {
489                        let (_, ys) = img.size();
490                        (0..ys)
491                            .map(|y| {
492                                let mut h = std::hash::DefaultHasher::new();
493                                for &v in img.row(y) {
494                                    v.to_bits().hash(&mut h);
495                                }
496                                h.finish()
497                            })
498                            .collect()
499                    })
500                    .collect()
501            })
502            .collect()
503    }
504
505    fn compare_pipelines(path: &Path) -> Result<(), Error> {
506        let file = std::fs::read(path)?;
507        let reference_frames = decode(&file, usize::MAX, true, false, None)?.1;
508        // Hash and drop reference pixels before second decode to halve peak
509        // memory. Critical for 32-bit targets where two full 4K decoded
510        // outputs + decoder state exceeds address space.
511        let reference_hashes = hash_frames(&reference_frames);
512        drop(reference_frames);
513        let frames = decode(&file, usize::MAX, false, false, None)?.1;
514        let frame_hashes = hash_frames(&frames);
515        assert_eq!(
516            reference_hashes,
517            frame_hashes,
518            "{}: pipeline outputs differ",
519            path.display()
520        );
521        Ok(())
522    }
523
524    for_each_test_file!(compare_pipelines);
525
526    fn compare_incremental(path: &Path) -> Result<(), Error> {
527        let file = std::fs::read(path).unwrap();
528        // One-shot decode — hash and drop before incremental decode.
529        let (_, one_shot_frames) = decode(&file, usize::MAX, false, false, None)?;
530        let reference_hashes = hash_frames(&one_shot_frames);
531        drop(one_shot_frames);
532        // Incremental decode with arbitrary flushes.
533        let (_, frames) = decode(&file, 123, false, true, None)?;
534        let frame_hashes = hash_frames(&frames);
535        assert_eq!(
536            reference_hashes,
537            frame_hashes,
538            "{}: incremental vs one-shot outputs differ",
539            path.display()
540        );
541
542        Ok(())
543    }
544
545    for_each_test_file!(compare_incremental);
546
547    #[test]
548    fn test_preview_size_none_for_regular_files() {
549        let file = std::fs::read("resources/test/basic.jxl").unwrap();
550        let options = JxlDecoderOptions::default();
551        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
552        let mut input = file.as_slice();
553        let decoder = loop {
554            match decoder.process(&mut input).unwrap() {
555                ProcessingResult::Complete { result } => break result,
556                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
557            }
558        };
559        assert!(decoder.basic_info().preview_size.is_none());
560    }
561
562    #[test]
563    fn test_preview_size_some_for_preview_files() {
564        let file = std::fs::read("resources/test/with_preview.jxl").unwrap();
565        let options = JxlDecoderOptions::default();
566        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
567        let mut input = file.as_slice();
568        let decoder = loop {
569            match decoder.process(&mut input).unwrap() {
570                ProcessingResult::Complete { result } => break result,
571                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
572            }
573        };
574        assert_eq!(decoder.basic_info().preview_size, Some((16, 16)));
575    }
576
577    #[test]
578    fn test_num_completed_passes() {
579        use crate::image::{Image, Rect};
580        let file = std::fs::read("resources/test/basic.jxl").unwrap();
581        let options = JxlDecoderOptions::default();
582        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
583        let mut input = file.as_slice();
584        // Process until we have image info
585        let mut decoder_with_info = loop {
586            match decoder.process(&mut input).unwrap() {
587                ProcessingResult::Complete { result } => break result,
588                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
589            }
590        };
591        let info = decoder_with_info.basic_info().clone();
592        let mut decoder_with_frame = loop {
593            match decoder_with_info.process(&mut input).unwrap() {
594                ProcessingResult::Complete { result } => break result,
595                ProcessingResult::NeedsMoreInput { fallback, .. } => {
596                    decoder_with_info = fallback;
597                }
598            }
599        };
600        // Before processing frame, passes should be 0
601        assert_eq!(decoder_with_frame.num_completed_passes(), 0);
602        // Process the frame
603        let mut output = Image::<f32>::new((info.size.0 * 3, info.size.1)).unwrap();
604        let rect = Rect {
605            size: output.size(),
606            origin: (0, 0),
607        };
608        let mut bufs = [JxlOutputBuffer::from_image_rect_mut(
609            output.get_rect_mut(rect).into_raw(),
610        )];
611        loop {
612            match decoder_with_frame.process(&mut input, &mut bufs).unwrap() {
613                ProcessingResult::Complete { .. } => break,
614                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder_with_frame = fallback,
615            }
616        }
617    }
618
619    #[test]
620    fn test_set_pixel_format() {
621        use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat};
622
623        let file = std::fs::read("resources/test/basic.jxl").unwrap();
624        let options = JxlDecoderOptions::default();
625        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
626        let mut input = file.as_slice();
627        let mut decoder = loop {
628            match decoder.process(&mut input).unwrap() {
629                ProcessingResult::Complete { result } => break result,
630                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
631            }
632        };
633        // Check default pixel format
634        let default_format = decoder.current_pixel_format().clone();
635        assert_eq!(default_format.color_type, JxlColorType::Rgb);
636
637        // Set a new pixel format
638        let new_format = JxlPixelFormat {
639            color_type: JxlColorType::Grayscale,
640            color_data_format: Some(JxlDataFormat::U8 { bit_depth: 8 }),
641            extra_channel_format: vec![],
642        };
643        decoder.set_pixel_format(new_format.clone());
644
645        // Verify it was set
646        assert_eq!(decoder.current_pixel_format(), &new_format);
647    }
648
649    #[test]
650    fn test_set_output_color_profile() {
651        use crate::api::JxlColorProfile;
652
653        let file = std::fs::read("resources/test/basic.jxl").unwrap();
654        let options = JxlDecoderOptions::default();
655        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
656        let mut input = file.as_slice();
657        let mut decoder = loop {
658            match decoder.process(&mut input).unwrap() {
659                ProcessingResult::Complete { result } => break result,
660                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
661            }
662        };
663
664        // Get the embedded profile and set it as output (should work)
665        let embedded = decoder.embedded_color_profile().clone();
666        let result = decoder.set_output_color_profile(embedded);
667        assert!(result.is_ok());
668
669        // Setting an ICC profile without CMS should fail
670        let icc_profile = JxlColorProfile::Icc(vec![0u8; 100]);
671        let result = decoder.set_output_color_profile(icc_profile);
672        assert!(result.is_err());
673    }
674
675    #[test]
676    fn test_default_output_tf_by_pixel_format() {
677        use crate::api::{JxlColorEncoding, JxlTransferFunction};
678
679        // Using test image with ICC profile to trigger default transfer function path
680        let file = std::fs::read("resources/test/lossy_with_icc.jxl").unwrap();
681        let options = JxlDecoderOptions::default();
682        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
683        let mut input = file.as_slice();
684        let mut decoder = loop {
685            match decoder.process(&mut input).unwrap() {
686                ProcessingResult::Complete { result } => break result,
687                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
688            }
689        };
690
691        // Output data format will default to F32, so output color profile will be linear sRGB
692        assert_eq!(
693            *decoder.output_color_profile().transfer_function().unwrap(),
694            JxlTransferFunction::Linear,
695        );
696
697        // Integer data format will set output color profile to sRGB
698        decoder.set_pixel_format(JxlPixelFormat::rgba8(0));
699        assert_eq!(
700            *decoder.output_color_profile().transfer_function().unwrap(),
701            JxlTransferFunction::SRGB,
702        );
703
704        decoder.set_pixel_format(JxlPixelFormat::rgba_f16(0));
705        assert_eq!(
706            *decoder.output_color_profile().transfer_function().unwrap(),
707            JxlTransferFunction::Linear,
708        );
709
710        decoder.set_pixel_format(JxlPixelFormat::rgba16(0));
711        assert_eq!(
712            *decoder.output_color_profile().transfer_function().unwrap(),
713            JxlTransferFunction::SRGB,
714        );
715
716        // Once output color profile is set by user, it will remain as is regardless of what pixel
717        // format is set
718        let profile = JxlColorProfile::Simple(JxlColorEncoding::srgb(false));
719        decoder.set_output_color_profile(profile.clone()).unwrap();
720        decoder.set_pixel_format(JxlPixelFormat::rgba_f16(0));
721        assert!(decoder.output_color_profile() == &profile);
722    }
723
724    #[test]
725    fn test_fill_opaque_alpha_both_pipelines() {
726        use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat};
727        use crate::image::{Image, Rect};
728
729        // Use basic.jxl which has no alpha channel
730        let file = std::fs::read("resources/test/basic.jxl").unwrap();
731
732        // Request RGBA format even though image has no alpha
733        let rgba_format = JxlPixelFormat {
734            color_type: JxlColorType::Rgba,
735            color_data_format: Some(JxlDataFormat::f32()),
736            extra_channel_format: vec![],
737        };
738
739        // Test both pipelines (simple and low-memory)
740        for use_simple in [true, false] {
741            let options = JxlDecoderOptions::default();
742            let decoder = JxlDecoder::<states::Initialized>::new(options);
743            let mut input = file.as_slice();
744
745            // Advance to image info
746            macro_rules! advance_decoder {
747                ($decoder:expr) => {
748                    loop {
749                        match $decoder.process(&mut input).unwrap() {
750                            ProcessingResult::Complete { result } => break result,
751                            ProcessingResult::NeedsMoreInput { fallback, .. } => {
752                                if input.is_empty() {
753                                    panic!("Unexpected end of input");
754                                }
755                                $decoder = fallback;
756                            }
757                        }
758                    }
759                };
760                ($decoder:expr, $buffers:expr) => {
761                    loop {
762                        match $decoder.process(&mut input, $buffers).unwrap() {
763                            ProcessingResult::Complete { result } => break result,
764                            ProcessingResult::NeedsMoreInput { fallback, .. } => {
765                                if input.is_empty() {
766                                    panic!("Unexpected end of input");
767                                }
768                                $decoder = fallback;
769                            }
770                        }
771                    }
772                };
773            }
774
775            let mut decoder = decoder;
776            let mut decoder = advance_decoder!(decoder);
777            decoder.set_use_simple_pipeline(use_simple);
778
779            // Set RGBA format
780            decoder.set_pixel_format(rgba_format.clone());
781
782            let basic_info = decoder.basic_info().clone();
783            let (width, height) = basic_info.size;
784
785            // Advance to frame info
786            let mut decoder = advance_decoder!(decoder);
787
788            // Prepare buffer for RGBA (4 channels interleaved)
789            let mut color_buffer = Image::<f32>::new((width * 4, height)).unwrap();
790            let mut buffers: Vec<_> = vec![JxlOutputBuffer::from_image_rect_mut(
791                color_buffer
792                    .get_rect_mut(Rect {
793                        origin: (0, 0),
794                        size: (width * 4, height),
795                    })
796                    .into_raw(),
797            )];
798
799            // Decode frame
800            let _decoder = advance_decoder!(decoder, &mut buffers);
801
802            // Verify all alpha values are 1.0 (opaque)
803            for y in 0..height {
804                let row = color_buffer.row(y);
805                for x in 0..width {
806                    let alpha = row[x * 4 + 3];
807                    assert_eq!(
808                        alpha, 1.0,
809                        "Alpha at ({},{}) should be 1.0, got {} (use_simple={})",
810                        x, y, alpha, use_simple
811                    );
812                }
813            }
814        }
815    }
816
817    /// Test that premultiply_output=true produces premultiplied alpha output
818    /// from a source with straight (non-premultiplied) alpha.
819    #[test]
820    fn test_premultiply_output_straight_alpha() {
821        use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat};
822
823        // Use alpha_nonpremultiplied.jxl which has straight alpha (alpha_associated=false)
824        let file =
825            std::fs::read("resources/test/conformance_test_images/alpha_nonpremultiplied.jxl")
826                .unwrap();
827
828        // Alpha is included in RGBA, so we set extra_channel_format to None
829        // to indicate no separate buffer for the alpha extra channel
830        let rgba_format = JxlPixelFormat {
831            color_type: JxlColorType::Rgba,
832            color_data_format: Some(JxlDataFormat::f32()),
833            extra_channel_format: vec![None],
834        };
835
836        // Test both pipelines
837        for use_simple in [true, false] {
838            let (straight_buffer, width, height) =
839                decode_with_format::<f32>(&file, &rgba_format, use_simple, false);
840            let (premul_buffer, _, _) =
841                decode_with_format::<f32>(&file, &rgba_format, use_simple, true);
842
843            // Verify premultiplied values: premul_rgb should equal straight_rgb * alpha
844            let mut found_semitransparent = false;
845            for y in 0..height {
846                let straight_row = straight_buffer.row(y);
847                let premul_row = premul_buffer.row(y);
848                for x in 0..width {
849                    let sr = straight_row[x * 4];
850                    let sg = straight_row[x * 4 + 1];
851                    let sb = straight_row[x * 4 + 2];
852                    let sa = straight_row[x * 4 + 3];
853
854                    let pr = premul_row[x * 4];
855                    let pg = premul_row[x * 4 + 1];
856                    let pb = premul_row[x * 4 + 2];
857                    let pa = premul_row[x * 4 + 3];
858
859                    // Alpha should be unchanged
860                    assert!(
861                        (sa - pa).abs() < 1e-5,
862                        "Alpha mismatch at ({},{}): straight={}, premul={} (use_simple={})",
863                        x,
864                        y,
865                        sa,
866                        pa,
867                        use_simple
868                    );
869
870                    // Check premultiplication: premul = straight * alpha
871                    let expected_r = sr * sa;
872                    let expected_g = sg * sa;
873                    let expected_b = sb * sa;
874
875                    // Allow 1% tolerance for precision differences between pipelines
876                    let tol = 0.01;
877                    assert!(
878                        (expected_r - pr).abs() < tol,
879                        "R mismatch at ({},{}): expected={}, got={} (use_simple={})",
880                        x,
881                        y,
882                        expected_r,
883                        pr,
884                        use_simple
885                    );
886                    assert!(
887                        (expected_g - pg).abs() < tol,
888                        "G mismatch at ({},{}): expected={}, got={} (use_simple={})",
889                        x,
890                        y,
891                        expected_g,
892                        pg,
893                        use_simple
894                    );
895                    assert!(
896                        (expected_b - pb).abs() < tol,
897                        "B mismatch at ({},{}): expected={}, got={} (use_simple={})",
898                        x,
899                        y,
900                        expected_b,
901                        pb,
902                        use_simple
903                    );
904
905                    if sa > 0.01 && sa < 0.99 {
906                        found_semitransparent = true;
907                    }
908                }
909            }
910
911            // Ensure the test image actually has some semi-transparent pixels
912            assert!(
913                found_semitransparent,
914                "Test image should have semi-transparent pixels (use_simple={})",
915                use_simple
916            );
917        }
918    }
919
920    /// Test that premultiply_output=true doesn't double-premultiply
921    /// when the source already has premultiplied alpha (alpha_associated=true).
922    #[test]
923    fn test_premultiply_output_already_premultiplied() {
924        use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat};
925
926        // Use alpha_premultiplied.jxl which has alpha_associated=true
927        let file = std::fs::read("resources/test/conformance_test_images/alpha_premultiplied.jxl")
928            .unwrap();
929
930        // Alpha is included in RGBA, so we set extra_channel_format to None
931        let rgba_format = JxlPixelFormat {
932            color_type: JxlColorType::Rgba,
933            color_data_format: Some(JxlDataFormat::f32()),
934            extra_channel_format: vec![None],
935        };
936
937        // Test both pipelines
938        for use_simple in [true, false] {
939            let (without_flag_buffer, width, height) =
940                decode_with_format::<f32>(&file, &rgba_format, use_simple, false);
941            let (with_flag_buffer, _, _) =
942                decode_with_format::<f32>(&file, &rgba_format, use_simple, true);
943
944            // Both outputs should be identical since source is already premultiplied
945            // and we shouldn't double-premultiply
946            for y in 0..height {
947                let without_row = without_flag_buffer.row(y);
948                let with_row = with_flag_buffer.row(y);
949                for x in 0..width {
950                    for c in 0..4 {
951                        let without_val = without_row[x * 4 + c];
952                        let with_val = with_row[x * 4 + c];
953                        assert!(
954                            (without_val - with_val).abs() < 1e-5,
955                            "Mismatch at ({},{}) channel {}: without_flag={}, with_flag={} (use_simple={})",
956                            x,
957                            y,
958                            c,
959                            without_val,
960                            with_val,
961                            use_simple
962                        );
963                    }
964                }
965            }
966        }
967    }
968
969    /// Test that animations with reference frames work correctly.
970    /// This exercises the buffer index calculation fix where reference frame
971    /// save stages use indices beyond the API-provided buffer array.
972    #[test]
973    fn test_animation_with_reference_frames() {
974        use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat};
975        use crate::image::{Image, Rect};
976
977        // Use animation_spline.jxl which has multiple frames with references
978        let file =
979            std::fs::read("resources/test/conformance_test_images/animation_spline.jxl").unwrap();
980
981        let options = JxlDecoderOptions::default();
982        let decoder = JxlDecoder::<states::Initialized>::new(options);
983        let mut input = file.as_slice();
984
985        // Advance to image info
986        let mut decoder = decoder;
987        let mut decoder = loop {
988            match decoder.process(&mut input).unwrap() {
989                ProcessingResult::Complete { result } => break result,
990                ProcessingResult::NeedsMoreInput { fallback, .. } => {
991                    decoder = fallback;
992                }
993            }
994        };
995
996        // Set RGB format with no extra channels
997        let rgb_format = JxlPixelFormat {
998            color_type: JxlColorType::Rgb,
999            color_data_format: Some(JxlDataFormat::f32()),
1000            extra_channel_format: vec![],
1001        };
1002        decoder.set_pixel_format(rgb_format);
1003
1004        let basic_info = decoder.basic_info().clone();
1005        let (width, height) = basic_info.size;
1006
1007        let mut frame_count = 0;
1008
1009        // Decode all frames
1010        loop {
1011            // Advance to frame info
1012            let mut decoder_frame = loop {
1013                match decoder.process(&mut input).unwrap() {
1014                    ProcessingResult::Complete { result } => break result,
1015                    ProcessingResult::NeedsMoreInput { fallback, .. } => {
1016                        decoder = fallback;
1017                    }
1018                }
1019            };
1020
1021            // Prepare buffer for RGB (3 channels interleaved)
1022            let mut color_buffer = Image::<f32>::new((width * 3, height)).unwrap();
1023            let mut buffers: Vec<_> = vec![JxlOutputBuffer::from_image_rect_mut(
1024                color_buffer
1025                    .get_rect_mut(Rect {
1026                        origin: (0, 0),
1027                        size: (width * 3, height),
1028                    })
1029                    .into_raw(),
1030            )];
1031
1032            // Decode frame - this should not panic even though reference frame
1033            // save stages target buffer indices beyond buffers.len()
1034            decoder = loop {
1035                match decoder_frame.process(&mut input, &mut buffers).unwrap() {
1036                    ProcessingResult::Complete { result } => break result,
1037                    ProcessingResult::NeedsMoreInput { fallback, .. } => {
1038                        decoder_frame = fallback;
1039                    }
1040                }
1041            };
1042
1043            frame_count += 1;
1044
1045            // Check if there are more frames
1046            if !decoder.has_more_frames() {
1047                break;
1048            }
1049        }
1050
1051        // Verify we decoded multiple frames
1052        assert!(
1053            frame_count > 1,
1054            "Expected multiple frames in animation, got {}",
1055            frame_count
1056        );
1057    }
1058
1059    #[test]
1060    fn test_skip_frame_then_decode_next() {
1061        use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat};
1062        use crate::image::{Image, Rect};
1063
1064        // Use animation_spline.jxl which has multiple frames
1065        let file =
1066            std::fs::read("resources/test/conformance_test_images/animation_spline.jxl").unwrap();
1067
1068        let options = JxlDecoderOptions::default();
1069        let decoder = JxlDecoder::<states::Initialized>::new(options);
1070        let mut input = file.as_slice();
1071
1072        // Advance to image info
1073        let mut decoder = decoder;
1074        let mut decoder = loop {
1075            match decoder.process(&mut input).unwrap() {
1076                ProcessingResult::Complete { result } => break result,
1077                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1078                    decoder = fallback;
1079                }
1080            }
1081        };
1082
1083        // Set RGB format
1084        let rgb_format = JxlPixelFormat {
1085            color_type: JxlColorType::Rgb,
1086            color_data_format: Some(JxlDataFormat::f32()),
1087            extra_channel_format: vec![],
1088        };
1089        decoder.set_pixel_format(rgb_format);
1090
1091        let basic_info = decoder.basic_info().clone();
1092        let (width, height) = basic_info.size;
1093
1094        // Advance to frame info for first frame
1095        let mut decoder_frame = loop {
1096            match decoder.process(&mut input).unwrap() {
1097                ProcessingResult::Complete { result } => break result,
1098                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1099                    decoder = fallback;
1100                }
1101            }
1102        };
1103
1104        // Skip the first frame (this is where the bug would leave stale frame state)
1105        let mut decoder = loop {
1106            match decoder_frame.skip_frame(&mut input).unwrap() {
1107                ProcessingResult::Complete { result } => break result,
1108                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1109                    decoder_frame = fallback;
1110                }
1111            }
1112        };
1113
1114        assert!(
1115            decoder.has_more_frames(),
1116            "Animation should have more frames"
1117        );
1118
1119        // Advance to frame info for second frame
1120        // Without the fix, this would panic at assert!(self.frame.is_none())
1121        let mut decoder_frame = loop {
1122            match decoder.process(&mut input).unwrap() {
1123                ProcessingResult::Complete { result } => break result,
1124                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1125                    decoder = fallback;
1126                }
1127            }
1128        };
1129
1130        // Decode the second frame to verify everything works
1131        let mut color_buffer = Image::<f32>::new((width * 3, height)).unwrap();
1132        let mut buffers: Vec<_> = vec![JxlOutputBuffer::from_image_rect_mut(
1133            color_buffer
1134                .get_rect_mut(Rect {
1135                    origin: (0, 0),
1136                    size: (width * 3, height),
1137                })
1138                .into_raw(),
1139        )];
1140
1141        let decoder = loop {
1142            match decoder_frame.process(&mut input, &mut buffers).unwrap() {
1143                ProcessingResult::Complete { result } => break result,
1144                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1145                    decoder_frame = fallback;
1146                }
1147            }
1148        };
1149
1150        // If we got here without panicking, the fix works
1151        // Optionally verify we can continue with more frames
1152        let _ = decoder.has_more_frames();
1153    }
1154
1155    /// Test that u8 output matches f32 output within quantization tolerance.
1156    /// This test would catch bugs like the offset miscalculation in PR #586
1157    /// that caused black bars in u8 output.
1158    #[test]
1159    fn test_output_format_u8_matches_f32() {
1160        use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat};
1161
1162        // Use bicycles.jxl - a larger image that exercises offset calculations
1163        let file = std::fs::read("resources/test/conformance_test_images/bicycles.jxl").unwrap();
1164
1165        // Test both RGB and BGRA to catch channel reordering bugs
1166        for (color_type, num_samples) in [(JxlColorType::Rgb, 3), (JxlColorType::Bgra, 4)] {
1167            let f32_format = JxlPixelFormat {
1168                color_type,
1169                color_data_format: Some(JxlDataFormat::f32()),
1170                extra_channel_format: vec![],
1171            };
1172            let u8_format = JxlPixelFormat {
1173                color_type,
1174                color_data_format: Some(JxlDataFormat::U8 { bit_depth: 8 }),
1175                extra_channel_format: vec![],
1176            };
1177
1178            // Test both pipelines
1179            for use_simple in [true, false] {
1180                let (f32_buffer, width, height) =
1181                    decode_with_format::<f32>(&file, &f32_format, use_simple, false);
1182                let (u8_buffer, _, _) =
1183                    decode_with_format::<u8>(&file, &u8_format, use_simple, false);
1184
1185                // Compare values: u8 / 255.0 should match f32
1186                // Tolerance: quantization error of ±0.5/255 ≈ 0.00196 plus small rounding
1187                let tolerance = 0.003;
1188                let mut max_error: f32 = 0.0;
1189
1190                for y in 0..height {
1191                    let f32_row = f32_buffer.row(y);
1192                    let u8_row = u8_buffer.row(y);
1193                    for x in 0..(width * num_samples) {
1194                        let f32_val = f32_row[x].clamp(0.0, 1.0);
1195                        let u8_val = u8_row[x] as f32 / 255.0;
1196                        let error = (f32_val - u8_val).abs();
1197                        max_error = max_error.max(error);
1198                        assert!(
1199                            error < tolerance,
1200                            "{:?} u8 mismatch at ({},{}): f32={}, u8={} (scaled={}), error={} (use_simple={})",
1201                            color_type,
1202                            x,
1203                            y,
1204                            f32_val,
1205                            u8_row[x],
1206                            u8_val,
1207                            error,
1208                            use_simple
1209                        );
1210                    }
1211                }
1212            }
1213        }
1214    }
1215
1216    /// Test that u16 output matches f32 output within quantization tolerance.
1217    #[test]
1218    fn test_output_format_u16_matches_f32() {
1219        use crate::api::{Endianness, JxlColorType, JxlDataFormat, JxlPixelFormat};
1220
1221        let file = std::fs::read("resources/test/conformance_test_images/bicycles.jxl").unwrap();
1222
1223        // Test both RGB and BGRA
1224        for (color_type, num_samples) in [(JxlColorType::Rgb, 3), (JxlColorType::Bgra, 4)] {
1225            let f32_format = JxlPixelFormat {
1226                color_type,
1227                color_data_format: Some(JxlDataFormat::f32()),
1228                extra_channel_format: vec![],
1229            };
1230            let u16_format = JxlPixelFormat {
1231                color_type,
1232                color_data_format: Some(JxlDataFormat::U16 {
1233                    endianness: Endianness::native(),
1234                    bit_depth: 16,
1235                }),
1236                extra_channel_format: vec![],
1237            };
1238
1239            for use_simple in [true, false] {
1240                let (f32_buffer, width, height) =
1241                    decode_with_format::<f32>(&file, &f32_format, use_simple, false);
1242                let (u16_buffer, _, _) =
1243                    decode_with_format::<u16>(&file, &u16_format, use_simple, false);
1244
1245                // Tolerance: quantization error of ±0.5/65535 plus small rounding
1246                let tolerance = 0.0001;
1247
1248                for y in 0..height {
1249                    let f32_row = f32_buffer.row(y);
1250                    let u16_row = u16_buffer.row(y);
1251                    for x in 0..(width * num_samples) {
1252                        let f32_val = f32_row[x].clamp(0.0, 1.0);
1253                        let u16_val = u16_row[x] as f32 / 65535.0;
1254                        let error = (f32_val - u16_val).abs();
1255                        assert!(
1256                            error < tolerance,
1257                            "{:?} u16 mismatch at ({},{}): f32={}, u16={} (scaled={}), error={} (use_simple={})",
1258                            color_type,
1259                            x,
1260                            y,
1261                            f32_val,
1262                            u16_row[x],
1263                            u16_val,
1264                            error,
1265                            use_simple
1266                        );
1267                    }
1268                }
1269            }
1270        }
1271    }
1272
1273    /// Test that f16 output matches f32 output within f16 precision tolerance.
1274    #[test]
1275    fn test_output_format_f16_matches_f32() {
1276        use crate::api::{Endianness, JxlColorType, JxlDataFormat, JxlPixelFormat};
1277        use crate::util::f16;
1278
1279        let file = std::fs::read("resources/test/conformance_test_images/bicycles.jxl").unwrap();
1280
1281        // Test both RGB and BGRA
1282        for (color_type, num_samples) in [(JxlColorType::Rgb, 3), (JxlColorType::Bgra, 4)] {
1283            let f32_format = JxlPixelFormat {
1284                color_type,
1285                color_data_format: Some(JxlDataFormat::f32()),
1286                extra_channel_format: vec![],
1287            };
1288            let f16_format = JxlPixelFormat {
1289                color_type,
1290                color_data_format: Some(JxlDataFormat::F16 {
1291                    endianness: Endianness::native(),
1292                }),
1293                extra_channel_format: vec![],
1294            };
1295
1296            for use_simple in [true, false] {
1297                let (f32_buffer, width, height) =
1298                    decode_with_format::<f32>(&file, &f32_format, use_simple, false);
1299                let (f16_buffer, _, _) =
1300                    decode_with_format::<f16>(&file, &f16_format, use_simple, false);
1301
1302                // f16 has about 3 decimal digits of precision
1303                // For values in [0,1], the relative error is about 0.001
1304                let tolerance = 0.002;
1305
1306                for y in 0..height {
1307                    let f32_row = f32_buffer.row(y);
1308                    let f16_row = f16_buffer.row(y);
1309                    for x in 0..(width * num_samples) {
1310                        let f32_val = f32_row[x];
1311                        let f16_val = f16_row[x].to_f32();
1312                        let error = (f32_val - f16_val).abs();
1313                        assert!(
1314                            error < tolerance,
1315                            "{:?} f16 mismatch at ({},{}): f32={}, f16={}, error={} (use_simple={})",
1316                            color_type,
1317                            x,
1318                            y,
1319                            f32_val,
1320                            f16_val,
1321                            error,
1322                            use_simple
1323                        );
1324                    }
1325                }
1326            }
1327        }
1328    }
1329
1330    /// Helper function to decode an image with a specific format.
1331    fn decode_with_format<T: crate::image::ImageDataType>(
1332        file: &[u8],
1333        pixel_format: &JxlPixelFormat,
1334        use_simple: bool,
1335        premultiply: bool,
1336    ) -> (Image<T>, usize, usize) {
1337        let options = JxlDecoderOptions {
1338            premultiply_output: premultiply,
1339            ..Default::default()
1340        };
1341        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
1342        let mut input = file;
1343
1344        // Advance to image info
1345        let mut decoder = loop {
1346            match decoder.process(&mut input).unwrap() {
1347                ProcessingResult::Complete { result } => break result,
1348                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1349                    if input.is_empty() {
1350                        panic!("Unexpected end of input");
1351                    }
1352                    decoder = fallback;
1353                }
1354            }
1355        };
1356        decoder.set_use_simple_pipeline(use_simple);
1357        decoder.set_pixel_format(pixel_format.clone());
1358
1359        let basic_info = decoder.basic_info().clone();
1360        let (width, height) = basic_info.size;
1361
1362        let num_samples = pixel_format.color_type.samples_per_pixel();
1363
1364        // Advance to frame info
1365        let decoder = loop {
1366            match decoder.process(&mut input).unwrap() {
1367                ProcessingResult::Complete { result } => break result,
1368                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1369                    if input.is_empty() {
1370                        panic!("Unexpected end of input");
1371                    }
1372                    decoder = fallback;
1373                }
1374            }
1375        };
1376
1377        let mut buffer = Image::<T>::new((width * num_samples, height)).unwrap();
1378        let mut buffers: Vec<_> = vec![JxlOutputBuffer::from_image_rect_mut(
1379            buffer
1380                .get_rect_mut(Rect {
1381                    origin: (0, 0),
1382                    size: (width * num_samples, height),
1383                })
1384                .into_raw(),
1385        )];
1386
1387        // Decode
1388        let mut decoder = decoder;
1389        loop {
1390            match decoder.process(&mut input, &mut buffers).unwrap() {
1391                ProcessingResult::Complete { .. } => break,
1392                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1393                    if input.is_empty() {
1394                        panic!("Unexpected end of input");
1395                    }
1396                    decoder = fallback;
1397                }
1398            }
1399        }
1400
1401        (buffer, width, height)
1402    }
1403
1404    /// Regression test for ClusterFuzz issue 5342436251336704
1405    /// Tests that malformed JXL files with overflow-inducing data don't panic
1406    #[test]
1407    fn test_fuzzer_smallbuffer_overflow() {
1408        use std::panic;
1409
1410        let data = include_bytes!("../../tests/testdata/fuzzer_smallbuffer_overflow.jxl");
1411
1412        // The test passes if it doesn't panic with "attempt to add with overflow"
1413        // It's OK if it returns an error or panics with "Unexpected end of input"
1414        let result = panic::catch_unwind(|| {
1415            let _ = decode(data, 1024, false, false, None);
1416        });
1417
1418        // If it panicked, make sure it wasn't an overflow panic
1419        if let Err(e) = result {
1420            let panic_msg = e
1421                .downcast_ref::<&str>()
1422                .map(|s| s.to_string())
1423                .or_else(|| e.downcast_ref::<String>().cloned())
1424                .unwrap_or_default();
1425            assert!(
1426                !panic_msg.contains("overflow"),
1427                "Unexpected overflow panic: {}",
1428                panic_msg
1429            );
1430        }
1431    }
1432
1433    /// Helper to wrap a bare codestream in a JXL container with a jxli frame index box.
1434    fn wrap_with_frame_index(
1435        codestream: &[u8],
1436        tnum: u32,
1437        tden: u32,
1438        entries: &[(u64, u64, u64)], // (OFF_delta, T, F)
1439    ) -> Vec<u8> {
1440        use crate::util::test::build_frame_index_content;
1441
1442        fn make_box(ty: &[u8; 4], content: &[u8]) -> Vec<u8> {
1443            let len = (8 + content.len()) as u32;
1444            let mut buf = Vec::new();
1445            buf.extend(len.to_be_bytes());
1446            buf.extend(ty);
1447            buf.extend(content);
1448            buf
1449        }
1450
1451        let jxli_content = build_frame_index_content(tnum, tden, entries);
1452
1453        // JXL signature box
1454        let sig = [
1455            0x00, 0x00, 0x00, 0x0c, 0x4a, 0x58, 0x4c, 0x20, 0x0d, 0x0a, 0x87, 0x0a,
1456        ];
1457        // ftyp box
1458        let ftyp = make_box(b"ftyp", b"jxl \x00\x00\x00\x00jxl ");
1459        let jxli = make_box(b"jxli", &jxli_content);
1460        let jxlc = make_box(b"jxlc", codestream);
1461
1462        let mut container = Vec::new();
1463        container.extend(&sig);
1464        container.extend(&ftyp);
1465        container.extend(&jxli);
1466        container.extend(&jxlc);
1467        container
1468    }
1469
1470    #[test]
1471    fn test_frame_index_parsed_from_container() {
1472        // Read a bare animation codestream and wrap it in a container with a jxli box.
1473        let codestream =
1474            std::fs::read("resources/test/conformance_test_images/animation_icos4d_5.jxl").unwrap();
1475
1476        // Create synthetic frame index entries (delta offsets).
1477        // These are synthetic -- we don't know real frame offsets, but we can verify parsing.
1478        let entries = vec![
1479            (0u64, 100u64, 1u64), // Frame 0 at offset 0
1480            (500, 100, 1),        // Frame 1 at offset 500
1481            (600, 100, 1),        // Frame 2 at offset 1100
1482        ];
1483
1484        let container = wrap_with_frame_index(&codestream, 1, 1000, &entries);
1485
1486        // Decode with a large chunk size so the jxli box is fully consumed.
1487        let options = JxlDecoderOptions::default();
1488        let mut dec = JxlDecoder::<states::Initialized>::new(options);
1489        let mut input: &[u8] = &container;
1490        let dec = loop {
1491            match dec.process(&mut input).unwrap() {
1492                ProcessingResult::Complete { result } => break result,
1493                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1494                    if input.is_empty() {
1495                        panic!("Unexpected end of input");
1496                    }
1497                    dec = fallback;
1498                }
1499            }
1500        };
1501
1502        // Check that frame index was parsed.
1503        let fi = dec.frame_index().expect("frame_index should be Some");
1504        assert_eq!(fi.num_frames(), 3);
1505        assert_eq!(fi.tnum, 1);
1506        assert_eq!(fi.tden.get(), 1000);
1507        // Verify absolute offsets (accumulated from deltas)
1508        assert_eq!(fi.entries[0].codestream_offset, 0);
1509        assert_eq!(fi.entries[1].codestream_offset, 500);
1510        assert_eq!(fi.entries[2].codestream_offset, 1100);
1511        assert_eq!(fi.entries[0].duration_ticks, 100);
1512        assert_eq!(fi.entries[2].frame_count, 1);
1513    }
1514
1515    #[test]
1516    fn test_frame_index_none_for_bare_codestream() {
1517        // A bare codestream has no container, so no frame index.
1518        let data =
1519            std::fs::read("resources/test/conformance_test_images/animation_icos4d_5.jxl").unwrap();
1520        let options = JxlDecoderOptions::default();
1521        let mut dec = JxlDecoder::<states::Initialized>::new(options);
1522        let mut input: &[u8] = &data;
1523        let dec = loop {
1524            match dec.process(&mut input).unwrap() {
1525                ProcessingResult::Complete { result } => break result,
1526                ProcessingResult::NeedsMoreInput { fallback, .. } => {
1527                    if input.is_empty() {
1528                        panic!("Unexpected end of input");
1529                    }
1530                    dec = fallback;
1531                }
1532            }
1533        };
1534        assert!(dec.frame_index().is_none());
1535    }
1536
1537    /// Regression test for Chromium ClusterFuzz issue 474401148.
1538    #[test]
1539    fn test_fuzzer_xyb_icc_no_panic() {
1540        use crate::api::ProcessingResult;
1541
1542        #[rustfmt::skip]
1543        let data: &[u8] = &[
1544            0xff, 0x0a, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00,
1545            0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x25, 0x00,
1546        ];
1547
1548        let opts = JxlDecoderOptions::default();
1549        let mut decoder = JxlDecoderInner::new(opts);
1550        let mut input = data;
1551
1552        if let Ok(ProcessingResult::Complete { .. }) = decoder.process(&mut input, None)
1553            && let Some(profile) = decoder.output_color_profile()
1554        {
1555            let _ = profile.try_as_icc();
1556        }
1557    }
1558
1559    #[test]
1560    fn test_pixel_limit_enforcement() {
1561        // Load a test image - green_queen is 256x256 = 65536 pixels
1562        let input = std::fs::read("resources/test/green_queen_vardct_e3.jxl").unwrap();
1563
1564        // Create options with a very restrictive pixel limit (smaller than the image)
1565        let mut options = JxlDecoderOptions::default();
1566        options.limits.max_pixels = Some(100); // Only 100 pixels allowed
1567
1568        let decoder = JxlDecoder::<states::Initialized>::new(options);
1569        let mut input_slice = &input[..];
1570
1571        // The decoder should fail when parsing the header with LimitExceeded error
1572        let result = decoder.process(&mut input_slice);
1573        match result {
1574            Err(err) => {
1575                assert!(
1576                    matches!(
1577                        err,
1578                        Error::LimitExceeded {
1579                            resource: "pixels",
1580                            ..
1581                        }
1582                    ),
1583                    "Expected LimitExceeded for pixels, got {:?}",
1584                    err
1585                );
1586            }
1587            Ok(ProcessingResult::NeedsMoreInput { .. }) => {
1588                panic!("Expected error, got needs more input");
1589            }
1590            Ok(ProcessingResult::Complete { .. }) => {
1591                panic!("Expected error, got success");
1592            }
1593        }
1594    }
1595
1596    #[test]
1597    fn test_restrictive_limits_preset() {
1598        // Verify the restrictive preset is reasonable
1599        let limits = crate::api::JxlDecoderLimits::restrictive();
1600        assert_eq!(limits.max_pixels, Some(100_000_000));
1601        assert_eq!(limits.max_extra_channels, Some(16));
1602        assert_eq!(limits.max_icc_size, Some(1 << 20));
1603        assert_eq!(limits.max_tree_size, Some(1 << 20));
1604        assert_eq!(limits.max_patches, Some(1 << 16));
1605        assert_eq!(limits.max_spline_points, Some(1 << 16));
1606        assert_eq!(limits.max_reference_frames, Some(2));
1607        assert_eq!(limits.max_memory_bytes, Some(1 << 30));
1608    }
1609
1610    #[test]
1611    fn test_extra_channel_metadata() {
1612        use crate::headers::extra_channels::ExtraChannel;
1613
1614        let file = std::fs::read("resources/test/extra_channels.jxl").unwrap();
1615        let options = JxlDecoderOptions::default();
1616        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
1617        let mut input = file.as_slice();
1618        let decoder = loop {
1619            match decoder.process(&mut input).unwrap() {
1620                ProcessingResult::Complete { result } => break result,
1621                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
1622            }
1623        };
1624        let info = decoder.basic_info();
1625        // extra_channels.jxl should have at least one extra channel
1626        assert!(
1627            !info.extra_channels.is_empty(),
1628            "expected at least one extra channel"
1629        );
1630
1631        // Verify all new fields are populated
1632        for ec in &info.extra_channels {
1633            // bits_per_sample should be a reasonable value
1634            assert!(
1635                ec.bits_per_sample > 0 && ec.bits_per_sample <= 32,
1636                "unexpected bits_per_sample: {}",
1637                ec.bits_per_sample
1638            );
1639            // dim_shift should be <= 3
1640            assert!(ec.dim_shift <= 3, "unexpected dim_shift: {}", ec.dim_shift);
1641        }
1642    }
1643
1644    #[test]
1645    fn test_extra_channel_alpha_with_new_fields() {
1646        use crate::headers::extra_channels::ExtraChannel;
1647
1648        // 3x3a has alpha
1649        let file = std::fs::read("resources/test/3x3a_srgb_lossless.jxl").unwrap();
1650        let options = JxlDecoderOptions::default();
1651        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
1652        let mut input = file.as_slice();
1653        let decoder = loop {
1654            match decoder.process(&mut input).unwrap() {
1655                ProcessingResult::Complete { result } => break result,
1656                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
1657            }
1658        };
1659        let info = decoder.basic_info();
1660        // Should have exactly one extra channel of type Alpha
1661        assert_eq!(info.extra_channels.len(), 1);
1662        let alpha = &info.extra_channels[0];
1663        assert_eq!(alpha.ec_type, ExtraChannel::Alpha);
1664        assert!(alpha.bits_per_sample > 0);
1665        // Default alpha channels typically have dim_shift 0 (full resolution)
1666        assert_eq!(alpha.dim_shift, 0);
1667    }
1668
1669    #[test]
1670    fn test_preview_metadata_in_basic_info() {
1671        // with_preview.jxl has a preview; basic.jxl does not
1672        let file = std::fs::read("resources/test/with_preview.jxl").unwrap();
1673        let options = JxlDecoderOptions::default();
1674        let mut decoder = JxlDecoder::<states::Initialized>::new(options);
1675        let mut input = file.as_slice();
1676        let decoder = loop {
1677            match decoder.process(&mut input).unwrap() {
1678                ProcessingResult::Complete { result } => break result,
1679                ProcessingResult::NeedsMoreInput { fallback, .. } => decoder = fallback,
1680            }
1681        };
1682        let info = decoder.basic_info();
1683        let (pw, ph) = info.preview_size.expect("expected preview_size");
1684        assert!(pw > 0 && ph > 0, "preview dimensions should be positive");
1685    }
1686
1687    #[test]
1688    fn test_stop_cancellation() {
1689        use almost_enough::Stopper;
1690        use enough::Stop;
1691
1692        let stop = Stopper::new();
1693        assert!(!stop.should_stop());
1694        stop.cancel();
1695        assert!(stop.should_stop());
1696        // Verify it integrates with our error type
1697        let result: crate::error::Result<()> = stop.check().map_err(Into::into);
1698        assert!(matches!(result, Err(crate::error::Error::Cancelled)));
1699    }
1700}