Skip to main content

zenjxl_decoder/image/
output_buffer.rs

1// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6use std::fmt::Debug;
7
8use super::{RawImageRectMut, Rect, internal::RawImageBuffer};
9
10/// Internal storage for a `JxlOutputBuffer`.
11///
12/// `Contiguous`: a single mutable byte slice with a row stride.
13/// `Fragmented`: a collection of per-row mutable byte slices, used for
14/// column-split tile fragments where rows are non-contiguous in memory.
15pub(crate) enum BufferStorage<'a> {
16    Contiguous {
17        data: &'a mut [u8],
18        bytes_between_rows: usize,
19    },
20    #[allow(dead_code)] // Constructed by split methods used in threads feature
21    Fragmented { rows: Vec<&'a mut [u8]> },
22}
23
24pub struct JxlOutputBuffer<'a> {
25    storage: BufferStorage<'a>,
26    bytes_per_row: usize,
27    num_rows: usize,
28    /// Row offset for band buffers: row_mut(r) accesses storage row (r - row_offset).
29    /// Always 0 for normal buffers. Set by split_into_row_bands for band sub-buffers.
30    row_offset: usize,
31}
32
33impl<'a> JxlOutputBuffer<'a> {
34    pub fn from_image_rect_mut(raw: RawImageRectMut<'a>) -> Self {
35        Self {
36            storage: BufferStorage::Contiguous {
37                data: raw.storage,
38                bytes_between_rows: raw.bytes_between_rows,
39            },
40            bytes_per_row: raw.bytes_per_row,
41            num_rows: raw.num_rows,
42            row_offset: 0,
43        }
44    }
45
46    /// Creates a new JxlOutputBuffer from a mutable byte slice.
47    pub fn new(buf: &'a mut [u8], num_rows: usize, bytes_per_row: usize) -> Self {
48        RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_per_row);
49        let expected_len = if num_rows == 0 {
50            0
51        } else {
52            (num_rows - 1) * bytes_per_row + bytes_per_row
53        };
54        assert!(buf.len() >= expected_len);
55        Self {
56            storage: BufferStorage::Contiguous {
57                data: if expected_len == 0 {
58                    &mut []
59                } else {
60                    &mut buf[..expected_len]
61                },
62                bytes_between_rows: bytes_per_row,
63            },
64            bytes_per_row,
65            num_rows,
66            row_offset: 0,
67        }
68    }
69
70    /// Creates a new JxlOutputBuffer from raw pointers.
71    /// It is guaranteed that `buf` will never be used to write uninitialized data.
72    ///
73    /// # Safety
74    /// - `buf` must be valid for writes for all bytes in the range
75    ///   `buf[i*bytes_between_rows..i*bytes_between_rows+bytes_per_row]` for all values of `i`
76    ///   from `0` to `num_rows-1`.
77    /// - The bytes in these ranges must not be accessed as long as the returned `Self` is in scope.
78    /// - All the bytes in those ranges (and in between) must be part of the same allocated object.
79    #[cfg(feature = "allow-unsafe")]
80    #[allow(unsafe_code)]
81    pub unsafe fn new_from_ptr(
82        buf: *mut std::mem::MaybeUninit<u8>,
83        num_rows: usize,
84        bytes_per_row: usize,
85        bytes_between_rows: usize,
86    ) -> Self {
87        RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_between_rows);
88        let total_len = if num_rows == 0 {
89            0
90        } else {
91            (num_rows - 1) * bytes_between_rows + bytes_per_row
92        };
93        let data = if total_len == 0 {
94            &mut []
95        } else {
96            // SAFETY: Caller guarantees `buf` is valid for `total_len` bytes.
97            // MaybeUninit<u8> and u8 have identical layout.
98            unsafe { std::slice::from_raw_parts_mut(buf as *mut u8, total_len) }
99        };
100        Self {
101            storage: BufferStorage::Contiguous {
102                data,
103                bytes_between_rows,
104            },
105            bytes_per_row,
106            num_rows,
107            row_offset: 0,
108        }
109    }
110
111    pub(crate) fn reborrow(lender: &'a mut JxlOutputBuffer<'_>) -> JxlOutputBuffer<'a> {
112        JxlOutputBuffer {
113            storage: match &mut lender.storage {
114                BufferStorage::Contiguous {
115                    data,
116                    bytes_between_rows,
117                } => BufferStorage::Contiguous {
118                    data,
119                    bytes_between_rows: *bytes_between_rows,
120                },
121                BufferStorage::Fragmented { rows } => BufferStorage::Fragmented {
122                    rows: rows.iter_mut().map(|r| &mut **r).collect(),
123                },
124            },
125            bytes_per_row: lender.bytes_per_row,
126            num_rows: lender.num_rows,
127            row_offset: lender.row_offset,
128        }
129    }
130
131    /// Returns a mutable row as a byte slice.
132    ///
133    /// For band buffers (row_offset > 0), `row` is in the parent buffer's
134    /// coordinate system and is translated internally.
135    #[inline]
136    pub(crate) fn row_mut(&mut self, row: usize) -> &mut [u8] {
137        let local_row = row.wrapping_sub(self.row_offset);
138        assert!(
139            local_row < self.num_rows,
140            "row {row} out of range [{}, {})",
141            self.row_offset,
142            self.row_offset + self.num_rows,
143        );
144        match &mut self.storage {
145            BufferStorage::Contiguous {
146                data,
147                bytes_between_rows,
148            } => {
149                let start = local_row * *bytes_between_rows;
150                &mut data[start..start + self.bytes_per_row]
151            }
152            BufferStorage::Fragmented { rows } => &mut rows[local_row][..self.bytes_per_row],
153        }
154    }
155
156    #[inline]
157    pub fn write_bytes(&mut self, row: usize, col: usize, bytes: &[u8]) {
158        let slice = self.row_mut(row);
159        slice[col..col + bytes.len()].copy_from_slice(bytes);
160    }
161
162    pub fn byte_size(&self) -> (usize, usize) {
163        (self.bytes_per_row, self.num_rows)
164    }
165
166    /// Split this buffer into non-overlapping row bands.
167    ///
168    /// `split_rows` contains sorted row indices where splits occur, in the
169    /// buffer's own coordinate system (i.e., relative to `self.row_offset`).
170    /// Returns one sub-buffer per band: [0, split_rows[0]),
171    /// [split_rows[0], split_rows[1]), ..., [split_rows[last], num_rows).
172    ///
173    /// Each returned sub-buffer has its `row_offset` set so that callers can
174    /// use the parent buffer's coordinate system for row access.
175    ///
176    /// While the returned sub-buffers are alive, `self` cannot be used.
177    /// When they are dropped, `self` becomes available again.
178    #[cfg(feature = "threads")]
179    #[allow(dead_code)] // Parallel row-band splitting for threaded output
180    pub(crate) fn split_into_row_bands(
181        &mut self,
182        split_rows: &[usize],
183    ) -> Vec<JxlOutputBuffer<'_>> {
184        let bpr = self.bytes_per_row;
185        let nrows = self.num_rows;
186        let base_offset = self.row_offset;
187
188        match &mut self.storage {
189            BufferStorage::Contiguous {
190                data,
191                bytes_between_rows,
192            } => {
193                let btr = *bytes_between_rows;
194                let mut result = Vec::with_capacity(split_rows.len() + 1);
195                let mut remaining: &mut [u8] = data;
196                let mut current_row = 0;
197
198                for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
199                    assert!(
200                        split_row >= current_row && split_row <= nrows,
201                        "split_rows must be sorted and <= num_rows"
202                    );
203                    let band_rows = split_row - current_row;
204
205                    if band_rows == 0 {
206                        result.push(JxlOutputBuffer {
207                            storage: BufferStorage::Contiguous {
208                                data: &mut [],
209                                bytes_between_rows: btr,
210                            },
211                            bytes_per_row: bpr,
212                            num_rows: 0,
213                            row_offset: base_offset + current_row,
214                        });
215                    } else {
216                        let span = (band_rows - 1) * btr + bpr;
217                        if split_row < nrows {
218                            let total_bytes = band_rows * btr;
219                            let tmp = remaining;
220                            let (band_full, rest) = tmp.split_at_mut(total_bytes);
221                            result.push(JxlOutputBuffer {
222                                storage: BufferStorage::Contiguous {
223                                    data: &mut band_full[..span],
224                                    bytes_between_rows: btr,
225                                },
226                                bytes_per_row: bpr,
227                                num_rows: band_rows,
228                                row_offset: base_offset + current_row,
229                            });
230                            remaining = rest;
231                        } else {
232                            let tmp = remaining;
233                            let (band, _) = tmp.split_at_mut(span);
234                            result.push(JxlOutputBuffer {
235                                storage: BufferStorage::Contiguous {
236                                    data: band,
237                                    bytes_between_rows: btr,
238                                },
239                                bytes_per_row: bpr,
240                                num_rows: band_rows,
241                                row_offset: base_offset + current_row,
242                            });
243                            remaining = &mut [];
244                        }
245                    }
246                    current_row = split_row;
247                }
248                result
249            }
250            BufferStorage::Fragmented { rows } => {
251                let mut result = Vec::with_capacity(split_rows.len() + 1);
252                let mut remaining: &mut [&'_ mut [u8]] = rows;
253                let mut current_row = 0;
254
255                for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
256                    assert!(
257                        split_row >= current_row && split_row <= nrows,
258                        "split_rows must be sorted and <= num_rows"
259                    );
260                    let band_rows = split_row - current_row;
261
262                    let tmp = remaining;
263                    let (band_slice, rest) = tmp.split_at_mut(band_rows);
264                    result.push(JxlOutputBuffer {
265                        storage: BufferStorage::Fragmented {
266                            rows: band_slice.iter_mut().map(|r| &mut **r).collect(),
267                        },
268                        bytes_per_row: bpr,
269                        num_rows: band_rows,
270                        row_offset: base_offset + current_row,
271                    });
272                    remaining = rest;
273                    current_row = split_row;
274                }
275                result
276            }
277        }
278    }
279
280    /// Split this buffer into non-overlapping column fragments.
281    ///
282    /// `split_cols` contains sorted byte-column positions where splits occur.
283    /// Returns one sub-buffer per fragment: [0, split_cols[0]),
284    /// [split_cols[0], split_cols[1]), ..., [split_cols[last], bytes_per_row).
285    ///
286    /// Each returned sub-buffer uses `Fragmented` storage (per-row slices)
287    /// since column sub-ranges are not contiguous across rows.
288    ///
289    /// Preserves `row_offset` so callers can use the parent buffer's coordinate
290    /// system for row access.
291    #[cfg(feature = "threads")]
292    #[allow(dead_code)] // Parallel column-fragment splitting for threaded output
293    pub(crate) fn split_into_col_fragments(
294        &mut self,
295        split_cols: &[usize],
296    ) -> Vec<JxlOutputBuffer<'_>> {
297        let bpr = self.bytes_per_row;
298        let nrows = self.num_rows;
299        let base_offset = self.row_offset;
300        let num_frags = split_cols.len() + 1;
301
302        // Validate split_cols.
303        for (i, &col) in split_cols.iter().enumerate() {
304            assert!(col <= bpr, "split_col {col} exceeds bytes_per_row {bpr}");
305            if i > 0 {
306                assert!(col >= split_cols[i - 1], "split_cols must be sorted");
307            }
308        }
309
310        // Pre-allocate fragment row collectors.
311        let mut fragment_rows: Vec<Vec<&mut [u8]>> =
312            (0..num_frags).map(|_| Vec::with_capacity(nrows)).collect();
313
314        match &mut self.storage {
315            BufferStorage::Contiguous {
316                data,
317                bytes_between_rows,
318            } => {
319                let btr = *bytes_between_rows;
320                let mut remaining: &mut [u8] = data;
321
322                for row_idx in 0..nrows {
323                    // Always consume via split_at_mut so the borrow checker
324                    // sees `remaining` is moved, not aliased.
325                    let tmp = remaining;
326                    let split_point = if row_idx < nrows - 1 { btr } else { tmp.len() };
327                    let (chunk, rest) = tmp.split_at_mut(split_point);
328                    remaining = rest;
329                    let row_useful = &mut chunk[..bpr];
330
331                    // Split this row into column fragments.
332                    let mut col_remaining = row_useful;
333                    let mut prev_col = 0;
334                    for (frag_idx, &split_col) in split_cols.iter().enumerate() {
335                        let width = split_col - prev_col;
336                        let (frag, rest) = col_remaining.split_at_mut(width);
337                        fragment_rows[frag_idx].push(frag);
338                        col_remaining = rest;
339                        prev_col = split_col;
340                    }
341                    fragment_rows[num_frags - 1].push(col_remaining);
342                }
343            }
344            BufferStorage::Fragmented { rows } => {
345                for row in rows.iter_mut() {
346                    let mut col_remaining: &mut [u8] = row;
347                    let mut prev_col = 0;
348                    for (frag_idx, &split_col) in split_cols.iter().enumerate() {
349                        let width = split_col - prev_col;
350                        let (frag, rest) = col_remaining.split_at_mut(width);
351                        fragment_rows[frag_idx].push(frag);
352                        col_remaining = rest;
353                        prev_col = split_col;
354                    }
355                    fragment_rows[num_frags - 1].push(col_remaining);
356                }
357            }
358        }
359
360        // Build result buffers.
361        let mut result = Vec::with_capacity(num_frags);
362        let mut prev_col = 0;
363        for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
364            let col_end = if frag_idx < split_cols.len() {
365                split_cols[frag_idx]
366            } else {
367                bpr
368            };
369            let width = col_end - prev_col;
370            result.push(JxlOutputBuffer {
371                storage: BufferStorage::Fragmented { rows },
372                bytes_per_row: width,
373                num_rows: nrows,
374                row_offset: base_offset,
375            });
376            prev_col = col_end;
377        }
378        result
379    }
380
381    /// Split this buffer into a 2D grid of fragments (by rows then columns).
382    ///
383    /// Combines row-band splitting and column-fragment splitting in a single
384    /// method to avoid multi-level borrowing. All returned fragments borrow
385    /// directly from `self.storage`.
386    ///
387    /// `split_rows`: sorted row indices where gy bands split.
388    /// `split_cols_per_band`: for each band, sorted column indices where gx tiles split.
389    ///
390    /// Returns `result[band_idx][frag_idx]` where each fragment uses `Fragmented`
391    /// storage with `row_offset` set to the band's starting row.
392    #[cfg(feature = "threads")]
393    pub(crate) fn split_into_tile_grid(
394        &mut self,
395        split_rows: &[usize],
396        split_cols_per_band: &[&[usize]],
397    ) -> Vec<Vec<JxlOutputBuffer<'_>>> {
398        let bpr = self.bytes_per_row;
399        let nrows = self.num_rows;
400        let base_offset = self.row_offset;
401        let num_bands = split_rows.len() + 1;
402
403        assert_eq!(
404            split_cols_per_band.len(),
405            num_bands,
406            "need one set of split_cols per band"
407        );
408
409        let BufferStorage::Contiguous {
410            data,
411            bytes_between_rows,
412        } = &mut self.storage
413        else {
414            panic!("split_into_tile_grid requires Contiguous storage")
415        };
416        let btr = *bytes_between_rows;
417        let mut remaining: &mut [u8] = data;
418
419        let mut result: Vec<Vec<JxlOutputBuffer<'_>>> = Vec::with_capacity(num_bands);
420        let mut current_row = 0;
421
422        for band_idx in 0..num_bands {
423            let band_end = if band_idx < split_rows.len() {
424                split_rows[band_idx]
425            } else {
426                nrows
427            };
428            assert!(
429                band_end >= current_row && band_end <= nrows,
430                "split_rows must be sorted and <= num_rows"
431            );
432            let band_rows = band_end - current_row;
433            let split_cols = split_cols_per_band[band_idx];
434            let num_frags = split_cols.len() + 1;
435
436            // Pre-allocate per-fragment row collectors for this band.
437            let mut fragment_rows: Vec<Vec<&mut [u8]>> = (0..num_frags)
438                .map(|_| Vec::with_capacity(band_rows))
439                .collect();
440
441            for row_offset_in_band in 0..band_rows {
442                let is_last_overall_row = current_row + row_offset_in_band == nrows - 1;
443                let tmp = remaining;
444                let split_point = if is_last_overall_row { tmp.len() } else { btr };
445                let (chunk, rest) = tmp.split_at_mut(split_point);
446                remaining = rest;
447                let row_useful = &mut chunk[..bpr];
448
449                // Split row into column fragments.
450                let mut col_remaining = row_useful;
451                let mut prev_col = 0;
452                for (frag_idx, &split_col) in split_cols.iter().enumerate() {
453                    let width = split_col - prev_col;
454                    let (frag, rest) = col_remaining.split_at_mut(width);
455                    fragment_rows[frag_idx].push(frag);
456                    col_remaining = rest;
457                    prev_col = split_col;
458                }
459                fragment_rows[num_frags - 1].push(col_remaining);
460            }
461
462            // Build fragment buffers for this band.
463            let mut band_frags = Vec::with_capacity(num_frags);
464            let mut prev_col = 0;
465            for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
466                let col_end = if frag_idx < split_cols.len() {
467                    split_cols[frag_idx]
468                } else {
469                    bpr
470                };
471                let width = col_end - prev_col;
472                band_frags.push(JxlOutputBuffer {
473                    storage: BufferStorage::Fragmented { rows },
474                    bytes_per_row: width,
475                    num_rows: band_rows,
476                    row_offset: base_offset + current_row,
477                });
478                prev_col = col_end;
479            }
480            result.push(band_frags);
481            current_row = band_end;
482        }
483        result
484    }
485
486    pub fn rect(&mut self, rect: Rect) -> JxlOutputBuffer<'_> {
487        if rect.size.0 == 0 || rect.size.1 == 0 {
488            return JxlOutputBuffer {
489                storage: BufferStorage::Contiguous {
490                    data: &mut [],
491                    bytes_between_rows: 0,
492                },
493                bytes_per_row: 0,
494                num_rows: 0,
495                row_offset: 0,
496            };
497        }
498        assert!(
499            rect.origin.1 >= self.row_offset,
500            "rect origin row {} < row_offset {}",
501            rect.origin.1,
502            self.row_offset,
503        );
504        let local_y = rect.origin.1 - self.row_offset;
505        assert!(local_y + rect.size.1 <= self.num_rows);
506        assert!(rect.origin.0 + rect.size.0 <= self.bytes_per_row);
507
508        match &mut self.storage {
509            BufferStorage::Contiguous {
510                data,
511                bytes_between_rows,
512            } => {
513                let btr = *bytes_between_rows;
514                let new_start = local_y * btr + rect.origin.0;
515                let data_span = (rect.size.1 - 1) * btr + rect.size.0;
516                assert!(new_start + data_span <= data.len());
517                JxlOutputBuffer {
518                    storage: BufferStorage::Contiguous {
519                        data: &mut data[new_start..new_start + data_span],
520                        bytes_between_rows: btr,
521                    },
522                    bytes_per_row: rect.size.0,
523                    num_rows: rect.size.1,
524                    // Child views from rect() are always 0-based.
525                    row_offset: 0,
526                }
527            }
528            BufferStorage::Fragmented { rows } => {
529                let col_start = rect.origin.0;
530                let col_end = col_start + rect.size.0;
531                let sub_rows: Vec<&mut [u8]> = rows[local_y..local_y + rect.size.1]
532                    .iter_mut()
533                    .map(|row| &mut row[col_start..col_end])
534                    .collect();
535                JxlOutputBuffer {
536                    storage: BufferStorage::Fragmented { rows: sub_rows },
537                    bytes_per_row: rect.size.0,
538                    num_rows: rect.size.1,
539                    // Child views from rect() are always 0-based.
540                    row_offset: 0,
541                }
542            }
543        }
544    }
545}
546
547impl Debug for JxlOutputBuffer<'_> {
548    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549        write!(
550            f,
551            "JxlOutputBuffer {}x{}",
552            self.bytes_per_row, self.num_rows
553        )
554    }
555}