1use std::fmt::Debug;
7
8use super::{RawImageRectMut, Rect, internal::RawImageBuffer};
9
10pub(crate) enum BufferStorage<'a> {
16 Contiguous {
17 data: &'a mut [u8],
18 bytes_between_rows: usize,
19 },
20 Fragmented {
21 rows: Vec<&'a mut [u8]>,
22 },
23}
24
25pub struct JxlOutputBuffer<'a> {
26 storage: BufferStorage<'a>,
27 bytes_per_row: usize,
28 num_rows: usize,
29 row_offset: usize,
32}
33
34impl<'a> JxlOutputBuffer<'a> {
35 pub fn from_image_rect_mut(raw: RawImageRectMut<'a>) -> Self {
36 Self {
37 storage: BufferStorage::Contiguous {
38 data: raw.storage,
39 bytes_between_rows: raw.bytes_between_rows,
40 },
41 bytes_per_row: raw.bytes_per_row,
42 num_rows: raw.num_rows,
43 row_offset: 0,
44 }
45 }
46
47 pub fn new(buf: &'a mut [u8], num_rows: usize, bytes_per_row: usize) -> Self {
49 RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_per_row);
50 let expected_len = if num_rows == 0 {
51 0
52 } else {
53 (num_rows - 1) * bytes_per_row + bytes_per_row
54 };
55 assert!(buf.len() >= expected_len);
56 Self {
57 storage: BufferStorage::Contiguous {
58 data: if expected_len == 0 {
59 &mut []
60 } else {
61 &mut buf[..expected_len]
62 },
63 bytes_between_rows: bytes_per_row,
64 },
65 bytes_per_row,
66 num_rows,
67 row_offset: 0,
68 }
69 }
70
71 #[cfg(feature = "allow-unsafe")]
81 #[allow(unsafe_code)]
82 pub unsafe fn new_from_ptr(
83 buf: *mut std::mem::MaybeUninit<u8>,
84 num_rows: usize,
85 bytes_per_row: usize,
86 bytes_between_rows: usize,
87 ) -> Self {
88 RawImageBuffer::check_vals(num_rows, bytes_per_row, bytes_between_rows);
89 let total_len = if num_rows == 0 {
90 0
91 } else {
92 (num_rows - 1) * bytes_between_rows + bytes_per_row
93 };
94 let data = if total_len == 0 {
95 &mut []
96 } else {
97 unsafe { std::slice::from_raw_parts_mut(buf as *mut u8, total_len) }
100 };
101 Self {
102 storage: BufferStorage::Contiguous {
103 data,
104 bytes_between_rows,
105 },
106 bytes_per_row,
107 num_rows,
108 row_offset: 0,
109 }
110 }
111
112 pub(crate) fn reborrow(lender: &'a mut JxlOutputBuffer<'_>) -> JxlOutputBuffer<'a> {
113 JxlOutputBuffer {
114 storage: match &mut lender.storage {
115 BufferStorage::Contiguous {
116 data,
117 bytes_between_rows,
118 } => BufferStorage::Contiguous {
119 data,
120 bytes_between_rows: *bytes_between_rows,
121 },
122 BufferStorage::Fragmented { rows } => BufferStorage::Fragmented {
123 rows: rows.iter_mut().map(|r| &mut **r).collect(),
124 },
125 },
126 bytes_per_row: lender.bytes_per_row,
127 num_rows: lender.num_rows,
128 row_offset: lender.row_offset,
129 }
130 }
131
132 #[inline]
137 pub(crate) fn row_mut(&mut self, row: usize) -> &mut [u8] {
138 let local_row = row.wrapping_sub(self.row_offset);
139 assert!(
140 local_row < self.num_rows,
141 "row {row} out of range [{}, {})",
142 self.row_offset,
143 self.row_offset + self.num_rows,
144 );
145 match &mut self.storage {
146 BufferStorage::Contiguous {
147 data,
148 bytes_between_rows,
149 } => {
150 let start = local_row * *bytes_between_rows;
151 &mut data[start..start + self.bytes_per_row]
152 }
153 BufferStorage::Fragmented { rows } => &mut rows[local_row][..self.bytes_per_row],
154 }
155 }
156
157 #[inline]
158 pub fn write_bytes(&mut self, row: usize, col: usize, bytes: &[u8]) {
159 let slice = self.row_mut(row);
160 slice[col..col + bytes.len()].copy_from_slice(bytes);
161 }
162
163 pub fn byte_size(&self) -> (usize, usize) {
164 (self.bytes_per_row, self.num_rows)
165 }
166
167 #[cfg(feature = "threads")]
180 pub(crate) fn split_into_row_bands(
181 &mut self,
182 split_rows: &[usize],
183 ) -> Vec<JxlOutputBuffer<'_>> {
184 let bpr = self.bytes_per_row;
185 let nrows = self.num_rows;
186 let base_offset = self.row_offset;
187
188 match &mut self.storage {
189 BufferStorage::Contiguous {
190 data,
191 bytes_between_rows,
192 } => {
193 let btr = *bytes_between_rows;
194 let mut result = Vec::with_capacity(split_rows.len() + 1);
195 let mut remaining: &mut [u8] = data;
196 let mut current_row = 0;
197
198 for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
199 assert!(
200 split_row >= current_row && split_row <= nrows,
201 "split_rows must be sorted and <= num_rows"
202 );
203 let band_rows = split_row - current_row;
204
205 if band_rows == 0 {
206 result.push(JxlOutputBuffer {
207 storage: BufferStorage::Contiguous {
208 data: &mut [],
209 bytes_between_rows: btr,
210 },
211 bytes_per_row: bpr,
212 num_rows: 0,
213 row_offset: base_offset + current_row,
214 });
215 } else {
216 let span = (band_rows - 1) * btr + bpr;
217 if split_row < nrows {
218 let total_bytes = band_rows * btr;
219 let tmp = remaining;
220 let (band_full, rest) = tmp.split_at_mut(total_bytes);
221 result.push(JxlOutputBuffer {
222 storage: BufferStorage::Contiguous {
223 data: &mut band_full[..span],
224 bytes_between_rows: btr,
225 },
226 bytes_per_row: bpr,
227 num_rows: band_rows,
228 row_offset: base_offset + current_row,
229 });
230 remaining = rest;
231 } else {
232 let tmp = remaining;
233 let (band, _) = tmp.split_at_mut(span);
234 result.push(JxlOutputBuffer {
235 storage: BufferStorage::Contiguous {
236 data: band,
237 bytes_between_rows: btr,
238 },
239 bytes_per_row: bpr,
240 num_rows: band_rows,
241 row_offset: base_offset + current_row,
242 });
243 remaining = &mut [];
244 }
245 }
246 current_row = split_row;
247 }
248 result
249 }
250 BufferStorage::Fragmented { rows } => {
251 let mut result = Vec::with_capacity(split_rows.len() + 1);
252 let mut remaining: &mut [&'_ mut [u8]] = rows;
253 let mut current_row = 0;
254
255 for &split_row in split_rows.iter().chain(std::iter::once(&nrows)) {
256 assert!(
257 split_row >= current_row && split_row <= nrows,
258 "split_rows must be sorted and <= num_rows"
259 );
260 let band_rows = split_row - current_row;
261
262 let tmp = remaining;
263 let (band_slice, rest) = tmp.split_at_mut(band_rows);
264 result.push(JxlOutputBuffer {
265 storage: BufferStorage::Fragmented {
266 rows: band_slice.iter_mut().map(|r| &mut **r).collect(),
267 },
268 bytes_per_row: bpr,
269 num_rows: band_rows,
270 row_offset: base_offset + current_row,
271 });
272 remaining = rest;
273 current_row = split_row;
274 }
275 result
276 }
277 }
278 }
279
280 #[cfg(feature = "threads")]
292 pub(crate) fn split_into_col_fragments(
293 &mut self,
294 split_cols: &[usize],
295 ) -> Vec<JxlOutputBuffer<'_>> {
296 let bpr = self.bytes_per_row;
297 let nrows = self.num_rows;
298 let base_offset = self.row_offset;
299 let num_frags = split_cols.len() + 1;
300
301 for (i, &col) in split_cols.iter().enumerate() {
303 assert!(col <= bpr, "split_col {col} exceeds bytes_per_row {bpr}");
304 if i > 0 {
305 assert!(col >= split_cols[i - 1], "split_cols must be sorted");
306 }
307 }
308
309 let mut fragment_rows: Vec<Vec<&mut [u8]>> =
311 (0..num_frags).map(|_| Vec::with_capacity(nrows)).collect();
312
313 match &mut self.storage {
314 BufferStorage::Contiguous {
315 data,
316 bytes_between_rows,
317 } => {
318 let btr = *bytes_between_rows;
319 let mut remaining: &mut [u8] = data;
320
321 for row_idx in 0..nrows {
322 let tmp = remaining;
325 let split_point = if row_idx < nrows - 1 { btr } else { tmp.len() };
326 let (chunk, rest) = tmp.split_at_mut(split_point);
327 remaining = rest;
328 let row_useful = &mut chunk[..bpr];
329
330 let mut col_remaining = row_useful;
332 let mut prev_col = 0;
333 for (frag_idx, &split_col) in split_cols.iter().enumerate() {
334 let width = split_col - prev_col;
335 let (frag, rest) = col_remaining.split_at_mut(width);
336 fragment_rows[frag_idx].push(frag);
337 col_remaining = rest;
338 prev_col = split_col;
339 }
340 fragment_rows[num_frags - 1].push(col_remaining);
341 }
342 }
343 BufferStorage::Fragmented { rows } => {
344 for row in rows.iter_mut() {
345 let mut col_remaining: &mut [u8] = row;
346 let mut prev_col = 0;
347 for (frag_idx, &split_col) in split_cols.iter().enumerate() {
348 let width = split_col - prev_col;
349 let (frag, rest) = col_remaining.split_at_mut(width);
350 fragment_rows[frag_idx].push(frag);
351 col_remaining = rest;
352 prev_col = split_col;
353 }
354 fragment_rows[num_frags - 1].push(col_remaining);
355 }
356 }
357 }
358
359 let mut result = Vec::with_capacity(num_frags);
361 let mut prev_col = 0;
362 for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
363 let col_end = if frag_idx < split_cols.len() {
364 split_cols[frag_idx]
365 } else {
366 bpr
367 };
368 let width = col_end - prev_col;
369 result.push(JxlOutputBuffer {
370 storage: BufferStorage::Fragmented { rows },
371 bytes_per_row: width,
372 num_rows: nrows,
373 row_offset: base_offset,
374 });
375 prev_col = col_end;
376 }
377 result
378 }
379
380 #[cfg(feature = "threads")]
392 pub(crate) fn split_into_tile_grid(
393 &mut self,
394 split_rows: &[usize],
395 split_cols_per_band: &[&[usize]],
396 ) -> Vec<Vec<JxlOutputBuffer<'_>>> {
397 let bpr = self.bytes_per_row;
398 let nrows = self.num_rows;
399 let base_offset = self.row_offset;
400 let num_bands = split_rows.len() + 1;
401
402 assert_eq!(
403 split_cols_per_band.len(),
404 num_bands,
405 "need one set of split_cols per band"
406 );
407
408 let BufferStorage::Contiguous {
409 data,
410 bytes_between_rows,
411 } = &mut self.storage
412 else {
413 panic!("split_into_tile_grid requires Contiguous storage")
414 };
415 let btr = *bytes_between_rows;
416 let mut remaining: &mut [u8] = data;
417
418 let mut result: Vec<Vec<JxlOutputBuffer<'_>>> = Vec::with_capacity(num_bands);
419 let mut current_row = 0;
420
421 for band_idx in 0..num_bands {
422 let band_end = if band_idx < split_rows.len() {
423 split_rows[band_idx]
424 } else {
425 nrows
426 };
427 assert!(
428 band_end >= current_row && band_end <= nrows,
429 "split_rows must be sorted and <= num_rows"
430 );
431 let band_rows = band_end - current_row;
432 let split_cols = split_cols_per_band[band_idx];
433 let num_frags = split_cols.len() + 1;
434
435 let mut fragment_rows: Vec<Vec<&mut [u8]>> = (0..num_frags)
437 .map(|_| Vec::with_capacity(band_rows))
438 .collect();
439
440 for row_offset_in_band in 0..band_rows {
441 let is_last_overall_row = current_row + row_offset_in_band == nrows - 1;
442 let tmp = remaining;
443 let split_point = if is_last_overall_row { tmp.len() } else { btr };
444 let (chunk, rest) = tmp.split_at_mut(split_point);
445 remaining = rest;
446 let row_useful = &mut chunk[..bpr];
447
448 let mut col_remaining = row_useful;
450 let mut prev_col = 0;
451 for (frag_idx, &split_col) in split_cols.iter().enumerate() {
452 let width = split_col - prev_col;
453 let (frag, rest) = col_remaining.split_at_mut(width);
454 fragment_rows[frag_idx].push(frag);
455 col_remaining = rest;
456 prev_col = split_col;
457 }
458 fragment_rows[num_frags - 1].push(col_remaining);
459 }
460
461 let mut band_frags = Vec::with_capacity(num_frags);
463 let mut prev_col = 0;
464 for (frag_idx, rows) in fragment_rows.into_iter().enumerate() {
465 let col_end = if frag_idx < split_cols.len() {
466 split_cols[frag_idx]
467 } else {
468 bpr
469 };
470 let width = col_end - prev_col;
471 band_frags.push(JxlOutputBuffer {
472 storage: BufferStorage::Fragmented { rows },
473 bytes_per_row: width,
474 num_rows: band_rows,
475 row_offset: base_offset + current_row,
476 });
477 prev_col = col_end;
478 }
479 result.push(band_frags);
480 current_row = band_end;
481 }
482 result
483 }
484
485 pub fn rect(&mut self, rect: Rect) -> JxlOutputBuffer<'_> {
486 if rect.size.0 == 0 || rect.size.1 == 0 {
487 return JxlOutputBuffer {
488 storage: BufferStorage::Contiguous {
489 data: &mut [],
490 bytes_between_rows: 0,
491 },
492 bytes_per_row: 0,
493 num_rows: 0,
494 row_offset: 0,
495 };
496 }
497 assert!(
498 rect.origin.1 >= self.row_offset,
499 "rect origin row {} < row_offset {}",
500 rect.origin.1,
501 self.row_offset,
502 );
503 let local_y = rect.origin.1 - self.row_offset;
504 assert!(local_y + rect.size.1 <= self.num_rows);
505 assert!(rect.origin.0 + rect.size.0 <= self.bytes_per_row);
506
507 match &mut self.storage {
508 BufferStorage::Contiguous {
509 data,
510 bytes_between_rows,
511 } => {
512 let btr = *bytes_between_rows;
513 let new_start = local_y * btr + rect.origin.0;
514 let data_span = (rect.size.1 - 1) * btr + rect.size.0;
515 assert!(new_start + data_span <= data.len());
516 JxlOutputBuffer {
517 storage: BufferStorage::Contiguous {
518 data: &mut data[new_start..new_start + data_span],
519 bytes_between_rows: btr,
520 },
521 bytes_per_row: rect.size.0,
522 num_rows: rect.size.1,
523 row_offset: 0,
525 }
526 }
527 BufferStorage::Fragmented { rows } => {
528 let col_start = rect.origin.0;
529 let col_end = col_start + rect.size.0;
530 let sub_rows: Vec<&mut [u8]> = rows[local_y..local_y + rect.size.1]
531 .iter_mut()
532 .map(|row| &mut row[col_start..col_end])
533 .collect();
534 JxlOutputBuffer {
535 storage: BufferStorage::Fragmented { rows: sub_rows },
536 bytes_per_row: rect.size.0,
537 num_rows: rect.size.1,
538 row_offset: 0,
540 }
541 }
542 }
543 }
544}
545
546impl Debug for JxlOutputBuffer<'_> {
547 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 write!(
549 f,
550 "JxlOutputBuffer {}x{}",
551 self.bytes_per_row, self.num_rows
552 )
553 }
554}