Skip to main content

zenjxl_decoder/image/
output_buffer.rs

1// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6use std::fmt::Debug;
7
8use super::{RawImageRectMut, Rect, internal::RawImageBuffer};
9
10/// Internal storage for a `JxlOutputBuffer`.
11///
12/// `Contiguous`: a single mutable byte slice with a row stride.
13/// `Fragmented`: a collection of per-row mutable byte slices, used for
14/// column-split tile fragments where rows are non-contiguous in memory.
15pub(crate) enum BufferStorage<'a> {
16    Contiguous {
17        data: &'a mut [u8],
18        bytes_between_rows: usize,
19    },
20    Fragmented {
21        rows: Vec<&'a mut [u8]>,
22    },
23}
24
25pub struct JxlOutputBuffer<'a> {
26    storage: BufferStorage<'a>,
27    bytes_per_row: usize,
28    num_rows: usize,
29    /// Row offset for band buffers: row_mut(r) accesses storage row (r - row_offset).
30    /// Always 0 for normal buffers. Set by split_into_row_bands for band sub-buffers.
31    row_offset: usize,
32}
33
34impl<'a> JxlOutputBuffer<'a> {
35    pub fn from_image_rect_mut(raw: RawImageRectMut<'a>) -> Self {
36        Self {
37            storage: BufferStorage::Contiguous {
38                data: raw.storage,
39                bytes_between_rows: raw.bytes_between_rows,
40            },
41            bytes_per_row: raw.bytes_per_row,
42            num_rows: raw.num_rows,
43            row_offset: 0,
44        }
45    }
46
47    /// Creates a new JxlOutputBuffer from a mutable byte slice.
48    pub fn new(buf: &'a mut [u8], num_rows: usize, bytes_per_row: usize) -> Self {
49        RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_per_row);
50        let expected_len = if num_rows == 0 {
51            0
52        } else {
53            (num_rows - 1) * bytes_per_row + bytes_per_row
54        };
55        assert!(buf.len() >= expected_len);
56        Self {
57            storage: BufferStorage::Contiguous {
58                data: if expected_len == 0 {
59                    &mut []
60                } else {
61                    &mut buf[..expected_len]
62                },
63                bytes_between_rows: bytes_per_row,
64            },
65            bytes_per_row,
66            num_rows,
67            row_offset: 0,
68        }
69    }
70
71    /// Creates a new JxlOutputBuffer from raw pointers.
72    /// It is guaranteed that `buf` will never be used to write uninitialized data.
73    ///
74    /// # Safety
75    /// - `buf` must be valid for writes for all bytes in the range
76    ///   `buf[i*bytes_between_rows..i*bytes_between_rows+bytes_per_row]` for all values of `i`
77    ///   from `0` to `num_rows-1`.
78    /// - The bytes in these ranges must not be accessed as long as the returned `Self` is in scope.
79    /// - All the bytes in those ranges (and in between) must be part of the same allocated object.
80    #[cfg(feature = "allow-unsafe")]
81    #[allow(unsafe_code)]
82    pub unsafe fn new_from_ptr(
83        buf: *mut std::mem::MaybeUninit<u8>,
84        num_rows: usize,
85        bytes_per_row: usize,
86        bytes_between_rows: usize,
87    ) -> Self {
88        RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_between_rows);
89        let total_len = if num_rows == 0 {
90            0
91        } else {
92            (num_rows - 1) * bytes_between_rows + bytes_per_row
93        };
94        let data = if total_len == 0 {
95            &mut []
96        } else {
97            // SAFETY: Caller guarantees `buf` is valid for `total_len` bytes.
98            // MaybeUninit<u8> and u8 have identical layout.
99            unsafe { std::slice::from_raw_parts_mut(buf as *mut u8, total_len) }
100        };
101        Self {
102            storage: BufferStorage::Contiguous {
103                data,
104                bytes_between_rows,
105            },
106            bytes_per_row,
107            num_rows,
108            row_offset: 0,
109        }
110    }
111
112    pub(crate) fn reborrow(lender: &'a mut JxlOutputBuffer<'_>) -> JxlOutputBuffer<'a> {
113        JxlOutputBuffer {
114            storage: match &mut lender.storage {
115                BufferStorage::Contiguous {
116                    data,
117                    bytes_between_rows,
118                } => BufferStorage::Contiguous {
119                    data,
120                    bytes_between_rows: *bytes_between_rows,
121                },
122                BufferStorage::Fragmented { rows } => BufferStorage::Fragmented {
123                    rows: rows.iter_mut().map(|r| &mut **r).collect(),
124                },
125            },
126            bytes_per_row: lender.bytes_per_row,
127            num_rows: lender.num_rows,
128            row_offset: lender.row_offset,
129        }
130    }
131
132    /// Returns a mutable row as a byte slice.
133    ///
134    /// For band buffers (row_offset > 0), `row` is in the parent buffer's
135    /// coordinate system and is translated internally.
136    #[inline]
137    pub(crate) fn row_mut(&mut self, row: usize) -> &mut [u8] {
138        let local_row = row.wrapping_sub(self.row_offset);
139        assert!(
140            local_row < self.num_rows,
141            "row {row} out of range [{}, {})",
142            self.row_offset,
143            self.row_offset + self.num_rows,
144        );
145        match &mut self.storage {
146            BufferStorage::Contiguous {
147                data,
148                bytes_between_rows,
149            } => {
150                let start = local_row * *bytes_between_rows;
151                &mut data[start..start + self.bytes_per_row]
152            }
153            BufferStorage::Fragmented { rows } => &mut rows[local_row][..self.bytes_per_row],
154        }
155    }
156
157    #[inline]
158    pub fn write_bytes(&mut self, row: usize, col: usize, bytes: &[u8]) {
159        let slice = self.row_mut(row);
160        slice[col..col + bytes.len()].copy_from_slice(bytes);
161    }
162
163    pub fn byte_size(&self) -> (usize, usize) {
164        (self.bytes_per_row, self.num_rows)
165    }
166
167    /// Split this buffer into non-overlapping row bands.
168    ///
169    /// `split_rows` contains sorted row indices where splits occur, in the
170    /// buffer's own coordinate system (i.e., relative to `self.row_offset`).
171    /// Returns one sub-buffer per band: [0, split_rows[0]),
172    /// [split_rows[0], split_rows[1]), ..., [split_rows[last], num_rows).
173    ///
174    /// Each returned sub-buffer has its `row_offset` set so that callers can
175    /// use the parent buffer's coordinate system for row access.
176    ///
177    /// While the returned sub-buffers are alive, `self` cannot be used.
178    /// When they are dropped, `self` becomes available again.
179    #[cfg(feature = "threads")]
180    pub(crate) fn split_into_row_bands(
181        &mut self,
182        split_rows: &[usize],
183    ) -> Vec<JxlOutputBuffer<'_>> {
184        let bpr = self.bytes_per_row;
185        let nrows = self.num_rows;
186        let base_offset = self.row_offset;
187
188        match &mut self.storage {
189            BufferStorage::Contiguous {
190                data,
191                bytes_between_rows,
192            } => {
193                let btr = *bytes_between_rows;
194                let mut result = Vec::with_capacity(split_rows.len() + 1);
195                let mut remaining: &mut [u8] = data;
196                let mut current_row = 0;
197
198                for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
199                    assert!(
200                        split_row >= current_row && split_row <= nrows,
201                        "split_rows must be sorted and <= num_rows"
202                    );
203                    let band_rows = split_row - current_row;
204
205                    if band_rows == 0 {
206                        result.push(JxlOutputBuffer {
207                            storage: BufferStorage::Contiguous {
208                                data: &mut [],
209                                bytes_between_rows: btr,
210                            },
211                            bytes_per_row: bpr,
212                            num_rows: 0,
213                            row_offset: base_offset + current_row,
214                        });
215                    } else {
216                        let span = (band_rows - 1) * btr + bpr;
217                        if split_row < nrows {
218                            let total_bytes = band_rows * btr;
219                            let tmp = remaining;
220                            let (band_full, rest) = tmp.split_at_mut(total_bytes);
221                            result.push(JxlOutputBuffer {
222                                storage: BufferStorage::Contiguous {
223                                    data: &mut band_full[..span],
224                                    bytes_between_rows: btr,
225                                },
226                                bytes_per_row: bpr,
227                                num_rows: band_rows,
228                                row_offset: base_offset + current_row,
229                            });
230                            remaining = rest;
231                        } else {
232                            let tmp = remaining;
233                            let (band, _) = tmp.split_at_mut(span);
234                            result.push(JxlOutputBuffer {
235                                storage: BufferStorage::Contiguous {
236                                    data: band,
237                                    bytes_between_rows: btr,
238                                },
239                                bytes_per_row: bpr,
240                                num_rows: band_rows,
241                                row_offset: base_offset + current_row,
242                            });
243                            remaining = &mut [];
244                        }
245                    }
246                    current_row = split_row;
247                }
248                result
249            }
250            BufferStorage::Fragmented { rows } => {
251                let mut result = Vec::with_capacity(split_rows.len() + 1);
252                let mut remaining: &mut [&'_ mut [u8]] = rows;
253                let mut current_row = 0;
254
255                for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
256                    assert!(
257                        split_row >= current_row && split_row <= nrows,
258                        "split_rows must be sorted and <= num_rows"
259                    );
260                    let band_rows = split_row - current_row;
261
262                    let tmp = remaining;
263                    let (band_slice, rest) = tmp.split_at_mut(band_rows);
264                    result.push(JxlOutputBuffer {
265                        storage: BufferStorage::Fragmented {
266                            rows: band_slice.iter_mut().map(|r| &mut **r).collect(),
267                        },
268                        bytes_per_row: bpr,
269                        num_rows: band_rows,
270                        row_offset: base_offset + current_row,
271                    });
272                    remaining = rest;
273                    current_row = split_row;
274                }
275                result
276            }
277        }
278    }
279
280    /// Split this buffer into non-overlapping column fragments.
281    ///
282    /// `split_cols` contains sorted byte-column positions where splits occur.
283    /// Returns one sub-buffer per fragment: [0, split_cols[0]),
284    /// [split_cols[0], split_cols[1]), ..., [split_cols[last], bytes_per_row).
285    ///
286    /// Each returned sub-buffer uses `Fragmented` storage (per-row slices)
287    /// since column sub-ranges are not contiguous across rows.
288    ///
289    /// Preserves `row_offset` so callers can use the parent buffer's coordinate
290    /// system for row access.
291    #[cfg(feature = "threads")]
292    pub(crate) fn split_into_col_fragments(
293        &mut self,
294        split_cols: &[usize],
295    ) -> Vec<JxlOutputBuffer<'_>> {
296        let bpr = self.bytes_per_row;
297        let nrows = self.num_rows;
298        let base_offset = self.row_offset;
299        let num_frags = split_cols.len() + 1;
300
301        // Validate split_cols.
302        for (i, &col) in split_cols.iter().enumerate() {
303            assert!(col <= bpr, "split_col {col} exceeds bytes_per_row {bpr}");
304            if i > 0 {
305                assert!(col >= split_cols[i - 1], "split_cols must be sorted");
306            }
307        }
308
309        // Pre-allocate fragment row collectors.
310        let mut fragment_rows: Vec<Vec<&mut [u8]>> =
311            (0..num_frags).map(|_| Vec::with_capacity(nrows)).collect();
312
313        match &mut self.storage {
314            BufferStorage::Contiguous {
315                data,
316                bytes_between_rows,
317            } => {
318                let btr = *bytes_between_rows;
319                let mut remaining: &mut [u8] = data;
320
321                for row_idx in 0..nrows {
322                    // Always consume via split_at_mut so the borrow checker
323                    // sees `remaining` is moved, not aliased.
324                    let tmp = remaining;
325                    let split_point = if row_idx < nrows - 1 { btr } else { tmp.len() };
326                    let (chunk, rest) = tmp.split_at_mut(split_point);
327                    remaining = rest;
328                    let row_useful = &mut chunk[..bpr];
329
330                    // Split this row into column fragments.
331                    let mut col_remaining = row_useful;
332                    let mut prev_col = 0;
333                    for (frag_idx, &split_col) in split_cols.iter().enumerate() {
334                        let width = split_col - prev_col;
335                        let (frag, rest) = col_remaining.split_at_mut(width);
336                        fragment_rows[frag_idx].push(frag);
337                        col_remaining = rest;
338                        prev_col = split_col;
339                    }
340                    fragment_rows[num_frags - 1].push(col_remaining);
341                }
342            }
343            BufferStorage::Fragmented { rows } => {
344                for row in rows.iter_mut() {
345                    let mut col_remaining: &mut [u8] = row;
346                    let mut prev_col = 0;
347                    for (frag_idx, &split_col) in split_cols.iter().enumerate() {
348                        let width = split_col - prev_col;
349                        let (frag, rest) = col_remaining.split_at_mut(width);
350                        fragment_rows[frag_idx].push(frag);
351                        col_remaining = rest;
352                        prev_col = split_col;
353                    }
354                    fragment_rows[num_frags - 1].push(col_remaining);
355                }
356            }
357        }
358
359        // Build result buffers.
360        let mut result = Vec::with_capacity(num_frags);
361        let mut prev_col = 0;
362        for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
363            let col_end = if frag_idx < split_cols.len() {
364                split_cols[frag_idx]
365            } else {
366                bpr
367            };
368            let width = col_end - prev_col;
369            result.push(JxlOutputBuffer {
370                storage: BufferStorage::Fragmented { rows },
371                bytes_per_row: width,
372                num_rows: nrows,
373                row_offset: base_offset,
374            });
375            prev_col = col_end;
376        }
377        result
378    }
379
380    /// Split this buffer into a 2D grid of fragments (by rows then columns).
381    ///
382    /// Combines row-band splitting and column-fragment splitting in a single
383    /// method to avoid multi-level borrowing. All returned fragments borrow
384    /// directly from `self.storage`.
385    ///
386    /// `split_rows`: sorted row indices where gy bands split.
387    /// `split_cols_per_band`: for each band, sorted column indices where gx tiles split.
388    ///
389    /// Returns `result[band_idx][frag_idx]` where each fragment uses `Fragmented`
390    /// storage with `row_offset` set to the band's starting row.
391    #[cfg(feature = "threads")]
392    pub(crate) fn split_into_tile_grid(
393        &mut self,
394        split_rows: &[usize],
395        split_cols_per_band: &[&[usize]],
396    ) -> Vec<Vec<JxlOutputBuffer<'_>>> {
397        let bpr = self.bytes_per_row;
398        let nrows = self.num_rows;
399        let base_offset = self.row_offset;
400        let num_bands = split_rows.len() + 1;
401
402        assert_eq!(
403            split_cols_per_band.len(),
404            num_bands,
405            "need one set of split_cols per band"
406        );
407
408        let BufferStorage::Contiguous {
409            data,
410            bytes_between_rows,
411        } = &mut self.storage
412        else {
413            panic!("split_into_tile_grid requires Contiguous storage")
414        };
415        let btr = *bytes_between_rows;
416        let mut remaining: &mut [u8] = data;
417
418        let mut result: Vec<Vec<JxlOutputBuffer<'_>>> = Vec::with_capacity(num_bands);
419        let mut current_row = 0;
420
421        for band_idx in 0..num_bands {
422            let band_end = if band_idx < split_rows.len() {
423                split_rows[band_idx]
424            } else {
425                nrows
426            };
427            assert!(
428                band_end >= current_row && band_end <= nrows,
429                "split_rows must be sorted and <= num_rows"
430            );
431            let band_rows = band_end - current_row;
432            let split_cols = split_cols_per_band[band_idx];
433            let num_frags = split_cols.len() + 1;
434
435            // Pre-allocate per-fragment row collectors for this band.
436            let mut fragment_rows: Vec<Vec<&mut [u8]>> = (0..num_frags)
437                .map(|_| Vec::with_capacity(band_rows))
438                .collect();
439
440            for row_offset_in_band in 0..band_rows {
441                let is_last_overall_row = current_row + row_offset_in_band == nrows - 1;
442                let tmp = remaining;
443                let split_point = if is_last_overall_row { tmp.len() } else { btr };
444                let (chunk, rest) = tmp.split_at_mut(split_point);
445                remaining = rest;
446                let row_useful = &mut chunk[..bpr];
447
448                // Split row into column fragments.
449                let mut col_remaining = row_useful;
450                let mut prev_col = 0;
451                for (frag_idx, &split_col) in split_cols.iter().enumerate() {
452                    let width = split_col - prev_col;
453                    let (frag, rest) = col_remaining.split_at_mut(width);
454                    fragment_rows[frag_idx].push(frag);
455                    col_remaining = rest;
456                    prev_col = split_col;
457                }
458                fragment_rows[num_frags - 1].push(col_remaining);
459            }
460
461            // Build fragment buffers for this band.
462            let mut band_frags = Vec::with_capacity(num_frags);
463            let mut prev_col = 0;
464            for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
465                let col_end = if frag_idx < split_cols.len() {
466                    split_cols[frag_idx]
467                } else {
468                    bpr
469                };
470                let width = col_end - prev_col;
471                band_frags.push(JxlOutputBuffer {
472                    storage: BufferStorage::Fragmented { rows },
473                    bytes_per_row: width,
474                    num_rows: band_rows,
475                    row_offset: base_offset + current_row,
476                });
477                prev_col = col_end;
478            }
479            result.push(band_frags);
480            current_row = band_end;
481        }
482        result
483    }
484
485    pub fn rect(&mut self, rect: Rect) -> JxlOutputBuffer<'_> {
486        if rect.size.0 == 0 || rect.size.1 == 0 {
487            return JxlOutputBuffer {
488                storage: BufferStorage::Contiguous {
489                    data: &mut [],
490                    bytes_between_rows: 0,
491                },
492                bytes_per_row: 0,
493                num_rows: 0,
494                row_offset: 0,
495            };
496        }
497        assert!(
498            rect.origin.1 >= self.row_offset,
499            "rect origin row {} < row_offset {}",
500            rect.origin.1,
501            self.row_offset,
502        );
503        let local_y = rect.origin.1 - self.row_offset;
504        assert!(local_y + rect.size.1 <= self.num_rows);
505        assert!(rect.origin.0 + rect.size.0 <= self.bytes_per_row);
506
507        match &mut self.storage {
508            BufferStorage::Contiguous {
509                data,
510                bytes_between_rows,
511            } => {
512                let btr = *bytes_between_rows;
513                let new_start = local_y * btr + rect.origin.0;
514                let data_span = (rect.size.1 - 1) * btr + rect.size.0;
515                assert!(new_start + data_span <= data.len());
516                JxlOutputBuffer {
517                    storage: BufferStorage::Contiguous {
518                        data: &mut data[new_start..new_start + data_span],
519                        bytes_between_rows: btr,
520                    },
521                    bytes_per_row: rect.size.0,
522                    num_rows: rect.size.1,
523                    // Child views from rect() are always 0-based.
524                    row_offset: 0,
525                }
526            }
527            BufferStorage::Fragmented { rows } => {
528                let col_start = rect.origin.0;
529                let col_end = col_start + rect.size.0;
530                let sub_rows: Vec<&mut [u8]> = rows[local_y..local_y + rect.size.1]
531                    .iter_mut()
532                    .map(|row| &mut row[col_start..col_end])
533                    .collect();
534                JxlOutputBuffer {
535                    storage: BufferStorage::Fragmented { rows: sub_rows },
536                    bytes_per_row: rect.size.0,
537                    num_rows: rect.size.1,
538                    // Child views from rect() are always 0-based.
539                    row_offset: 0,
540                }
541            }
542        }
543    }
544}
545
546impl Debug for JxlOutputBuffer<'_> {
547    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548        write!(
549            f,
550            "JxlOutputBuffer {}x{}",
551            self.bytes_per_row, self.num_rows
552        )
553    }
554}